462 lines
13 KiB
TypeScript
462 lines
13 KiB
TypeScript
// NOTE: vectordb requirement must be listed in extensions/vscode to avoid error
|
|
import { v4 as uuidv4 } from "uuid";
|
|
import { Table } from "vectordb";
|
|
import { IContinueServerClient } from "../continueServer/interface.js";
|
|
import {
|
|
BranchAndDir,
|
|
Chunk,
|
|
EmbeddingsProvider,
|
|
IndexTag,
|
|
IndexingProgressUpdate,
|
|
} from "../index.js";
|
|
import { getBasename } from "../util/index.js";
|
|
import { getLanceDbPath, migrate } from "../util/paths.js";
|
|
import { chunkDocument } from "./chunk/chunk.js";
|
|
import { DatabaseConnection, SqliteDb, tagToString } from "./refreshIndex.js";
|
|
import {
|
|
CodebaseIndex,
|
|
IndexResultType,
|
|
PathAndCacheKey,
|
|
RefreshIndexResults,
|
|
} from "./types.js";
|
|
|
|
// LanceDB converts to lowercase, so names must all be lowercase
|
|
interface LanceDbRow {
|
|
uuid: string;
|
|
path: string;
|
|
cachekey: string;
|
|
vector: number[];
|
|
[key: string]: any;
|
|
}
|
|
|
|
export class LanceDbIndex implements CodebaseIndex {
|
|
relativeExpectedTime: number = 13;
|
|
get artifactId(): string {
|
|
return `vectordb::${this.embeddingsProvider.id}`;
|
|
}
|
|
|
|
constructor(
|
|
private readonly embeddingsProvider: EmbeddingsProvider,
|
|
private readonly readFile: (filepath: string) => Promise<string>,
|
|
private readonly continueServerClient?: IContinueServerClient,
|
|
) {}
|
|
|
|
private tableNameForTag(tag: IndexTag) {
|
|
return tagToString(tag).replace(/[^\w-_.]/g, "");
|
|
}
|
|
|
|
private async createSqliteCacheTable(db: DatabaseConnection) {
|
|
await db.exec(`CREATE TABLE IF NOT EXISTS lance_db_cache (
|
|
uuid TEXT PRIMARY KEY,
|
|
cacheKey TEXT NOT NULL,
|
|
path TEXT NOT NULL,
|
|
artifact_id TEXT NOT NULL,
|
|
vector TEXT NOT NULL,
|
|
startLine INTEGER NOT NULL,
|
|
endLine INTEGER NOT NULL,
|
|
contents TEXT NOT NULL
|
|
)`);
|
|
|
|
await new Promise((resolve) =>
|
|
migrate(
|
|
"lancedb_sqlite_artifact_id_column",
|
|
async () => {
|
|
try {
|
|
await db.exec(
|
|
"ALTER TABLE lance_db_cache ADD COLUMN artifact_id TEXT NOT NULL DEFAULT 'UNDEFINED'",
|
|
);
|
|
} finally {
|
|
resolve(undefined);
|
|
}
|
|
},
|
|
() => resolve(undefined),
|
|
),
|
|
);
|
|
}
|
|
|
|
private async *computeChunks(
|
|
items: PathAndCacheKey[],
|
|
): AsyncGenerator<
|
|
| [
|
|
number,
|
|
LanceDbRow,
|
|
{ startLine: number; endLine: number; contents: string },
|
|
string,
|
|
]
|
|
| PathAndCacheKey
|
|
> {
|
|
const contents = await Promise.all(
|
|
items.map(({ path }) => this.readFile(path)),
|
|
);
|
|
|
|
for (let i = 0; i < items.length; i++) {
|
|
// Break into chunks
|
|
const content = contents[i];
|
|
const chunks: Chunk[] = [];
|
|
|
|
let hasEmptyChunks = false;
|
|
|
|
for await (const chunk of chunkDocument({
|
|
filepath: items[i].path,
|
|
contents: content,
|
|
maxChunkSize: this.embeddingsProvider.maxChunkSize,
|
|
digest: items[i].cacheKey,
|
|
})) {
|
|
if (chunk.content.length == 0) {
|
|
hasEmptyChunks = true;
|
|
break;
|
|
}
|
|
chunks.push(chunk);
|
|
}
|
|
|
|
if (hasEmptyChunks) {
|
|
// File did not chunk properly, let's skip it.
|
|
continue;
|
|
}
|
|
|
|
if (chunks.length > 20) {
|
|
// Too many chunks to index, probably a larger file than we want to include
|
|
continue;
|
|
}
|
|
|
|
let embeddings: number[][];
|
|
try {
|
|
// Calculate embeddings
|
|
embeddings = await this.embeddingsProvider.embed(
|
|
chunks.map((c) => c.content),
|
|
);
|
|
} catch (e) {
|
|
// Rather than fail the entire indexing process, we'll just skip this file
|
|
// so that it may be picked up on the next indexing attempt
|
|
console.warn(
|
|
`Failed to generate embedding for ${chunks[0]?.filepath} with provider: ${this.embeddingsProvider.id}: ${e}`,
|
|
);
|
|
continue;
|
|
}
|
|
|
|
if (embeddings.some((emb) => emb === undefined)) {
|
|
throw new Error(
|
|
`Failed to generate embedding for ${chunks[0]?.filepath} with provider: ${this.embeddingsProvider.id}`,
|
|
);
|
|
}
|
|
|
|
// Create row format
|
|
for (let j = 0; j < chunks.length; j++) {
|
|
const progress = (i + j / chunks.length) / items.length;
|
|
const row = {
|
|
vector: embeddings[j],
|
|
path: items[i].path,
|
|
cachekey: items[i].cacheKey,
|
|
uuid: uuidv4(),
|
|
};
|
|
const chunk = chunks[j];
|
|
yield [
|
|
progress,
|
|
row,
|
|
{
|
|
contents: chunk.content,
|
|
startLine: chunk.startLine,
|
|
endLine: chunk.endLine,
|
|
},
|
|
`Indexing ${getBasename(chunks[j].filepath)}`,
|
|
];
|
|
}
|
|
|
|
yield items[i];
|
|
}
|
|
}
|
|
|
|
async *update(
|
|
tag: IndexTag,
|
|
results: RefreshIndexResults,
|
|
markComplete: (
|
|
items: PathAndCacheKey[],
|
|
resultType: IndexResultType,
|
|
) => void,
|
|
repoName: string | undefined,
|
|
): AsyncGenerator<IndexingProgressUpdate> {
|
|
const lancedb = await import("vectordb");
|
|
const tableName = this.tableNameForTag(tag);
|
|
const db = await lancedb.connect(getLanceDbPath());
|
|
|
|
const sqlite = await SqliteDb.get();
|
|
await this.createSqliteCacheTable(sqlite);
|
|
|
|
// Compute
|
|
let table: Table<number[]> | undefined = undefined;
|
|
const existingTables = await db.tableNames();
|
|
let needToCreateTable = !existingTables.includes(tableName);
|
|
|
|
const addComputedLanceDbRows = async (
|
|
pathAndCacheKey: PathAndCacheKey,
|
|
computedRows: LanceDbRow[],
|
|
) => {
|
|
// Create table if needed, add computed rows
|
|
if (table) {
|
|
if (computedRows.length > 0) {
|
|
await table.add(computedRows);
|
|
}
|
|
} else if (existingTables.includes(tableName)) {
|
|
table = await db.openTable(tableName);
|
|
needToCreateTable = false;
|
|
if (computedRows.length > 0) {
|
|
await table.add(computedRows);
|
|
}
|
|
} else if (computedRows.length > 0) {
|
|
table = await db.createTable(tableName, computedRows);
|
|
needToCreateTable = false;
|
|
}
|
|
|
|
// Mark item complete
|
|
markComplete([pathAndCacheKey], IndexResultType.Compute);
|
|
};
|
|
|
|
// Check remote cache
|
|
if (this.continueServerClient?.connected) {
|
|
try {
|
|
const keys = results.compute.map(({ cacheKey }) => cacheKey);
|
|
const resp = await this.continueServerClient.getFromIndexCache(
|
|
keys,
|
|
"embeddings",
|
|
repoName,
|
|
);
|
|
for (const [cacheKey, chunks] of Object.entries(resp.files)) {
|
|
// Get path for cacheKey
|
|
const path = results.compute.find(
|
|
(item) => item.cacheKey === cacheKey,
|
|
)?.path;
|
|
if (!path) {
|
|
console.warn(
|
|
"Continue server sent a cacheKey that wasn't requested",
|
|
cacheKey,
|
|
);
|
|
continue;
|
|
}
|
|
|
|
// Build LanceDbRow objects
|
|
const rows: LanceDbRow[] = [];
|
|
for (const chunk of chunks) {
|
|
const row = {
|
|
path,
|
|
cachekey: cacheKey,
|
|
uuid: uuidv4(),
|
|
vector: chunk.vector,
|
|
};
|
|
rows.push(row);
|
|
|
|
await sqlite.run(
|
|
"INSERT INTO lance_db_cache (uuid, cacheKey, path, artifact_id, vector, startLine, endLine, contents) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
|
|
row.uuid,
|
|
row.cachekey,
|
|
row.path,
|
|
this.artifactId,
|
|
JSON.stringify(row.vector),
|
|
chunk.startLine,
|
|
chunk.endLine,
|
|
chunk.contents,
|
|
);
|
|
}
|
|
|
|
await addComputedLanceDbRows({ cacheKey, path }, rows);
|
|
}
|
|
|
|
// Remove items that don't need to be recomputed
|
|
results.compute = results.compute.filter(
|
|
(item) => !resp.files[item.cacheKey],
|
|
);
|
|
} catch (e) {
|
|
console.log("Error checking remote cache: ", e);
|
|
}
|
|
}
|
|
|
|
const progressReservedForTagging = 0.1;
|
|
let accumulatedProgress = 0;
|
|
|
|
let computedRows: LanceDbRow[] = [];
|
|
for await (const update of this.computeChunks(results.compute)) {
|
|
if (Array.isArray(update)) {
|
|
const [progress, row, data, desc] = update;
|
|
computedRows.push(row);
|
|
|
|
// Add the computed row to the cache
|
|
await sqlite.run(
|
|
"INSERT INTO lance_db_cache (uuid, cacheKey, path, artifact_id, vector, startLine, endLine, contents) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
|
|
row.uuid,
|
|
row.cachekey,
|
|
row.path,
|
|
this.artifactId,
|
|
JSON.stringify(row.vector),
|
|
data.startLine,
|
|
data.endLine,
|
|
data.contents,
|
|
);
|
|
|
|
accumulatedProgress = progress * (1 - progressReservedForTagging);
|
|
yield {
|
|
progress: accumulatedProgress,
|
|
desc,
|
|
status: "indexing",
|
|
};
|
|
} else {
|
|
await addComputedLanceDbRows(update, computedRows);
|
|
computedRows = [];
|
|
}
|
|
}
|
|
|
|
// Add tag - retrieve the computed info from lance sqlite cache
|
|
for (const { path, cacheKey } of results.addTag) {
|
|
const stmt = await sqlite.prepare(
|
|
"SELECT * FROM lance_db_cache WHERE cacheKey = ? AND path = ? AND artifact_id = ?",
|
|
cacheKey,
|
|
path,
|
|
this.artifactId,
|
|
);
|
|
const cachedItems = await stmt.all();
|
|
|
|
const lanceRows: LanceDbRow[] = cachedItems.map((item) => {
|
|
return {
|
|
path,
|
|
cachekey: cacheKey,
|
|
uuid: item.uuid,
|
|
vector: JSON.parse(item.vector),
|
|
};
|
|
});
|
|
|
|
if (lanceRows.length > 0) {
|
|
if (needToCreateTable) {
|
|
table = await db.createTable(tableName, lanceRows);
|
|
needToCreateTable = false;
|
|
} else if (!table) {
|
|
table = await db.openTable(tableName);
|
|
needToCreateTable = false;
|
|
await table.add(lanceRows);
|
|
} else {
|
|
await table?.add(lanceRows);
|
|
}
|
|
}
|
|
|
|
markComplete([{ path, cacheKey }], IndexResultType.AddTag);
|
|
accumulatedProgress += 1 / results.addTag.length / 3;
|
|
yield {
|
|
progress: accumulatedProgress,
|
|
desc: `Indexing ${getBasename(path)}`,
|
|
status: "indexing",
|
|
};
|
|
}
|
|
|
|
// Delete or remove tag - remove from lance table)
|
|
if (!needToCreateTable) {
|
|
const toDel = [...results.removeTag, ...results.del];
|
|
for (const { path, cacheKey } of toDel) {
|
|
// This is where the aforementioned lowercase conversion problem shows
|
|
await table?.delete(`cachekey = '${cacheKey}' AND path = '${path}'`);
|
|
|
|
accumulatedProgress += 1 / toDel.length / 3;
|
|
yield {
|
|
progress: accumulatedProgress,
|
|
desc: `Stashing ${getBasename(path)}`,
|
|
status: "indexing",
|
|
};
|
|
}
|
|
}
|
|
markComplete(results.removeTag, IndexResultType.RemoveTag);
|
|
|
|
// Delete - also remove from sqlite cache
|
|
for (const { path, cacheKey } of results.del) {
|
|
await sqlite.run(
|
|
"DELETE FROM lance_db_cache WHERE cacheKey = ? AND path = ? AND artifact_id = ?",
|
|
cacheKey,
|
|
path,
|
|
this.artifactId,
|
|
);
|
|
accumulatedProgress += 1 / results.del.length / 3;
|
|
yield {
|
|
progress: accumulatedProgress,
|
|
desc: `Removing ${getBasename(path)}`,
|
|
status: "indexing",
|
|
};
|
|
}
|
|
|
|
markComplete(results.del, IndexResultType.Delete);
|
|
yield {
|
|
progress: 1,
|
|
desc: "Completed Calculating Embeddings",
|
|
status: "done",
|
|
};
|
|
}
|
|
|
|
private async _retrieveForTag(
|
|
tag: IndexTag,
|
|
n: number,
|
|
directory: string | undefined,
|
|
vector: number[],
|
|
db: any, /// lancedb.Connection
|
|
): Promise<LanceDbRow[]> {
|
|
const tableName = this.tableNameForTag(tag);
|
|
const tableNames = await db.tableNames();
|
|
if (!tableNames.includes(tableName)) {
|
|
console.warn("Table not found in LanceDB", tableName);
|
|
return [];
|
|
}
|
|
|
|
const table = await db.openTable(tableName);
|
|
let query = table.search(vector);
|
|
if (directory) {
|
|
// seems like lancedb is only post-filtering, so have to return a bunch of results and slice after
|
|
query = query.where(`path LIKE '${directory}%'`).limit(300);
|
|
} else {
|
|
query = query.limit(n);
|
|
}
|
|
const results = await query.execute();
|
|
return results.slice(0, n) as any;
|
|
}
|
|
|
|
async retrieve(
|
|
query: string,
|
|
n: number,
|
|
tags: BranchAndDir[],
|
|
filterDirectory: string | undefined,
|
|
): Promise<Chunk[]> {
|
|
const lancedb = await import("vectordb");
|
|
if (!lancedb.connect) {
|
|
throw new Error("LanceDB failed to load a native module");
|
|
}
|
|
const [vector] = await this.embeddingsProvider.embed([query]);
|
|
const db = await lancedb.connect(getLanceDbPath());
|
|
|
|
let allResults = [];
|
|
for (const tag of tags) {
|
|
const results = await this._retrieveForTag(
|
|
{ ...tag, artifactId: this.artifactId },
|
|
n,
|
|
filterDirectory,
|
|
vector,
|
|
db,
|
|
);
|
|
allResults.push(...results);
|
|
}
|
|
|
|
allResults = allResults
|
|
.sort((a, b) => a._distance - b._distance)
|
|
.slice(0, n);
|
|
|
|
const sqliteDb = await SqliteDb.get();
|
|
const data = await sqliteDb.all(
|
|
`SELECT * FROM lance_db_cache WHERE uuid in (${allResults
|
|
.map((r) => `'${r.uuid}'`)
|
|
.join(",")})`,
|
|
);
|
|
|
|
return data.map((d) => {
|
|
return {
|
|
digest: d.cacheKey,
|
|
filepath: d.path,
|
|
startLine: d.startLine,
|
|
endLine: d.endLine,
|
|
index: 0,
|
|
content: d.contents,
|
|
};
|
|
});
|
|
}
|
|
}
|