import type { NodeHTTPResponse } from '@trpc/server/adapters/node-http'; import isPlainObject from 'lodash/isPlainObject'; import isArray from 'lodash/isArray'; import { v4 as uuidv4 } from 'uuid'; import type { TransformedFunction, MappedPromiseItem, RequestData } from './trpc'; import { trpc } from './trpc'; import { logger } from './logger'; import { fnMap } from './fnMap'; import { promiseMap } from './promiseMap'; import { getClientFunctionCallEventSubject, clientFunctionReturnEventSubject } from './subjects'; interface TResponse extends NodeHTTPResponse { socketId: number; } /** * Subscribe client function calls. * It will be triggered once receiving data from client function call. */ export const subscribeClientFunctionCalls = () => { clientFunctionReturnEventSubject.subscribe({ next(data: any) { const map = promiseMap.get(data.socketId); if (map && map.has(data.reqId)) { const { resolve, reject } = map.get(data.reqId) as MappedPromiseItem; map.delete(data.reqId); if (data.payload instanceof Error) { reject(data.payload); } else { resolve(data.payload); } } } }); }; /** * Call client functions. * There are two situations where this function needs to be called: * One is for the server to call the client function and retrieve data from the client; * Another is for the server to send event data to the client (without the need to retrieve data from the client). * @param data The data received from the client, including function mapping relationship. * @param socketId The ID of websocket. * @param restArgs The data will be sent back to the client and only has a value when the client function is an event callback. * @returns A promise or undefined. */ const callClientFunction = ( data: Record, socketId: number, ...restArgs: unknown[] ): Promise | undefined => { const isEventCallback = restArgs.length > 0; let req: RequestData; if (isEventCallback) { req = { ...data, payloads: restArgs } as RequestData; // Use nextTick to wait subscription successful. // nextTick(() => getClientFunctionCallEventSubject(socketId).next(req)); getClientFunctionCallEventSubject(socketId).next(req); } else { req = { ...data, reqId: uuidv4() } as RequestData; if (!promiseMap.has(socketId)) { promiseMap.set(socketId, new Map()); } const map = promiseMap.get(socketId)!; return new Promise((resolve, reject) => { map.set(req.reqId, { resolve, reject }); // Use nextTick to wait subscription successful. // nextTick(() => getClientFunctionCallEventSubject(socketId).next(req)); getClientFunctionCallEventSubject(socketId).next(req); }); } }; export function transformParams( params: Record | unknown[], socketId: number ): void { isArray(params) ? transformArrayParams(params, socketId) : transformPlainObjectParams(params, socketId); } /** * Transform Parameters. * Create a server side function with mappingId from client if type is 'Function', and save it in a map; * Create an Error instance if type is 'Error'; * Replace field value with created data above. * @param params parameters from rpc client. * @param socketId The ID of websocket. */ export function transformPlainObjectParams( params: Record, socketId: number ): void { for (const key in params) { const param = params[key]; if (isPlainObject(param)) { if ( (param as Record).__type__ === 'Function' && (param as Record).mappingId ) { if (!fnMap.has(socketId)) { fnMap.set(socketId, new Map()); } const map = fnMap.get(socketId)!; if (map.has((param as Record).mappingId as string)) { params[key] = map.get((param as Record).mappingId as string); } else { params[key] = (...args: unknown[]) => { return callClientFunction( param as Record, socketId, ...args ); }; map.set( (param as Record).mappingId as string, params[key] as TransformedFunction ); } } else if ((param as Record).__type__ === 'Error') { params[key] = new Error((param as Record).stack as string); } } else if (isArray(param)) { transformArrayParams(param, socketId); } } } function transformArrayParams(params: unknown[], socketId: number): void { for (let i = 0; i < params.length; i++) { const param = params[i]; if (isPlainObject(param)) { if ( (param as Record).__type__ === 'Function' && (param as Record).mappingId ) { if (!fnMap.has(socketId)) { fnMap.set(socketId, new Map()); } const map = fnMap.get(socketId)!; if (map.has((param as Record).mappingId as string)) { params[i] = map.get((param as Record).mappingId as string); } else { params[i] = (...args: unknown[]) => { return callClientFunction( param as Record, socketId, ...args ); }; map.set( (param as Record).mappingId as string, params[i] as TransformedFunction ); } } else if ((param as Record).__type__ === 'Error') { params[i] = new Error((param as Record).stack as string); } else if (isPlainObject(param)) { transformPlainObjectParams(param as Record, socketId); } } else if (isArray(param)) { transformArrayParams(param, socketId); } } } export const { middleware } = trpc; /** * When rpc client wants to request to the rpc server, * it will transform the parameters first, mainly it will transform functions. * The rpc client will create an id to map the id to each function, and then send * the id to the rpc server. * The rpc server received the parameters from rpc client, * it will transform the parameter first too, mainly it will transform the parameter which includes `__type__: 'Function'` * and an mappingId. * it will create a new function and mapped the function to the mappingId from rpc client. * When rpc server want to call functions in rpc client, it will send the mappingId and data back to the rpc client. * The rpc client received the mappingId and the data, it will get the mapped function from cache and then call it with the data. * If the function calling is an event callback, the rpc client will not send back the function result to the rpc server, * otherwise, the rpc client will send back the function result to the rpc server. */ export const publicProcedure = trpc.procedure.use( trpc.middleware(({ rawInput, next, ctx, type, path }) => { const { socketId } = ctx.res as unknown as TResponse; logger.debug(`socketId: ${socketId}`); logger.debug(`trpc.middleware ${path}.${type} rawInput: ${JSON.stringify(rawInput)}`); if (socketId) { transformParams(rawInput as Record | unknown[], socketId); ctx.socketId = socketId; } return next(); }) ); export const { router } = trpc;