continue/core/llm/llms/Bedrock.ts

216 lines
6.2 KiB
TypeScript

import {
BedrockRuntimeClient,
InvokeModelWithResponseStreamCommand,
} from "@aws-sdk/client-bedrock-runtime";
import { fromIni } from "@aws-sdk/credential-providers";
import {
ChatMessage,
CompletionOptions,
LLMOptions,
MessageContent,
ModelProvider,
} from "../../index.js";
import { stripImages } from "../images.js";
import { BaseLLM } from "../index.js";
class Bedrock extends BaseLLM {
private static PROFILE_NAME: string = "bedrock";
static providerName: ModelProvider = "bedrock";
static defaultOptions: Partial<LLMOptions> = {
region: "us-east-1",
model: "claude-3-sonnet-20240229",
contextLength: 200_000,
};
constructor(options: LLMOptions) {
super(options);
if (!options.apiBase) {
this.apiBase = `https://bedrock-runtime.${options.region}.amazonaws.com`;
}
}
protected async *_streamComplete(
prompt: string,
options: CompletionOptions,
): AsyncGenerator<string> {
const messages = [{ role: "user" as const, content: prompt }];
for await (const update of this._streamChat(messages, options)) {
yield stripImages(update.content);
}
}
protected async *_streamChat(
messages: ChatMessage[],
options: CompletionOptions,
): AsyncGenerator<ChatMessage> {
const credentials = await this._getCredentials();
const client = new BedrockRuntimeClient({
region: this.region,
credentials: {
accessKeyId: credentials.accessKeyId,
secretAccessKey: credentials.secretAccessKey,
sessionToken: credentials.sessionToken || "",
},
});
const toolkit = this._getToolkit(options.model);
const command = toolkit.generateCommand(messages, options);
const response = await client.send(command);
if (response.body) {
for await (const value of response.body) {
const text = toolkit.unwrapResponseChunk(value);
if (text) {
yield { role: "assistant", content: text };
}
}
}
}
private async _getCredentials() {
try {
return await fromIni({
profile: Bedrock.PROFILE_NAME,
})();
} catch (e) {
console.warn(
`AWS profile with name ${Bedrock.PROFILE_NAME} not found in ~/.aws/credentials, using default profile`,
);
return await fromIni()();
}
}
private _getToolkit(model: string): BedrockModelToolkit {
if (model.includes("claude-3")) {
return new AnthropicClaude3Toolkit(this);
} else if (model.includes("llama")) {
return new Llama3Toolkit(this);
} else {
throw new Error(
`Model ${model} is currently not supported in Continue for Bedrock`,
);
}
}
}
interface BedrockModelToolkit {
generateCommand(
messages: ChatMessage[],
options: CompletionOptions,
): InvokeModelWithResponseStreamCommand;
unwrapResponseChunk(rawValue: any): string;
}
class AnthropicClaude3Toolkit implements BedrockModelToolkit {
constructor(private bedrock: Bedrock) {}
generateCommand(
messages: ChatMessage[],
options: CompletionOptions,
): InvokeModelWithResponseStreamCommand {
const payload = {
anthropic_version: "bedrock-2023-05-31",
max_tokens: options.maxTokens,
system: this.bedrock.systemMessage,
messages: this._convertMessages(messages),
temperature: options.temperature,
top_p: options.topP,
top_k: options.topK,
// each stop sequence must contain non-whitespace
stop_sequences: options.stop?.filter((stop) => stop.trim() !== ""),
};
return new InvokeModelWithResponseStreamCommand({
body: new TextEncoder().encode(JSON.stringify(payload)),
contentType: "application/json",
modelId: options.model,
});
}
private _convertMessages(msgs: ChatMessage[]): any[] {
return msgs
.filter((m) => m.role !== "system")
.map((message) => this._convertMessage(message));
}
private _convertMessage(message: ChatMessage): any {
return {
role: message.role,
content: this._convertMessageContent(message.content),
};
}
private _convertMessageContent(messageContent: MessageContent): any {
if (typeof messageContent === "string") {
return messageContent;
}
return messageContent.map((part) => {
if (part.type === "text") {
return part;
}
return {
type: "image",
source: {
type: "base64",
media_type: "image/jpeg",
data: part.imageUrl?.url.split(",")[1],
},
};
});
}
unwrapResponseChunk(rawValue: any): string {
const binaryChunk = rawValue.chunk?.bytes;
const textChunk = new TextDecoder().decode(binaryChunk);
const chunk = JSON.parse(textChunk).delta?.text;
return chunk;
}
}
class Llama3Toolkit implements BedrockModelToolkit {
constructor(private bedrock: Bedrock) {}
generateCommand(
messages: ChatMessage[],
options: CompletionOptions,
): InvokeModelWithResponseStreamCommand {
let prompt = "<|begin_of_text|>";
if (this.bedrock.systemMessage) {
prompt += `<|start_header_id|>system<|end_header_id|>${this.bedrock.systemMessage}<|eot_id|>`;
}
for (const message of messages) {
let content = "";
if (typeof message.content === "string") {
content = message.content;
} else {
for (const part of message.content) {
if (part.type === "text") {
content += part.text || "";
} else {
console.warn("Skipping non-text message part", part);
}
}
}
if (content) {
prompt += `<|start_header_id|>${message.role}<|end_header_id|>${content}<|eot_id|>`;
}
}
prompt += "<|start_header_id|>assistant<|end_header_id|>";
const payload = {
prompt,
max_gen_len: options.maxTokens,
temperature: options.temperature,
top_p: options.topP,
};
return new InvokeModelWithResponseStreamCommand({
body: new TextEncoder().encode(JSON.stringify(payload)),
contentType: "application/json",
modelId: options.model,
});
}
unwrapResponseChunk(rawValue: any): string {
const binaryChunk = rawValue.chunk?.bytes;
const textChunk = new TextDecoder().decode(binaryChunk);
const chunk = JSON.parse(textChunk).generation;
return chunk;
}
}
export default Bedrock;