import { COPILOT_CLOUD_PUBLIC_API_KEY_HEADER } from "@copilotkit/shared"; import { useCopilotContext } from "@copilotkit/react-core"; import { useCallback } from "react"; import { AutosuggestionsBareFunction } from "../../types"; import { retry } from "../../lib/retry"; import { InsertionEditorState } from "../../types/base/autosuggestions-bare-function"; import { SuggestionsApiConfig } from "../../types/autosuggestions-config/suggestions-api-config"; import { Message, Role, TextMessage, convertGqlOutputToMessages, convertMessagesToGqlInput, filterAgentStateMessages, CopilotRequestType, } from "@copilotkit/runtime-client-gql"; /** * Returns a memoized function that sends a request to the specified API endpoint to get an autosuggestion for the user's input. * The function takes in the text before and after the cursor, and an abort signal. * It sends a POST request to the API endpoint with the messages array containing the system message, few shot messages, and user messages. * The function returns the suggestion from the API response. * * @param textareaPurpose - The purpose of the textarea. This is included in the system message. * @param apiEndpoint - The API endpoint to send the autosuggestion request to. * @param makeSystemMessage - A function that takes in a context string and returns a system message to include in the autosuggestion request. * @param fewShotMessages - An array of few shot messages to include in the autosuggestion request. * @param contextCategories - The categories of context strings we want to include. By default, we include the (default) "global" context category. * @returns A memoized function that sends a request to the specified API endpoint to get an autosuggestion for the user's input. */ export function useMakeStandardAutosuggestionFunction( textareaPurpose: string, contextCategories: string[], apiConfig: SuggestionsApiConfig, ): AutosuggestionsBareFunction { const runtimeClient = { generateCopilotResponse: (...args: any[]) => {} }; const { getContextString, copilotApiConfig } = useCopilotContext(); const { chatApiEndpoint: url, publicApiKey, credentials, properties, } = copilotApiConfig; const headers = { ...copilotApiConfig.headers, ...(publicApiKey ? { [COPILOT_CLOUD_PUBLIC_API_KEY_HEADER]: publicApiKey } : {}), }; const { maxTokens, stop, temperature = 0 } = apiConfig; return useCallback( async (editorState: InsertionEditorState, abortSignal: AbortSignal) => { const res = await retry(async () => { // @ts-expect-error -- Passing null is forbidden, but we're filtering it later const messages: Message[] = [ new TextMessage({ role: Role.System, content: apiConfig.makeSystemPrompt( textareaPurpose, getContextString([], contextCategories), ), }), ...apiConfig.fewShotMessages, editorState.textAfterCursor != "" ? new TextMessage({ role: Role.User, content: editorState.textAfterCursor, }) : null, new TextMessage({ role: Role.User, content: `${editorState.textAfterCursor}`, }), new TextMessage({ role: Role.User, content: `${editorState.textBeforeCursor}`, }), ].filter(Boolean); // const response = await runtimeClient // .generateCopilotResponse({ // data: { // frontend: { // actions: [], // url: window.location.href, // }, // messages: convertMessagesToGqlInput(filterAgentStateMessages(messages)), // metadata: { // requestType: CopilotRequestType.TextareaCompletion, // }, // forwardedParameters: { // maxTokens, // stop, // temperature, // }, // }, // properties, // signal: abortSignal, // }) // .toPromise(); const response: any = {}; let result = ""; for (const message of convertGqlOutputToMessages( response.data?.generateCopilotResponse?.messages ?? [], )) { if (abortSignal.aborted) { break; } if (message.isTextMessage()) { result += message.content; } } return result; }); return res; }, [apiConfig, getContextString, contextCategories, textareaPurpose], ); }