import * as pg from "pg"; import type CursorType from "pg-cursor"; import type { SQLOptions, SocketSQLStreamPacket, SocketSQLStreamServer } from "prostgles-types"; import { CHANNELS, omitKeys, pickKeys } from "prostgles-types"; import type { BasicCallback } from "../PubSubManager/PubSubManager"; import type { VoidFunction } from "../SchemaWatch/SchemaWatch"; import type { DB } from "../initProstgles"; import type { DboBuilder } from "./DboBuilder"; import type { PRGLIOSocket } from "./DboBuilderTypes"; import { getErrorAsObject, getSerializedClientErrorFromPGError } from "./dboBuilderUtils"; import { getDetailedFieldInfo } from "./runSql/runSqlUtils"; const Cursor = require("pg-cursor") as typeof CursorType; type ClientStreamedRequest = { socket: PRGLIOSocket; query: string; options: SQLOptions | undefined; persistConnection?: boolean; }; type StreamedQuery = ClientStreamedRequest & { id: number; cursor: CursorType | undefined; client: pg.Client | undefined; stop?: VoidFunction; onError: (error: any) => void; }; type Info = { command: string | undefined; fields: any[]; rowCount: number; duration: number; }; const socketIdToLastQueryId: Map = new Map(); const getSetShortSocketId = (socketId: string) => { const shortId = socketId.slice(0, 3); const currId = socketIdToLastQueryId.get(shortId) ?? 0; const newId = currId + 1; socketIdToLastQueryId.set(shortId, newId); return newId; }; export class QueryStreamer { db: DB; dboBuilder: DboBuilder; socketQueries: Map> = new Map(); adminClient: pg.Client; constructor(dboBuilder: DboBuilder) { this.dboBuilder = dboBuilder; this.db = dboBuilder.db; const setAdminClient = () => { this.adminClient = this.getConnection(undefined, { keepAlive: true }); return this.adminClient.connect(); }; this.adminClient = this.getConnection( (error) => { if (error.message?.includes("database") && error.message?.includes("does not exist")) return; console.log("Admin client error. Reconnecting...", error); void setAdminClient(); }, { keepAlive: true }, ); void this.adminClient.connect(); } getConnection = (onError: ((err: any) => void) | undefined, extraOptions?: pg.ClientConfig) => { const connectionInfo = typeof this.db.$cn === "string" ? { connectionString: this.db.$cn } : (this.db.$cn as any); const client = new pg.Client({ ...connectionInfo, ...extraOptions }); client.on("error", (err) => { onError?.(err); }); return client; }; onDisconnect = (socketId: string) => { const socketQueries = this.socketQueries.get(socketId); if (!socketQueries) return; Array.from(socketQueries.values()).forEach(({ client, stop }) => { stop?.(); /** end does not stop active query?! */ void client?.end(); }); this.socketQueries.delete(socketId); }; create = (query: ClientStreamedRequest): SocketSQLStreamServer => { const { socket, persistConnection } = query; const socketId = socket.id; const id = getSetShortSocketId(socketId); const channel = `${CHANNELS.SQL_STREAM}__${socketId}_${id}`; const unsubChannel = `${channel}.unsubscribe`; if (this.socketQueries.get(socketId)?.get(id) && !persistConnection) { throw `Must stop existing query ${id} first`; } let errored = false; const socketQuery: StreamedQuery = { ...query, id, client: undefined, cursor: undefined, onError: (rawError: any) => { if (errored) return; errored = true; const errorWithoutQuery = getSerializedClientErrorFromPGError(rawError, { type: "sql", localParams: { clientReq: { socket } }, prostgles: this.dboBuilder.prostgles, }); // For some reason query is not present on the error object from sql stream mode const error = { ...errorWithoutQuery, query: query.query }; socket.emit(channel, { type: "error", error, } satisfies SocketSQLStreamPacket); }, }; const socketQueries = this.socketQueries.get(socketId) ?? new Map(); this.socketQueries.set(socketId, socketQueries.set(id, socketQuery)); let processID = -1; let streamState: "started" | "ended" | "errored" | undefined; const startStream = async (client: pg.Client | undefined, query: ClientStreamedRequest) => { await this.dboBuilder.cacheDBTypes(); const socketQuery = this.socketQueries.get(socketId)?.get(id); if (!socketQuery) { throw "socket query not found"; } /** Only send fields on first request */ let fieldsWereSent = false; const emit = ({ reachedEnd, rows, info, }: | { reachedEnd: true; rows: any[]; info: Info } | { reachedEnd: false; rows: any[]; info: Omit }) => { if (!(info as any).fields) throw "No fields"; const fields = getDetailedFieldInfo(this.dboBuilder.dbTypesCache!, info.fields); const packet: SocketSQLStreamPacket = { type: "data", rows, fields: fieldsWereSent ? undefined : fields, info: reachedEnd ? info : undefined, ended: reachedEnd, processId: processID, }; socket.emit(channel, packet); if (reachedEnd) { this.dboBuilder.prostgles.schemaWatch?.onSchemaChangeFallback?.({ command: info.command, query: query.query, }); } fieldsWereSent = true; }; const currentClient = client ?? this.getConnection((err) => { socketQuery.onError(err); void currentClient.end(); }); socketQuery.client = currentClient; try { if (!client) { await currentClient.connect(); } processID = (currentClient as any).processID; if ( query.options?.streamLimit && (!Number.isInteger(query.options.streamLimit) || query.options.streamLimit < 0) ) { throw "streamLimit must be a positive integer"; } const batchSize = query.options?.streamLimit ? Math.min(1e3, query.options.streamLimit) : 1e3; const cursor = currentClient.query( new Cursor(query.query, undefined, { rowMode: "array" }), ); socketQuery.cursor = cursor; let streamLimitReached = false; let reachedEnd = false; void (async () => { try { let rowChunk: any[] = []; let rowsSent = 0; do { rowChunk = await cursor.read(batchSize); const info = pickKeys((cursor as any)._result, [ "fields", "rowCount", "command", "duration", ]) as Info; rowsSent += rowChunk.length; streamLimitReached = Boolean( query.options?.streamLimit && rowsSent >= query.options.streamLimit, ); reachedEnd = rowChunk.length < batchSize; emit({ info, rows: rowChunk, reachedEnd: reachedEnd || streamLimitReached, }); } while (!reachedEnd && !streamLimitReached); streamState = "ended"; if (!query.options?.persistStreamConnection) { this.socketQueries.get(socketId)?.delete(id); void currentClient.end(); } void cursor.close(); } catch (error: any) { streamState = "errored"; if (error.message === "cannot insert multiple commands into a prepared statement") { this.dboBuilder .runSQL( query.query, {}, { returnType: "arrayMode", hasParams: false }, { clientReq: { socket: query.socket }, }, ) .then((res) => { emit({ info: omitKeys(res, ["rows"]), reachedEnd: true, rows: res.rows, }); }) .catch((newError) => { socketQuery.onError(newError); }); } else { socketQuery.onError(error); } } })(); } catch (err) { socketQuery.onError(err); await currentClient.end(); } }; const cleanup = () => { socket.removeAllListeners(unsubChannel); socket.removeAllListeners(channel); this.socketQueries.get(socketId)?.delete(id); }; const stop = async (opts: { terminate?: boolean } | undefined, cb: BasicCallback) => { const { client: queryClient } = this.socketQueries.get(socketId)?.get(id) ?? {}; if (!queryClient) { cb(null, "No active query client found"); return; } if (opts?.terminate) { setTimeout(() => { void queryClient.end(); }, 4e3); } try { const stopFunction = opts?.terminate ? "pg_terminate_backend" : "pg_cancel_backend"; const rows = await this.adminClient.query( `SELECT ${stopFunction}(pid), pid, state, query FROM pg_stat_activity WHERE pid = $1`, [processID], ); cleanup(); cb({ processID, info: rows.rows[0] }); } catch (error) { cb(null, error); } }; socketQuery.stop = () => stop({ terminate: true }, () => { /* Empty */ }); socket.removeAllListeners(unsubChannel); socket.once(unsubChannel, stop); let runCount = 0; socket.removeAllListeners(channel); socket.on( channel, async (_data: { query: string; params: any } | undefined, cb: BasicCallback) => { if (streamState === "started") { return cb(processID, "Already started"); } streamState = "started"; try { /* Persisted connection query */ if (runCount) { const persistedClient = this.socketQueries.get(socketId)?.get(id); if (!persistedClient) throw "Persisted query client not found"; await startStream(persistedClient.client, { ...query, query: _data!.query, }); } else { await startStream(undefined, query); } cb(processID); } catch (err) { console.error(err); cb(processID, getErrorAsObject(err)); } runCount++; }, ); /** If not started within 5 seconds then assume it will never happen */ setTimeout(() => { if (streamState) return; cleanup(); }, 5e3); return { channel, unsubChannel, }; }; }