import { Logger } from './Logger'; import { TokenAccountMetadata } from '../swap_api_utils'; import BN from 'bn.js'; import { SYSTEM_PROGRAM_ADDRESS } from '@solana-program/system'; import { TOKEN_PROGRAM_ADDRESS } from '@solana-program/token'; import { getTokenDecoder, TOKEN_2022_PROGRAM_ADDRESS } from '@solana-program/token-2022'; import { AccountInfoBase, AccountInfoWithBase64EncodedData, Address, Base64EncodedWireTransaction, getBase64EncodedWireTransaction, Rpc, SimulateTransactionApi, SolanaError, SolanaRpcApi, Transaction, } from '@solana/kit'; import { getAssociatedTokenAddress } from './utils'; export interface TokenBalance { mint: Address; amount: BN; } export interface BalanceChange { address: Address; beforeAmount: BN; afterAmount: BN; change: BN; } // Extract the response type from SimulateTransactionApi.simulateTransaction // 1) Turn an overloaded function type into a union of call signatures type OverloadToUnion = F extends { (a: infer A1, b: infer B1): infer R1; (a: infer A2, b: infer B2): infer R2; (a: infer A3, b: infer B3): infer R3; (a: infer A4, b: infer B4): infer R4; (a: infer A5, b: infer B5): infer R5; (a: infer A6, b: infer B6): infer R6; (a: infer A7, b: infer B7): infer R7; (a: infer A8, b: infer B8): infer R8; (a: infer A9, b: infer B9): infer R9; (a: infer A10, b: infer B10): infer R10; } ? | ((a: A1, b: B1) => R1) | ((a: A2, b: B2) => R2) | ((a: A3, b: B3) => R3) | ((a: A4, b: B4) => R4) | ((a: A5, b: B5) => R5) | ((a: A6, b: B6) => R6) | ((a: A7, b: B7) => R7) | ((a: A8, b: B8) => R8) | ((a: A9, b: B9) => R9) | ((a: A10, b: B10) => R10) : F; // non-overloaded fallback // 2) Keep only the base64 + base64-accounts overloads type Base64AccountsOverload = Extract< OverloadToUnion, (a: Base64EncodedWireTransaction, b: { encoding: 'base64'; accounts: { encoding: 'base64'; addresses: any } }) => any >; // 3) From SolanaRpcResponse<...> get the inner payload type type RpcPayload = F extends (...args: any[]) => infer R ? R : never; // 4) Final public type export type SimulateTransactionResponse = RpcPayload< Base64AccountsOverload >; export interface GenericSimulationResult { successful: boolean; tokenBalanceChanges: BalanceChange[]; lamportsBalanceChanges: BalanceChange[]; isSlippageError: boolean; simulationError: SolanaError | undefined; simulationResponse: SimulateTransactionResponse | undefined; simulationUrl: string; simulationTimestamp?: number; } export interface SwapSimulationResult { successful: boolean; inputToken: BalanceChange | null; outputToken: BalanceChange | null; nativeChangeAmount: BN; isSlippageError: boolean; simulationError: SolanaError | undefined; simulationResponse: SimulateTransactionResponse | undefined; simulationUrl: string; simulationTimestamp?: number; } export function buildGenericSimulationResultFromTrackedAccounts(params: { tokenAccounts: TokenAccountMetadata[]; lamportsAccounts: Address[]; tokenBalancesBefore: Map; lamportsBalancesBefore: Map; trackedAccounts: ((AccountInfoBase & AccountInfoWithBase64EncodedData) | null)[]; simulationResponse?: SimulateTransactionResponse | undefined; simulationUrl: string; simulationTimestamp?: number; }): GenericSimulationResult { const { tokenAccounts, lamportsAccounts, tokenBalancesBefore, lamportsBalancesBefore, trackedAccounts, simulationResponse, simulationUrl, simulationTimestamp, } = params; const tokenBalanceChanges: BalanceChange[] = []; for (let i = 0; i < tokenAccounts.length; i++) { const tokenAccount = tokenAccounts[i]; const beforeBalance = tokenBalancesBefore.get(tokenAccount.address) || null; const afterBalance = extractTokenBalance(trackedAccounts[i], tokenAccount.mint); const change = calculateBalanceChange(tokenAccount.address, beforeBalance, afterBalance); if (change) { tokenBalanceChanges.push(change); } } const lamportsBalanceChanges: BalanceChange[] = []; for (let i = 0; i < lamportsAccounts.length; i++) { const accountIndex = i + tokenAccounts.length; const address = lamportsAccounts[i]; const beforeAmount = lamportsBalancesBefore.get(address) || new BN(0); const afterAmount = new BN(trackedAccounts[accountIndex]?.lamports.toString() || 0); const change = afterAmount.sub(beforeAmount); lamportsBalanceChanges.push({ address, beforeAmount, afterAmount, change, }); } return { successful: true, tokenBalanceChanges, lamportsBalanceChanges, isSlippageError: false, simulationError: undefined, simulationResponse, simulationUrl, simulationTimestamp, }; } export function mapGenericSimulationResultToSwapSimulationResult( walletAddress: Address, result: GenericSimulationResult, inputTokenAta: Address, outputTokenAta: Address, ): SwapSimulationResult { return processSwapWithATABalances(walletAddress, result, inputTokenAta, outputTokenAta); } export async function simulateTransactionsWithBalanceTracking( rpc: Rpc, transactions: Transaction[], tokenAccounts: TokenAccountMetadata[], lamportsAccounts: Address[], logger: Logger = console, transactionLabels?: string[], ): Promise { const [{ tokenBalancesBefore, lamportsBalancesBefore }, responses] = await Promise.all([ getInitialTokenAccountBalances(rpc, tokenAccounts, lamportsAccounts), await Promise.allSettled( transactions.map(async (tx) => simulateTransactionWithAccountTracking(rpc, tx, tokenAccounts, lamportsAccounts)), ), ]); const simulationResponses: GenericSimulationResult[] = responses.map((res, i) => { const transaction = transactions[i]; const simulationUrl = getSimulationUrl(transaction); if (res.status === 'rejected') { const label = transactionLabels?.[i]; logger.error(`Error simulating transaction${label ? ` (${label})` : ''}`, res.reason); return { successful: false, tokenBalanceChanges: [], lamportsBalanceChanges: [], isSlippageError: false, simulationError: res.reason, simulationResponse: undefined, simulationUrl, }; } const { response: simulationResponse, simulationTimestamp } = res.value; if (simulationResponse.value.err) { return { successful: false, tokenBalanceChanges: [], lamportsBalanceChanges: [], isSlippageError: isSimulationSlippageError(simulationResponse.value.logs), simulationError: undefined, simulationResponse, simulationUrl, simulationTimestamp, }; } return buildGenericSimulationResultFromTrackedAccounts({ tokenAccounts, lamportsAccounts, tokenBalancesBefore, lamportsBalancesBefore, trackedAccounts: simulationResponse.value.accounts || [], simulationResponse, simulationUrl, simulationTimestamp, }); }); return simulationResponses; } export async function simulateSwapsWithATABalances( connection: Rpc, transactions: Transaction[], walletAddress: Address, inputMint: Address, outputMint: Address, inputMintProgramId: Address, outputMintProgramId: Address, logger: Logger = console, transactionLabels?: string[], accountOverrides?: { inputTokenAccount?: Address; outputTokenAccount?: Address; }, ): Promise { const [inputTokenAta, outputTokenAta] = await Promise.all([ accountOverrides?.inputTokenAccount ? accountOverrides.inputTokenAccount : getAssociatedTokenAddress(walletAddress, inputMint, inputMintProgramId), accountOverrides?.outputTokenAccount ? accountOverrides.outputTokenAccount : getAssociatedTokenAddress(walletAddress, outputMint, outputMintProgramId), ]); const tokenAccounts: TokenAccountMetadata[] = [ { address: inputTokenAta, mint: inputMint, programAddress: inputMintProgramId }, { address: outputTokenAta, mint: outputMint, programAddress: outputMintProgramId }, ]; const lamportsAccounts = [walletAddress]; // Use the generic function const responses = await simulateTransactionsWithBalanceTracking( connection, transactions, tokenAccounts, lamportsAccounts, logger, transactionLabels, ); return responses.map((res) => mapGenericSimulationResultToSwapSimulationResult(walletAddress, res, inputTokenAta, outputTokenAta), ); } export function extractTokenBalance( accountInfo: (AccountInfoBase & AccountInfoWithBase64EncodedData) | null, mint: Address, ): TokenBalance | null { if (!accountInfo) { return { mint, amount: new BN(0), }; } try { const token = getTokenDecoder().decode(Buffer.from(accountInfo.data[0], accountInfo.data[1])); if (token) { if (accountInfo.owner === SYSTEM_PROGRAM_ADDRESS) { return { mint, amount: new BN(0), }; } if (!(accountInfo.owner === TOKEN_PROGRAM_ADDRESS || accountInfo.owner === TOKEN_2022_PROGRAM_ADDRESS)) { return null; } if (token.mint !== mint) { return null; } const amount = new BN(token.amount.toString()); return { mint, amount, }; } return null; } catch { return null; } } export function isSimulationSlippageError(simulationLogs: string[] | null): boolean { return ( Array.isArray(simulationLogs) && simulationLogs.some((log) => log && typeof log === 'string' && log.toLowerCase().includes('slippage')) ); } function calculateBalanceChange( address: Address, beforeBalance: TokenBalance | null, afterBalance: TokenBalance | null, ): BalanceChange | null { if (!beforeBalance && !afterBalance) { return null; } const before = beforeBalance?.amount || new BN(0); const after = afterBalance?.amount || new BN(0); return { address, beforeAmount: before, afterAmount: after, change: after.sub(before), }; } async function getInitialTokenAccountBalances( rpc: Rpc, tokenAccounts: TokenAccountMetadata[], lamportsAccounts: Address[], ): Promise<{ tokenBalancesBefore: Map; lamportsBalancesBefore: Map }> { const allAccounts = [...tokenAccounts.map((ta) => ta.address), ...lamportsAccounts]; const accountInfos = await rpc.getMultipleAccounts(allAccounts).send(); // Process token accounts const tokenBalancesBefore = new Map(); for (let i = 0; i < tokenAccounts.length; i++) { const address = tokenAccounts[i].address; const tokenAccount = tokenAccounts[i]; const accountInfo = accountInfos.value[i]; const balance = extractTokenBalance(accountInfo, tokenAccount.mint); if (balance) { tokenBalancesBefore.set(address, balance); } } // Process lamports accounts const lamportsBalancesBefore = new Map(); for (let i = tokenAccounts.length; i < allAccounts.length; i++) { const address = lamportsAccounts[i - tokenAccounts.length]; const accountInfo = accountInfos.value[i]; const lamports = new BN(accountInfo?.lamports.toString() || 0); lamportsBalancesBefore.set(address, lamports); } return { tokenBalancesBefore, lamportsBalancesBefore }; } async function simulateTransactionWithAccountTracking( rpc: Rpc, transaction: Transaction, tokenAccounts: TokenAccountMetadata[], lamportsAccounts: Address[], ): Promise<{ response: SimulateTransactionResponse; simulationTimestamp: number }> { const simulationResponse = await rpc .simulateTransaction(getBase64EncodedWireTransaction(transaction), { sigVerify: false, commitment: 'confirmed', replaceRecentBlockhash: true, encoding: 'base64', accounts: { encoding: 'base64', addresses: [...tokenAccounts.map((ta) => ta.address), ...lamportsAccounts], }, }) .send(); const simulationTimestamp = Date.now(); return { response: simulationResponse, simulationTimestamp }; } function processSwapWithATABalances( walletAddress: Address, result: GenericSimulationResult, inputTokenAta: Address, outputTokenAta: Address, ): SwapSimulationResult { if (!result.successful) { return { successful: false, inputToken: null, outputToken: null, nativeChangeAmount: new BN(0), isSlippageError: result.isSlippageError, simulationError: result.simulationError, simulationResponse: result.simulationResponse, simulationUrl: result.simulationUrl, simulationTimestamp: result.simulationTimestamp, }; } // Extract specific input/output token changes const inputTokenChange = result.tokenBalanceChanges.find((change) => change.address === inputTokenAta) || null; const outputTokenChange = result.tokenBalanceChanges.find((change) => change.address === outputTokenAta) || null; // Extract wallet lamports change const walletLamportsChange = result.lamportsBalanceChanges.find((change) => change.address === walletAddress); const nativeChangeAmount = walletLamportsChange ? walletLamportsChange.change : new BN(0); return { successful: true, inputToken: inputTokenChange, outputToken: outputTokenChange, nativeChangeAmount, isSlippageError: result.isSlippageError, simulationError: result.simulationError, simulationResponse: result.simulationResponse, simulationUrl: result.simulationUrl, simulationTimestamp: result.simulationTimestamp, }; } function getSimulationUrl(tx: Transaction): string { const encodedTxMessage = Buffer.from(tx.messageBytes).toString('base64'); const sigs = Object.keys(tx.signatures); const payer: string | undefined = sigs.length > 0 ? sigs[0] : undefined; const payerUriQuery = payer ? `&signatures=${encodeURIComponent(`[${payer}]`)}` : ''; return `https://explorer.solana.com/tx/inspector?message=${encodeURIComponent(encodedTxMessage)}${payerUriQuery}`; }