/* eslint-disable @typescript-eslint/no-loop-func */ import { TRANSACTION_ROLLBACK_ERROR_PREFIX } from "../constants.js"; import { transactionContext } from "../contexts/transactionContext.js"; import type { ConnectionPoolClient } from "../factories/createConnectionPool.js"; import { getPoolClientState } from "../state.js"; import type { ClientConfiguration, Interceptor, Logger, Query, QueryContext, QueryId, QueryResult, QueryResultRow, StreamResult, } from "../types.js"; import { getStackTrace } from "./getStackTrace.js"; import { SpanStatusCode, trace } from "@opentelemetry/api"; import type { DriverNotice } from "@slonik/driver"; import { BackendTerminatedError, BackendTerminatedUnexpectedlyError, DataIntegrityError, InvalidInputError, NotFoundError, SlonikError, TupleMovedToAnotherPartitionError, UnexpectedForeignConnectionError, UnexpectedStateError, } from "@slonik/errors"; import type { PrimitiveValueExpression, QuerySqlToken } from "@slonik/sql-tag"; import { defer, generateUid } from "@slonik/utilities"; import pLimit from "p-limit"; import { serializeError } from "serialize-error"; export type IntegrityValidation = { validationType: | "MANY_COLUMNS" | "MANY_ROWS" | "MANY_ROWS_ONE_COLUMN" | "MAYBE_MANY_ROWS_ONE_COLUMN" | "MAYBE_ONE_COLUMN" | "MAYBE_ONE_ROW" | "ONE_COLUMN" | "ONE_ROW"; }; const tracer = trace.getTracer("slonik.interceptors"); export type ExecutionRoutine = ( connection: ConnectionPoolClient, sql: string, values: readonly PrimitiveValueExpression[], queryContext: QueryContext, query: Query, ) => Promise; type GenericQueryResult = QueryResult | StreamResult; type TransactionQuery = { readonly executionContext: QueryContext; readonly executionRoutine: ExecutionRoutine; readonly name?: string; readonly sql: string; readonly values: readonly PrimitiveValueExpression[]; }; const retryQuery = async ( connectionLogger: Logger, connection: ConnectionPoolClient, query: TransactionQuery, retryLimit: number, ) => { let result: GenericQueryResult; let remainingRetries = retryLimit; let attempt = 0; // @todo Provide information about the queries being retried to the logger. while (remainingRetries-- > 0) { attempt++; try { connectionLogger.trace( { attempt, queryId: query.executionContext.queryId, }, "retrying query", ); result = await query.executionRoutine( connection, query.sql, // @todo Refresh execution context to reflect that the query has been re-tried. query.values, // This (probably) requires changing `queryId` and `queryInputTime`. // It should be needed only for the last query (because other queries will not be processed by the middlewares). query.executionContext, { name: query.name, sql: query.sql, values: query.values, }, ); // If the attempt succeeded break out of the loop break; } catch (error) { if ( typeof error.code === "string" && error.code.startsWith(TRANSACTION_ROLLBACK_ERROR_PREFIX) && remainingRetries > 0 ) { continue; } throw error; } } // eslint-disable-next-line @typescript-eslint/no-non-null-assertion return result!; }; type StackCrumb = { columnNumber: null | number; fileName: null | string; functionName: null | string; lineNumber: null | number; }; // oxlint-disable-next-line complexity const executeQueryInternal = async ( connectionLogger: Logger, connection: ConnectionPoolClient, clientConfiguration: ClientConfiguration, query: QuerySqlToken, inheritedQueryId: QueryId | undefined, executionRoutine: ExecutionRoutine, integrityValidation?: IntegrityValidation, stream: boolean = false, ): Promise< QueryResult> | StreamResult // eslint-disable-next-line complexity > => { const poolClientState = getPoolClientState(connection); if (poolClientState.terminated) { throw new BackendTerminatedError(poolClientState.terminated); } if (query.sql.trim() === "") { throw new InvalidInputError("Unexpected SQL input. Query cannot be empty."); } if (query.sql.trim() === "$1") { throw new InvalidInputError( "Unexpected SQL input. Query cannot be empty. Found only value binding.", ); } const transactionStore = transactionContext.getStore(); if ( clientConfiguration.dangerouslyAllowForeignConnections !== true && transactionStore?.transactionId && transactionStore.transactionId !== poolClientState.transactionId ) { throw new UnexpectedForeignConnectionError(); } const queryInputTime = process.hrtime.bigint(); let stackTrace: null | StackCrumb[] = null; if (clientConfiguration.captureStackTrace) { stackTrace = getStackTrace(); } const queryId = inheritedQueryId ?? generateUid(); const log = connectionLogger.child({ queryId, }); const originalQuery = { // Include statement name for prepared statements if provided name: query.name, // See comments in `formatSlonikPlaceholder` for more information. sql: query.sql.replaceAll("$slonik_", "$"), values: query.values, }; let actualQuery: Query = { name: originalQuery.name, sql: originalQuery.sql, values: originalQuery.values, }; const executionContext: QueryContext = { connectionId: poolClientState.connectionId, log, originalQuery, poolId: poolClientState.poolId, queryId, queryInputTime, resultParser: query.parser, sandbox: {}, stackTrace, transactionId: poolClientState.transactionId, }; for (const interceptor of clientConfiguration.interceptors) { const beforeTransformQuery = interceptor.beforeTransformQuery; if (beforeTransformQuery) { await tracer.startActiveSpan("slonik.interceptor.beforeTransformQuery", async (span) => { span.setAttribute("interceptor.name", interceptor.name); try { await beforeTransformQuery(executionContext, actualQuery); } catch (error) { span.recordException(error); span.setStatus({ code: SpanStatusCode.ERROR, message: String(error), }); throw error; } finally { span.end(); } }); } } for (const interceptor of clientConfiguration.interceptors) { const transformQuery = interceptor.transformQuery; if (transformQuery) { actualQuery = tracer.startActiveSpan("slonik.interceptor.transformQuery", (span) => { span.setAttribute("interceptor.name", interceptor.name); try { return transformQuery(executionContext, actualQuery); } catch (error) { span.recordException(error); span.setStatus({ code: SpanStatusCode.ERROR, message: String(error), }); throw error; } finally { span.end(); } }); } } let result: GenericQueryResult | null; if (!stream) { for (const interceptor of clientConfiguration.interceptors) { const beforeQueryExecution = interceptor.beforeQueryExecution; if (beforeQueryExecution) { result = await tracer.startActiveSpan( "slonik.interceptor.beforeQueryExecution", { attributes: { "interceptor.name": interceptor.name, }, }, async (span) => { try { return await beforeQueryExecution(executionContext, actualQuery); } catch (error) { span.recordException(error); span.setStatus({ code: SpanStatusCode.ERROR, message: String(error), }); throw error; } finally { span.end(); } }, ); if (result) { log.info( "beforeQueryExecution interceptor produced a result; short-circuiting query execution using beforeQueryExecution result", ); return result; } } } } const notices: DriverNotice[] = []; const noticeListener = (notice: DriverNotice) => { notices.push(notice); }; const activeQuery = defer(); const blockingPromise = poolClientState.activeQuery?.promise ?? null; poolClientState.activeQuery = activeQuery; await blockingPromise; connection.on("notice", noticeListener); const queryWithContext = { executionContext, executionRoutine, name: actualQuery.name, sql: actualQuery.sql, values: actualQuery.values, }; try { try { try { // eslint-disable-next-line require-atomic-updates result = await executionRoutine( connection, actualQuery.sql, actualQuery.values, executionContext, actualQuery, ); } catch (error) { const shouldRetry = typeof error.code === "string" && error.code.startsWith(TRANSACTION_ROLLBACK_ERROR_PREFIX) && clientConfiguration.queryRetryLimit > 0; // Transactions errors in queries that are part of a transaction are handled by the transaction/nestedTransaction functions if (shouldRetry && !poolClientState.transactionId) { // eslint-disable-next-line require-atomic-updates result = await retryQuery( connectionLogger, connection, queryWithContext, clientConfiguration.queryRetryLimit, ); } else { throw error; } } } catch (error) { // The driver is responsible for throwing an appropriately wrapped error. if (error instanceof BackendTerminatedError) { poolClientState.terminated = error; } // If the error has been already handled by the driver, then we should not wrap it again. if (!(error instanceof SlonikError)) { if (error.message === "Connection terminated unexpectedly") { throw new BackendTerminatedUnexpectedlyError(error); } if ( error.message.includes( "tuple to be locked was already moved to another partition due to concurrent update", ) ) { throw new TupleMovedToAnotherPartitionError(error); } } error.notices = notices; throw error; } finally { connection.off("notice", noticeListener); activeQuery.resolve(null); } } catch (error) { log.error( { error: serializeError(error), }, "execution routine produced an error", ); for (const interceptor of clientConfiguration.interceptors) { if (interceptor.queryExecutionError) { await interceptor.queryExecutionError(executionContext, actualQuery, error, notices); } } error.notices = notices; throw error; } if (!result) { throw new UnexpectedStateError("Expected query result to be returned."); } // @ts-expect-error -- We want to keep notices as readonly for consumer, but write to it here. result.notices = notices; const interceptors: Interceptor[] = clientConfiguration.interceptors.slice(); if (result.type !== "QueryResult") { return result; } try { if (integrityValidation) { if (integrityValidation.validationType === "ONE_ROW") { if (result.rows.length === 0) { throw new NotFoundError("Query returned no rows.", actualQuery); } if (result.rows.length > 1) { throw new DataIntegrityError("Query returned multiple rows.", actualQuery); } } if (integrityValidation.validationType === "MAYBE_ONE_ROW" && result.rows.length > 1) { throw new DataIntegrityError("Query returned multiple rows.", actualQuery); } // Cache the column count from the first row to avoid repeated Object.keys() calls. // It is safe to assume that whatever the first row is, it will be the same for all rows. const firstRowColumnCount = result.rows.length > 0 ? Object.keys(result.rows[0]).length : 0; if (integrityValidation.validationType === "ONE_COLUMN") { if (result.rows.length === 0) { throw new NotFoundError("Query returned no rows.", actualQuery); } if (result.rows.length !== 1) { throw new DataIntegrityError("Query returned multiple rows.", actualQuery); } if (firstRowColumnCount !== 1) { throw new DataIntegrityError("Query returned rows with multiple columns.", actualQuery); } } if (integrityValidation.validationType === "MAYBE_ONE_COLUMN") { if (result.rows.length > 1) { throw new DataIntegrityError("Query returned multiple rows.", actualQuery); } if (result.rows.length === 1 && firstRowColumnCount !== 1) { throw new DataIntegrityError("Query returned rows with multiple columns.", actualQuery); } } if (integrityValidation.validationType === "MANY_ROWS" && result.rows.length === 0) { throw new NotFoundError("Query returned no rows.", actualQuery); } if (integrityValidation.validationType === "MANY_ROWS_ONE_COLUMN") { if (result.rows.length === 0) { throw new NotFoundError("Query returned no rows.", actualQuery); } if (firstRowColumnCount !== 1) { throw new DataIntegrityError("Query returned rows with multiple columns.", actualQuery); } } if ( integrityValidation.validationType === "MAYBE_MANY_ROWS_ONE_COLUMN" && result.rows.length > 0 && firstRowColumnCount !== 1 ) { throw new DataIntegrityError("Query returned rows with multiple columns.", actualQuery); } } } catch (error) { for (const interceptor of clientConfiguration.interceptors) { if (interceptor.dataIntegrityError) { await interceptor.dataIntegrityError(executionContext, actualQuery, error, result); } } throw error; } for (const interceptor of interceptors) { const afterQueryExecution = interceptor.afterQueryExecution; if (afterQueryExecution) { await tracer.startActiveSpan( "slonik.interceptor.afterQueryExecution", { attributes: { "interceptor.name": interceptor.name, }, }, async (span) => { try { await afterQueryExecution( executionContext, actualQuery, result as QueryResult, ); } catch (error) { span.recordException(error); span.setStatus({ code: SpanStatusCode.ERROR, message: String(error), }); throw error; } finally { span.end(); } }, ); } } for (const interceptor of interceptors) { const transformRow = interceptor.transformRow; if (transformRow) { const { fields, rows } = result; const transformedRows: QueryResultRow[] = tracer.startActiveSpan( "slonik.interceptor.transformRow", { attributes: { "interceptor.name": interceptor.name, "rows.length": rows.length, }, }, (span) => { try { return rows.map((row) => { return transformRow(executionContext, actualQuery, row, fields); }); } catch (error) { span.recordException(error); span.setStatus({ code: SpanStatusCode.ERROR, message: String(error), }); throw error; } finally { span.end(); } }, ); // avoid spreading the result object to avoid performance overhead // eslint-disable-next-line @typescript-eslint/no-explicit-any (result as any).rows = transformedRows; } } for (const interceptor of interceptors) { const transformRowAsync = interceptor.transformRowAsync; if (transformRowAsync) { const { fields, rows } = result; const transformedRows: QueryResultRow[] = await tracer.startActiveSpan( "slonik.interceptor.transformRowAsync", { attributes: { "interceptor.name": interceptor.name, "rows.length": rows.length, }, }, async (span) => { try { const limit = pLimit(10); return await Promise.all( rows.map((row) => { return limit(() => transformRowAsync(executionContext, actualQuery, row, fields)); }), ); } catch (error) { span.recordException(error); span.setStatus({ code: SpanStatusCode.ERROR, message: String(error), }); throw error; } finally { span.end(); } }, ); // avoid spreading the result object to avoid performance overhead // eslint-disable-next-line @typescript-eslint/no-explicit-any (result as any).rows = transformedRows; } } for (const interceptor of interceptors) { const beforeQueryResult = interceptor.beforeQueryResult; if (beforeQueryResult) { await tracer.startActiveSpan( "slonik.interceptor.beforeQueryResult", { attributes: { "interceptor.name": interceptor.name, }, }, async (span) => { try { await beforeQueryResult( executionContext, actualQuery, result as QueryResult, ); } catch (error) { span.recordException(error); span.setStatus({ code: SpanStatusCode.ERROR, message: String(error), }); throw error; } finally { span.end(); } }, ); } } return result; }; export const executeQuery = async ( connectionLogger: Logger, connection: ConnectionPoolClient, clientConfiguration: ClientConfiguration, query: QuerySqlToken, inheritedQueryId: QueryId | undefined, executionRoutine: ExecutionRoutine, stream: boolean, integrityValidation?: IntegrityValidation, ): Promise> | StreamResult> => { return await tracer.startActiveSpan( "slonik.executeQuery", { attributes: { sql: query.sql, }, }, async (span) => { try { return await executeQueryInternal( connectionLogger, connection, clientConfiguration, query, inheritedQueryId, executionRoutine, integrityValidation, stream, ); } catch (error) { span.recordException(error); span.setStatus({ code: SpanStatusCode.ERROR, message: String(error), }); throw error; } finally { span.end(); } }, ); };