import { ModelProvider, TemplateType } from "./typings"; import { anthropicTemplateMessages, chatmlTemplateMessages, codeLlama70bTemplateMessages, deepseekTemplateMessages, gemmaTemplateMessage, llama2TemplateMessages, llavaTemplateMessages, neuralChatTemplateMessages, openchatTemplateMessages, phi2TemplateMessages, phindTemplateMessages, templateAlpacaMessages, xWinCoderTemplateMessages, zephyrTemplateMessages, } from "./templates/chat"; import { alpacaEditPrompt, claudeEditPrompt, codeLlama70bEditPrompt, deepseekEditPrompt, gemmaEditPrompt, gptEditPrompt, mistralEditPrompt, neuralChatEditPrompt, openchatEditPrompt, osModelsEditPrompt, phindEditPrompt, simplifiedEditPrompt, xWinCoderEditPrompt, zephyrEditPrompt, } from "./templates/edit"; const PROVIDER_HANDLES_TEMPLATING: ModelProvider[] = [ "lmstudio", "openai", "ollama", "together", "anthropic", "bedrock", ]; const PROVIDER_SUPPORTS_IMAGES: ModelProvider[] = [ "openai", "ollama", "google-palm", "free-trial", "anthropic", "bedrock", ]; function modelSupportsImages(provider: ModelProvider, model: string): boolean { if (!PROVIDER_SUPPORTS_IMAGES.includes(provider)) { return false; } if (model.includes("llava")) { return true; } if (model.includes("claude-3")) { return true; } if (["gpt-4-vision-preview"].includes(model)) { return true; } if ( model === "gemini-ultra" && (provider === "google-palm" || provider === "free-trial") ) { return true; } return false; } const PARALLEL_PROVIDERS: ModelProvider[] = [ "anthropic", "bedrock", "deepinfra", "gemini", "google-palm", "huggingface-inference-api", "huggingface-tgi", "mistral", "free-trial", "replicate", "together", ]; function llmCanGenerateInParallel( provider: ModelProvider, model: string, ): boolean { if (provider === "openai") { return model.includes("gpt"); } return PARALLEL_PROVIDERS.includes(provider); } function autodetectTemplateType(model: string): TemplateType | undefined { const lower = model.toLowerCase(); if (lower.includes("codellama") && lower.includes("70b")) { return "codellama-70b"; } if ( lower.includes("gpt") || lower.includes("chat-bison") || lower.includes("pplx") || lower.includes("gemini") ) { return undefined; } if (lower.includes("llava")) { return "llava"; } if (lower.includes("tinyllama")) { return "zephyr"; } if (lower.includes("xwin")) { return "xwin-coder"; } if (lower.includes("dolphin")) { return "chatml"; } if (lower.includes("gemma")) { return "gemma"; } if (lower.includes("phi2")) { return "phi2"; } if (lower.includes("phind")) { return "phind"; } if (lower.includes("llama")) { return "llama2"; } if (lower.includes("zephyr")) { return "zephyr"; } // Claude requests always sent through Messages API, so formatting not necessary if (lower.includes("claude")) { return "none"; } if (lower.includes("alpaca") || lower.includes("wizard")) { return "alpaca"; } if (lower.includes("mistral") || lower.includes("mixtral")) { return "llama2"; } if (lower.includes("deepseek")) { return "deepseek"; } if (lower.includes("ninja") || lower.includes("openchat")) { return "openchat"; } if (lower.includes("neural-chat")) { return "neural-chat"; } return "chatml"; } function autodetectTemplateFunction( model: string, provider: ModelProvider, explicitTemplate: TemplateType | undefined = undefined, ) { if ( explicitTemplate === undefined && PROVIDER_HANDLES_TEMPLATING.includes(provider) ) { return null; } const templateType = explicitTemplate ?? autodetectTemplateType(model); if (templateType) { const mapping: Record = { llama2: llama2TemplateMessages, alpaca: templateAlpacaMessages, phi2: phi2TemplateMessages, phind: phindTemplateMessages, zephyr: zephyrTemplateMessages, anthropic: anthropicTemplateMessages, chatml: chatmlTemplateMessages, deepseek: deepseekTemplateMessages, openchat: openchatTemplateMessages, "xwin-coder": xWinCoderTemplateMessages, "neural-chat": neuralChatTemplateMessages, llava: llavaTemplateMessages, "codellama-70b": codeLlama70bTemplateMessages, gemma: gemmaTemplateMessage, none: null, }; return mapping[templateType]; } return null; } const USES_OS_MODELS_EDIT_PROMPT: TemplateType[] = [ "alpaca", "chatml", // "codellama-70b", Doesn't respond well to this prompt "deepseek", "gemma", "llama2", "llava", "neural-chat", "openchat", "phi2", "phind", "xwin-coder", "zephyr", ]; function autodetectPromptTemplates( model: string, explicitTemplate: TemplateType | undefined = undefined, ) { const templateType = explicitTemplate ?? autodetectTemplateType(model); const templates: Record = {}; let editTemplate = null; if (templateType && USES_OS_MODELS_EDIT_PROMPT.includes(templateType)) { // This is overriding basically everything else // Will probably delete the rest later, but for now it's easy to revert editTemplate = osModelsEditPrompt; } else if (templateType === "phind") { editTemplate = phindEditPrompt; } else if (templateType === "phi2") { editTemplate = simplifiedEditPrompt; } else if (templateType === "zephyr") { editTemplate = zephyrEditPrompt; } else if (templateType === "llama2") { if (model.includes("mistral")) { editTemplate = mistralEditPrompt; } else { editTemplate = osModelsEditPrompt; } } else if (templateType === "alpaca") { editTemplate = alpacaEditPrompt; } else if (templateType === "deepseek") { editTemplate = deepseekEditPrompt; } else if (templateType === "openchat") { editTemplate = openchatEditPrompt; } else if (templateType === "xwin-coder") { editTemplate = xWinCoderEditPrompt; } else if (templateType === "neural-chat") { editTemplate = neuralChatEditPrompt; } else if (templateType === "codellama-70b") { editTemplate = codeLlama70bEditPrompt; } else if (templateType === "anthropic") { editTemplate = claudeEditPrompt; } else if (templateType === "gemma") { editTemplate = gemmaEditPrompt; } else if (templateType) { editTemplate = gptEditPrompt; } if (editTemplate !== null) { templates["edit"] = editTemplate; } return templates; } export { autodetectPromptTemplates, autodetectTemplateFunction, autodetectTemplateType, llmCanGenerateInParallel, modelSupportsImages, };