feat/bedrock-prompt-caching

This commit is contained in:
chezsmithy 2025-04-02 00:38:56 -07:00
parent d6b9dcb4d3
commit d04b7eecb9
7 changed files with 671 additions and 602 deletions

View File

@ -44,9 +44,9 @@
"version": "1.1.0",
"license": "Apache-2.0",
"dependencies": {
"@aws-sdk/client-bedrock-runtime": "^3.758.0",
"@aws-sdk/client-sagemaker-runtime": "^3.758.0",
"@aws-sdk/credential-providers": "^3.758.0",
"@aws-sdk/client-bedrock-runtime": "^3.779.0",
"@aws-sdk/client-sagemaker-runtime": "^3.777.0",
"@aws-sdk/credential-providers": "^3.778.0",
"@continuedev/config-types": "^1.0.13",
"@continuedev/config-yaml": "^1.0.67",
"@continuedev/fetch": "^1.0.4",

View File

@ -2,6 +2,7 @@ import {
BedrockRuntimeClient,
ContentBlock,
ConverseStreamCommand,
ConverseStreamCommandOutput,
InvokeModelCommand,
Message
} from "@aws-sdk/client-bedrock-runtime";
@ -31,6 +32,14 @@ interface ToolUseState {
input: string;
}
/**
* Interface for prompt caching metrics
*/
interface PromptCachingMetrics {
cacheReadInputTokens: number;
cacheWriteInputTokens: number;
}
class Bedrock extends BaseLLM {
static providerName = "bedrock";
static defaultOptions: Partial<LLMOptions> = {
@ -41,6 +50,10 @@ class Bedrock extends BaseLLM {
};
private _currentToolResponse: Partial<ToolUseState> | null = null;
private _promptCachingMetrics: PromptCachingMetrics = {
cacheReadInputTokens: 0,
cacheWriteInputTokens: 0
};
public requestOptions: { region?: string; credentials?: any; headers?: Record<string, string> };
@ -111,9 +124,9 @@ class Bedrock extends BaseLLM {
const input = this._generateConverseInput(messages, { ...options, stream: true });
const command = new ConverseStreamCommand(input);
let response;
let response: ConverseStreamCommandOutput;
try {
response = await client.send(command, { abortSignal: signal });
response = await client.send(command, { abortSignal: signal }) as ConverseStreamCommandOutput;
} catch (error: unknown) {
console.error(error);
const message = error instanceof Error ? error.message : "Unknown error";
@ -124,8 +137,17 @@ class Bedrock extends BaseLLM {
throw new Error("No stream received from Bedrock API");
}
// Reset cache metrics for new request
this._promptCachingMetrics = {
cacheReadInputTokens: 0,
cacheWriteInputTokens: 0
};
try {
for await (const chunk of response.stream) {
if (chunk.metadata?.usage) {
console.log(`${JSON.stringify(chunk.metadata.usage)}`);
}
if (chunk.contentBlockDelta?.delta) {
@ -228,11 +250,30 @@ class Bedrock extends BaseLLM {
): any {
const convertedMessages = this._convertMessages(messages);
const shouldCacheSystemMessage =
!!this.systemMessage && this.cacheBehavior?.cacheSystemMessage;
const enablePromptCaching = shouldCacheSystemMessage || this.cacheBehavior?.cacheConversation;
// Add header for prompt caching
if (enablePromptCaching) {
this.requestOptions.headers = {
...this.requestOptions.headers,
"x-amzn-bedrock-enablepromptcaching": "true"
};
}
const supportsTools = PROVIDER_TOOL_SUPPORT.bedrock?.(options.model || "") ?? false;
return {
modelId: options.model,
messages: convertedMessages,
system: this.systemMessage ? [{ text: this.systemMessage }] : undefined,
system: this.systemMessage ? (
shouldCacheSystemMessage ?
[
{ text: this.systemMessage },
{ cachePoint: { type: "default" } }
] :
[{ text: this.systemMessage }]
) : undefined,
toolConfig: supportsTools && options.tools ? {
tools: options.tools.map(tool => ({
toolSpec: {
@ -269,12 +310,14 @@ class Bedrock extends BaseLLM {
};
}
private _convertMessage(message: ChatMessage): Message | null {
private _convertMessage(message: ChatMessage, addCaching: boolean = false): Message | null {
// Handle system messages explicitly
if (message.role === "system") {
return null;
}
const cachePoint = addCaching ? { cachePoint: { type: "default" } } : undefined;
// Tool response handling
if (message.role === "tool") {
return {
@ -338,58 +381,82 @@ class Bedrock extends BaseLLM {
// Standard text message
if (typeof message.content === "string") {
const content: any[] = [{ text: message.content }];
if (addCaching) {
content.push({ cachePoint: { type: "default" } });
}
return {
role: message.role,
content: [{ text: message.content }]
content
};
}
// Improved multimodal content handling
if (Array.isArray(message.content)) {
const content: any[] = [];
// Process all parts first
message.content.forEach(part => {
if (part.type === "text") {
content.push({ text: part.text });
} else if (part.type === "imageUrl" && part.imageUrl) {
try {
const [mimeType, base64Data] = part.imageUrl.url.split(",");
const format = mimeType.split("/")[1]?.split(";")[0] || "jpeg";
content.push({
image: {
format,
source: {
bytes: Buffer.from(base64Data, "base64")
}
}
});
} catch (error) {
console.warn(`Failed to process image: ${error}`);
}
}
});
// Add cache point as a separate block at the end if needed
if (addCaching && content.length > 0) {
content.push({ cachePoint: { type: "default" } });
}
return {
role: message.role,
content: message.content.map(part => {
if (part.type === "text") {
return { text: part.text };
}
if (part.type === "imageUrl" && part.imageUrl) {
try {
const [mimeType, base64Data] = part.imageUrl.url.split(",");
const format = mimeType.split("/")[1]?.split(";")[0] || "jpeg";
return {
image: {
format,
source: {
bytes: Buffer.from(base64Data, "base64")
}
}
};
} catch (error) {
console.warn(`Failed to process image: ${error}`);
return null;
}
}
return null;
}).filter(Boolean)
content
} as Message;
}
return null;
}
private _convertMessages(messages: ChatMessage[]): any[] {
const filteredmessages = messages.filter(
(m) => m.role !== "system" && !!m.content,
);
const lastTwoUserMsgIndices = filteredmessages
.map((msg, index) => (msg.role === "user" ? index : -1))
.filter((index) => index !== -1)
.slice(-2);
const converted = filteredmessages.map((message, filteredMsgIdx) => {
// Add cache_control parameter to the last two user messages
// The second-to-last because it retrieves potentially already cached contents,
// The last one because we want it cached for later retrieval.
// See: https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html
const addCaching =
this.cacheBehavior?.cacheConversation &&
lastTwoUserMsgIndices.includes(filteredMsgIdx);
try {
return this._convertMessage(message, addCaching);
} catch (error) {
console.error(`Failed to convert message: ${error}`);
return null;
}
}).filter(Boolean);
const converted = messages
.map((message) => {
try {
return this._convertMessage(message);
} catch (error) {
console.error(`Failed to convert message: ${error}`);
return null;
}
})
.filter(Boolean);
return converted;
}
private async _getCredentials() {

1100
core/package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@ -43,9 +43,9 @@
"typescript": "^5.6.3"
},
"dependencies": {
"@aws-sdk/client-bedrock-runtime": "^3.758.0",
"@aws-sdk/client-sagemaker-runtime": "^3.758.0",
"@aws-sdk/credential-providers": "^3.758.0",
"@aws-sdk/client-bedrock-runtime": "^3.779.0",
"@aws-sdk/client-sagemaker-runtime": "^3.777.0",
"@aws-sdk/credential-providers": "^3.778.0",
"@continuedev/config-types": "^1.0.13",
"@continuedev/config-yaml": "^1.0.67",
"@continuedev/fetch": "^1.0.4",

View File

@ -102,9 +102,9 @@
"version": "1.1.0",
"license": "Apache-2.0",
"dependencies": {
"@aws-sdk/client-bedrock-runtime": "^3.758.0",
"@aws-sdk/client-sagemaker-runtime": "^3.758.0",
"@aws-sdk/credential-providers": "^3.758.0",
"@aws-sdk/client-bedrock-runtime": "^3.779.0",
"@aws-sdk/client-sagemaker-runtime": "^3.777.0",
"@aws-sdk/credential-providers": "^3.778.0",
"@continuedev/config-types": "^1.0.13",
"@continuedev/config-yaml": "^1.0.67",
"@continuedev/fetch": "^1.0.4",

6
gui/package-lock.json generated
View File

@ -107,9 +107,9 @@
"version": "1.1.0",
"license": "Apache-2.0",
"dependencies": {
"@aws-sdk/client-bedrock-runtime": "^3.758.0",
"@aws-sdk/client-sagemaker-runtime": "^3.758.0",
"@aws-sdk/credential-providers": "^3.758.0",
"@aws-sdk/client-bedrock-runtime": "^3.779.0",
"@aws-sdk/client-sagemaker-runtime": "^3.777.0",
"@aws-sdk/credential-providers": "^3.778.0",
"@continuedev/config-types": "^1.0.13",
"@continuedev/config-yaml": "^1.0.67",
"@continuedev/fetch": "^1.0.4",

4
package-lock.json generated
View File

@ -8,7 +8,8 @@
"@typescript-eslint/parser": "^7.8.0",
"eslint-plugin-import": "^2.29.1",
"prettier": "^3.3.3",
"prettier-plugin-tailwindcss": "^0.6.8"
"prettier-plugin-tailwindcss": "^0.6.8",
"typescript": "^5.6.3"
}
},
"node_modules/@eslint-community/eslint-utils": {
@ -2976,7 +2977,6 @@
"resolved": "https://registry.npmjs.org/typescript/-/typescript-5.6.3.tgz",
"integrity": "sha512-hjcS1mhfuyi4WW8IWtjP7brDrG2cuDZukyrYrSauoXGNgx0S7zceP07adYkJycEr56BOUTNPzbInooiN3fn1qw==",
"dev": true,
"peer": true,
"bin": {
"tsc": "bin/tsc",
"tsserver": "bin/tsserver"