continue/core/llm/index.ts

638 lines
18 KiB
TypeScript

import { findLlmInfo } from "@continuedev/llm-info";
import Handlebars from "handlebars";
import {
ChatMessage,
ChatMessageRole,
CompletionOptions,
ILLM,
LLMFullCompletionOptions,
LLMOptions,
ModelName,
ModelProvider,
PromptLog,
PromptTemplate,
RequestOptions,
TemplateType,
} from "../index.js";
import { logDevData } from "../util/devdata.js";
import { DevDataSqliteDb } from "../util/devdataSqlite.js";
import { fetchwithRequestOptions } from "../util/fetchWithOptions.js";
import mergeJson from "../util/merge.js";
import { Telemetry } from "../util/posthog.js";
import { withExponentialBackoff } from "../util/withExponentialBackoff.js";
import {
autodetectPromptTemplates,
autodetectTemplateFunction,
autodetectTemplateType,
modelSupportsImages,
} from "./autodetect.js";
import {
CONTEXT_LENGTH_FOR_MODEL,
DEFAULT_ARGS,
DEFAULT_CONTEXT_LENGTH,
DEFAULT_MAX_TOKENS,
} from "./constants.js";
import {
compileChatMessages,
countTokens,
pruneRawPromptFromTop,
} from "./countTokens.js";
import CompletionOptionsForModels from "./templates/options.js";
import { stripImages } from "./images.js";
export abstract class BaseLLM implements ILLM {
static providerName: ModelProvider;
static defaultOptions: Partial<LLMOptions> | undefined = undefined;
get providerName(): ModelProvider {
return (this.constructor as typeof BaseLLM).providerName;
}
supportsFim(): boolean {
return false;
}
supportsImages(): boolean {
return modelSupportsImages(this.providerName, this.model, this.title);
}
supportsCompletions(): boolean {
if (this.providerName === "openai") {
if (
this.apiBase?.includes("api.groq.com") ||
this.apiBase?.includes("api.mistral.ai") ||
this.apiBase?.includes(":1337") ||
this._llmOptions.useLegacyCompletionsEndpoint?.valueOf() === false
) {
// Jan + Groq + Mistral don't support completions : (
// Seems to be going out of style...
return false;
}
}
if (["groq", "mistral"].includes(this.providerName)) {
return false;
}
return true;
}
supportsPrefill(): boolean {
return ["ollama", "anthropic", "mistral"].includes(this.providerName);
}
uniqueId: string;
model: string;
title?: string;
systemMessage?: string;
contextLength: number;
completionOptions: CompletionOptions;
requestOptions?: RequestOptions;
template?: TemplateType;
promptTemplates?: Record<string, PromptTemplate>;
templateMessages?: (messages: ChatMessage[]) => string;
writeLog?: (str: string) => Promise<void>;
llmRequestHook?: (model: string, prompt: string) => any;
apiKey?: string;
apiBase?: string;
engine?: string;
apiVersion?: string;
apiType?: string;
region?: string;
projectId?: string;
accountId?: string;
aiGatewaySlug?: string;
private _llmOptions: LLMOptions;
constructor(_options: LLMOptions) {
this._llmOptions = _options;
// Set default options
const options = {
title: (this.constructor as typeof BaseLLM).providerName,
...(this.constructor as typeof BaseLLM).defaultOptions,
..._options,
};
this.model = options.model;
const llmInfo = findLlmInfo(this.model);
const templateType =
options.template ?? autodetectTemplateType(options.model);
this.title = options.title;
this.uniqueId = options.uniqueId ?? "None";
this.systemMessage = options.systemMessage;
this.contextLength =
options.contextLength ?? llmInfo?.contextLength ?? DEFAULT_CONTEXT_LENGTH;
this.completionOptions = {
...options.completionOptions,
model: options.model || "gpt-4",
maxTokens: options.completionOptions?.maxTokens ?? DEFAULT_MAX_TOKENS,
};
if (CompletionOptionsForModels[options.model as ModelName]) {
this.completionOptions = mergeJson(
this.completionOptions,
CompletionOptionsForModels[options.model as ModelName] ?? {},
);
}
this.requestOptions = options.requestOptions;
this.promptTemplates = {
...autodetectPromptTemplates(options.model, templateType),
...options.promptTemplates,
};
this.templateMessages =
options.templateMessages ??
autodetectTemplateFunction(
options.model,
this.providerName,
options.template,
);
this.writeLog = options.writeLog;
this.llmRequestHook = options.llmRequestHook;
this.apiKey = options.apiKey;
this.aiGatewaySlug = options.aiGatewaySlug;
this.apiBase = options.apiBase;
if (this.apiBase && !this.apiBase.endsWith("/")) {
this.apiBase = `${this.apiBase}/`;
}
this.accountId = options.accountId;
this.engine = options.engine;
this.apiVersion = options.apiVersion;
this.apiType = options.apiType;
this.region = options.region;
this.projectId = options.projectId;
}
listModels(): Promise<string[]> {
return Promise.resolve([]);
}
private _compileChatMessages(
options: CompletionOptions,
messages: ChatMessage[],
functions?: any[],
) {
let contextLength = this.contextLength;
if (
options.model !== this.model &&
options.model in CONTEXT_LENGTH_FOR_MODEL
) {
contextLength =
CONTEXT_LENGTH_FOR_MODEL[options.model] || DEFAULT_CONTEXT_LENGTH;
}
return compileChatMessages(
options.model,
messages,
contextLength,
options.maxTokens ?? DEFAULT_MAX_TOKENS,
this.supportsImages(),
undefined,
functions,
this.systemMessage,
);
}
private _getSystemMessage(): string | undefined {
// TODO: Merge with config system message
return this.systemMessage;
}
private _templatePromptLikeMessages(prompt: string): string {
if (!this.templateMessages) {
return prompt;
}
const msgs: ChatMessage[] = [{ role: "user", content: prompt }];
const systemMessage = this._getSystemMessage();
if (systemMessage) {
msgs.unshift({ role: "system", content: systemMessage });
}
return this.templateMessages(msgs);
}
private _compileLogMessage(
prompt: string,
completionOptions: CompletionOptions,
): string {
const dict = { contextLength: this.contextLength, ...completionOptions };
const settings = Object.entries(dict)
.map(([key, value]) => `${key}: ${value}`)
.join("\n");
return `Settings:
${settings}
############################################
${prompt}`;
}
private _logTokensGenerated(
model: string,
prompt: string,
completion: string,
) {
let promptTokens = this.countTokens(prompt);
let generatedTokens = this.countTokens(completion);
Telemetry.capture(
"tokens_generated",
{
model: model,
provider: this.providerName,
promptTokens: promptTokens,
generatedTokens: generatedTokens,
},
true,
);
DevDataSqliteDb.logTokensGenerated(
model,
this.providerName,
promptTokens,
generatedTokens,
);
logDevData("tokens_generated", {
model: model,
provider: this.providerName,
promptTokens: promptTokens,
generatedTokens: generatedTokens,
});
}
fetch(url: RequestInfo | URL, init?: RequestInit): Promise<Response> {
// Custom Node.js fetch
const customFetch = async (input: URL | RequestInfo, init: any) => {
try {
const resp = await fetchwithRequestOptions(
new URL(input as any),
{ ...init },
{ ...this.requestOptions },
);
// Error mapping to be more helpful
if (!resp.ok) {
let text = await resp.text();
if (resp.status === 404 && !resp.url.includes("/v1")) {
if (text.includes("try pulling it first")) {
const model = JSON.parse(text).error.split(" ")[1].slice(1, -1);
text = `The model "${model}" was not found. To download it, run \`ollama run ${model}\`.`;
} else if (text.includes("/api/chat")) {
text =
"The /api/chat endpoint was not found. This may mean that you are using an older version of Ollama that does not support /api/chat. Upgrading to the latest version will solve the issue.";
} else {
text =
"This may mean that you forgot to add '/v1' to the end of your 'apiBase' in config.json.";
}
} else if (
resp.status === 404 &&
resp.url.includes("api.openai.com")
) {
text =
"You may need to add pre-paid credits before using the OpenAI API.";
}
throw new Error(
`HTTP ${resp.status} ${resp.statusText} from ${resp.url}\n\n${text}`,
);
}
return resp;
} catch (e: any) {
// Errors to ignore
if (!e.message.includes("/api/show")) {
console.warn(
`${e.message}\n\nCode: ${e.code}\nError number: ${e.errno}\nSyscall: ${e.erroredSysCall}\nType: ${e.type}\n\n${e.stack}`,
);
if (
e.code === "ECONNREFUSED" &&
e.message.includes("http://127.0.0.1:11434")
) {
throw new Error(
"Failed to connect to local Ollama instance. To start Ollama, first download it at https://ollama.ai.",
);
}
}
throw new Error(e.message);
}
};
return withExponentialBackoff<Response>(
() => customFetch(url, init) as any,
5,
0.5,
);
}
private _parseCompletionOptions(options: LLMFullCompletionOptions) {
const log = options.log ?? true;
const raw = options.raw ?? false;
options.log = undefined;
const completionOptions: CompletionOptions = mergeJson(
this.completionOptions,
options,
);
return { completionOptions, log, raw };
}
private _formatChatMessages(messages: ChatMessage[]): string {
const msgsCopy = messages ? messages.map((msg) => ({ ...msg })) : [];
let formatted = "";
for (const msg of msgsCopy) {
if ("content" in msg && Array.isArray(msg.content)) {
const content = stripImages(msg.content);
msg.content = content;
}
formatted += `<${msg.role}>\n${msg.content || ""}\n\n`;
}
return formatted;
}
async *_streamFim(
prefix: string,
suffix: string,
options: CompletionOptions,
): AsyncGenerator<string, PromptLog> {
throw new Error("Not implemented");
}
async *streamFim(
prefix: string,
suffix: string,
options: LLMFullCompletionOptions = {},
): AsyncGenerator<string> {
const { completionOptions, log } = this._parseCompletionOptions(options);
const madeUpFimPrompt = `${prefix}<FIM>${suffix}`;
if (log) {
if (this.writeLog) {
await this.writeLog(
this._compileLogMessage(madeUpFimPrompt, completionOptions),
);
}
if (this.llmRequestHook) {
this.llmRequestHook(completionOptions.model, madeUpFimPrompt);
}
}
let completion = "";
for await (const chunk of this._streamFim(
prefix,
suffix,
completionOptions,
)) {
completion += chunk;
yield chunk;
}
this._logTokensGenerated(
completionOptions.model,
madeUpFimPrompt,
completion,
);
if (log && this.writeLog) {
await this.writeLog(`Completion:\n\n${completion}\n\n`);
}
return {
prompt: madeUpFimPrompt,
completion,
completionOptions,
};
}
async *streamComplete(
_prompt: string,
options: LLMFullCompletionOptions = {},
) {
const { completionOptions, log, raw } =
this._parseCompletionOptions(options);
let prompt = pruneRawPromptFromTop(
completionOptions.model,
this.contextLength,
_prompt,
completionOptions.maxTokens ?? DEFAULT_MAX_TOKENS,
);
if (!raw) {
prompt = this._templatePromptLikeMessages(prompt);
}
if (log) {
if (this.writeLog) {
await this.writeLog(this._compileLogMessage(prompt, completionOptions));
}
if (this.llmRequestHook) {
this.llmRequestHook(completionOptions.model, prompt);
}
}
let completion = "";
for await (const chunk of this._streamComplete(prompt, completionOptions)) {
completion += chunk;
yield chunk;
}
this._logTokensGenerated(completionOptions.model, prompt, completion);
if (log && this.writeLog) {
await this.writeLog(`Completion:\n\n${completion}\n\n`);
}
return { prompt, completion, completionOptions };
}
async complete(_prompt: string, options: LLMFullCompletionOptions = {}) {
const { completionOptions, log, raw } =
this._parseCompletionOptions(options);
let prompt = pruneRawPromptFromTop(
completionOptions.model,
this.contextLength,
_prompt,
completionOptions.maxTokens ?? DEFAULT_MAX_TOKENS,
);
if (!raw) {
prompt = this._templatePromptLikeMessages(prompt);
}
if (log) {
if (this.writeLog) {
await this.writeLog(this._compileLogMessage(prompt, completionOptions));
}
if (this.llmRequestHook) {
this.llmRequestHook(completionOptions.model, prompt);
}
}
const completion = await this._complete(prompt, completionOptions);
this._logTokensGenerated(completionOptions.model, prompt, completion);
if (log && this.writeLog) {
await this.writeLog(`Completion:\n\n${completion}\n\n`);
}
return completion;
}
async chat(messages: ChatMessage[], options: LLMFullCompletionOptions = {}) {
let completion = "";
for await (const chunk of this.streamChat(messages, options)) {
completion += chunk.content;
}
return { role: "assistant" as ChatMessageRole, content: completion };
}
async *streamChat(
_messages: ChatMessage[],
options: LLMFullCompletionOptions = {},
): AsyncGenerator<ChatMessage, PromptLog> {
const { completionOptions, log, raw } =
this._parseCompletionOptions(options);
const messages = this._compileChatMessages(completionOptions, _messages);
const prompt = this.templateMessages
? this.templateMessages(messages)
: this._formatChatMessages(messages);
if (log) {
if (this.writeLog) {
await this.writeLog(this._compileLogMessage(prompt, completionOptions));
}
if (this.llmRequestHook) {
this.llmRequestHook(completionOptions.model, prompt);
}
}
let completion = "";
try {
if (this.templateMessages) {
for await (const chunk of this._streamComplete(
prompt,
completionOptions,
)) {
completion += chunk;
yield { role: "assistant", content: chunk };
}
} else {
for await (const chunk of this._streamChat(
messages,
completionOptions,
)) {
completion += chunk.content;
yield chunk;
}
}
} catch (error) {
console.log(error);
throw error;
}
this._logTokensGenerated(completionOptions.model, prompt, completion);
if (log && this.writeLog) {
await this.writeLog(`Completion:\n\n${completion}\n\n`);
}
return {
prompt,
completion,
completionOptions,
};
}
// biome-ignore lint/correctness/useYield: Purposefully not implemented
protected async *_streamComplete(
prompt: string,
options: CompletionOptions,
): AsyncGenerator<string> {
throw new Error("Not implemented");
}
protected async *_streamChat(
messages: ChatMessage[],
options: CompletionOptions,
): AsyncGenerator<ChatMessage> {
if (!this.templateMessages) {
throw new Error(
"You must either implement templateMessages or _streamChat",
);
}
for await (const chunk of this._streamComplete(
this.templateMessages(messages),
options,
)) {
yield { role: "assistant", content: chunk };
}
}
protected async _complete(prompt: string, options: CompletionOptions) {
let completion = "";
for await (const chunk of this._streamComplete(prompt, options)) {
completion += chunk;
}
return completion;
}
countTokens(text: string): number {
return countTokens(text, this.model);
}
protected collectArgs(options: CompletionOptions): any {
return {
...DEFAULT_ARGS,
// model: this.model,
...options,
};
}
public renderPromptTemplate(
template: PromptTemplate,
history: ChatMessage[],
otherData: Record<string, string>,
canPutWordsInModelsMouth = false,
): string | ChatMessage[] {
if (typeof template === "string") {
const data: any = {
history: history,
...otherData,
};
if (history.length > 0 && history[0].role === "system") {
data.system_message = history.shift()!.content;
}
const compiledTemplate = Handlebars.compile(template);
return compiledTemplate(data);
}
const rendered = template(history, {
...otherData,
supportsCompletions: this.supportsCompletions() ? "true" : "false",
supportsPrefill: this.supportsPrefill() ? "true" : "false",
});
if (
typeof rendered !== "string" &&
rendered[rendered.length - 1]?.role === "assistant" &&
!canPutWordsInModelsMouth
) {
// Some providers don't allow you to put words in the model's mouth
// So we have to manually compile the prompt template and use
// raw /completions, not /chat/completions
const templateMessages = autodetectTemplateFunction(
this.model,
this.providerName,
autodetectTemplateType(this.model),
);
return templateMessages(rendered);
}
return rendered;
}
}