import { getSolanaErrorFromJsonRpcError, SOLANA_ERROR__INVARIANT_VIOLATION__DATA_PUBLISHER_CHANNEL_UNIMPLEMENTED, SOLANA_ERROR__RPC_SUBSCRIPTIONS__EXPECTED_SERVER_SUBSCRIPTION_ID, SolanaError, } from '@solana/errors'; import { AbortController } from '@solana/event-target-impl'; import { safeRace } from '@solana/promises'; import { createRpcMessage, RpcRequest, RpcResponseData, RpcResponseTransformer } from '@solana/rpc-spec-types'; import { DataPublisher } from '@solana/subscribable'; import { demultiplexDataPublisher } from '@solana/subscribable'; import { RpcSubscriptionChannelEvents } from './rpc-subscriptions-channel'; import { RpcSubscriptionsChannel } from './rpc-subscriptions-channel'; type Config = Readonly<{ channel: RpcSubscriptionsChannel | RpcResponseData>; responseTransformer?: RpcResponseTransformer; signal: AbortSignal; subscribeRequest: RpcRequest; unsubscribeMethodName: string; }>; type RpcNotification = Readonly<{ method: string; params: Readonly<{ result: TNotification; subscription: number; }>; }>; type RpcSubscriptionId = number; type RpcSubscriptionNotificationEvents = Omit, 'message'> & { notification: TNotification; }; const subscriberCountBySubscriptionIdByChannel = new WeakMap>(); function decrementSubscriberCountAndReturnNewCount(channel: WeakKey, subscriptionId?: number): number | undefined { return augmentSubscriberCountAndReturnNewCount(-1, channel, subscriptionId); } function incrementSubscriberCount(channel: WeakKey, subscriptionId?: number): void { augmentSubscriberCountAndReturnNewCount(1, channel, subscriptionId); } function getSubscriberCountBySubscriptionIdForChannel(channel: WeakKey): Record { let subscriberCountBySubscriptionId = subscriberCountBySubscriptionIdByChannel.get(channel); if (!subscriberCountBySubscriptionId) { subscriberCountBySubscriptionIdByChannel.set(channel, (subscriberCountBySubscriptionId = {})); } return subscriberCountBySubscriptionId; } function augmentSubscriberCountAndReturnNewCount( amount: -1 | 1, channel: WeakKey, subscriptionId?: number, ): number | undefined { if (subscriptionId === undefined) { return; } const subscriberCountBySubscriptionId = getSubscriberCountBySubscriptionIdForChannel(channel); if (!subscriberCountBySubscriptionId[subscriptionId] && amount > 0) { subscriberCountBySubscriptionId[subscriptionId] = 0; } const newCount = amount + subscriberCountBySubscriptionId[subscriptionId]; if (newCount <= 0) { delete subscriberCountBySubscriptionId[subscriptionId]; } else { subscriberCountBySubscriptionId[subscriptionId] = newCount; } return newCount; } const cache = new WeakMap(); function getMemoizedDemultiplexedNotificationPublisherFromChannelAndResponseTransformer( channel: RpcSubscriptionsChannel>, subscribeRequest: RpcRequest, responseTransformer?: RpcResponseTransformer, ): DataPublisher<{ [channelName: `notification:${number}`]: TNotification; }> { let publisherByResponseTransformer = cache.get(channel); if (!publisherByResponseTransformer) { cache.set(channel, (publisherByResponseTransformer = new WeakMap())); } const responseTransformerKey = responseTransformer ?? channel; let publisher = publisherByResponseTransformer.get(responseTransformerKey); if (!publisher) { publisherByResponseTransformer.set( responseTransformerKey, (publisher = demultiplexDataPublisher(channel, 'message', rawMessage => { const message = rawMessage as RpcNotification | RpcResponseData; if (!('method' in message)) { return; } const transformedNotification = responseTransformer ? responseTransformer(message.params.result, subscribeRequest) : message.params.result; return [`notification:${message.params.subscription}`, transformedNotification]; })), ); } return publisher; } /** * Given a channel, this function executes the particular subscription plan required by the Solana * JSON RPC Subscriptions API. * * @param config * * 1. Calls the `subscribeRequest` on the remote RPC * 2. Waits for a response containing the subscription id * 3. Returns a {@link DataPublisher} that publishes notifications related to that subscriptions id, * filtering out all others * 4. Calls the `unsubscribeMethodName` on the remote RPC when the abort signal is fired. */ export async function executeRpcPubSubSubscriptionPlan({ channel, responseTransformer, signal, subscribeRequest, unsubscribeMethodName, }: Config): Promise>> { let subscriptionId: number | undefined; channel.on( 'error', () => { // An error on the channel indicates that the subscriptions are dead. // There is no longer any sense hanging on to subscription ids. // Erasing it here will prevent the unsubscribe code from running. subscriptionId = undefined; subscriberCountBySubscriptionIdByChannel.delete(channel); }, { signal }, ); /** * STEP 1 * Create a promise that rejects if this subscription is aborted and sends * the unsubscribe message if the subscription is active at that time. */ const abortPromise = new Promise((_, reject) => { function handleAbort(this: AbortSignal) { /** * Because of https://github.com/solana-labs/solana/pull/18943, two subscriptions for * materially the same notification will be coalesced on the server. This means they * will be assigned the same subscription id, and will occupy one subscription slot. We * must be careful not to send the unsubscribe message until the last subscriber aborts. */ if (decrementSubscriberCountAndReturnNewCount(channel, subscriptionId) === 0) { const unsubscribePayload = createRpcMessage({ methodName: unsubscribeMethodName, params: [subscriptionId], }); subscriptionId = undefined; channel.send(unsubscribePayload).catch(() => {}); } // eslint-disable-next-line @typescript-eslint/prefer-promise-reject-errors reject(this.reason); } if (signal.aborted) { handleAbort.call(signal); } else { signal.addEventListener('abort', handleAbort); } }); /** * STEP 2 * Send the subscription request. */ const subscribePayload = createRpcMessage(subscribeRequest); await channel.send(subscribePayload); /** * STEP 3 * Wait for the acknowledgement from the server with the subscription id. */ const subscriptionIdPromise = new Promise((resolve, reject) => { const abortController = new AbortController(); signal.addEventListener('abort', abortController.abort.bind(abortController)); const options = { signal: abortController.signal } as const; channel.on( 'error', err => { abortController.abort(); reject(err); }, options, ); channel.on( 'message', message => { if (message && typeof message === 'object' && 'id' in message && message.id === subscribePayload.id) { abortController.abort(); if ('error' in message) { reject(getSolanaErrorFromJsonRpcError(message.error)); } else { resolve(message.result); } } }, options, ); }); subscriptionId = await safeRace([abortPromise, subscriptionIdPromise]); if (subscriptionId == null) { throw new SolanaError(SOLANA_ERROR__RPC_SUBSCRIPTIONS__EXPECTED_SERVER_SUBSCRIPTION_ID); } incrementSubscriberCount(channel, subscriptionId); /** * STEP 4 * Filter out notifications unrelated to this subscription. */ const notificationPublisher = getMemoizedDemultiplexedNotificationPublisherFromChannelAndResponseTransformer( channel, subscribeRequest, responseTransformer, ); const notificationKey = `notification:${subscriptionId}` as const; return { on(type, listener, options) { switch (type) { case 'notification': return notificationPublisher.on( notificationKey, listener as (data: RpcSubscriptionNotificationEvents['notification']) => void, options, ); case 'error': return channel.on( 'error', listener as (data: RpcSubscriptionNotificationEvents['error']) => void, options, ); default: throw new SolanaError(SOLANA_ERROR__INVARIANT_VIOLATION__DATA_PUBLISHER_CHANNEL_UNIMPLEMENTED, { channelName: type, supportedChannelNames: ['notification', 'error'], }); } }, }; }