continue/core/indexing/embeddings/TransformersJsEmbeddingsPro...

83 lines
2.2 KiB
TypeScript

import path from "path";
import { EmbeddingsProviderName } from "../../index.js";
// @ts-ignore
// prettier-ignore
import { type PipelineType } from "../../vendor/modules/@xenova/transformers/src/transformers.js";
import BaseEmbeddingsProvider from "./BaseEmbeddingsProvider.js";
class EmbeddingsPipeline {
static task: PipelineType = "feature-extraction";
static model = "all-MiniLM-L6-v2";
static instance: any | null = null;
static async getInstance() {
if (EmbeddingsPipeline.instance === null) {
// @ts-ignore
// prettier-ignore
const { env, pipeline } = await import("../../vendor/modules/@xenova/transformers/src/transformers.js");
env.allowLocalModels = true;
env.allowRemoteModels = false;
env.localModelPath = path.join(
typeof __dirname === "undefined"
? // @ts-ignore
path.dirname(new URL(import.meta.url).pathname)
: __dirname,
"..",
"models",
);
EmbeddingsPipeline.instance = await pipeline(
EmbeddingsPipeline.task,
EmbeddingsPipeline.model,
);
}
return EmbeddingsPipeline.instance;
}
}
export class TransformersJsEmbeddingsProvider extends BaseEmbeddingsProvider {
static providerName: EmbeddingsProviderName = "transformers.js";
static maxGroupSize: number = 4;
static model: string = "all-MiniLM-L6-v2";
constructor() {
super({ model: TransformersJsEmbeddingsProvider.model }, () =>
Promise.resolve(null),
);
}
async embed(chunks: string[]) {
const extractor = await EmbeddingsPipeline.getInstance();
if (!extractor) {
throw new Error("TransformerJS embeddings pipeline is not initialized");
}
if (chunks.length === 0) {
return [];
}
const outputs = [];
for (
let i = 0;
i < chunks.length;
i += TransformersJsEmbeddingsProvider.maxGroupSize
) {
const chunkGroup = chunks.slice(
i,
i + TransformersJsEmbeddingsProvider.maxGroupSize,
);
const output = await extractor(chunkGroup, {
pooling: "mean",
normalize: true,
});
outputs.push(...output.tolist());
}
return outputs;
}
}
export default TransformersJsEmbeddingsProvider;