diff --git a/crates/http-api-bindings/src/chat/openai_chat.rs b/crates/http-api-bindings/src/chat/openai_chat.rs index 989dd72525f6..769112dba7f9 100644 --- a/crates/http-api-bindings/src/chat/openai_chat.rs +++ b/crates/http-api-bindings/src/chat/openai_chat.rs @@ -52,6 +52,7 @@ impl ChatCompletionStream for OpenAIChatEngine { .max_tokens(options.max_decoding_tokens as u16) .model(&self.model_name) .temperature(options.sampling_temperature) + .presence_penalty(options.presence_penalty) .stream(true) .messages( serde_json::from_value::>(serde_json::to_value( diff --git a/crates/tabby-common/src/index/code/mod.rs b/crates/tabby-common/src/index/code/mod.rs index 0f8646880c9e..d6111c3ebf0c 100644 --- a/crates/tabby-common/src/index/code/mod.rs +++ b/crates/tabby-common/src/index/code/mod.rs @@ -3,7 +3,7 @@ mod document; use lazy_static::lazy_static; use regex::Regex; use tantivy::{ - query::{BooleanQuery, ConstScoreQuery, Occur, Query, TermQuery}, + query::{BooleanQuery, BoostQuery, ConstScoreQuery, Occur, Query, TermQuery}, schema::{Field, IndexRecordOption, Schema, TextFieldIndexing, TextOptions, STORED, STRING}, tokenizer::{RegexTokenizer, RemoveLongFilter, TextAnalyzer}, Index, Term, @@ -119,16 +119,21 @@ impl CodeSearchSchema { pub fn code_search_query(&self, query: &CodeSearchQuery) -> BooleanQuery { let language_query = self.language_query(&query.language); - let body_query = self.body_query(&CodeSearchSchema::tokenize_body(&query.content)); let git_url_query = self.git_url_query(&query.git_url); + // Create body query with a scoring normalized by the number of tokens. + let body_tokens = CodeSearchSchema::tokenize_body(&query.content); + let body_query = self.body_query(&body_tokens); + let normalized_score_body_query = + BoostQuery::new(body_query, 1.0 / body_tokens.len() as f32); + // language / git_url / filepath field shouldn't contribute to the score, mark them to 0.0. let mut subqueries: Vec<(Occur, Box)> = vec![ ( Occur::Must, Box::new(ConstScoreQuery::new(language_query, 0.0)), ), - (Occur::Must, body_query), + (Occur::Must, Box::new(normalized_score_body_query)), ( Occur::Must, Box::new(ConstScoreQuery::new(git_url_query, 0.0)), diff --git a/crates/tabby-inference/src/chat.rs b/crates/tabby-inference/src/chat.rs index 11ff922f7632..4006f5750fdd 100644 --- a/crates/tabby-inference/src/chat.rs +++ b/crates/tabby-inference/src/chat.rs @@ -14,6 +14,9 @@ pub struct ChatCompletionOptions { #[builder(default = "1920")] pub max_decoding_tokens: i32, + + #[builder(default = "0.0")] + pub presence_penalty: f32, } #[async_trait] diff --git a/crates/tabby/src/services/answer.rs b/crates/tabby/src/services/answer.rs index e898a3297f01..2fa60bdc6812 100644 --- a/crates/tabby/src/services/answer.rs +++ b/crates/tabby/src/services/answer.rs @@ -47,6 +47,10 @@ pub struct AnswerService { serper: Option>, } +// FIXME(meng): make this configurable. +const RELEVANT_CODE_THRESHOLD: f32 = 5.5; +const PRESENCE_PENALTY: f32 = 0.5; + impl AnswerService { fn new( chat: Arc, @@ -84,6 +88,7 @@ impl AnswerService { // 1. Collect relevant code if needed. let relevant_code = if let Some(code_query) = req.code_query { + self.override_query_with_code_query(query, &code_query).await; self.collect_relevant_code(code_query).await } else { vec![] @@ -121,6 +126,7 @@ impl AnswerService { // 5. Generate answer from the query let s = self.chat.clone().generate(ChatCompletionRequestBuilder::default() .messages(req.messages.clone()) + .presence_penalty(PRESENCE_PENALTY) .build() .expect("Failed to create ChatCompletionRequest")) .await; @@ -134,17 +140,22 @@ impl AnswerService { } async fn collect_relevant_code(&self, query: CodeSearchQuery) -> Vec { - match self.code.search_in_language(query, 5, 0).await { - Ok(resp) => resp.hits.into_iter().map(|hit| hit.doc).collect(), + let hits = match self.code.search_in_language(query, 5, 0).await { + Ok(docs) => docs.hits, Err(err) => { if let CodeSearchError::NotReady = err { - warn!("Failed to search code: {:?}", err); - } else { debug!("Code search is not ready yet"); + } else { + warn!("Failed to search code: {:?}", err); } vec![] } - } + }; + + hits.into_iter() + .filter(|hit| hit.score > RELEVANT_CODE_THRESHOLD) + .map(|hit| hit.doc) + .collect() } async fn collect_relevant_docs(&self, query: &str) -> Vec { @@ -233,6 +244,17 @@ Remember, based on the original question and related contexts, suggest three suc content.lines().map(remove_bullet_prefix).collect() } + async fn override_query_with_code_query( + &self, + query: &mut Message, + code_query: &CodeSearchQuery, + ) { + query.content = format!( + "{}\n\n```{}\n{}\n```", + query.content, code_query.language, code_query.content + ) + } + async fn generate_prompt( &self, relevant_code: &[CodeSearchDocument], @@ -271,6 +293,7 @@ Here are the set of contexts: {context} Remember, don't blindly repeat the contexts verbatim. When possible, give code snippet to demonstrate the answer. And here is the user question: + {question} "# ) diff --git a/crates/tabby/src/services/chat.rs b/crates/tabby/src/services/chat.rs index b1947863d4be..b48128c04cd0 100644 --- a/crates/tabby/src/services/chat.rs +++ b/crates/tabby/src/services/chat.rs @@ -34,6 +34,9 @@ pub struct ChatCompletionRequest { #[builder(default = "None")] seed: Option, + + #[builder(default = "0.0")] + presence_penalty: f32, } #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] @@ -104,6 +107,7 @@ impl ChatService { builder.seed(*x); }); builder + .presence_penalty(request.presence_penalty) .build() .expect("Failed to create ChatCompletionOptions") };