import type { QueryClient } from '@tanstack/react-query'; import type { TRPCClient } from '@trpc/client'; import type { AnyTRPCRouter } from '@trpc/server'; import * as React from 'react'; import type { TRPCOptionsProxy } from './createOptionsProxy'; import { createTRPCOptionsProxy } from './createOptionsProxy'; import type { DefaultFeatureFlags, FeatureFlags, KeyPrefixOptions, } from './types'; type TRPCProviderType< TRouter extends AnyTRPCRouter, TFeatureFlags extends FeatureFlags = DefaultFeatureFlags, > = React.FC< { children: React.ReactNode; queryClient: QueryClient; trpcClient: TRPCClient; } & KeyPrefixOptions >; export interface CreateTRPCContextResult< TRouter extends AnyTRPCRouter, TFeatureFlags extends FeatureFlags = DefaultFeatureFlags, > { TRPCProvider: TRPCProviderType; useTRPC: () => TRPCOptionsProxy; useTRPCClient: () => TRPCClient; } /** * Create a set of type-safe provider-consumers * * @see https://trpc.io/docs/client/tanstack-react-query/setup#3a-setup-the-trpc-context-provider */ export function createTRPCContext< TRouter extends AnyTRPCRouter, TFeatureFlags extends FeatureFlags = DefaultFeatureFlags, >(): CreateTRPCContextResult { const TRPCClientContext = React.createContext | null>( null, ); const TRPCContext = React.createContext | null>(null); const TRPCProvider: TRPCProviderType = (props) => { const value = React.useMemo( () => createTRPCOptionsProxy({ client: props.trpcClient, queryClient: props.queryClient, keyPrefix: props.keyPrefix as any, }), [props.trpcClient, props.queryClient, props.keyPrefix], ); return ( {props.children} ); }; TRPCProvider.displayName = 'TRPCProvider'; function useTRPC() { const utils = React.useContext(TRPCContext); if (!utils) { throw new Error('useTRPC() can only be used inside of a '); } return utils; } function useTRPCClient() { const client = React.useContext(TRPCClientContext); if (!client) { throw new Error( 'useTRPCClient() can only be used inside of a ', ); } return client; } return { TRPCProvider, useTRPC, useTRPCClient } as CreateTRPCContextResult< TRouter, TFeatureFlags >; }