/**
* This component will typically wrap your entire application (or a sub-tree of your application where you want to have an AI assistant). It provides the AI context to all other components and hooks.
*
* ## Example
*
* You can find more information about self-hosting VN SDK [here](/guides/self-hosting).
*
* ```tsx
* import { AiProvider } from "@vn-sdk/react-core";
*
*
* // ... your app ...
*
* ```
*/
import { useCallback, useEffect, useMemo, useRef, useState, SetStateAction } from "react";
import {
AiContext,
AiApiConfig,
ChatComponentsCache,
AgentSession,
AuthState,
} from "../../context/ai-context";
import useTree from "../../hooks/use-tree";
import { AiChatSuggestionConfiguration, DocumentPointer } from "../../types";
import { flushSync } from "react-dom";
import {
AI_CLOUD_CHAT_URL,
AiCloudConfig,
FunctionCallHandler,
AI_CLOUD_PUBLIC_API_KEY_HEADER,
randomUUID,
ConfigurationError,
MissingPublicApiKeyError,
AiSDKError,
AiErrorEvent,
AiErrorHandler,
} from "@vn-sdk/shared";
import { FrontendAction } from "../../types/frontend-action";
import useFlatCategoryStore from "../../hooks/use-flat-category-store";
import { AiProps } from "./ai-props";
import { AiAgentStateRender } from "../../types/ai-agent-action";
import { AiAgentState } from "../../types/ai-agent-state";
import { AiMessages, MessagesTapProvider } from "./ai-messages";
import { ToastProvider } from "../toast/toast-provider";
import { getErrorActions, UsageBanner } from "../usage-banner";
import { useAiRuntimeClient } from "../../hooks/use-ai-runtime-client";
import { shouldShowDevConsole } from "../../utils";
import { AiErrorBoundary } from "../error-boundary/error-boundary";
import { Agent, ExtensionsInput } from "@vn-sdk/runtime-client-gql";
import {
LangGraphInterruptAction,
LangGraphInterruptActionSetterArgs,
} from "../../types/interrupt-action";
import { ConsoleTrigger } from "../dev-console/console-trigger";
export function AiProvider({ children, ...props }: AiProps) {
const enabled = shouldShowDevConsole(props.showDevConsole);
// Use API key if provided, otherwise use the license key
const publicApiKey = props.publicApiKey || props.publicLicenseKey;
return (
{children}
);
}
export function AiProviderInternal(cpkProps: AiProps) {
const { children, ...props } = cpkProps;
/**
* This will throw an error if the props are invalid.
*/
validateProps(cpkProps);
// Use license key as API key if provided, otherwise use the API key
const publicApiKey = props.publicLicenseKey || props.publicApiKey;
const chatApiEndpoint = props.runtimeUrl || AI_CLOUD_CHAT_URL;
const [actions, setActions] = useState>>({});
const [aiAgentStateRenders, setAiAgentStateRenders] = useState<
Record>
>({});
const chatComponentsCache = useRef({
actions: {},
aiAgentStateRenders: {},
});
const { addElement, removeElement, printTree, getAllElements } = useTree();
const [isLoading, setIsLoading] = useState(false);
const [chatInstructions, setChatInstructions] = useState("");
const [authStates, setAuthStates] = useState>({});
const [extensions, setExtensions] = useState({});
const [additionalInstructions, setAdditionalInstructions] = useState([]);
const {
addElement: addDocument,
removeElement: removeDocument,
allElements: allDocuments,
} = useFlatCategoryStore();
// Compute all the functions and properties that we need to pass
const setAction = useCallback((id: string, action: FrontendAction) => {
setActions((prevPoints) => {
return {
...prevPoints,
[id]: action,
};
});
}, []);
const removeAction = useCallback((id: string) => {
setActions((prevPoints) => {
const newPoints = { ...prevPoints };
delete newPoints[id];
return newPoints;
});
}, []);
const setAiAgentStateRender = useCallback((id: string, stateRender: AiAgentStateRender) => {
setAiAgentStateRenders((prevPoints) => {
return {
...prevPoints,
[id]: stateRender,
};
});
}, []);
const removeAiAgentStateRender = useCallback((id: string) => {
setAiAgentStateRenders((prevPoints) => {
const newPoints = { ...prevPoints };
delete newPoints[id];
return newPoints;
});
}, []);
const getContextString = useCallback(
(documents: DocumentPointer[], categories: string[]) => {
const documentsString = documents
.map((document) => {
return `${document.name} (${document.sourceApplication}):\n${document.getContents()}`;
})
.join("\n\n");
const nonDocumentStrings = printTree(categories);
return `${documentsString}\n\n${nonDocumentStrings}`;
},
[printTree],
);
const addContext = useCallback(
(
context: string,
parentId?: string,
categories: string[] = defaultAiContextCategories,
) => {
return addElement(context, categories, parentId);
},
[addElement],
);
const removeContext = useCallback(
(id: string) => {
removeElement(id);
},
[removeElement],
);
const getAllContext = useCallback(() => {
return getAllElements();
}, [getAllElements]);
const getFunctionCallHandler = useCallback(
(customEntryPoints?: Record>) => {
return entryPointsToFunctionCallHandler(Object.values(customEntryPoints || actions));
},
[actions],
);
const getDocumentsContext = useCallback(
(categories: string[]) => {
return allDocuments(categories);
},
[allDocuments],
);
const addDocumentContext = useCallback(
(documentPointer: DocumentPointer, categories: string[] = defaultAiContextCategories) => {
return addDocument(documentPointer, categories);
},
[addDocument],
);
const removeDocumentContext = useCallback(
(documentId: string) => {
removeDocument(documentId);
},
[removeDocument],
);
// get the appropriate AiApiConfig from the props
const aiApiConfig: AiApiConfig = useMemo(() => {
let cloud: AiCloudConfig | undefined = undefined;
if (publicApiKey) {
cloud = {
guardrails: {
input: {
restrictToTopic: {
enabled: Boolean(props.guardrails_c),
validTopics: props.guardrails_c?.validTopics || [],
invalidTopics: props.guardrails_c?.invalidTopics || [],
},
},
},
};
}
return {
publicApiKey: publicApiKey,
...(cloud ? { cloud } : {}),
chatApiEndpoint: chatApiEndpoint,
headers: props.headers || {},
properties: props.properties || {},
transcribeAudioUrl: props.transcribeAudioUrl,
textToSpeechUrl: props.textToSpeechUrl,
credentials: props.credentials,
};
}, [
publicApiKey,
props.headers,
props.properties,
props.transcribeAudioUrl,
props.textToSpeechUrl,
props.credentials,
props.cloudRestrictToTopic,
props.guardrails_c,
]);
const headers = useMemo(() => {
const authHeaders = Object.values(authStates || {}).reduce((acc, state) => {
if (state.status === "authenticated" && state.authHeaders) {
return {
...acc,
...Object.entries(state.authHeaders).reduce(
(headers, [key, value]) => ({
...headers,
[key.startsWith("X-Custom-") ? key : `X-Custom-${key}`]: value,
}),
{},
),
};
}
return acc;
}, {});
return {
...(aiApiConfig.headers || {}),
...(aiApiConfig.publicApiKey
? { [AI_CLOUD_PUBLIC_API_KEY_HEADER]: aiApiConfig.publicApiKey }
: {}),
...authHeaders,
};
}, [aiApiConfig.headers, aiApiConfig.publicApiKey, authStates]);
const [internalErrorHandlers, _setInternalErrorHandler] = useState<
Record
>({});
const setInternalErrorHandler = useCallback((handler: Record) => {
_setInternalErrorHandler((prev: Record) => ({
...prev,
...handler,
}));
}, []);
const removeInternalErrorHandler = useCallback((key: string) => {
_setInternalErrorHandler((prev) => {
const { [key]: _removed, ...rest } = prev;
return rest;
});
}, []);
// Keep latest values in refs
const onErrorRef = useRef(props.onError);
useEffect(() => {
onErrorRef.current = props.onError;
}, [props.onError]);
const internalHandlersRef = useRef>({});
useEffect(() => {
internalHandlersRef.current = internalErrorHandlers;
}, [internalErrorHandlers]);
const handleErrors = useCallback(
async (error: AiErrorEvent) => {
if (aiApiConfig.publicApiKey && onErrorRef.current) {
try {
await onErrorRef.current(error);
} catch (e) {
console.error("Error in public onError handler:", e);
}
}
const handlers = Object.values(internalHandlersRef.current);
await Promise.all(
handlers.map((h) =>
Promise.resolve(h(error)).catch((e) =>
console.error("Error in internal error handler:", e),
),
),
);
},
[aiApiConfig.publicApiKey],
);
const runtimeClient = useAiRuntimeClient({
url: aiApiConfig.chatApiEndpoint,
publicApiKey: publicApiKey,
headers,
credentials: aiApiConfig.credentials,
showDevConsole: shouldShowDevConsole(props.showDevConsole),
onError: handleErrors,
});
const [chatSuggestionConfiguration, setChatSuggestionConfiguration] = useState<{
[key: string]: AiChatSuggestionConfiguration;
}>({});
const addChatSuggestionConfiguration = useCallback(
(id: string, suggestion: AiChatSuggestionConfiguration) => {
setChatSuggestionConfiguration((prev) => ({ ...prev, [id]: suggestion }));
},
[setChatSuggestionConfiguration],
);
const removeChatSuggestionConfiguration = useCallback(
(id: string) => {
setChatSuggestionConfiguration((prev) => {
const { [id]: _, ...rest } = prev;
return rest;
});
},
[setChatSuggestionConfiguration],
);
const [availableAgents, setAvailableAgents] = useState([]);
const [aiAgentStates, setAiAgentStates] = useState>({});
const aiAgentStatesRef = useRef>({});
const setAiAgentStatesWithRef = useCallback(
(
value:
| Record
| ((prev: Record) => Record),
) => {
const newValue = typeof value === "function" ? value(aiAgentStatesRef.current) : value;
aiAgentStatesRef.current = newValue;
setAiAgentStates((prev) => {
return newValue;
});
},
[],
);
const hasLoadedAgents = useRef(false);
useEffect(() => {
if (hasLoadedAgents.current) return;
const fetchData = async () => {
const result = await runtimeClient.availableAgents();
if (result.data?.availableAgents) {
setAvailableAgents(result.data.availableAgents.agents);
}
hasLoadedAgents.current = true;
};
void fetchData();
}, []);
let initialAgentSession: AgentSession | null = null;
if (props.agent) {
initialAgentSession = {
agentName: props.agent,
};
}
const [agentSession, setAgentSession] = useState(initialAgentSession);
// Update agentSession when props.agent changes
useEffect(() => {
if (props.agent) {
setAgentSession({
agentName: props.agent,
});
} else {
setAgentSession(null);
}
}, [props.agent]);
const [internalThreadId, setInternalThreadId] = useState(props.threadId || randomUUID());
const setThreadId = useCallback(
(value: SetStateAction) => {
if (props.threadId) {
throw new Error("Cannot call setThreadId() when threadId is provided via props.");
}
setInternalThreadId(value);
},
[props.threadId],
);
// update the internal threadId if the props.threadId changes
useEffect(() => {
if (props.threadId !== undefined) {
setInternalThreadId(props.threadId);
}
}, [props.threadId]);
const [runId, setRunId] = useState(null);
const chatAbortControllerRef = useRef(null);
const showDevConsole = shouldShowDevConsole(props.showDevConsole);
const [langGraphInterruptActions, _setLangGraphInterruptAction] = useState<
Record
>({});
const setLangGraphInterruptAction = useCallback(
(threadId: string, action: LangGraphInterruptActionSetterArgs) => {
_setLangGraphInterruptAction((prev) => {
if (action == null)
return {
...prev,
[threadId]: null,
};
let event = prev[threadId]?.event;
if (action.event) {
// @ts-ignore
event = { ...(prev[threadId]?.event || {}), ...action.event };
}
return {
...prev,
[threadId]: { ...(prev[threadId] ?? {}), ...action, event } as LangGraphInterruptAction,
};
});
},
[],
);
const removeLangGraphInterruptAction = useCallback((threadId: string): void => {
setLangGraphInterruptAction(threadId, null);
}, []);
const memoizedChildren = useMemo(() => children, [children]);
const [bannerError, setBannerError] = useState(null);
const agentLock = useMemo(() => props.agent ?? null, [props.agent]);
const forwardedParameters = useMemo(
() => props.forwardedParameters ?? {},
[props.forwardedParameters],
);
const updateExtensions = useCallback(
(newExtensions: SetStateAction) => {
setExtensions((prev: ExtensionsInput) => {
const resolved = typeof newExtensions === "function" ? newExtensions(prev) : newExtensions;
const isSameLength = Object.keys(resolved).length === Object.keys(prev).length;
const isEqual =
isSameLength &&
// @ts-ignore
Object.entries(resolved).every(([key, value]) => prev[key] === value);
return isEqual ? prev : resolved;
});
},
[setExtensions],
);
const updateAuthStates = useCallback(
(newAuthStates: SetStateAction>) => {
setAuthStates((prev) => {
const resolved = typeof newAuthStates === "function" ? newAuthStates(prev) : newAuthStates;
const isSameLength = Object.keys(resolved).length === Object.keys(prev).length;
const isEqual =
isSameLength &&
// @ts-ignore
Object.entries(resolved).every(([key, value]) => prev[key] === value);
return isEqual ? prev : resolved;
});
},
[setAuthStates],
);
return (
{memoizedChildren}
{showDevConsole && }
{bannerError && showDevConsole && (
setBannerError(null)}
actions={getErrorActions(bannerError)}
/>
)}
);
}
export const defaultAiContextCategories = ["global"];
function entryPointsToFunctionCallHandler(actions: FrontendAction[]): FunctionCallHandler {
return async ({ name, args }: { name: string; args: Record }) => {
let actionsByFunctionName: Record> = {};
for (let action of actions) {
actionsByFunctionName[action.name] = action;
}
const action = actionsByFunctionName[name];
let result: any = undefined;
if (action) {
await new Promise((resolve, reject) => {
flushSync(async () => {
try {
result = await action.handler?.(args);
resolve();
} catch (error) {
reject(error);
}
});
});
await new Promise((resolve) => setTimeout(resolve, 20));
}
return result;
};
}
function formatFeatureName(featureName: string): string {
return featureName
.replace(/_c$/, "")
.split("_")
.map((word) => word.charAt(0).toUpperCase() + word.slice(1).toLowerCase())
.join(" ");
}
function validateProps(props: AiProps): never | void {
const cloudFeatures = Object.keys(props).filter((key) => key.endsWith("_c"));
// Check if we have either a runtimeUrl or one of the API keys
const hasApiKey = props.publicApiKey || props.publicLicenseKey;
if (!props.runtimeUrl && !hasApiKey) {
throw new ConfigurationError(
"Missing required prop: 'runtimeUrl' or 'publicApiKey' or 'publicLicenseKey'",
);
}
if (cloudFeatures.length > 0 && !hasApiKey) {
throw new MissingPublicApiKeyError(
`Missing required prop: 'publicApiKey' or 'publicLicenseKey' to use cloud features: ${cloudFeatures
.map(formatFeatureName)
.join(", ")}`,
);
}
}