125 lines
3.8 KiB
TypeScript
125 lines
3.8 KiB
TypeScript
import { Chunk } from "../../../index.js";
|
|
import { RETRIEVAL_PARAMS } from "../../../util/parameters.js";
|
|
import { recentlyEditedFilesCache } from "../recentlyEditedFilesCache.js";
|
|
import { deduplicateChunks } from "../util.js";
|
|
import BaseRetrievalPipeline from "./BaseRetrievalPipeline.js";
|
|
|
|
export default class RerankerRetrievalPipeline extends BaseRetrievalPipeline {
|
|
private async _retrieveInitial(): Promise<Chunk[]> {
|
|
const { input, nRetrieve } = this.options;
|
|
|
|
const retrievalResults: Chunk[] = [];
|
|
|
|
const ftsChunks = await this.retrieveFts(input, nRetrieve);
|
|
const embeddingsChunks = await this.retrieveEmbeddings(input, nRetrieve);
|
|
const recentlyEditedFilesChunks =
|
|
await this.retrieveAndChunkRecentlyEditedFiles(nRetrieve);
|
|
|
|
retrievalResults.push(
|
|
...recentlyEditedFilesChunks,
|
|
...ftsChunks,
|
|
...embeddingsChunks,
|
|
);
|
|
|
|
const deduplicatedRetrievalResults: Chunk[] =
|
|
deduplicateChunks(retrievalResults);
|
|
|
|
return deduplicatedRetrievalResults;
|
|
}
|
|
|
|
private async _rerank(input: string, chunks: Chunk[]): Promise<Chunk[]> {
|
|
if (!this.options.reranker) {
|
|
throw new Error("No reranker provided");
|
|
}
|
|
|
|
let scores: number[] = await this.options.reranker.rerank(input, chunks);
|
|
|
|
// Filter out low-scoring results
|
|
let results = chunks;
|
|
// let results = chunks.filter(
|
|
// (_, i) => scores[i] >= RETRIEVAL_PARAMS.rerankThreshold,
|
|
// );
|
|
// scores = scores.filter(
|
|
// (score) => score >= RETRIEVAL_PARAMS.rerankThreshold,
|
|
// );
|
|
|
|
results.sort(
|
|
(a, b) => scores[results.indexOf(a)] - scores[results.indexOf(b)],
|
|
);
|
|
results = results.slice(-this.options.nFinal);
|
|
return results;
|
|
}
|
|
|
|
private async _expandWithEmbeddings(chunks: Chunk[]): Promise<Chunk[]> {
|
|
const topResults = chunks.slice(
|
|
-RETRIEVAL_PARAMS.nResultsToExpandWithEmbeddings,
|
|
);
|
|
|
|
const expanded = await Promise.all(
|
|
topResults.map(async (chunk, i) => {
|
|
const results = await this.retrieveEmbeddings(
|
|
chunk.content,
|
|
RETRIEVAL_PARAMS.nEmbeddingsExpandTo,
|
|
);
|
|
return results;
|
|
}),
|
|
);
|
|
return expanded.flat();
|
|
}
|
|
|
|
private async _expandRankedResults(chunks: Chunk[]): Promise<Chunk[]> {
|
|
let results: Chunk[] = [];
|
|
|
|
const embeddingsResults = await this._expandWithEmbeddings(chunks);
|
|
results.push(...embeddingsResults);
|
|
|
|
return results;
|
|
}
|
|
|
|
async run(): Promise<Chunk[]> {
|
|
const intialResults = await this._retrieveInitial();
|
|
const rankedResults = await this._rerank(this.options.input, intialResults);
|
|
|
|
// // // Expand top reranked results
|
|
// const expanded = await this._expandRankedResults(results);
|
|
// results.push(...expanded);
|
|
|
|
// // De-duplicate
|
|
// results = deduplicateChunks(results);
|
|
|
|
// // Rerank again
|
|
// results = await this._rerank(input, results);
|
|
|
|
// TODO: stitch together results
|
|
|
|
return rankedResults;
|
|
}
|
|
}
|
|
|
|
// Source: expansion with code graph
|
|
// consider doing this after reranking? Or just having a lower reranking threshold
|
|
// This is VS Code only until we use PSI for JetBrains or build our own general solution
|
|
// TODO: Need to pass in the expandSnippet function as a function argument
|
|
// because this import causes `tsc` to fail
|
|
// if ((await extras.ide.getIdeInfo()).ideType === "vscode") {
|
|
// const { expandSnippet } = await import(
|
|
// "../../../extensions/vscode/src/util/expandSnippet"
|
|
// );
|
|
// let expansionResults = (
|
|
// await Promise.all(
|
|
// extras.selectedCode.map(async (rif) => {
|
|
// return expandSnippet(
|
|
// rif.filepath,
|
|
// rif.range.start.line,
|
|
// rif.range.end.line,
|
|
// extras.ide,
|
|
// );
|
|
// }),
|
|
// )
|
|
// ).flat() as Chunk[];
|
|
// retrievalResults.push(...expansionResults);
|
|
// }
|
|
|
|
// Source: Open file exact match
|
|
// Source: Class/function name exact match
|