continue/core/llm/countTokens.ts

391 lines
11 KiB
TypeScript

import { Tiktoken, encodingForModel as _encodingForModel } from "js-tiktoken";
import { ChatMessage, MessageContent, MessagePart } from "../index.js";
import {
AsyncEncoder,
GPTAsyncEncoder,
LlamaAsyncEncoder,
} from "./asyncEncoder.js";
import { autodetectTemplateType } from "./autodetect.js";
import { TOKEN_BUFFER_FOR_SAFETY } from "./constants.js";
import { stripImages } from "./images.js";
import llamaTokenizer from "./llamaTokenizer.js";
interface Encoding {
encode: Tiktoken["encode"];
decode: Tiktoken["decode"];
}
class LlamaEncoding implements Encoding {
encode(text: string): number[] {
return llamaTokenizer.encode(text);
}
decode(tokens: number[]): string {
return llamaTokenizer.decode(tokens);
}
}
let gptEncoding: Encoding | null = null;
const gptAsyncEncoder = new GPTAsyncEncoder();
const llamaEncoding = new LlamaEncoding();
const llamaAsyncEncoder = new LlamaAsyncEncoder();
function asyncEncoderForModel(modelName: string): AsyncEncoder {
const modelType = autodetectTemplateType(modelName);
if (!modelType || modelType === "none") {
return gptAsyncEncoder;
}
// Temporary due to issues packaging the worker files
return process.env.IS_BINARY ? gptAsyncEncoder : llamaAsyncEncoder;
}
function encodingForModel(modelName: string): Encoding {
const modelType = autodetectTemplateType(modelName);
if (!modelType || modelType === "none") {
if (!gptEncoding) {
gptEncoding = _encodingForModel("gpt-4");
}
return gptEncoding;
}
return llamaEncoding;
}
function countImageTokens(content: MessagePart): number {
if (content.type === "imageUrl") {
return 85;
}
throw new Error("Non-image content type");
}
async function countTokensAsync(
content: MessageContent,
// defaults to llama2 because the tokenizer tends to produce more tokens
modelName = "llama2",
): Promise<number> {
const encoding = asyncEncoderForModel(modelName);
if (Array.isArray(content)) {
const promises = content.map(async (part) => {
if (part.type === "imageUrl") {
return countImageTokens(part);
}
return (await encoding.encode(part.text ?? "")).length;
});
return (await Promise.all(promises)).reduce((sum, val) => sum + val, 0);
}
return (await encoding.encode(content ?? "")).length;
}
function countTokens(
content: MessageContent,
// defaults to llama2 because the tokenizer tends to produce more tokens
modelName = "llama2",
): number {
const encoding = encodingForModel(modelName);
if (Array.isArray(content)) {
return content.reduce((acc, part) => {
return acc + part.type === "imageUrl"
? countImageTokens(part)
: encoding.encode(part.text ?? "", "all", []).length;
}, 0);
} else {
return encoding.encode(content ?? "", "all", []).length;
}
}
function flattenMessages(msgs: ChatMessage[]): ChatMessage[] {
const flattened: ChatMessage[] = [];
for (let i = 0; i < msgs.length; i++) {
const msg = msgs[i];
if (
flattened.length > 0 &&
flattened[flattened.length - 1].role === msg.role
) {
flattened[flattened.length - 1].content += `\n\n${msg.content || ""}`;
} else {
flattened.push(msg);
}
}
return flattened;
}
function countChatMessageTokens(
modelName: string,
chatMessage: ChatMessage,
): number {
// Doing simpler, safer version of what is here:
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
// every message follows <|im_start|>{role/name}\n{content}<|end|>\n
const TOKENS_PER_MESSAGE: number = 4;
return countTokens(chatMessage.content, modelName) + TOKENS_PER_MESSAGE;
}
function pruneLinesFromTop(
prompt: string,
maxTokens: number,
modelName: string,
): string {
let totalTokens = countTokens(prompt, modelName);
const lines = prompt.split("\n");
while (totalTokens > maxTokens && lines.length > 0) {
totalTokens -= countTokens(lines.shift()!, modelName);
}
return lines.join("\n");
}
function pruneLinesFromBottom(
prompt: string,
maxTokens: number,
modelName: string,
): string {
let totalTokens = countTokens(prompt, modelName);
const lines = prompt.split("\n");
while (totalTokens > maxTokens && lines.length > 0) {
totalTokens -= countTokens(lines.pop()!, modelName);
}
return lines.join("\n");
}
function pruneStringFromBottom(
modelName: string,
maxTokens: number,
prompt: string,
): string {
const encoding = encodingForModel(modelName);
const tokens = encoding.encode(prompt, "all", []);
if (tokens.length <= maxTokens) {
return prompt;
}
return encoding.decode(tokens.slice(0, maxTokens));
}
function pruneStringFromTop(
modelName: string,
maxTokens: number,
prompt: string,
): string {
const encoding = encodingForModel(modelName);
const tokens = encoding.encode(prompt, "all", []);
if (tokens.length <= maxTokens) {
return prompt;
}
return encoding.decode(tokens.slice(tokens.length - maxTokens));
}
function pruneRawPromptFromTop(
modelName: string,
contextLength: number,
prompt: string,
tokensForCompletion: number,
): string {
const maxTokens =
contextLength - tokensForCompletion - TOKEN_BUFFER_FOR_SAFETY;
return pruneStringFromTop(modelName, maxTokens, prompt);
}
function pruneRawPromptFromBottom(
modelName: string,
contextLength: number,
prompt: string,
tokensForCompletion: number,
): string {
const maxTokens =
contextLength - tokensForCompletion - TOKEN_BUFFER_FOR_SAFETY;
return pruneStringFromBottom(modelName, maxTokens, prompt);
}
function summarize(message: MessageContent): string {
if (Array.isArray(message)) {
return `${stripImages(message).substring(0, 100)}...`;
}
return `${message.substring(0, 100)}...`;
}
function pruneChatHistory(
modelName: string,
chatHistory: ChatMessage[],
contextLength: number,
tokensForCompletion: number,
): ChatMessage[] {
let totalTokens =
tokensForCompletion +
chatHistory.reduce((acc, message) => {
return acc + countChatMessageTokens(modelName, message);
}, 0);
// 0. Prune any messages that take up more than 1/3 of the context length
const longestMessages = [...chatHistory];
longestMessages.sort((a, b) => b.content.length - a.content.length);
const longerThanOneThird = longestMessages.filter(
(message: ChatMessage) =>
countTokens(message.content, modelName) > contextLength / 3,
);
const distanceFromThird = longerThanOneThird.map(
(message: ChatMessage) =>
countTokens(message.content, modelName) - contextLength / 3,
);
for (let i = 0; i < longerThanOneThird.length; i++) {
// Prune line-by-line from the top
const message = longerThanOneThird[i];
const content = stripImages(message.content);
const deltaNeeded = totalTokens - contextLength;
const delta = Math.min(deltaNeeded, distanceFromThird[i]);
message.content = pruneStringFromTop(
modelName,
countTokens(message.content, modelName) - delta,
content,
);
totalTokens -= delta;
}
// 1. Replace beyond last 5 messages with summary
let i = 0;
while (totalTokens > contextLength && i < chatHistory.length - 5) {
const message = chatHistory[0];
totalTokens -= countTokens(message.content, modelName);
totalTokens += countTokens(summarize(message.content), modelName);
message.content = summarize(message.content);
i++;
}
// 2. Remove entire messages until the last 5
while (
chatHistory.length > 5 &&
totalTokens > contextLength &&
chatHistory.length > 0
) {
const message = chatHistory.shift()!;
totalTokens -= countTokens(message.content, modelName);
}
// 3. Truncate message in the last 5, except last 1
i = 0;
while (
totalTokens > contextLength &&
chatHistory.length > 0 &&
i < chatHistory.length - 1
) {
const message = chatHistory[i];
totalTokens -= countTokens(message.content, modelName);
totalTokens += countTokens(summarize(message.content), modelName);
message.content = summarize(message.content);
i++;
}
// 4. Remove entire messages in the last 5, except last 1
while (totalTokens > contextLength && chatHistory.length > 1) {
const message = chatHistory.shift()!;
totalTokens -= countTokens(message.content, modelName);
}
// 5. Truncate last message
if (totalTokens > contextLength && chatHistory.length > 0) {
const message = chatHistory[0];
message.content = pruneRawPromptFromTop(
modelName,
contextLength,
stripImages(message.content),
tokensForCompletion,
);
totalTokens = contextLength;
}
return chatHistory;
}
function compileChatMessages(
modelName: string,
msgs: ChatMessage[] | undefined,
contextLength: number,
maxTokens: number,
supportsImages: boolean,
prompt: string | undefined = undefined,
functions: any[] | undefined = undefined,
systemMessage: string | undefined = undefined,
): ChatMessage[] {
const msgsCopy = msgs
? msgs.map((msg) => ({ ...msg })).filter((msg) => msg.content !== "")
: [];
if (prompt) {
const promptMsg: ChatMessage = {
role: "user",
content: prompt,
};
msgsCopy.push(promptMsg);
}
if (systemMessage && systemMessage.trim() !== "") {
const systemChatMsg: ChatMessage = {
role: "system",
content: systemMessage,
};
// Insert as second to last
// Later moved to top, but want second-priority to last user message
msgsCopy.splice(-1, 0, systemChatMsg);
}
let functionTokens = 0;
if (functions) {
for (const func of functions) {
functionTokens += countTokens(JSON.stringify(func), modelName);
}
}
if (maxTokens + functionTokens + TOKEN_BUFFER_FOR_SAFETY >= contextLength) {
throw new Error(
`maxTokens (${maxTokens}) is too close to contextLength (${contextLength}), which doesn't leave room for response. Try increasing the contextLength parameter of the model in your config.json.`,
);
}
// If images not supported, convert MessagePart[] to string
if (!supportsImages) {
for (const msg of msgsCopy) {
if ("content" in msg && Array.isArray(msg.content)) {
const content = stripImages(msg.content);
msg.content = content;
}
}
}
const history = pruneChatHistory(
modelName,
msgsCopy,
contextLength,
functionTokens + maxTokens + TOKEN_BUFFER_FOR_SAFETY,
);
if (
systemMessage &&
history.length >= 2 &&
history[history.length - 2].role === "system"
) {
const movedSystemMessage = history.splice(-2, 1)[0];
history.unshift(movedSystemMessage);
}
const flattenedHistory = flattenMessages(history);
return flattenedHistory;
}
export {
compileChatMessages,
countTokens,
countTokensAsync,
pruneLinesFromBottom,
pruneLinesFromTop,
pruneRawPromptFromTop,
pruneStringFromBottom,
pruneStringFromTop,
};