From 9e06df30b47074b3d93369de9e5d863fb8eac7a9 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Thu, 18 Apr 2024 12:27:34 -0700 Subject: [PATCH] feat(scheduler): switch code indexing implementation to text splitter. (#1868) * feat(scheduler): switch code indexing implementation to text splitter. * update * update index --- Cargo.lock | 77 +++++++--- crates/tabby-common/src/api/code.rs | 2 - crates/tabby-common/src/index.rs | 29 +--- crates/tabby-common/src/lib.rs | 12 +- crates/tabby-scheduler/Cargo.toml | 1 + crates/tabby-scheduler/src/code/mod.rs | 11 ++ crates/tabby-scheduler/src/index.rs | 203 ++++--------------------- crates/tabby-scheduler/src/lib.rs | 14 +- crates/tabby/src/services/code.rs | 2 - 9 files changed, 124 insertions(+), 227 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 55485b5b999b..17d8f8a1d5c9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -19,9 +19,9 @@ checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" [[package]] name = "ahash" -version = "0.8.7" +version = "0.8.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77c3a9648d43b9cd48db467b3f87fdd6e146bcc88ab0180006cef2179fe11d01" +checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", "getrandom 0.2.11", @@ -271,6 +271,18 @@ dependencies = [ "rand 0.8.5", ] +[[package]] +name = "auto_enums" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1899bfcfd9340ceea3533ea157360ba8fa864354eccbceab58e1006ecab35393" +dependencies = [ + "derive_utils 0.14.1", + "proc-macro2", + "quote", + "syn 2.0.52", +] + [[package]] name = "autocfg" version = "1.1.0" @@ -1037,6 +1049,17 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "derive_utils" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61bb5a1014ce6dfc2a378578509abe775a5aa06bff584a547555d9efdb81b926" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.52", +] + [[package]] name = "digest" version = "0.10.7" @@ -1083,9 +1106,9 @@ checksum = "9ea835d29036a4087793836fa931b08837ad5e957da9e23886b29586fb9b6650" [[package]] name = "either" -version = "1.8.1" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91" +checksum = "a47c1c47d2f5964e29c61246e81db715514cd532db6b5116a25ea3c03d6780a2" dependencies = [ "serde", ] @@ -1326,7 +1349,7 @@ version = "0.1.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3422d14de7903a52e9dbc10ae05a7e14445ec61890100e098754e120b2bd7b1e" dependencies = [ - "derive_utils", + "derive_utils 0.11.2", "quote", "syn 1.0.109", ] @@ -1472,7 +1495,7 @@ dependencies = [ "aho-corasick", "bstr", "log", - "regex-automata 0.4.3", + "regex-automata 0.4.6", "regex-syntax 0.8.2", ] @@ -1843,7 +1866,7 @@ dependencies = [ "globset", "log", "memchr", - "regex-automata 0.4.3", + "regex-automata 0.4.6", "same-file", "walkdir", "winapi-util", @@ -1948,9 +1971,9 @@ dependencies = [ [[package]] name = "itertools" -version = "0.12.0" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25db6b064527c5d482d0423354fcd07a89a2dfe07b67892e62411946db7f07b0" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" dependencies = [ "either", ] @@ -2684,9 +2707,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.18.0" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "oneshot" @@ -3325,13 +3348,13 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.2" +version = "1.10.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "380b951a9c5e80ddfd6136919eef32310721aa4aacd4889a8d39124b026ab343" +checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.3", + "regex-automata 0.4.6", "regex-syntax 0.8.2", ] @@ -3346,9 +3369,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.3" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f804c7828047e88b2d32e2d7fe5a105da8ee3264f01902f796c8e067dc2483f" +checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" dependencies = [ "aho-corasick", "memchr", @@ -4108,7 +4131,7 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ce81b7bd7c4493975347ef60d8c7e8b742d4694f4c49f93e0a12ea263938176c" dependencies = [ - "itertools 0.12.0", + "itertools 0.12.1", "nom", "unicode_categories", ] @@ -4610,6 +4633,7 @@ dependencies = [ "tabby-common", "tantivy", "temp_testdir", + "text-splitter", "tokio", "tokio-cron-scheduler", "tracing", @@ -4893,6 +4917,21 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "text-splitter" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d5315f993b00372fd909fcf8587535e65f03ac5fd9400f49dd72ce1f6be23cf" +dependencies = [ + "ahash", + "auto_enums", + "either", + "itertools 0.12.1", + "once_cell", + "regex", + "unicode-segmentation", +] + [[package]] name = "textdistance" version = "1.0.2" @@ -5636,9 +5675,9 @@ dependencies = [ [[package]] name = "unicode-segmentation" -version = "1.10.1" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1dd624098567895118886609431a7c3b8f516e41d30e0643f03d94592a147e36" +checksum = "d4c87d22b6e3f4a18d4d40ef354e97c90fcb14dd91d7dc0aa9d8a1172ebf7202" [[package]] name = "unicode-width" diff --git a/crates/tabby-common/src/api/code.rs b/crates/tabby-common/src/api/code.rs index c2cc31f4bc79..4566c09ec825 100644 --- a/crates/tabby-common/src/api/code.rs +++ b/crates/tabby-common/src/api/code.rs @@ -21,9 +21,7 @@ pub struct HitDocument { pub body: String, pub filepath: String, pub git_url: String, - pub kind: String, pub language: String, - pub name: String, } #[derive(Error, Debug)] diff --git a/crates/tabby-common/src/index.rs b/crates/tabby-common/src/index.rs index f6bd711386b3..a82984366296 100644 --- a/crates/tabby-common/src/index.rs +++ b/crates/tabby-common/src/index.rs @@ -1,26 +1,18 @@ use tantivy::{ query::{TermQuery, TermSetQuery}, schema::{Field, IndexRecordOption, Schema, TextFieldIndexing, TextOptions, STORED, STRING}, - tokenizer::{NgramTokenizer, RegexTokenizer, RemoveLongFilter, TextAnalyzer}, + tokenizer::{RegexTokenizer, RemoveLongFilter, TextAnalyzer}, Index, Term, }; static CODE_TOKENIZER: &str = "code"; -static IDENTIFIER_TOKENIZER: &str = "identifier"; pub fn register_tokenizers(index: &Index) { let code_tokenizer = TextAnalyzer::builder(RegexTokenizer::new(r"(?:\w+)").unwrap()) - .filter(RemoveLongFilter::limit(128)) + .filter(RemoveLongFilter::limit(64)) .build(); index.tokenizers().register(CODE_TOKENIZER, code_tokenizer); - - let identifier_tokenzier = - TextAnalyzer::builder(NgramTokenizer::prefix_only(2, 5).unwrap()).build(); - - index - .tokenizers() - .register(IDENTIFIER_TOKENIZER, identifier_tokenzier); } pub struct CodeSearchSchema { @@ -28,8 +20,6 @@ pub struct CodeSearchSchema { pub field_git_url: Field, pub field_filepath: Field, pub field_language: Field, - pub field_name: Field, - pub field_kind: Field, pub field_body: Field, } @@ -39,23 +29,14 @@ impl CodeSearchSchema { let code_indexing_options = TextFieldIndexing::default() .set_tokenizer(CODE_TOKENIZER) - .set_index_option(tantivy::schema::IndexRecordOption::WithFreqsAndPositions); + .set_index_option(tantivy::schema::IndexRecordOption::WithFreqs); let code_options = TextOptions::default() .set_indexing_options(code_indexing_options) .set_stored(); - let name_indexing_options = TextFieldIndexing::default() - .set_tokenizer(IDENTIFIER_TOKENIZER) - .set_index_option(tantivy::schema::IndexRecordOption::WithFreqsAndPositions); - let name_options = TextOptions::default() - .set_indexing_options(name_indexing_options) - .set_stored(); - let field_git_url = builder.add_text_field("git_url", STRING | STORED); let field_filepath = builder.add_text_field("filepath", STRING | STORED); let field_language = builder.add_text_field("language", STRING | STORED); - let field_name = builder.add_text_field("name", name_options); - let field_kind = builder.add_text_field("kind", STRING | STORED); let field_body = builder.add_text_field("body", code_options); let schema = builder.build(); @@ -64,8 +45,6 @@ impl CodeSearchSchema { field_git_url, field_filepath, field_language, - field_name, - field_kind, field_body, } } @@ -87,7 +66,7 @@ impl CodeSearchSchema { }; Box::new(TermQuery::new( Term::from_field_text(self.field_language, language), - IndexRecordOption::WithFreqsAndPositions, + IndexRecordOption::Basic, )) } diff --git a/crates/tabby-common/src/lib.rs b/crates/tabby-common/src/lib.rs index b6abd249986a..948fdde75bba 100644 --- a/crates/tabby-common/src/lib.rs +++ b/crates/tabby-common/src/lib.rs @@ -14,7 +14,7 @@ use std::{ fs::File, io::{BufReader, Error}, ops::Range, - path::PathBuf, + path::{Path, PathBuf}, }; use path::dataset_dir; @@ -47,6 +47,16 @@ impl SourceFile { }); Ok(iter) } + + pub fn read_content(&self) -> std::io::Result { + let path = Path::new(&self.basedir).join(&self.filepath); + std::fs::read_to_string(path) + } + + pub fn read_file_size(&self) -> usize { + let path = Path::new(&self.basedir).join(&self.filepath); + std::fs::metadata(path).map(|x| x.len()).unwrap_or_default() as usize + } } #[derive(Serialize, Deserialize, Clone, Debug)] diff --git a/crates/tabby-scheduler/Cargo.toml b/crates/tabby-scheduler/Cargo.toml index 53d0c61c6aea..fa2b6d7ab8bf 100644 --- a/crates/tabby-scheduler/Cargo.toml +++ b/crates/tabby-scheduler/Cargo.toml @@ -35,6 +35,7 @@ tokio = { workspace = true, features = ["process"] } package-lock-json-parser = "0.4.0" npm-package-json = "0.1.3" yarn-lock-parser = "0.7.0" +text-splitter = "0.10.0" [dev-dependencies] temp_testdir = { workspace = true } diff --git a/crates/tabby-scheduler/src/code/mod.rs b/crates/tabby-scheduler/src/code/mod.rs index bbd28d2d8587..912df9f9d7c8 100644 --- a/crates/tabby-scheduler/src/code/mod.rs +++ b/crates/tabby-scheduler/src/code/mod.rs @@ -1,16 +1,19 @@ use tabby_common::{Point, Tag}; +use text_splitter::{Characters, TextSplitter}; use tree_sitter_tags::TagsContext; mod languages; pub struct CodeIntelligence { context: TagsContext, + splitter: TextSplitter, } impl Default for CodeIntelligence { fn default() -> Self { Self { context: TagsContext::new(), + splitter: TextSplitter::default().with_trim_chunks(true), } } } @@ -49,4 +52,12 @@ impl CodeIntelligence { }) .collect() } + + // FIXME(meng): implement with treesitter based CodeSplitter. + pub fn chunks<'splitter, 'text: 'splitter>( + &'splitter self, + text: &'text str, + ) -> impl Iterator + 'splitter { + self.splitter.chunks(text, 192) + } } diff --git a/crates/tabby-scheduler/src/index.rs b/crates/tabby-scheduler/src/index.rs index c184441f679a..c03ee7273c78 100644 --- a/crates/tabby-scheduler/src/index.rs +++ b/crates/tabby-scheduler/src/index.rs @@ -1,4 +1,4 @@ -use std::{fs, io::IsTerminal, ops::Range, path::Path}; +use std::{fs, io::IsTerminal}; use anyhow::Result; use kdam::BarExt; @@ -11,12 +11,11 @@ use tabby_common::{ use tantivy::{directory::MmapDirectory, doc, Index}; use tracing::warn; -use crate::utils::tqdm; +use crate::{code::CodeIntelligence, utils::tqdm}; // Magic numbers static MAX_LINE_LENGTH_THRESHOLD: usize = 300; static AVG_LINE_LENGTH_THRESHOLD: f32 = 150f32; -static MAX_BODY_LINES_THRESHOLD: usize = 15; pub fn index_repositories(_config: &[RepositoryConfig]) -> Result<()> { let code = CodeSearchSchema::new(); @@ -30,30 +29,33 @@ pub fn index_repositories(_config: &[RepositoryConfig]) -> Result<()> { let mut writer = index.writer(150_000_000)?; writer.delete_all_documents()?; + let total_file_size: usize = SourceFile::all()? + .filter(is_valid_file) + .map(|x| x.read_file_size()) + .sum(); + let mut pb = std::io::stdout() .is_terminal() - .then(SourceFile::all) - .transpose()? - .map(|iter| tqdm(iter.count())); - for file in SourceFile::all()? { - pb.as_mut().map(|b| b.update(1)).transpose()?; - - if file.max_line_length > MAX_LINE_LENGTH_THRESHOLD { - continue; - } + .then(|| tqdm(total_file_size)); + + let intelligence = CodeIntelligence::default(); + for file in SourceFile::all()?.filter(is_valid_file) { + let text = match file.read_content() { + Ok(content) => content, + Err(e) => { + warn!("Failed to read content of '{}': {}", file.filepath, e); + continue; + } + }; - if file.avg_line_length > AVG_LINE_LENGTH_THRESHOLD { - continue; - } + for body in intelligence.chunks(&text) { + pb.as_mut().map(|b| b.update(body.len())).transpose()?; - for doc in from_source_file(file) { writer.add_document(doc!( - code.field_git_url => doc.git_url, - code.field_filepath => doc.filepath, - code.field_language => doc.language, - code.field_name => doc.name, - code.field_body => doc.body, - code.field_kind => doc.kind, + code.field_git_url => file.git_url.clone(), + code.field_filepath => file.filepath.clone(), + code.field_language => file.language.clone(), + code.field_body => body, ))?; } } @@ -64,158 +66,7 @@ pub fn index_repositories(_config: &[RepositoryConfig]) -> Result<()> { Ok(()) } -/// Atomic repository document in index. -struct IndexedDocument { - git_url: String, - filepath: String, - language: String, - name: String, - body: String, - kind: String, -} - -fn read_range(filename: &str, content: &str, range: Range) -> Option { - let Some(content) = content.get(range.clone()) else { - warn!("Failed to read content '{range:?}' from '{filename}'"); - return None; - }; - Some(content.to_string()) -} - -fn from_source_file(file: SourceFile) -> impl Iterator { - file.tags.into_iter().filter_map(move |tag| { - let path = Path::new(&file.basedir).join(&file.filepath); - let Ok(file_content) = std::fs::read_to_string(&path) else { - warn!("Failed to read file '{}'", path.display()); - return None; - }; - let name = read_range(&file.filepath, &file_content, tag.name_range)?; - let body = read_range(&file.filepath, &file_content, tag.range)?; - - if body.lines().collect::>().len() > MAX_BODY_LINES_THRESHOLD { - return None; - } - - Some(IndexedDocument { - git_url: file.git_url.clone(), - filepath: file.filepath.clone(), - language: file.language.clone(), - name, - body, - kind: tag.syntax_type_name, - }) - }) -} - -#[cfg(test)] -mod tests { - use serde_json::{from_value, json}; - use temp_testdir::TempDir; - - use super::*; - - fn test_source_file(basedir: &Path) -> SourceFile { - let filepath = "trainer.py"; - let fullpath = basedir.join(filepath).display().to_string(); - let file_content = "import os\nimport glob\nfrom dataclasses import dataclass, field\nfrom typing import List\n\nimport peft\nimport torch\nfrom transformers import (\n AutoModelForCausalLM,\n AutoTokenizer,\n HfArgumentParser,\n Trainer,\n TrainingArguments,\n)\nfrom datasets import Dataset, load_dataset\n\n\nclass ConstantLengthDataset:\n \"\"\"\n Iterable dataset that returns constant length chunks of tokens from stream of text files.\n Args:\n tokenizer (Tokenizer): The processor used for proccessing the data.\n dataset (dataset.Dataset): Dataset with text files.\n infinite (bool): If True the iterator is reset after dataset reaches end else stops.\n seq_length (int): Length of token sequences to return.\n num_of_sequences (int): Number of token sequences to keep in buffer.\n chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.\n \"\"\"\n\n def __init__(\n self,\n tokenizer,\n dataset,\n infinite=False,\n seq_length=1024,\n num_of_sequences=1024,\n chars_per_token=3.6,\n content_field=\"content\",\n ):\n self.tokenizer = tokenizer\n self.concat_token_id = tokenizer.eos_token_id\n self.dataset = dataset\n self.seq_length = seq_length\n self.infinite = infinite\n self.current_size = 0\n self.max_buffer_size = seq_length * chars_per_token * num_of_sequences\n self.content_field = content_field\n\n def __call__(self):\n def gen():\n for x in self:\n yield x\n\n return gen()\n\n def __iter__(self):\n for buffer in self._read_dataset_into_buffer():\n yield from self._tokenize(buffer)\n\n def _tokenize(self, buffer):\n tokenized_inputs = self.tokenizer(buffer, truncation=False)[\"input_ids\"]\n\n all_token_ids = []\n for tokenized_input in tokenized_inputs:\n all_token_ids.extend(tokenized_input + [self.concat_token_id])\n\n for i in range(0, len(all_token_ids), self.seq_length):\n input_ids = all_token_ids[i : i + self.seq_length]\n\n if len(input_ids) < self.seq_length:\n input_ids = all_token_ids[-self.seq_length :]\n\n if len(input_ids) == self.seq_length:\n self.current_size += 1\n yield dict(input_ids=input_ids, labels=input_ids)\n\n def _read_dataset_into_buffer(self):\n iterator = iter(self.dataset)\n more_examples = True\n while more_examples:\n buffer, buffer_len = [], 0\n while True:\n if buffer_len >= self.max_buffer_size:\n break\n try:\n buffer.append(next(iterator)[self.content_field])\n buffer_len += len(buffer[-1])\n except StopIteration:\n if self.infinite:\n iterator = iter(self.dataset)\n else:\n more_examples = False\n break\n yield buffer\n\n\n"; - std::fs::write(fullpath, file_content).expect("Failed to write test file"); - from_value(json!( - { - "git_url": "https://fake.com/tabbyml.git", - "basedir": basedir.display().to_string(), - "filepath": filepath, - "language": "python", - "max_line_length": 115, - "avg_line_length": 32.388393, - "alphanum_fraction": 0.6066319, - "tags": [ - { - "range": { - "start": 290, - "end": 320 - }, - "name_range": { - "start": 296, - "end": 317 - }, - "utf16_column_range": { - "start": 6, - "end": 27 - }, - "span": { - "start": { - "row": 17, - "column": 6 - }, - "end": { - "row": 17, - "column": 27 - } - }, - "line_range": { - "start": 290, - "end": 318 - }, - "is_definition": true, - "syntax_type_name": "class" - }, - { - "range": { - "start": 953, - "end": 970 - }, - "name_range": { - "start": 957, - "end": 965 - }, - "utf16_column_range": { - "start": 8, - "end": 16 - }, - "span": { - "start": { - "row": 29, - "column": 8 - }, - "end": { - "row": 29, - "column": 16 - } - }, - "line_range": { - "start": 953, - "end": 966 - }, - "is_definition": true, - "syntax_type_name": "function" - }, - ] - })) - .expect("JSON is valid SourceFile") - } - - #[test] - fn it_create_documents() { - let root = TempDir::default(); - - let source_file: SourceFile = test_source_file(&root); - let docs: Vec<_> = from_source_file(source_file).collect(); - assert_eq!(docs.len(), 2); - - assert_eq!(docs[0].name, "ConstantLengthDataset"); - assert_eq!(docs[0].kind, "class"); - assert!( - docs[0].body.starts_with("class ConstantLengthDataset"), - "body: {:?}", - docs[0].body - ); - - assert_eq!(docs[1].name, "__init__"); - assert_eq!(docs[1].kind, "function"); - assert!( - docs[1].body.starts_with("def __init__"), - "body: {:?}", - docs[1].body - ); - } +fn is_valid_file(file: &SourceFile) -> bool { + file.max_line_length <= MAX_LINE_LENGTH_THRESHOLD + && file.avg_line_length <= AVG_LINE_LENGTH_THRESHOLD } diff --git a/crates/tabby-scheduler/src/lib.rs b/crates/tabby-scheduler/src/lib.rs index f50ab34ff0cf..02ebbb90f8d7 100644 --- a/crates/tabby-scheduler/src/lib.rs +++ b/crates/tabby-scheduler/src/lib.rs @@ -6,10 +6,13 @@ mod index; mod repository; mod utils; -use std::sync::Arc; +use std::{fs, sync::Arc}; use anyhow::Result; -use tabby_common::config::{RepositoryAccess, RepositoryConfig}; +use tabby_common::{ + config::{RepositoryAccess, RepositoryConfig}, + path, +}; use tokio_cron_scheduler::{Job, JobScheduler}; use tracing::{error, info, warn}; @@ -62,6 +65,13 @@ fn job_index(repositories: &[RepositoryConfig]) -> Result<()> { println!("Indexing repositories..."); let ret = index::index_repositories(repositories); if let Err(err) = ret { + let index_dir = path::index_dir(); + warn!( + "Failed to index repositories: {}, removing index directory '{}'...", + err, + index_dir.display() + ); + fs::remove_dir_all(index_dir)?; return Err(err.context("Failed to index repositories")); } Ok(()) diff --git a/crates/tabby/src/services/code.rs b/crates/tabby/src/services/code.rs index 91846c0eabb7..e05dd373c303 100644 --- a/crates/tabby/src/services/code.rs +++ b/crates/tabby/src/services/code.rs @@ -76,8 +76,6 @@ impl CodeSearchImpl { body: get_field(&doc, self.schema.field_body), filepath: get_field(&doc, self.schema.field_filepath), git_url: get_field(&doc, self.schema.field_git_url), - kind: get_field(&doc, self.schema.field_kind), - name: get_field(&doc, self.schema.field_name), language: get_field(&doc, self.schema.field_language), }, id: doc_address.doc_id,