Skip to content

Commit

Permalink
refactor(openai): mark suffix support legacy, use prompt as default
Browse files Browse the repository at this point in the history
  • Loading branch information
zwpaper committed Oct 29, 2024
1 parent a18e217 commit 638b6cd
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 16 deletions.
11 changes: 4 additions & 7 deletions crates/http-api-bindings/src/completion/mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use reqwest_eventsource::{Event, EventSource};
use serde::{Deserialize, Serialize};
use tabby_inference::{CompletionOptions, CompletionStream};

use super::FIM_TOKEN;
use super::split_fim_prompt;

pub struct MistralFIMEngine {
client: reqwest::Client,
Expand Down Expand Up @@ -68,13 +68,10 @@ struct FIMResponseDelta {
#[async_trait]
impl CompletionStream for MistralFIMEngine {
async fn generate(&self, prompt: &str, options: CompletionOptions) -> BoxStream<String> {
let parts = prompt.splitn(2, FIM_TOKEN).collect::<Vec<_>>();
let (prompt, suffix) = split_fim_prompt(prompt, true);

Check warning on line 71 in crates/http-api-bindings/src/completion/mistral.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/mistral.rs#L71

Added line #L71 was not covered by tests
let request = FIMRequest {
prompt: parts[0].to_owned(),
suffix: parts
.get(1)
.map(|x| x.to_string())
.filter(|x| !x.is_empty()),
prompt: prompt.to_owned(),
suffix: suffix.map(|x| x.to_owned()),

Check warning on line 74 in crates/http-api-bindings/src/completion/mistral.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/mistral.rs#L73-L74

Added lines #L73 - L74 were not covered by tests
model: self.model_name.clone(),
max_tokens: options.max_decoding_tokens,
temperature: options.sampling_temperature,
Expand Down
61 changes: 60 additions & 1 deletion crates/http-api-bindings/src/completion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,27 @@ pub async fn create(model: &HttpModelConfig) -> Arc<dyn CompletionStream> {
);
Arc::new(engine)
}
"openai/completion" => {
"openai/completion" | "openai/legacy_completion" => {

Check warning on line 34 in crates/http-api-bindings/src/completion/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/mod.rs#L34

Added line #L34 was not covered by tests
let engine = OpenAICompletionEngine::create(
model.model_name.clone(),
model
.api_endpoint
.as_deref()
.expect("api_endpoint is required"),
model.api_key.clone(),
true,
);
Arc::new(engine)

Check warning on line 44 in crates/http-api-bindings/src/completion/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/mod.rs#L42-L44

Added lines #L42 - L44 were not covered by tests
}
"openai/legacy_completion_no_fim" => {
let engine = OpenAICompletionEngine::create(
model.model_name.clone(),
model
.api_endpoint
.as_deref()
.expect("api_endpoint is required"),
model.api_key.clone(),
false,

Check warning on line 54 in crates/http-api-bindings/src/completion/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/mod.rs#L46-L54

Added lines #L46 - L54 were not covered by tests
);
Arc::new(engine)
}
Expand All @@ -59,3 +72,49 @@ pub fn build_completion_prompt(model: &HttpModelConfig) -> (Option<String>, Opti
(model.prompt_template.clone(), model.chat_template.clone())
}
}

fn split_fim_prompt<'a>(prompt: &'a str, support_fim: bool) -> (&'a str, Option<&'a str>) {
if support_fim {
return (prompt, None);
}

let parts = prompt.splitn(2, FIM_TOKEN).collect::<Vec<_>>();
(parts[0], parts.get(1).copied())
}

#[cfg(test)]
mod tests {
use std::vec;

use super::*;

#[test]
fn test_split_fim_prompt() {
let support_fim = vec![
"prefix<|FIM|>suffix",
"prefix<|FIM|>",
"<|FIM|>suffix",
"<|FIM|>",
"prefix",
];
for input in support_fim {
let (prompt, suffix) = split_fim_prompt(input, true);
assert_eq!(prompt, input);
assert!(suffix.is_none());
}
}

#[test]
fn test_split_fim_prompt_no_fim() {
let no_fim = vec![
("prefix<|FIM|>suffix", ("prefix", Some("suffix"))),
("prefix<|FIM|>", ("prefix", Some(""))),
("<|FIM|>suffix", ("", Some("suffix"))),
("<|FIM|>", ("", Some(""))),
("prefix", ("prefix", None)),
];
for (input, expected) in no_fim {
assert_eq!(split_fim_prompt(input, false), expected);
}
}
}
24 changes: 16 additions & 8 deletions crates/http-api-bindings/src/completion/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,27 @@ use reqwest_eventsource::{Event, EventSource};
use serde::{Deserialize, Serialize};
use tabby_inference::{CompletionOptions, CompletionStream};

use super::FIM_TOKEN;
use super::split_fim_prompt;

pub struct OpenAICompletionEngine {
client: reqwest::Client,
model_name: String,
api_endpoint: String,
api_key: Option<String>,

/// OpenAI Completion API use suffix field in request when FIM is not supported,
/// support_fim is used to mark if FIM is supported,
/// provide a `openai/legacy_completion_no_fim` backend to use suffix field.
support_fim: bool,
}

impl OpenAICompletionEngine {
pub fn create(model_name: Option<String>, api_endpoint: &str, api_key: Option<String>) -> Self {
pub fn create(
model_name: Option<String>,
api_endpoint: &str,
api_key: Option<String>,
support_fim: bool,
) -> Self {

Check warning on line 28 in crates/http-api-bindings/src/completion/openai.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/openai.rs#L23-L28

Added lines #L23 - L28 were not covered by tests
let model_name = model_name.expect("model_name is required for openai/completion");
let client = reqwest::Client::new();

Expand All @@ -24,6 +34,7 @@ impl OpenAICompletionEngine {
model_name,
api_endpoint: format!("{}/completions", api_endpoint),
api_key,
support_fim,

Check warning on line 37 in crates/http-api-bindings/src/completion/openai.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/openai.rs#L37

Added line #L37 was not covered by tests
}
}
}
Expand Down Expand Up @@ -53,14 +64,11 @@ struct CompletionResponseChoice {
#[async_trait]
impl CompletionStream for OpenAICompletionEngine {
async fn generate(&self, prompt: &str, options: CompletionOptions) -> BoxStream<String> {
let parts = prompt.splitn(2, FIM_TOKEN).collect::<Vec<_>>();
let (prompt, suffix) = split_fim_prompt(prompt, self.support_fim);

Check warning on line 67 in crates/http-api-bindings/src/completion/openai.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/openai.rs#L67

Added line #L67 was not covered by tests
let request = CompletionRequest {
model: self.model_name.clone(),
prompt: parts[0].to_owned(),
suffix: parts
.get(1)
.map(|x| x.to_string())
.filter(|x| !x.is_empty()),
prompt: prompt.to_owned(),
suffix: suffix.map(|x| x.to_owned()),

Check warning on line 71 in crates/http-api-bindings/src/completion/openai.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/openai.rs#L70-L71

Added lines #L70 - L71 were not covered by tests
max_tokens: options.max_decoding_tokens,
temperature: options.sampling_temperature,
stream: true,
Expand Down

0 comments on commit 638b6cd

Please sign in to comment.