Skip to content

Commit

Permalink
chore: pass presence_penalty field to completion options (#2280)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsxiaoys authored May 29, 2024
1 parent 4d43209 commit 11bf840
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 0 deletions.
2 changes: 2 additions & 0 deletions crates/http-api-bindings/src/completion/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ struct CompletionRequest {
temperature: f32,
stream: bool,
penalty_last_n: i32,
presence_penalty: f32,
}

#[derive(Deserialize)]
Expand All @@ -47,6 +48,7 @@ impl CompletionStream for LlamaCppEngine {
temperature: options.sampling_temperature,
stream: true,
penalty_last_n: 0,
presence_penalty: options.presence_penalty,
};

let mut request = self.client.post(&self.api_endpoint).json(&request);
Expand Down
2 changes: 2 additions & 0 deletions crates/ollama-api-bindings/src/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ pub struct OllamaCompletion {
#[async_trait]
impl CompletionStream for OllamaCompletion {
async fn generate(&self, prompt: &str, options: CompletionOptions) -> BoxStream<String> {
// FIXME: options.presence_penalty is not used
let ollama_options = GenerationOptions::default()
.num_ctx(options.max_input_length as u32)
.num_predict(options.max_decoding_tokens)
.seed(options.seed as i32)
.repeat_last_n(0)
.temperature(options.sampling_temperature);
let request = GenerationRequest::new(self.model.to_owned(), prompt.to_owned())
.template("{{ .Prompt }}".to_string())
Expand Down
3 changes: 3 additions & 0 deletions crates/tabby-inference/src/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ pub struct CompletionOptions {
pub sampling_temperature: f32,

pub seed: u64,

#[builder(default = "0.0")]
pub presence_penalty: f32,
}

#[async_trait]
Expand Down
1 change: 1 addition & 0 deletions crates/tabby/src/services/model/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ impl ChatCompletionStream for ChatCompletionImpl {
.seed(options.seed)
.max_decoding_tokens(options.max_decoding_tokens)
.sampling_temperature(options.sampling_temperature)
.presence_penalty(options.presence_penalty)
.build()?;

let prompt = self.prompt_builder.build(messages)?;
Expand Down

0 comments on commit 11bf840

Please sign in to comment.