/** * This component will typically wrap your entire application (or a sub-tree of your application where you want to have an AI assistant). It provides the AI context to all other components and hooks. * * ## Example * * You can find more information about self-hosting VN SDK [here](/guides/self-hosting). * * ```tsx * import { AiProvider } from "@vn-sdk/react-core"; * * * // ... your app ... * * ``` */ import { useCallback, useEffect, useMemo, useRef, useState, SetStateAction } from "react"; import { AiContext, AiApiConfig, ChatComponentsCache, AgentSession, AuthState, } from "../../context/ai-context"; import useTree from "../../hooks/use-tree"; import { AiChatSuggestionConfiguration, DocumentPointer } from "../../types"; import { flushSync } from "react-dom"; import { AI_CLOUD_CHAT_URL, AiCloudConfig, FunctionCallHandler, AI_CLOUD_PUBLIC_API_KEY_HEADER, randomUUID, ConfigurationError, MissingPublicApiKeyError, AiSDKError, AiErrorEvent, AiErrorHandler, } from "@vn-sdk/shared"; import { FrontendAction } from "../../types/frontend-action"; import useFlatCategoryStore from "../../hooks/use-flat-category-store"; import { AiProps } from "./ai-props"; import { AiAgentStateRender } from "../../types/ai-agent-action"; import { AiAgentState } from "../../types/ai-agent-state"; import { AiMessages, MessagesTapProvider } from "./ai-messages"; import { ToastProvider } from "../toast/toast-provider"; import { getErrorActions, UsageBanner } from "../usage-banner"; import { useAiRuntimeClient } from "../../hooks/use-ai-runtime-client"; import { shouldShowDevConsole } from "../../utils"; import { AiErrorBoundary } from "../error-boundary/error-boundary"; import { Agent, ExtensionsInput } from "@vn-sdk/runtime-client-gql"; import { LangGraphInterruptAction, LangGraphInterruptActionSetterArgs, } from "../../types/interrupt-action"; import { ConsoleTrigger } from "../dev-console/console-trigger"; export function AiProvider({ children, ...props }: AiProps) { const enabled = shouldShowDevConsole(props.showDevConsole); // Use API key if provided, otherwise use the license key const publicApiKey = props.publicApiKey || props.publicLicenseKey; return ( {children} ); } export function AiProviderInternal(cpkProps: AiProps) { const { children, ...props } = cpkProps; /** * This will throw an error if the props are invalid. */ validateProps(cpkProps); // Use license key as API key if provided, otherwise use the API key const publicApiKey = props.publicLicenseKey || props.publicApiKey; const chatApiEndpoint = props.runtimeUrl || AI_CLOUD_CHAT_URL; const [actions, setActions] = useState>>({}); const [aiAgentStateRenders, setAiAgentStateRenders] = useState< Record> >({}); const chatComponentsCache = useRef({ actions: {}, aiAgentStateRenders: {}, }); const { addElement, removeElement, printTree, getAllElements } = useTree(); const [isLoading, setIsLoading] = useState(false); const [chatInstructions, setChatInstructions] = useState(""); const [authStates, setAuthStates] = useState>({}); const [extensions, setExtensions] = useState({}); const [additionalInstructions, setAdditionalInstructions] = useState([]); const { addElement: addDocument, removeElement: removeDocument, allElements: allDocuments, } = useFlatCategoryStore(); // Compute all the functions and properties that we need to pass const setAction = useCallback((id: string, action: FrontendAction) => { setActions((prevPoints) => { return { ...prevPoints, [id]: action, }; }); }, []); const removeAction = useCallback((id: string) => { setActions((prevPoints) => { const newPoints = { ...prevPoints }; delete newPoints[id]; return newPoints; }); }, []); const setAiAgentStateRender = useCallback((id: string, stateRender: AiAgentStateRender) => { setAiAgentStateRenders((prevPoints) => { return { ...prevPoints, [id]: stateRender, }; }); }, []); const removeAiAgentStateRender = useCallback((id: string) => { setAiAgentStateRenders((prevPoints) => { const newPoints = { ...prevPoints }; delete newPoints[id]; return newPoints; }); }, []); const getContextString = useCallback( (documents: DocumentPointer[], categories: string[]) => { const documentsString = documents .map((document) => { return `${document.name} (${document.sourceApplication}):\n${document.getContents()}`; }) .join("\n\n"); const nonDocumentStrings = printTree(categories); return `${documentsString}\n\n${nonDocumentStrings}`; }, [printTree], ); const addContext = useCallback( ( context: string, parentId?: string, categories: string[] = defaultAiContextCategories, ) => { return addElement(context, categories, parentId); }, [addElement], ); const removeContext = useCallback( (id: string) => { removeElement(id); }, [removeElement], ); const getAllContext = useCallback(() => { return getAllElements(); }, [getAllElements]); const getFunctionCallHandler = useCallback( (customEntryPoints?: Record>) => { return entryPointsToFunctionCallHandler(Object.values(customEntryPoints || actions)); }, [actions], ); const getDocumentsContext = useCallback( (categories: string[]) => { return allDocuments(categories); }, [allDocuments], ); const addDocumentContext = useCallback( (documentPointer: DocumentPointer, categories: string[] = defaultAiContextCategories) => { return addDocument(documentPointer, categories); }, [addDocument], ); const removeDocumentContext = useCallback( (documentId: string) => { removeDocument(documentId); }, [removeDocument], ); // get the appropriate AiApiConfig from the props const aiApiConfig: AiApiConfig = useMemo(() => { let cloud: AiCloudConfig | undefined = undefined; if (publicApiKey) { cloud = { guardrails: { input: { restrictToTopic: { enabled: Boolean(props.guardrails_c), validTopics: props.guardrails_c?.validTopics || [], invalidTopics: props.guardrails_c?.invalidTopics || [], }, }, }, }; } return { publicApiKey: publicApiKey, ...(cloud ? { cloud } : {}), chatApiEndpoint: chatApiEndpoint, headers: props.headers || {}, properties: props.properties || {}, transcribeAudioUrl: props.transcribeAudioUrl, textToSpeechUrl: props.textToSpeechUrl, credentials: props.credentials, }; }, [ publicApiKey, props.headers, props.properties, props.transcribeAudioUrl, props.textToSpeechUrl, props.credentials, props.cloudRestrictToTopic, props.guardrails_c, ]); const headers = useMemo(() => { const authHeaders = Object.values(authStates || {}).reduce((acc, state) => { if (state.status === "authenticated" && state.authHeaders) { return { ...acc, ...Object.entries(state.authHeaders).reduce( (headers, [key, value]) => ({ ...headers, [key.startsWith("X-Custom-") ? key : `X-Custom-${key}`]: value, }), {}, ), }; } return acc; }, {}); return { ...(aiApiConfig.headers || {}), ...(aiApiConfig.publicApiKey ? { [AI_CLOUD_PUBLIC_API_KEY_HEADER]: aiApiConfig.publicApiKey } : {}), ...authHeaders, }; }, [aiApiConfig.headers, aiApiConfig.publicApiKey, authStates]); const [internalErrorHandlers, _setInternalErrorHandler] = useState< Record >({}); const setInternalErrorHandler = useCallback((handler: Record) => { _setInternalErrorHandler((prev: Record) => ({ ...prev, ...handler, })); }, []); const removeInternalErrorHandler = useCallback((key: string) => { _setInternalErrorHandler((prev) => { const { [key]: _removed, ...rest } = prev; return rest; }); }, []); // Keep latest values in refs const onErrorRef = useRef(props.onError); useEffect(() => { onErrorRef.current = props.onError; }, [props.onError]); const internalHandlersRef = useRef>({}); useEffect(() => { internalHandlersRef.current = internalErrorHandlers; }, [internalErrorHandlers]); const handleErrors = useCallback( async (error: AiErrorEvent) => { if (aiApiConfig.publicApiKey && onErrorRef.current) { try { await onErrorRef.current(error); } catch (e) { console.error("Error in public onError handler:", e); } } const handlers = Object.values(internalHandlersRef.current); await Promise.all( handlers.map((h) => Promise.resolve(h(error)).catch((e) => console.error("Error in internal error handler:", e), ), ), ); }, [aiApiConfig.publicApiKey], ); const runtimeClient = useAiRuntimeClient({ url: aiApiConfig.chatApiEndpoint, publicApiKey: publicApiKey, headers, credentials: aiApiConfig.credentials, showDevConsole: shouldShowDevConsole(props.showDevConsole), onError: handleErrors, }); const [chatSuggestionConfiguration, setChatSuggestionConfiguration] = useState<{ [key: string]: AiChatSuggestionConfiguration; }>({}); const addChatSuggestionConfiguration = useCallback( (id: string, suggestion: AiChatSuggestionConfiguration) => { setChatSuggestionConfiguration((prev) => ({ ...prev, [id]: suggestion })); }, [setChatSuggestionConfiguration], ); const removeChatSuggestionConfiguration = useCallback( (id: string) => { setChatSuggestionConfiguration((prev) => { const { [id]: _, ...rest } = prev; return rest; }); }, [setChatSuggestionConfiguration], ); const [availableAgents, setAvailableAgents] = useState([]); const [aiAgentStates, setAiAgentStates] = useState>({}); const aiAgentStatesRef = useRef>({}); const setAiAgentStatesWithRef = useCallback( ( value: | Record | ((prev: Record) => Record), ) => { const newValue = typeof value === "function" ? value(aiAgentStatesRef.current) : value; aiAgentStatesRef.current = newValue; setAiAgentStates((prev) => { return newValue; }); }, [], ); const hasLoadedAgents = useRef(false); useEffect(() => { if (hasLoadedAgents.current) return; const fetchData = async () => { const result = await runtimeClient.availableAgents(); if (result.data?.availableAgents) { setAvailableAgents(result.data.availableAgents.agents); } hasLoadedAgents.current = true; }; void fetchData(); }, []); let initialAgentSession: AgentSession | null = null; if (props.agent) { initialAgentSession = { agentName: props.agent, }; } const [agentSession, setAgentSession] = useState(initialAgentSession); // Update agentSession when props.agent changes useEffect(() => { if (props.agent) { setAgentSession({ agentName: props.agent, }); } else { setAgentSession(null); } }, [props.agent]); const [internalThreadId, setInternalThreadId] = useState(props.threadId || randomUUID()); const setThreadId = useCallback( (value: SetStateAction) => { if (props.threadId) { throw new Error("Cannot call setThreadId() when threadId is provided via props."); } setInternalThreadId(value); }, [props.threadId], ); // update the internal threadId if the props.threadId changes useEffect(() => { if (props.threadId !== undefined) { setInternalThreadId(props.threadId); } }, [props.threadId]); const [runId, setRunId] = useState(null); const chatAbortControllerRef = useRef(null); const showDevConsole = shouldShowDevConsole(props.showDevConsole); const [langGraphInterruptActions, _setLangGraphInterruptAction] = useState< Record >({}); const setLangGraphInterruptAction = useCallback( (threadId: string, action: LangGraphInterruptActionSetterArgs) => { _setLangGraphInterruptAction((prev) => { if (action == null) return { ...prev, [threadId]: null, }; let event = prev[threadId]?.event; if (action.event) { // @ts-ignore event = { ...(prev[threadId]?.event || {}), ...action.event }; } return { ...prev, [threadId]: { ...(prev[threadId] ?? {}), ...action, event } as LangGraphInterruptAction, }; }); }, [], ); const removeLangGraphInterruptAction = useCallback((threadId: string): void => { setLangGraphInterruptAction(threadId, null); }, []); const memoizedChildren = useMemo(() => children, [children]); const [bannerError, setBannerError] = useState(null); const agentLock = useMemo(() => props.agent ?? null, [props.agent]); const forwardedParameters = useMemo( () => props.forwardedParameters ?? {}, [props.forwardedParameters], ); const updateExtensions = useCallback( (newExtensions: SetStateAction) => { setExtensions((prev: ExtensionsInput) => { const resolved = typeof newExtensions === "function" ? newExtensions(prev) : newExtensions; const isSameLength = Object.keys(resolved).length === Object.keys(prev).length; const isEqual = isSameLength && // @ts-ignore Object.entries(resolved).every(([key, value]) => prev[key] === value); return isEqual ? prev : resolved; }); }, [setExtensions], ); const updateAuthStates = useCallback( (newAuthStates: SetStateAction>) => { setAuthStates((prev) => { const resolved = typeof newAuthStates === "function" ? newAuthStates(prev) : newAuthStates; const isSameLength = Object.keys(resolved).length === Object.keys(prev).length; const isEqual = isSameLength && // @ts-ignore Object.entries(resolved).every(([key, value]) => prev[key] === value); return isEqual ? prev : resolved; }); }, [setAuthStates], ); return ( {memoizedChildren} {showDevConsole && } {bannerError && showDevConsole && ( setBannerError(null)} actions={getErrorActions(bannerError)} /> )} ); } export const defaultAiContextCategories = ["global"]; function entryPointsToFunctionCallHandler(actions: FrontendAction[]): FunctionCallHandler { return async ({ name, args }: { name: string; args: Record }) => { let actionsByFunctionName: Record> = {}; for (let action of actions) { actionsByFunctionName[action.name] = action; } const action = actionsByFunctionName[name]; let result: any = undefined; if (action) { await new Promise((resolve, reject) => { flushSync(async () => { try { result = await action.handler?.(args); resolve(); } catch (error) { reject(error); } }); }); await new Promise((resolve) => setTimeout(resolve, 20)); } return result; }; } function formatFeatureName(featureName: string): string { return featureName .replace(/_c$/, "") .split("_") .map((word) => word.charAt(0).toUpperCase() + word.slice(1).toLowerCase()) .join(" "); } function validateProps(props: AiProps): never | void { const cloudFeatures = Object.keys(props).filter((key) => key.endsWith("_c")); // Check if we have either a runtimeUrl or one of the API keys const hasApiKey = props.publicApiKey || props.publicLicenseKey; if (!props.runtimeUrl && !hasApiKey) { throw new ConfigurationError( "Missing required prop: 'runtimeUrl' or 'publicApiKey' or 'publicLicenseKey'", ); } if (cloudFeatures.length > 0 && !hasApiKey) { throw new MissingPublicApiKeyError( `Missing required prop: 'publicApiKey' or 'publicLicenseKey' to use cloud features: ${cloudFeatures .map(formatFeatureName) .join(", ")}`, ); } }