import { z } from "zod" import { reasoningEffortsSchema, modelInfoSchema } from "./model" import { codebaseIndexProviderSchema } from "./codebase-index" /** * ProviderName */ export const providerNames = [ "anthropic", "claude-code", "glama", "openrouter", "bedrock", "vertex", "openai", "ollama", "vscode-lm", "lmstudio", "gemini", "gemini-cli", "openai-native", "mistral", "moonshot", "deepseek", "unbound", "requesty", "human-relay", "fake-ai", "xai", "groq", "chutes", "litellm", "huggingface", "kilocode", "fireworks", "cerebras", ] as const export const providerNamesSchema = z.enum(providerNames) export type ProviderName = z.infer /** * ProviderSettingsEntry */ export const providerSettingsEntrySchema = z.object({ id: z.string(), name: z.string(), apiProvider: providerNamesSchema.optional(), }) export type ProviderSettingsEntry = z.infer /** * ProviderSettings */ /** * Default value for consecutive mistake limit */ export const DEFAULT_CONSECUTIVE_MISTAKE_LIMIT = 3 const baseProviderSettingsSchema = z.object({ includeMaxTokens: z.boolean().optional(), diffEnabled: z.boolean().optional(), todoListEnabled: z.boolean().optional(), fuzzyMatchThreshold: z.number().optional(), modelTemperature: z.number().nullish(), rateLimitSeconds: z.number().optional(), consecutiveMistakeLimit: z.number().min(0).optional(), // Model reasoning. enableReasoningEffort: z.boolean().optional(), reasoningEffort: reasoningEffortsSchema.optional(), modelMaxTokens: z.number().optional(), modelMaxThinkingTokens: z.number().optional(), }) // Several of the providers share common model config properties. const apiModelIdProviderModelSchema = baseProviderSettingsSchema.extend({ apiModelId: z.string().optional(), }) const anthropicSchema = apiModelIdProviderModelSchema.extend({ apiKey: z.string().optional(), anthropicBaseUrl: z.string().optional(), anthropicUseAuthToken: z.boolean().optional(), }) const claudeCodeSchema = apiModelIdProviderModelSchema.extend({ claudeCodePath: z.string().optional(), claudeCodeMaxOutputTokens: z.number().int().min(1).max(200000).optional(), }) const glamaSchema = baseProviderSettingsSchema.extend({ glamaModelId: z.string().optional(), glamaApiKey: z.string().optional(), }) const openRouterSchema = baseProviderSettingsSchema.extend({ openRouterApiKey: z.string().optional(), openRouterModelId: z.string().optional(), openRouterBaseUrl: z.string().optional(), openRouterSpecificProvider: z.string().optional(), openRouterUseMiddleOutTransform: z.boolean().optional(), }) const bedrockSchema = apiModelIdProviderModelSchema.extend({ awsAccessKey: z.string().optional(), awsSecretKey: z.string().optional(), awsSessionToken: z.string().optional(), awsRegion: z.string().optional(), awsUseCrossRegionInference: z.boolean().optional(), awsUsePromptCache: z.boolean().optional(), awsProfile: z.string().optional(), awsUseProfile: z.boolean().optional(), awsApiKey: z.string().optional(), awsUseApiKey: z.boolean().optional(), awsCustomArn: z.string().optional(), awsModelContextWindow: z.number().optional(), awsBedrockEndpointEnabled: z.boolean().optional(), awsBedrockEndpoint: z.string().optional(), }) const vertexSchema = apiModelIdProviderModelSchema.extend({ vertexKeyFile: z.string().optional(), vertexJsonCredentials: z.string().optional(), vertexProjectId: z.string().optional(), vertexRegion: z.string().optional(), }) const openAiSchema = baseProviderSettingsSchema.extend({ openAiBaseUrl: z.string().optional(), openAiApiKey: z.string().optional(), openAiLegacyFormat: z.boolean().optional(), openAiR1FormatEnabled: z.boolean().optional(), openAiModelId: z.string().optional(), openAiCustomModelInfo: modelInfoSchema.nullish(), openAiUseAzure: z.boolean().optional(), azureApiVersion: z.string().optional(), openAiStreamingEnabled: z.boolean().optional(), openAiHostHeader: z.string().optional(), // Keep temporarily for backward compatibility during migration. openAiHeaders: z.record(z.string(), z.string()).optional(), }) const ollamaSchema = baseProviderSettingsSchema.extend({ ollamaModelId: z.string().optional(), ollamaBaseUrl: z.string().optional(), }) const vsCodeLmSchema = baseProviderSettingsSchema.extend({ vsCodeLmModelSelector: z .object({ vendor: z.string().optional(), family: z.string().optional(), version: z.string().optional(), id: z.string().optional(), }) .optional(), }) const lmStudioSchema = baseProviderSettingsSchema.extend({ lmStudioModelId: z.string().optional(), lmStudioBaseUrl: z.string().optional(), lmStudioDraftModelId: z.string().optional(), lmStudioSpeculativeDecodingEnabled: z.boolean().optional(), }) const geminiSchema = apiModelIdProviderModelSchema.extend({ geminiApiKey: z.string().optional(), googleGeminiBaseUrl: z.string().optional(), enableUrlContext: z.boolean().optional(), enableGrounding: z.boolean().optional(), }) const geminiCliSchema = apiModelIdProviderModelSchema.extend({ geminiCliOAuthPath: z.string().optional(), geminiCliProjectId: z.string().optional(), }) const openAiNativeSchema = apiModelIdProviderModelSchema.extend({ openAiNativeApiKey: z.string().optional(), openAiNativeBaseUrl: z.string().optional(), }) const mistralSchema = apiModelIdProviderModelSchema.extend({ mistralApiKey: z.string().optional(), mistralCodestralUrl: z.string().optional(), }) const deepSeekSchema = apiModelIdProviderModelSchema.extend({ deepSeekBaseUrl: z.string().optional(), deepSeekApiKey: z.string().optional(), }) const moonshotSchema = apiModelIdProviderModelSchema.extend({ moonshotBaseUrl: z .union([z.literal("https://api.moonshot.ai/v1"), z.literal("https://api.moonshot.cn/v1")]) .optional(), moonshotApiKey: z.string().optional(), }) const unboundSchema = baseProviderSettingsSchema.extend({ unboundApiKey: z.string().optional(), unboundModelId: z.string().optional(), }) const requestySchema = baseProviderSettingsSchema.extend({ requestyApiKey: z.string().optional(), requestyModelId: z.string().optional(), }) const humanRelaySchema = baseProviderSettingsSchema const fakeAiSchema = baseProviderSettingsSchema.extend({ fakeAi: z.unknown().optional(), }) const xaiSchema = apiModelIdProviderModelSchema.extend({ xaiApiKey: z.string().optional(), }) const groqSchema = apiModelIdProviderModelSchema.extend({ groqApiKey: z.string().optional(), }) const huggingFaceSchema = baseProviderSettingsSchema.extend({ huggingFaceApiKey: z.string().optional(), huggingFaceModelId: z.string().optional(), huggingFaceInferenceProvider: z.string().optional(), }) const chutesSchema = apiModelIdProviderModelSchema.extend({ chutesApiKey: z.string().optional(), }) const litellmSchema = baseProviderSettingsSchema.extend({ litellmBaseUrl: z.string().optional(), litellmApiKey: z.string().optional(), litellmModelId: z.string().optional(), litellmUsePromptCache: z.boolean().optional(), }) const kilocodeSchema = baseProviderSettingsSchema.extend({ kilocodeToken: z.string().optional(), kilocodeModel: z.string().optional(), }) const fireworksSchema = apiModelIdProviderModelSchema.extend({ fireworksApiKey: z.string().optional(), }) const cerebrasSchema = apiModelIdProviderModelSchema.extend({ cerebrasApiKey: z.string().optional(), cerebrasModelId: z.string().optional(), }) const defaultSchema = z.object({ apiProvider: z.undefined(), }) export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProvider", [ anthropicSchema.merge(z.object({ apiProvider: z.literal("anthropic") })), claudeCodeSchema.merge(z.object({ apiProvider: z.literal("claude-code") })), glamaSchema.merge(z.object({ apiProvider: z.literal("glama") })), openRouterSchema.merge(z.object({ apiProvider: z.literal("openrouter") })), bedrockSchema.merge(z.object({ apiProvider: z.literal("bedrock") })), vertexSchema.merge(z.object({ apiProvider: z.literal("vertex") })), openAiSchema.merge(z.object({ apiProvider: z.literal("openai") })), ollamaSchema.merge(z.object({ apiProvider: z.literal("ollama") })), vsCodeLmSchema.merge(z.object({ apiProvider: z.literal("vscode-lm") })), lmStudioSchema.merge(z.object({ apiProvider: z.literal("lmstudio") })), geminiSchema.merge(z.object({ apiProvider: z.literal("gemini") })), geminiCliSchema.merge(z.object({ apiProvider: z.literal("gemini-cli") })), openAiNativeSchema.merge(z.object({ apiProvider: z.literal("openai-native") })), mistralSchema.merge(z.object({ apiProvider: z.literal("mistral") })), deepSeekSchema.merge(z.object({ apiProvider: z.literal("deepseek") })), moonshotSchema.merge(z.object({ apiProvider: z.literal("moonshot") })), unboundSchema.merge(z.object({ apiProvider: z.literal("unbound") })), requestySchema.merge(z.object({ apiProvider: z.literal("requesty") })), humanRelaySchema.merge(z.object({ apiProvider: z.literal("human-relay") })), fakeAiSchema.merge(z.object({ apiProvider: z.literal("fake-ai") })), xaiSchema.merge(z.object({ apiProvider: z.literal("xai") })), groqSchema.merge(z.object({ apiProvider: z.literal("groq") })), huggingFaceSchema.merge(z.object({ apiProvider: z.literal("huggingface") })), chutesSchema.merge(z.object({ apiProvider: z.literal("chutes") })), litellmSchema.merge(z.object({ apiProvider: z.literal("litellm") })), kilocodeSchema.merge(z.object({ apiProvider: z.literal("kilocode") })), fireworksSchema.merge(z.object({ apiProvider: z.literal("fireworks") })), cerebrasSchema.merge(z.object({ apiProvider: z.literal("cerebras") })), defaultSchema, ]) export const providerSettingsSchema = z.object({ apiProvider: providerNamesSchema.optional(), ...anthropicSchema.shape, ...claudeCodeSchema.shape, ...glamaSchema.shape, ...openRouterSchema.shape, ...bedrockSchema.shape, ...vertexSchema.shape, ...openAiSchema.shape, ...ollamaSchema.shape, ...vsCodeLmSchema.shape, ...lmStudioSchema.shape, ...geminiSchema.shape, ...geminiCliSchema.shape, ...openAiNativeSchema.shape, ...mistralSchema.shape, ...deepSeekSchema.shape, ...moonshotSchema.shape, ...unboundSchema.shape, ...requestySchema.shape, ...humanRelaySchema.shape, ...fakeAiSchema.shape, ...xaiSchema.shape, ...groqSchema.shape, ...huggingFaceSchema.shape, ...chutesSchema.shape, ...litellmSchema.shape, ...kilocodeSchema.shape, ...fireworksSchema.shape, ...cerebrasSchema.shape, ...codebaseIndexProviderSchema.shape, }) export type ProviderSettings = z.infer export const PROVIDER_SETTINGS_KEYS = providerSettingsSchema.keyof().options export const MODEL_ID_KEYS: Partial[] = [ "apiModelId", "glamaModelId", "openRouterModelId", "openAiModelId", "ollamaModelId", "lmStudioModelId", "lmStudioDraftModelId", "unboundModelId", "requestyModelId", "litellmModelId", "huggingFaceModelId", "kilocodeModel", "cerebrasModelId", ] export const getModelId = (settings: ProviderSettings): string | undefined => { const modelIdKey = MODEL_ID_KEYS.find((key) => settings[key]) return modelIdKey ? (settings[modelIdKey] as string) : undefined } // Providers that use Anthropic-style API protocol export const ANTHROPIC_STYLE_PROVIDERS: ProviderName[] = ["anthropic", "claude-code", "bedrock"] // Helper function to determine API protocol for a provider and model export const getApiProtocol = (provider: ProviderName | undefined, modelId?: string): "anthropic" | "openai" => { // First check if the provider is an Anthropic-style provider if (provider && ANTHROPIC_STYLE_PROVIDERS.includes(provider)) { return "anthropic" } // For vertex provider, check if the model ID contains "claude" (case-insensitive) if (provider && provider === "vertex" && modelId && modelId.toLowerCase().includes("claude")) { return "anthropic" } // Default to OpenAI protocol return "openai" }