Merge branch 'AddDocIndexingMaxDepth' of https://github.com/justinmilner1/continue into justinmilner1-AddDocIndexingMaxDepth

This commit is contained in:
Nate Sesti 2024-06-04 16:11:41 -07:00
commit e4b12ad5f3
7 changed files with 140 additions and 62 deletions

View File

@ -1,5 +1,10 @@
import { v4 as uuidv4 } from "uuid";
import { ContextItemId, IDE } from ".";
import {
ContextItemId,
IDE,
IndexingProgressUpdate,
SiteIndexingConfig,
} from ".";
import { CompletionProvider } from "./autocomplete/completionProvider";
import { ConfigHandler } from "./config/handler";
import {
@ -25,7 +30,6 @@ import type { IMessenger, Message } from "./util/messenger";
import { editConfigJson, getConfigJsonPath } from "./util/paths";
import { Telemetry } from "./util/posthog";
import { streamDiffLines } from "./util/verticalEdit";
import { IndexingProgressUpdate } from ".";
export class Core {
// implements IMessenger<ToCoreProtocol, FromCoreProtocol>
@ -33,7 +37,7 @@ export class Core {
codebaseIndexerPromise: Promise<CodebaseIndexer>;
completionProvider: CompletionProvider;
continueServerClientPromise: Promise<ContinueServerClient>;
indexingState: IndexingProgressUpdate
indexingState: IndexingProgressUpdate;
private abortedMessageIds: Set<string> = new Set();
@ -61,7 +65,7 @@ export class Core {
private readonly ide: IDE,
private readonly onWrite: (text: string) => Promise<void> = async () => {},
) {
this.indexingState = { status:"loading", desc: 'loading', progress: 0 }
this.indexingState = { status: "loading", desc: "loading", progress: 0 };
const ideSettingsPromise = messenger.request("getIdeSettings", undefined);
this.configHandler = new ConfigHandler(
this.ide,
@ -214,9 +218,15 @@ export class Core {
// Context providers
on("context/addDocs", async (msg) => {
const siteIndexingConfig: SiteIndexingConfig = {
startUrl: msg.data.startUrl,
rootUrl: msg.data.rootUrl,
title: msg.data.title,
maxDepth: msg.data.maxDepth,
};
for await (const _ of indexDocs(
msg.data.title,
new URL(msg.data.url),
siteIndexingConfig,
new TransformersJsEmbeddingsProvider(),
)) {
}
@ -525,7 +535,7 @@ export class Core {
on("index/indexingProgressBarInitialized", async (msg) => {
// Triggered when progress bar is initialized.
// If a non-default state has been stored, update the indexing display to that state
if (this.indexingState.status != 'loading') {
if (this.indexingState.status != "loading") {
this.messenger.request("indexProgress", this.indexingState);
}
});
@ -543,7 +553,7 @@ export class Core {
this.indexingCancellationController.signal,
)) {
this.messenger.request("indexProgress", update);
this.indexingState = update
this.indexingState = update;
}
}
}

7
core/index.d.ts vendored
View File

@ -162,6 +162,13 @@ export interface ContextSubmenuItem {
description: string;
}
export interface SiteIndexingConfig {
startUrl: string;
rootUrl: string;
title: string;
maxDepth?: number;
}
export interface IContextProvider {
get description(): ContextProviderDescription;

View File

@ -18,6 +18,17 @@ const IGNORE_PATHS_ENDING_IN = [
const GITHUB_PATHS_TO_TRAVERSE = ["/blob/", "/tree/"];
async function getDefaultBranch(owner: string, repo: string): Promise<string> {
const octokit = new Octokit({ auth: undefined });
const repoInfo = await octokit.repos.get({
owner,
repo,
});
return repoInfo.data.default_branch;
}
async function crawlGithubRepo(baseUrl: URL) {
const octokit = new Octokit({
auth: undefined,
@ -25,17 +36,15 @@ async function crawlGithubRepo(baseUrl: URL) {
const [_, owner, repo] = baseUrl.pathname.split("/");
const dirContentsConfig = {
owner: owner,
repo: repo,
};
const branch = await getDefaultBranch(owner, repo);
console.log("Github repo detected. Crawling", branch, "branch");
const tree = await octokit.request(
"GET /repos/{owner}/{repo}/git/trees/{tree_sha}",
{
owner,
repo,
tree_sha: "main",
tree_sha: branch,
headers: {
"X-GitHub-Api-Version": "2022-11-28",
},
@ -54,6 +63,7 @@ async function getLinksFromUrl(url: string, path: string) {
const baseUrl = new URL(url);
const location = new URL(path, url);
let response;
try {
response = await fetch(location.toString());
} catch (error: unknown) {
@ -128,46 +138,69 @@ export type PageData = {
html: string;
};
export async function* crawlPage(url: URL): AsyncGenerator<PageData> {
export async function* crawlPage(
url: URL,
maxDepth: number = 3,
): AsyncGenerator<PageData> {
console.log("Starting crawl from: ", url, " - Max Depth: ", maxDepth);
const { baseUrl, basePath } = splitUrl(url);
let paths: string[] = [basePath];
let paths: { path: string; depth: number }[] = [{ path: basePath, depth: 0 }];
if (url.hostname === "github.com") {
const githubLinks = await crawlGithubRepo(url);
paths = [...paths, ...githubLinks];
const githubLinkObjects = githubLinks.map((link) => ({
path: link,
depth: 0,
}));
paths = [...paths, ...githubLinkObjects];
}
let index = 0;
while (index < paths.length) {
const promises = paths
.slice(index, index + 50)
.map((path) => getLinksFromUrl(baseUrl, path));
const batch = paths.slice(index, index + 50);
const results = await Promise.all(promises);
try {
const promises = batch.map(({ path, depth }) =>
getLinksFromUrl(baseUrl, path).then((links) => ({
links,
path,
depth,
})),
); // Adjust for depth tracking
for (const { html, links } of results) {
if (html !== "") {
yield {
url: url.toString(),
path: paths[index],
html: html,
};
}
const results = await Promise.all(promises);
for (const {
links: { html, links: linksArray },
path,
depth,
} of results) {
if (html !== "" && depth <= maxDepth) {
// Check depth
yield {
url: url.toString(),
path,
html,
};
}
for (const link of links) {
if (!paths.includes(link)) {
paths.push(link);
// Ensure we only add links if within depth limit
if (depth < maxDepth) {
for (let link of linksArray) {
if (!paths.some((p) => p.path === link)) {
paths.push({ path: link, depth: depth + 1 }); // Increment depth for new paths
}
}
}
}
index++;
} catch (e) {
if (e instanceof TypeError) {
console.warn("Error while crawling page: ", e); // Likely an invalid url, continue with process
} else {
console.error("Error while crawling page: ", e);
}
}
paths = paths.filter((path) =>
results.some(
(result) => result.html !== "" && result.links.includes(path),
),
);
index += batch.length; // Proceed to next batch
}
console.log("Crawl completed");
}

View File

@ -4,16 +4,18 @@ import {
IndexingProgressUpdate,
} from "../../index.js";
import { SiteIndexingConfig } from "../../index.js";
import { Article, chunkArticle, pageToArticle } from "./article.js";
import { crawlPage } from "./crawl.js";
import { addDocs, hasDoc } from "./db.js";
export async function* indexDocs(
title: string,
baseUrl: URL,
siteIndexingConfig: SiteIndexingConfig,
embeddingsProvider: EmbeddingsProvider,
): AsyncGenerator<IndexingProgressUpdate> {
if (await hasDoc(baseUrl.toString())) {
const startUrl = new URL(siteIndexingConfig.startUrl);
if (await hasDoc(siteIndexingConfig.startUrl.toString())) {
yield {
progress: 1,
desc: "Already indexed",
@ -30,12 +32,12 @@ export async function* indexDocs(
const articles: Article[] = [];
for await (const page of crawlPage(baseUrl)) {
// Crawl pages and retrieve info as articles
for await (const page of crawlPage(startUrl, siteIndexingConfig.maxDepth)) {
const article = pageToArticle(page);
if (!article) {
continue;
}
articles.push(article);
yield {
@ -48,6 +50,8 @@ export async function* indexDocs(
const chunks: Chunk[] = [];
const embeddings: number[][] = [];
// Create embeddings of retrieved articles
console.log("Creating Embeddings for ", articles.length, " articles");
for (const article of articles) {
yield {
progress: Math.max(1, Math.floor(100 / (articles.length + 1))),
@ -55,18 +59,24 @@ export async function* indexDocs(
status: "indexing",
};
const subpathEmbeddings = await embeddingsProvider.embed(
chunkArticle(article).map((chunk) => {
chunks.push(chunk);
try {
const subpathEmbeddings = await embeddingsProvider.embed(
chunkArticle(article).map((chunk) => {
chunks.push(chunk);
return chunk.content;
}),
);
return chunk.content;
}),
);
embeddings.push(...subpathEmbeddings);
embeddings.push(...subpathEmbeddings);
} catch (e) {
console.warn("Error chunking article: ", e);
}
}
await addDocs(title, baseUrl, chunks, embeddings);
// Add docs to databases
console.log("Adding ", embeddings.length, " embeddings to db");
await addDocs(siteIndexingConfig.title, startUrl, chunks, embeddings);
yield {
progress: 1,

View File

@ -1,8 +1,4 @@
export interface SiteIndexingConfig {
startUrl: string;
title: string;
rootUrl: string;
}
import {SiteIndexingConfig} from "../../index.js";
const configs: SiteIndexingConfig[] = [
{

View File

@ -11,6 +11,7 @@ import {
RangeInFile,
SerializedContinueConfig,
SessionInfo,
SiteIndexingConfig,
} from "..";
import { AutocompleteInput } from "../autocomplete/completionProvider";
import { IdeSettings } from "./ideWebview";
@ -58,8 +59,8 @@ export type ToCoreFromIdeOrWebviewProtocol = {
ContextItemWithId[],
];
"context/loadSubmenuItems": [{ title: string }, ContextSubmenuItem[]];
"context/addDocs": [{ title: string; url: string }, void];
"autocomplete/complete": [AutocompleteInput, string[]];
"context/addDocs": [SiteIndexingConfig, void];
"autocomplete/cancel": [undefined, void];
"autocomplete/accept": [{ completionId: string }, void];
"command/run": [

View File

@ -1,3 +1,4 @@
import { SiteIndexingConfig } from "core";
import { usePostHog } from "posthog-js/react";
import React, { useContext, useLayoutEffect } from "react";
import { useDispatch } from "react-redux";
@ -15,9 +16,12 @@ const GridDiv = styled.div`
`;
function AddDocsDialog() {
const defaultMaxDepth = 3;
const [docsUrl, setDocsUrl] = React.useState("");
const [docsTitle, setDocsTitle] = React.useState("");
const [urlValid, setUrlValid] = React.useState(false);
const [maxDepth, setMaxDepth] = React.useState<number | string>(""); // Change here
const dispatch = useDispatch();
const ideMessenger = useContext(IdeMessengerContext);
@ -61,17 +65,34 @@ function AddDocsDialog() {
value={docsTitle}
onChange={(e) => setDocsTitle(e.target.value)}
/>
<Input
type="text"
placeholder={`Optional: Max Depth (Default: ${defaultMaxDepth})`}
title="The maximum search tree depth - where your input url is the root node"
value={maxDepth}
onChange={(e) => {
const value = e.target.value;
if (value == "") {
setMaxDepth("");
} else if (!isNaN(+value) && Number(value) > 0) {
setMaxDepth(Number(value));
}
}}
/>
<Button
disabled={!docsUrl || !urlValid}
className="ml-auto"
onClick={() => {
ideMessenger.post("context/addDocs", {
url: docsUrl,
const siteIndexingConfig: SiteIndexingConfig = {
startUrl: docsUrl,
rootUrl: docsUrl,
title: docsTitle,
});
maxDepth: typeof maxDepth === "string" ? defaultMaxDepth : maxDepth, // Ensure maxDepth is a number
};
ideMessenger.post("context/addDocs", siteIndexingConfig);
setDocsTitle("");
setDocsUrl("");
setMaxDepth("");
dispatch(setShowDialog(false));
addItem("docs", {
id: docsUrl,