continue/core/llm/autodetect.ts

330 lines
7.2 KiB
TypeScript

import { ModelProvider, TemplateType } from "../index.js";
import {
anthropicTemplateMessages,
chatmlTemplateMessages,
codeLlama70bTemplateMessages,
deepseekTemplateMessages,
gemmaTemplateMessage,
llama2TemplateMessages,
llama3TemplateMessages,
llavaTemplateMessages,
neuralChatTemplateMessages,
openchatTemplateMessages,
phi2TemplateMessages,
phindTemplateMessages,
templateAlpacaMessages,
xWinCoderTemplateMessages,
zephyrTemplateMessages,
} from "./templates/chat.js";
import {
alpacaEditPrompt,
claudeEditPrompt,
codeLlama70bEditPrompt,
deepseekEditPrompt,
gemmaEditPrompt,
gptEditPrompt,
llama3EditPrompt,
mistralEditPrompt,
neuralChatEditPrompt,
openchatEditPrompt,
osModelsEditPrompt,
phindEditPrompt,
simplifiedEditPrompt,
xWinCoderEditPrompt,
zephyrEditPrompt,
} from "./templates/edit.js";
const PROVIDER_HANDLES_TEMPLATING: ModelProvider[] = [
"lmstudio",
"openai",
"ollama",
"together",
"msty",
"anthropic",
"bedrock",
"continue-proxy",
"mistral",
];
const PROVIDER_SUPPORTS_IMAGES: ModelProvider[] = [
"openai",
"ollama",
"gemini",
"free-trial",
"msty",
"anthropic",
"bedrock",
"continue-proxy",
];
const MODEL_SUPPORTS_IMAGES: string[] = [
"llava",
"gpt-4-turbo",
"gpt-4o",
"gpt-4o-mini",
"gpt-4-vision",
"claude-3",
"gemini-ultra",
"gemini-1.5-pro",
"gemini-1.5-flash",
"sonnet",
"opus",
"haiku",
];
function modelSupportsImages(
provider: ModelProvider,
model: string,
title: string | undefined,
): boolean {
if (!PROVIDER_SUPPORTS_IMAGES.includes(provider)) {
return false;
}
const lower = model.toLowerCase();
if (
MODEL_SUPPORTS_IMAGES.some(
(modelName) => lower.includes(modelName) || title?.includes(modelName),
)
) {
return true;
}
return false;
}
const PARALLEL_PROVIDERS: ModelProvider[] = [
"anthropic",
"bedrock",
"deepinfra",
"gemini",
"huggingface-inference-api",
"huggingface-tgi",
"mistral",
"free-trial",
"replicate",
"together",
];
function llmCanGenerateInParallel(
provider: ModelProvider,
model: string,
): boolean {
if (provider === "openai") {
return model.includes("gpt");
}
return PARALLEL_PROVIDERS.includes(provider);
}
function autodetectTemplateType(model: string): TemplateType | undefined {
const lower = model.toLowerCase();
if (lower.includes("codellama") && lower.includes("70b")) {
return "codellama-70b";
}
if (
lower.includes("gpt") ||
lower.includes("command") ||
lower.includes("chat-bison") ||
lower.includes("pplx") ||
lower.includes("gemini")
) {
return undefined;
}
if (lower.includes("llama3")) {
return "llama3";
}
if (lower.includes("llava")) {
return "llava";
}
if (lower.includes("tinyllama")) {
return "zephyr";
}
if (lower.includes("xwin")) {
return "xwin-coder";
}
if (lower.includes("dolphin")) {
return "chatml";
}
if (lower.includes("gemma")) {
return "gemma";
}
if (lower.includes("phi2")) {
return "phi2";
}
if (lower.includes("phind")) {
return "phind";
}
if (lower.includes("llama")) {
return "llama2";
}
if (lower.includes("zephyr")) {
return "zephyr";
}
// Claude requests always sent through Messages API, so formatting not necessary
if (lower.includes("claude")) {
return "none";
}
if (lower.includes("codestral")) {
return "none";
}
if (lower.includes("alpaca") || lower.includes("wizard")) {
return "alpaca";
}
if (lower.includes("mistral") || lower.includes("mixtral")) {
return "llama2";
}
if (lower.includes("deepseek")) {
return "deepseek";
}
if (lower.includes("ninja") || lower.includes("openchat")) {
return "openchat";
}
if (lower.includes("neural-chat")) {
return "neural-chat";
}
return "chatml";
}
function autodetectTemplateFunction(
model: string,
provider: ModelProvider,
explicitTemplate: TemplateType | undefined = undefined,
) {
if (
explicitTemplate === undefined &&
PROVIDER_HANDLES_TEMPLATING.includes(provider)
) {
return null;
}
const templateType = explicitTemplate ?? autodetectTemplateType(model);
if (templateType) {
const mapping: Record<TemplateType, any> = {
llama2: llama2TemplateMessages,
alpaca: templateAlpacaMessages,
phi2: phi2TemplateMessages,
phind: phindTemplateMessages,
zephyr: zephyrTemplateMessages,
anthropic: anthropicTemplateMessages,
chatml: chatmlTemplateMessages,
deepseek: deepseekTemplateMessages,
openchat: openchatTemplateMessages,
"xwin-coder": xWinCoderTemplateMessages,
"neural-chat": neuralChatTemplateMessages,
llava: llavaTemplateMessages,
"codellama-70b": codeLlama70bTemplateMessages,
gemma: gemmaTemplateMessage,
llama3: llama3TemplateMessages,
none: null,
};
return mapping[templateType];
}
return null;
}
const USES_OS_MODELS_EDIT_PROMPT: TemplateType[] = [
"alpaca",
"chatml",
// "codellama-70b", Doesn't respond well to this prompt
"deepseek",
"gemma",
"llama2",
"llava",
"neural-chat",
"openchat",
"phi2",
"phind",
"xwin-coder",
"zephyr",
"llama3",
];
function autodetectPromptTemplates(
model: string,
explicitTemplate: TemplateType | undefined = undefined,
) {
const templateType = explicitTemplate ?? autodetectTemplateType(model);
const templates: Record<string, any> = {};
let editTemplate = null;
if (templateType && USES_OS_MODELS_EDIT_PROMPT.includes(templateType)) {
// This is overriding basically everything else
// Will probably delete the rest later, but for now it's easy to revert
editTemplate = osModelsEditPrompt;
} else if (templateType === "phind") {
editTemplate = phindEditPrompt;
} else if (templateType === "phi2") {
editTemplate = simplifiedEditPrompt;
} else if (templateType === "zephyr") {
editTemplate = zephyrEditPrompt;
} else if (templateType === "llama2") {
if (model.includes("mistral")) {
editTemplate = mistralEditPrompt;
} else {
editTemplate = osModelsEditPrompt;
}
} else if (templateType === "alpaca") {
editTemplate = alpacaEditPrompt;
} else if (templateType === "deepseek") {
editTemplate = deepseekEditPrompt;
} else if (templateType === "openchat") {
editTemplate = openchatEditPrompt;
} else if (templateType === "xwin-coder") {
editTemplate = xWinCoderEditPrompt;
} else if (templateType === "neural-chat") {
editTemplate = neuralChatEditPrompt;
} else if (templateType === "codellama-70b") {
editTemplate = codeLlama70bEditPrompt;
} else if (templateType === "anthropic") {
editTemplate = claudeEditPrompt;
} else if (templateType === "gemma") {
editTemplate = gemmaEditPrompt;
} else if (templateType === "llama3") {
editTemplate = llama3EditPrompt;
} else if (templateType === "none") {
editTemplate = null;
} else if (templateType) {
editTemplate = gptEditPrompt;
} else if (model.includes("codestral")) {
editTemplate = osModelsEditPrompt;
}
if (editTemplate !== null) {
templates.edit = editTemplate;
}
return templates;
}
export {
autodetectPromptTemplates,
autodetectTemplateFunction,
autodetectTemplateType,
llmCanGenerateInParallel,
modelSupportsImages,
};