Skip to content

Commit

Permalink
refactor: filter codesearch result by score, normalize score by numbe… (
Browse files Browse the repository at this point in the history
#2209)

* refactor: filter codesearch result by score, normalize score by number of tokens

* update

* [autofix.ci] apply automated fixes

* update

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
wsxiaoys and autofix-ci[bot] authored May 21, 2024
1 parent 1886385 commit 5655934
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 8 deletions.
1 change: 1 addition & 0 deletions crates/http-api-bindings/src/chat/openai_chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ impl ChatCompletionStream for OpenAIChatEngine {
.max_tokens(options.max_decoding_tokens as u16)
.model(&self.model_name)
.temperature(options.sampling_temperature)
.presence_penalty(options.presence_penalty)
.stream(true)
.messages(
serde_json::from_value::<Vec<ChatCompletionRequestMessage>>(serde_json::to_value(
Expand Down
11 changes: 8 additions & 3 deletions crates/tabby-common/src/index/code/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ mod document;
use lazy_static::lazy_static;
use regex::Regex;
use tantivy::{
query::{BooleanQuery, ConstScoreQuery, Occur, Query, TermQuery},
query::{BooleanQuery, BoostQuery, ConstScoreQuery, Occur, Query, TermQuery},
schema::{Field, IndexRecordOption, Schema, TextFieldIndexing, TextOptions, STORED, STRING},
tokenizer::{RegexTokenizer, RemoveLongFilter, TextAnalyzer},
Index, Term,
Expand Down Expand Up @@ -119,16 +119,21 @@ impl CodeSearchSchema {

pub fn code_search_query(&self, query: &CodeSearchQuery) -> BooleanQuery {
let language_query = self.language_query(&query.language);
let body_query = self.body_query(&CodeSearchSchema::tokenize_body(&query.content));
let git_url_query = self.git_url_query(&query.git_url);

// Create body query with a scoring normalized by the number of tokens.
let body_tokens = CodeSearchSchema::tokenize_body(&query.content);
let body_query = self.body_query(&body_tokens);
let normalized_score_body_query =
BoostQuery::new(body_query, 1.0 / body_tokens.len() as f32);

// language / git_url / filepath field shouldn't contribute to the score, mark them to 0.0.
let mut subqueries: Vec<(Occur, Box<dyn Query>)> = vec![
(
Occur::Must,
Box::new(ConstScoreQuery::new(language_query, 0.0)),
),
(Occur::Must, body_query),
(Occur::Must, Box::new(normalized_score_body_query)),
(
Occur::Must,
Box::new(ConstScoreQuery::new(git_url_query, 0.0)),
Expand Down
3 changes: 3 additions & 0 deletions crates/tabby-inference/src/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ pub struct ChatCompletionOptions {

#[builder(default = "1920")]
pub max_decoding_tokens: i32,

#[builder(default = "0.0")]
pub presence_penalty: f32,
}

#[async_trait]
Expand Down
33 changes: 28 additions & 5 deletions crates/tabby/src/services/answer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ pub struct AnswerService {
serper: Option<Box<dyn DocSearch>>,
}

// FIXME(meng): make this configurable.
const RELEVANT_CODE_THRESHOLD: f32 = 5.5;
const PRESENCE_PENALTY: f32 = 0.5;

impl AnswerService {
fn new(
chat: Arc<ChatService>,
Expand Down Expand Up @@ -84,6 +88,7 @@ impl AnswerService {

// 1. Collect relevant code if needed.
let relevant_code = if let Some(code_query) = req.code_query {
self.override_query_with_code_query(query, &code_query).await;
self.collect_relevant_code(code_query).await
} else {
vec![]
Expand Down Expand Up @@ -121,6 +126,7 @@ impl AnswerService {
// 5. Generate answer from the query
let s = self.chat.clone().generate(ChatCompletionRequestBuilder::default()
.messages(req.messages.clone())
.presence_penalty(PRESENCE_PENALTY)
.build()
.expect("Failed to create ChatCompletionRequest"))
.await;
Expand All @@ -134,17 +140,22 @@ impl AnswerService {
}

async fn collect_relevant_code(&self, query: CodeSearchQuery) -> Vec<CodeSearchDocument> {
match self.code.search_in_language(query, 5, 0).await {
Ok(resp) => resp.hits.into_iter().map(|hit| hit.doc).collect(),
let hits = match self.code.search_in_language(query, 5, 0).await {
Ok(docs) => docs.hits,
Err(err) => {
if let CodeSearchError::NotReady = err {
warn!("Failed to search code: {:?}", err);
} else {
debug!("Code search is not ready yet");
} else {
warn!("Failed to search code: {:?}", err);
}
vec![]
}
}
};

hits.into_iter()
.filter(|hit| hit.score > RELEVANT_CODE_THRESHOLD)
.map(|hit| hit.doc)
.collect()
}

async fn collect_relevant_docs(&self, query: &str) -> Vec<DocSearchDocument> {
Expand Down Expand Up @@ -233,6 +244,17 @@ Remember, based on the original question and related contexts, suggest three suc
content.lines().map(remove_bullet_prefix).collect()
}

async fn override_query_with_code_query(
&self,
query: &mut Message,
code_query: &CodeSearchQuery,
) {
query.content = format!(
"{}\n\n```{}\n{}\n```",
query.content, code_query.language, code_query.content
)
}

async fn generate_prompt(
&self,
relevant_code: &[CodeSearchDocument],
Expand Down Expand Up @@ -271,6 +293,7 @@ Here are the set of contexts:
{context}
Remember, don't blindly repeat the contexts verbatim. When possible, give code snippet to demonstrate the answer. And here is the user question:
{question}
"#
)
Expand Down
4 changes: 4 additions & 0 deletions crates/tabby/src/services/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ pub struct ChatCompletionRequest {

#[builder(default = "None")]
seed: Option<u64>,

#[builder(default = "0.0")]
presence_penalty: f32,
}

#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
Expand Down Expand Up @@ -104,6 +107,7 @@ impl ChatService {
builder.seed(*x);
});
builder
.presence_penalty(request.presence_penalty)
.build()
.expect("Failed to create ChatCompletionOptions")
};
Expand Down

0 comments on commit 5655934

Please sign in to comment.