continue/core/indexing/embeddings/HuggingFaceTEIEmbeddingsPro...

107 lines
3.0 KiB
TypeScript

import { Response } from "node-fetch";
import { EmbeddingsProviderName, EmbedOptions, FetchFunction } from "../..";
import { withExponentialBackoff } from "../../util/withExponentialBackoff";
import BaseEmbeddingsProvider from "./BaseEmbeddingsProvider";
class HuggingFaceTEIEmbeddingsProvider extends BaseEmbeddingsProvider {
static providerName: EmbeddingsProviderName = "huggingface-tei";
maxBatchSize = 32;
static defaultOptions: Partial<EmbedOptions> | undefined = {
apiBase: "http://localhost:8080",
model: "tei",
};
constructor(options: EmbedOptions, fetch: FetchFunction) {
super(options, fetch);
// without this extra slash the last portion of the path will be dropped from the URL when using the node.js URL constructor
if (!this.options.apiBase?.endsWith("/")) {
this.options.apiBase += "/";
}
this.doInfoRequest().then((response) => {
this.options.model = response.model_id;
this.maxBatchSize = response.max_client_batch_size;
});
}
async embed(chunks: string[]) {
const promises = [];
for (let i = 0; i < chunks.length; i += this.maxBatchSize) {
promises.push(
this.doEmbedRequest(chunks.slice(i, i + this.maxBatchSize)),
);
}
const results = await Promise.all(promises);
return results.flat();
}
async doEmbedRequest(batch: string[]): Promise<number[][]> {
const resp = await withExponentialBackoff<Response>(() =>
this.fetch(new URL("embed", this.options.apiBase), {
method: "POST",
body: JSON.stringify({
inputs: batch,
}),
headers: {
"Content-Type": "application/json",
},
}),
);
if (!resp.ok) {
const text = await resp.text();
const embedError = JSON.parse(text) as TEIEmbedErrorResponse;
if (!embedError.error_type || !embedError.error) {
throw new Error(text);
}
throw new TEIEmbedError(embedError);
}
return (await resp.json()) as number[][];
}
async doInfoRequest(): Promise<TEIInfoResponse> {
const resp = await withExponentialBackoff<Response>(() =>
this.fetch(new URL("info", this.options.apiBase), {
method: "GET",
}),
);
if (!resp.ok) {
throw new Error(await resp.text());
}
return (await resp.json()) as TEIInfoResponse;
}
}
class TEIEmbedError extends Error {
constructor(teiResponse: TEIEmbedErrorResponse) {
super(JSON.stringify(teiResponse));
}
}
type TEIEmbedErrorResponse = {
error: string;
error_type: string;
};
type TEIInfoResponse = {
model_id: string;
model_sha: string;
model_dtype: string;
model_type: {
embedding: {
pooling: string;
};
};
max_concurrent_requests: number;
max_input_length: number;
max_batch_tokens: number;
max_batch_requests: number;
max_client_batch_size: number;
auto_truncate: boolean;
tokenization_workers: number;
version: string;
sha: string;
docker_label: string;
};
export default HuggingFaceTEIEmbeddingsProvider;