Skip to content

Commit

Permalink
chore(completion): replace CRLF with LF for code completion LLM reque…
Browse files Browse the repository at this point in the history
…sts (#3303)

* feat: replease crlf to lf in completions

* [autofix.ci] apply automated fixes

* chore: crlf logic should under service

* [autofix.ci] apply automated fixes

* chore: replace CRLF for prompt instead of segments

* [autofix.ci] apply automated fixes

* chore: use functions and ut for replacing crlf

* chore: rename generated to generated_text

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
zwpaper and autofix-ci[bot] authored Oct 28, 2024
1 parent 7656965 commit 0e75cb2
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 6 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/tabby/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ color-eyre = { version = "0.6.3" }
reqwest.workspace = true
async-openai.workspace = true
spinners = "4.1.1"
regex.workspace = true

[dependencies.openssl]
optional = true
Expand Down
146 changes: 140 additions & 6 deletions crates/tabby/src/services/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ mod completion_prompt;

use std::sync::Arc;

use regex::Regex;
use serde::{Deserialize, Serialize};
use tabby_common::{
api::{
Expand Down Expand Up @@ -321,9 +322,14 @@ impl CompletionService {
self.config.max_decoding_tokens,
);

let mut use_crlf = false;
let (prompt, segments, snippets) = if let Some(prompt) = request.raw_prompt() {
(prompt, None, vec![])
} else if let Some(segments) = request.segments.as_ref() {
if contains_crlf(segments) {
use_crlf = true;
}

let snippets = self
.build_snippets(
&language,
Expand All @@ -335,24 +341,25 @@ impl CompletionService {
let prompt = self
.prompt_builder
.build(&language, segments.clone(), &snippets);
(prompt, Some(segments), snippets)

(override_prompt(prompt, use_crlf), Some(segments), snippets)
} else {
return Err(CompletionError::EmptyPrompt);
};

let text = self.engine.generate(&prompt, options).await;
let segments = segments.cloned().map(|s| s.into());
let generated_text =
override_generated_text(self.engine.generate(&prompt, options).await, use_crlf);

self.logger.log(
request.user.clone(),
Event::Completion {
completion_id: completion_id.clone(),
language,
prompt: prompt.clone(),
segments,
segments: segments.cloned().map(|x| x.into()),
choices: vec![api::event::Choice {
index: 0,
text: text.clone(),
text: generated_text.clone(),
}],
user_agent: user_agent.map(|x| x.to_owned()),
},
Expand All @@ -368,12 +375,47 @@ impl CompletionService {

Ok(CompletionResponse::new(
completion_id,
vec![Choice::new(text)],
vec![Choice::new(generated_text)],
debug_data,
))
}
}

fn contains_crlf(segments: &Segments) -> bool {
if segments.prefix.contains("\r\n") {
return true;
}
if let Some(suffix) = &segments.suffix {
if suffix.contains("\r\n") {
return true;
}
}

false
}

fn override_prompt(prompt: String, use_crlf: bool) -> String {
if use_crlf {
prompt.replace("\r\n", "\n")
} else {
prompt
}
}

/// override_generated_text replaces \n with \r\n in the generated text if use_crlf is true.
/// This is used to ensure that the generated text has the same line endings as the prompt.
///
/// Because there might be \r\n in the text, which also has a `\n` and should not be replaced,
/// we can not simply replace \n with \r\n.
fn override_generated_text(generated: String, use_crlf: bool) -> String {
if use_crlf {
let re = Regex::new(r"([^\r])\n").unwrap(); // Match \n that is preceded by anything except \r
re.replace_all(&generated, "$1\r\n").to_string() // Replace with captured character and \r\n
} else {
generated
}
}

pub async fn create_completion_service_and_chat(
config: &CompletionConfig,
code: Arc<dyn CodeSearch>,
Expand Down Expand Up @@ -490,4 +532,96 @@ mod tests {
.build("rust", segment.clone(), &[]);
assert_eq!(prompt, "<pre>fn hello_world() -> &'static str {<mid>}<end>");
}

#[test]
fn test_contains_crlf() {
let contained_crlf = vec![
Segments {
prefix: "fn hello_world() -> &'static str {\r\n".into(),
suffix: Some("}".into()),
filepath: None,
git_url: None,
declarations: None,
relevant_snippets_from_changed_files: None,
relevant_snippets_from_recently_opened_files: None,
clipboard: None,
},
Segments {
prefix: "fn hello_world() -> &'static str {".into(),
suffix: Some("}\r\n".into()),
filepath: None,
git_url: None,
declarations: None,
relevant_snippets_from_changed_files: None,
relevant_snippets_from_recently_opened_files: None,
clipboard: None,
},
Segments {
prefix: "fn hello_world() -> &'static str {\r\n".into(),
suffix: Some("}\r\n".into()),
filepath: None,
git_url: None,
declarations: None,
relevant_snippets_from_changed_files: None,
relevant_snippets_from_recently_opened_files: None,
clipboard: None,
},
];
for segments in contained_crlf {
assert!(contains_crlf(&segments));
}

let not_contained_crlf = vec![Segments {
prefix: "fn hello_world() -> &'static str {\r".into(),
suffix: Some("}\n".into()),
filepath: None,
git_url: None,
declarations: None,
relevant_snippets_from_changed_files: None,
relevant_snippets_from_recently_opened_files: None,
clipboard: None,
}];
for segments in not_contained_crlf {
assert!(!contains_crlf(&segments));
}
}

#[test]
fn test_override_prompt() {
let prompt = "fn hello_world() -> &'static str {\r\n".to_string();
let use_crlf = true;
assert_eq!(
override_prompt(prompt.clone(), use_crlf),
"fn hello_world() -> &'static str {\n"
);

let use_crlf = false;
assert_eq!(override_prompt(prompt.clone(), use_crlf), prompt);
}

#[test]
fn test_override_generated() {
let cases = vec![
(
"fn hello_world() -> &'static str {\r\n".to_string(),
"fn hello_world() -> &'static str {\r\n".to_string(),
),
(
"fn hello_world() -> &'static str {\n".to_string(),
"fn hello_world() -> &'static str {\r\n".to_string(),
),
(
"fn hello_world() -> &'static str {\r".to_string(),
"fn hello_world() -> &'static str {\r".to_string(),
),
(
"fn hello_world() -> &'static str {".to_string(),
"fn hello_world() -> &'static str {".to_string(),
),
];

for (generated, expected) in cases {
assert_eq!(override_generated_text(generated, true), expected);
}
}
}

0 comments on commit 0e75cb2

Please sign in to comment.