import { Configuration, OpenAIApi, CreateEmbeddingRequest, ConfigurationParameters, } from "openai"; import type { AxiosRequestConfig } from "axios"; import { getEnvironmentVariable, isNode } from "../util/env.js"; import { AzureOpenAIInput } from "../types/openai-types.js"; import fetchAdapter from "../util/axios-fetch-adapter.js"; import { chunkArray } from "../util/chunk.js"; import { Embeddings, EmbeddingsParams } from "./base.js"; import { getEndpoint, OpenAIEndpointConfig } from "../util/azure.js"; export interface OpenAIEmbeddingsParams extends EmbeddingsParams { /** Model name to use */ modelName: string; /** * Timeout to use when making requests to OpenAI. */ timeout?: number; /** * The maximum number of documents to embed in a single request. This is * limited by the OpenAI API to a maximum of 2048. */ batchSize?: number; /** * Whether to strip new lines from the input text. This is recommended by * OpenAI, but may not be suitable for all use cases. */ stripNewLines?: boolean; } export class OpenAIEmbeddings extends Embeddings implements OpenAIEmbeddingsParams, AzureOpenAIInput { modelName = "text-embedding-ada-002"; batchSize = 512; stripNewLines = true; timeout?: number; azureOpenAIApiVersion?: string; azureOpenAIApiKey?: string; azureOpenAIApiInstanceName?: string; azureOpenAIApiDeploymentName?: string; azureOpenAIBasePath?: string; private client: OpenAIApi; private clientConfig: ConfigurationParameters; constructor( fields?: Partial & Partial & { verbose?: boolean; openAIApiKey?: string; }, configuration?: ConfigurationParameters ) { super(fields ?? {}); const apiKey = fields?.openAIApiKey ?? getEnvironmentVariable("OPENAI_API_KEY"); const azureApiKey = fields?.azureOpenAIApiKey ?? getEnvironmentVariable("AZURE_OPENAI_API_KEY"); if (!azureApiKey && !apiKey) { throw new Error("OpenAI or Azure OpenAI API key not found"); } const azureApiInstanceName = fields?.azureOpenAIApiInstanceName ?? getEnvironmentVariable("AZURE_OPENAI_API_INSTANCE_NAME"); const azureApiDeploymentName = (fields?.azureOpenAIApiEmbeddingsDeploymentName || fields?.azureOpenAIApiDeploymentName) ?? (getEnvironmentVariable("AZURE_OPENAI_API_EMBEDDINGS_DEPLOYMENT_NAME") || getEnvironmentVariable("AZURE_OPENAI_API_DEPLOYMENT_NAME")); const azureApiVersion = fields?.azureOpenAIApiVersion ?? getEnvironmentVariable("AZURE_OPENAI_API_VERSION"); this.azureOpenAIBasePath = fields?.azureOpenAIBasePath ?? getEnvironmentVariable("AZURE_OPENAI_BASE_PATH"); this.modelName = fields?.modelName ?? this.modelName; this.batchSize = fields?.batchSize ?? (azureApiKey ? 1 : this.batchSize); this.stripNewLines = fields?.stripNewLines ?? this.stripNewLines; this.timeout = fields?.timeout; this.azureOpenAIApiVersion = azureApiVersion; this.azureOpenAIApiKey = azureApiKey; this.azureOpenAIApiInstanceName = azureApiInstanceName; this.azureOpenAIApiDeploymentName = azureApiDeploymentName; if (this.azureOpenAIApiKey) { if (!this.azureOpenAIApiInstanceName && !this.azureOpenAIBasePath) { throw new Error("Azure OpenAI API instance name not found"); } if (!this.azureOpenAIApiDeploymentName) { throw new Error("Azure OpenAI API deployment name not found"); } if (!this.azureOpenAIApiVersion) { throw new Error("Azure OpenAI API version not found"); } } this.clientConfig = { apiKey, ...configuration, }; } async embedDocuments(texts: string[]): Promise { const subPrompts = chunkArray( this.stripNewLines ? texts.map((t) => t.replace(/\n/g, " ")) : texts, this.batchSize ); const embeddings: number[][] = []; for (let i = 0; i < subPrompts.length; i += 1) { const input = subPrompts[i]; const { data } = await this.embeddingWithRetry({ model: this.modelName, input, }); for (let j = 0; j < input.length; j += 1) { embeddings.push(data.data[j].embedding); } } return embeddings; } async embedQuery(text: string): Promise { const { data } = await this.embeddingWithRetry({ model: this.modelName, input: this.stripNewLines ? text.replace(/\n/g, " ") : text, }); return data.data[0].embedding; } private async embeddingWithRetry(request: CreateEmbeddingRequest) { if (!this.client) { const openAIEndpointConfig: OpenAIEndpointConfig = { azureOpenAIApiDeploymentName: this.azureOpenAIApiDeploymentName, azureOpenAIApiInstanceName: this.azureOpenAIApiInstanceName, azureOpenAIApiKey: this.azureOpenAIApiKey, azureOpenAIBasePath: this.azureOpenAIBasePath, basePath: this.clientConfig.basePath, }; const endpoint = getEndpoint(openAIEndpointConfig); const clientConfig = new Configuration({ ...this.clientConfig, basePath: endpoint, baseOptions: { timeout: this.timeout, adapter: isNode() ? undefined : fetchAdapter, ...this.clientConfig.baseOptions, }, }); this.client = new OpenAIApi(clientConfig); } const axiosOptions: AxiosRequestConfig = {}; if (this.azureOpenAIApiKey) { axiosOptions.headers = { "api-key": this.azureOpenAIApiKey, ...axiosOptions.headers, }; axiosOptions.params = { "api-version": this.azureOpenAIApiVersion, ...axiosOptions.params, }; } return this.caller.call( this.client.createEmbedding.bind(this.client), request, axiosOptions ); } }