feat/bedrock-prompt-caching
This commit is contained in:
parent
d6b9dcb4d3
commit
d04b7eecb9
|
@ -44,9 +44,9 @@
|
|||
"version": "1.1.0",
|
||||
"license": "Apache-2.0",
|
||||
"dependencies": {
|
||||
"@aws-sdk/client-bedrock-runtime": "^3.758.0",
|
||||
"@aws-sdk/client-sagemaker-runtime": "^3.758.0",
|
||||
"@aws-sdk/credential-providers": "^3.758.0",
|
||||
"@aws-sdk/client-bedrock-runtime": "^3.779.0",
|
||||
"@aws-sdk/client-sagemaker-runtime": "^3.777.0",
|
||||
"@aws-sdk/credential-providers": "^3.778.0",
|
||||
"@continuedev/config-types": "^1.0.13",
|
||||
"@continuedev/config-yaml": "^1.0.67",
|
||||
"@continuedev/fetch": "^1.0.4",
|
||||
|
|
|
@ -2,6 +2,7 @@ import {
|
|||
BedrockRuntimeClient,
|
||||
ContentBlock,
|
||||
ConverseStreamCommand,
|
||||
ConverseStreamCommandOutput,
|
||||
InvokeModelCommand,
|
||||
Message
|
||||
} from "@aws-sdk/client-bedrock-runtime";
|
||||
|
@ -31,6 +32,14 @@ interface ToolUseState {
|
|||
input: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Interface for prompt caching metrics
|
||||
*/
|
||||
interface PromptCachingMetrics {
|
||||
cacheReadInputTokens: number;
|
||||
cacheWriteInputTokens: number;
|
||||
}
|
||||
|
||||
class Bedrock extends BaseLLM {
|
||||
static providerName = "bedrock";
|
||||
static defaultOptions: Partial<LLMOptions> = {
|
||||
|
@ -41,6 +50,10 @@ class Bedrock extends BaseLLM {
|
|||
};
|
||||
|
||||
private _currentToolResponse: Partial<ToolUseState> | null = null;
|
||||
private _promptCachingMetrics: PromptCachingMetrics = {
|
||||
cacheReadInputTokens: 0,
|
||||
cacheWriteInputTokens: 0
|
||||
};
|
||||
|
||||
public requestOptions: { region?: string; credentials?: any; headers?: Record<string, string> };
|
||||
|
||||
|
@ -111,9 +124,9 @@ class Bedrock extends BaseLLM {
|
|||
const input = this._generateConverseInput(messages, { ...options, stream: true });
|
||||
const command = new ConverseStreamCommand(input);
|
||||
|
||||
let response;
|
||||
let response: ConverseStreamCommandOutput;
|
||||
try {
|
||||
response = await client.send(command, { abortSignal: signal });
|
||||
response = await client.send(command, { abortSignal: signal }) as ConverseStreamCommandOutput;
|
||||
} catch (error: unknown) {
|
||||
console.error(error);
|
||||
const message = error instanceof Error ? error.message : "Unknown error";
|
||||
|
@ -124,8 +137,17 @@ class Bedrock extends BaseLLM {
|
|||
throw new Error("No stream received from Bedrock API");
|
||||
}
|
||||
|
||||
// Reset cache metrics for new request
|
||||
this._promptCachingMetrics = {
|
||||
cacheReadInputTokens: 0,
|
||||
cacheWriteInputTokens: 0
|
||||
};
|
||||
|
||||
try {
|
||||
for await (const chunk of response.stream) {
|
||||
if (chunk.metadata?.usage) {
|
||||
console.log(`${JSON.stringify(chunk.metadata.usage)}`);
|
||||
}
|
||||
|
||||
if (chunk.contentBlockDelta?.delta) {
|
||||
|
||||
|
@ -228,11 +250,30 @@ class Bedrock extends BaseLLM {
|
|||
): any {
|
||||
const convertedMessages = this._convertMessages(messages);
|
||||
|
||||
const shouldCacheSystemMessage =
|
||||
!!this.systemMessage && this.cacheBehavior?.cacheSystemMessage;
|
||||
const enablePromptCaching = shouldCacheSystemMessage || this.cacheBehavior?.cacheConversation;
|
||||
|
||||
// Add header for prompt caching
|
||||
if (enablePromptCaching) {
|
||||
this.requestOptions.headers = {
|
||||
...this.requestOptions.headers,
|
||||
"x-amzn-bedrock-enablepromptcaching": "true"
|
||||
};
|
||||
}
|
||||
|
||||
const supportsTools = PROVIDER_TOOL_SUPPORT.bedrock?.(options.model || "") ?? false;
|
||||
return {
|
||||
modelId: options.model,
|
||||
messages: convertedMessages,
|
||||
system: this.systemMessage ? [{ text: this.systemMessage }] : undefined,
|
||||
system: this.systemMessage ? (
|
||||
shouldCacheSystemMessage ?
|
||||
[
|
||||
{ text: this.systemMessage },
|
||||
{ cachePoint: { type: "default" } }
|
||||
] :
|
||||
[{ text: this.systemMessage }]
|
||||
) : undefined,
|
||||
toolConfig: supportsTools && options.tools ? {
|
||||
tools: options.tools.map(tool => ({
|
||||
toolSpec: {
|
||||
|
@ -269,12 +310,14 @@ class Bedrock extends BaseLLM {
|
|||
};
|
||||
}
|
||||
|
||||
private _convertMessage(message: ChatMessage): Message | null {
|
||||
private _convertMessage(message: ChatMessage, addCaching: boolean = false): Message | null {
|
||||
// Handle system messages explicitly
|
||||
if (message.role === "system") {
|
||||
return null;
|
||||
}
|
||||
|
||||
const cachePoint = addCaching ? { cachePoint: { type: "default" } } : undefined;
|
||||
|
||||
// Tool response handling
|
||||
if (message.role === "tool") {
|
||||
return {
|
||||
|
@ -338,58 +381,82 @@ class Bedrock extends BaseLLM {
|
|||
|
||||
// Standard text message
|
||||
if (typeof message.content === "string") {
|
||||
const content: any[] = [{ text: message.content }];
|
||||
if (addCaching) {
|
||||
content.push({ cachePoint: { type: "default" } });
|
||||
}
|
||||
return {
|
||||
role: message.role,
|
||||
content: [{ text: message.content }]
|
||||
content
|
||||
};
|
||||
}
|
||||
|
||||
// Improved multimodal content handling
|
||||
if (Array.isArray(message.content)) {
|
||||
const content: any[] = [];
|
||||
|
||||
// Process all parts first
|
||||
message.content.forEach(part => {
|
||||
if (part.type === "text") {
|
||||
content.push({ text: part.text });
|
||||
} else if (part.type === "imageUrl" && part.imageUrl) {
|
||||
try {
|
||||
const [mimeType, base64Data] = part.imageUrl.url.split(",");
|
||||
const format = mimeType.split("/")[1]?.split(";")[0] || "jpeg";
|
||||
content.push({
|
||||
image: {
|
||||
format,
|
||||
source: {
|
||||
bytes: Buffer.from(base64Data, "base64")
|
||||
}
|
||||
}
|
||||
});
|
||||
} catch (error) {
|
||||
console.warn(`Failed to process image: ${error}`);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Add cache point as a separate block at the end if needed
|
||||
if (addCaching && content.length > 0) {
|
||||
content.push({ cachePoint: { type: "default" } });
|
||||
}
|
||||
|
||||
return {
|
||||
role: message.role,
|
||||
content: message.content.map(part => {
|
||||
if (part.type === "text") {
|
||||
return { text: part.text };
|
||||
}
|
||||
if (part.type === "imageUrl" && part.imageUrl) {
|
||||
try {
|
||||
const [mimeType, base64Data] = part.imageUrl.url.split(",");
|
||||
const format = mimeType.split("/")[1]?.split(";")[0] || "jpeg";
|
||||
return {
|
||||
image: {
|
||||
format,
|
||||
source: {
|
||||
bytes: Buffer.from(base64Data, "base64")
|
||||
}
|
||||
}
|
||||
};
|
||||
} catch (error) {
|
||||
console.warn(`Failed to process image: ${error}`);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}).filter(Boolean)
|
||||
content
|
||||
} as Message;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private _convertMessages(messages: ChatMessage[]): any[] {
|
||||
const filteredmessages = messages.filter(
|
||||
(m) => m.role !== "system" && !!m.content,
|
||||
);
|
||||
const lastTwoUserMsgIndices = filteredmessages
|
||||
.map((msg, index) => (msg.role === "user" ? index : -1))
|
||||
.filter((index) => index !== -1)
|
||||
.slice(-2);
|
||||
|
||||
const converted = filteredmessages.map((message, filteredMsgIdx) => {
|
||||
// Add cache_control parameter to the last two user messages
|
||||
// The second-to-last because it retrieves potentially already cached contents,
|
||||
// The last one because we want it cached for later retrieval.
|
||||
// See: https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html
|
||||
const addCaching =
|
||||
this.cacheBehavior?.cacheConversation &&
|
||||
lastTwoUserMsgIndices.includes(filteredMsgIdx);
|
||||
|
||||
try {
|
||||
return this._convertMessage(message, addCaching);
|
||||
} catch (error) {
|
||||
console.error(`Failed to convert message: ${error}`);
|
||||
return null;
|
||||
}
|
||||
}).filter(Boolean);
|
||||
|
||||
const converted = messages
|
||||
.map((message) => {
|
||||
try {
|
||||
return this._convertMessage(message);
|
||||
} catch (error) {
|
||||
console.error(`Failed to convert message: ${error}`);
|
||||
return null;
|
||||
}
|
||||
})
|
||||
.filter(Boolean);
|
||||
return converted;
|
||||
|
||||
}
|
||||
|
||||
private async _getCredentials() {
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -43,9 +43,9 @@
|
|||
"typescript": "^5.6.3"
|
||||
},
|
||||
"dependencies": {
|
||||
"@aws-sdk/client-bedrock-runtime": "^3.758.0",
|
||||
"@aws-sdk/client-sagemaker-runtime": "^3.758.0",
|
||||
"@aws-sdk/credential-providers": "^3.758.0",
|
||||
"@aws-sdk/client-bedrock-runtime": "^3.779.0",
|
||||
"@aws-sdk/client-sagemaker-runtime": "^3.777.0",
|
||||
"@aws-sdk/credential-providers": "^3.778.0",
|
||||
"@continuedev/config-types": "^1.0.13",
|
||||
"@continuedev/config-yaml": "^1.0.67",
|
||||
"@continuedev/fetch": "^1.0.4",
|
||||
|
|
|
@ -102,9 +102,9 @@
|
|||
"version": "1.1.0",
|
||||
"license": "Apache-2.0",
|
||||
"dependencies": {
|
||||
"@aws-sdk/client-bedrock-runtime": "^3.758.0",
|
||||
"@aws-sdk/client-sagemaker-runtime": "^3.758.0",
|
||||
"@aws-sdk/credential-providers": "^3.758.0",
|
||||
"@aws-sdk/client-bedrock-runtime": "^3.779.0",
|
||||
"@aws-sdk/client-sagemaker-runtime": "^3.777.0",
|
||||
"@aws-sdk/credential-providers": "^3.778.0",
|
||||
"@continuedev/config-types": "^1.0.13",
|
||||
"@continuedev/config-yaml": "^1.0.67",
|
||||
"@continuedev/fetch": "^1.0.4",
|
||||
|
|
|
@ -107,9 +107,9 @@
|
|||
"version": "1.1.0",
|
||||
"license": "Apache-2.0",
|
||||
"dependencies": {
|
||||
"@aws-sdk/client-bedrock-runtime": "^3.758.0",
|
||||
"@aws-sdk/client-sagemaker-runtime": "^3.758.0",
|
||||
"@aws-sdk/credential-providers": "^3.758.0",
|
||||
"@aws-sdk/client-bedrock-runtime": "^3.779.0",
|
||||
"@aws-sdk/client-sagemaker-runtime": "^3.777.0",
|
||||
"@aws-sdk/credential-providers": "^3.778.0",
|
||||
"@continuedev/config-types": "^1.0.13",
|
||||
"@continuedev/config-yaml": "^1.0.67",
|
||||
"@continuedev/fetch": "^1.0.4",
|
||||
|
|
|
@ -8,7 +8,8 @@
|
|||
"@typescript-eslint/parser": "^7.8.0",
|
||||
"eslint-plugin-import": "^2.29.1",
|
||||
"prettier": "^3.3.3",
|
||||
"prettier-plugin-tailwindcss": "^0.6.8"
|
||||
"prettier-plugin-tailwindcss": "^0.6.8",
|
||||
"typescript": "^5.6.3"
|
||||
}
|
||||
},
|
||||
"node_modules/@eslint-community/eslint-utils": {
|
||||
|
@ -2976,7 +2977,6 @@
|
|||
"resolved": "https://registry.npmjs.org/typescript/-/typescript-5.6.3.tgz",
|
||||
"integrity": "sha512-hjcS1mhfuyi4WW8IWtjP7brDrG2cuDZukyrYrSauoXGNgx0S7zceP07adYkJycEr56BOUTNPzbInooiN3fn1qw==",
|
||||
"dev": true,
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"tsc": "bin/tsc",
|
||||
"tsserver": "bin/tsserver"
|
||||
|
|
Loading…
Reference in New Issue