Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(embedding): improve logging and simplify voyage embedding im… #3600

Merged
merged 2 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions crates/http-api-bindings/src/embedding/llama.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use tabby_inference::Embedding;
use tracing::Instrument;

use crate::create_reqwest_client;
use crate::{create_reqwest_client, embedding_info_span};

pub struct LlamaCppEngine {
client: reqwest::Client,
Expand Down Expand Up @@ -44,7 +45,10 @@
request = request.bearer_auth(api_key);
}

let response = request.send().await?;
let response = request
.send()
.instrument(embedding_info_span!("llamacpp"))
.await?;

Check warning on line 51 in crates/http-api-bindings/src/embedding/llama.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/embedding/llama.rs#L48-L51

Added lines #L48 - L51 were not covered by tests
if response.status().is_server_error() {
let error = response.text().await?;
return Err(anyhow::anyhow!(
Expand Down
22 changes: 14 additions & 8 deletions crates/http-api-bindings/src/embedding/mod.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
mod llama;
mod openai;
mod voyage;

use core::panic;
use std::sync::Arc;

use llama::LlamaCppEngine;
use openai::OpenAIEmbeddingEngine;
use tabby_common::config::HttpModelConfig;
use tabby_inference::Embedding;

use self::{openai::OpenAIEmbeddingEngine, voyage::VoyageEmbeddingEngine};
use super::rate_limit;

pub async fn create(config: &HttpModelConfig) -> Arc<dyn Embedding> {
Expand All @@ -30,16 +29,16 @@
config.api_key.as_deref(),
),
"ollama/embedding" => ollama_api_bindings::create_embedding(config).await,
"voyage/embedding" => VoyageEmbeddingEngine::create(
config.api_endpoint.as_deref(),
"voyage/embedding" => OpenAIEmbeddingEngine::create(
config
.api_endpoint
.as_deref()
.unwrap_or("https://api.voyageai.com/v1"),

Check warning on line 36 in crates/http-api-bindings/src/embedding/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/embedding/mod.rs#L32-L36

Added lines #L32 - L36 were not covered by tests
config
.model_name
.as_deref()
.expect("model_name must be set for voyage/embedding"),
config
.api_key
.clone()
.expect("api_key must be set for voyage/embedding"),
config.api_key.as_deref(),

Check warning on line 41 in crates/http-api-bindings/src/embedding/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/embedding/mod.rs#L41

Added line #L41 was not covered by tests
),
unsupported_kind => panic!(
"Unsupported kind for http embedding model: {}",
Expand All @@ -52,3 +51,10 @@
config.rate_limit.request_per_minute,
))
}

#[macro_export]
macro_rules! embedding_info_span {
($kind:expr) => {
tracing::info_span!("embedding", kind = $kind)
};
}
91 changes: 67 additions & 24 deletions crates/http-api-bindings/src/embedding/openai.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
use anyhow::Context;
use async_openai::{
config::OpenAIConfig,
types::{CreateEmbeddingRequest, EmbeddingInput},
};
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use tabby_inference::Embedding;
use tracing::{info_span, Instrument};
use tracing::Instrument;

use crate::embedding_info_span;

pub struct OpenAIEmbeddingEngine {
client: async_openai::Client<OpenAIConfig>,
client: Client,
api_endpoint: String,
api_key: String,
model_name: String,
}

Expand All @@ -18,41 +20,69 @@
model_name: &str,
api_key: Option<&str>,
) -> Box<dyn Embedding> {
let config = OpenAIConfig::default()
.with_api_base(api_endpoint)
.with_api_key(api_key.unwrap_or_default());

let client = async_openai::Client::with_config(config);

let client = Client::new();

Check warning on line 23 in crates/http-api-bindings/src/embedding/openai.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/embedding/openai.rs#L23

Added line #L23 was not covered by tests
Box::new(Self {
client,
api_endpoint: format!("{}/embeddings", api_endpoint),
api_key: api_key.unwrap_or_default().to_owned(),

Check warning on line 27 in crates/http-api-bindings/src/embedding/openai.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/embedding/openai.rs#L26-L27

Added lines #L26 - L27 were not covered by tests
model_name: model_name.to_owned(),
})
}
}

#[derive(Debug, Serialize)]
struct EmbeddingRequest {
input: Vec<String>,
model: String,
}

#[derive(Debug, Deserialize)]

Check warning on line 39 in crates/http-api-bindings/src/embedding/openai.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/embedding/openai.rs#L39

Added line #L39 was not covered by tests
struct EmbeddingResponse {
data: Vec<EmbeddingData>,
}

#[derive(Debug, Deserialize)]

Check warning on line 44 in crates/http-api-bindings/src/embedding/openai.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/embedding/openai.rs#L44

Added line #L44 was not covered by tests
struct EmbeddingData {
embedding: Vec<f32>,
}

#[async_trait]
impl Embedding for OpenAIEmbeddingEngine {
async fn embed(&self, prompt: &str) -> anyhow::Result<Vec<f32>> {
let request = CreateEmbeddingRequest {
let request = EmbeddingRequest {
input: vec![prompt.to_owned()],

Check warning on line 53 in crates/http-api-bindings/src/embedding/openai.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/embedding/openai.rs#L52-L53

Added lines #L52 - L53 were not covered by tests
model: self.model_name.clone(),
input: EmbeddingInput::String(prompt.to_owned()),
encoding_format: None,
user: None,
dimensions: None,
};
let resp = self

let request_builder = self

Check warning on line 57 in crates/http-api-bindings/src/embedding/openai.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/embedding/openai.rs#L56-L57

Added lines #L56 - L57 were not covered by tests
.client
.embeddings()
.create(request)
.instrument(info_span!("embedding", kind = "openai"))
.post(&self.api_endpoint)
.json(&request)
.header("content-type", "application/json")
.bearer_auth(&self.api_key);

Check warning on line 62 in crates/http-api-bindings/src/embedding/openai.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/embedding/openai.rs#L59-L62

Added lines #L59 - L62 were not covered by tests

let response = request_builder
.send()
.instrument(embedding_info_span!("openai"))

Check warning on line 66 in crates/http-api-bindings/src/embedding/openai.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/embedding/openai.rs#L64-L66

Added lines #L64 - L66 were not covered by tests
.await?;
let data = resp

if !response.status().is_success() {
let status = response.status();
let error = response.text().await?;
return Err(anyhow::anyhow!("Error {}: {}", status.as_u16(), error));
}

Check warning on line 73 in crates/http-api-bindings/src/embedding/openai.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/embedding/openai.rs#L69-L73

Added lines #L69 - L73 were not covered by tests

let response_body = response
.json::<EmbeddingResponse>()
.await
.context("Failed to parse response body")?;

Check warning on line 78 in crates/http-api-bindings/src/embedding/openai.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/embedding/openai.rs#L75-L78

Added lines #L75 - L78 were not covered by tests

response_body

Check warning on line 80 in crates/http-api-bindings/src/embedding/openai.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/embedding/openai.rs#L80

Added line #L80 was not covered by tests
.data
.into_iter()
.next()
.context("Failed to get embedding")?;
Ok(data.embedding)
.map(|data| data.embedding)
.ok_or_else(|| anyhow::anyhow!("No embedding data found"))

Check warning on line 85 in crates/http-api-bindings/src/embedding/openai.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/embedding/openai.rs#L84-L85

Added lines #L84 - L85 were not covered by tests
}
}

Expand All @@ -73,4 +103,17 @@
let embedding = engine.embed("Hello, world!").await.unwrap();
assert_eq!(embedding.len(), 768);
}

#[tokio::test]
#[ignore]
async fn test_voyage_embedding() {
let api_key = std::env::var("VOYAGE_API_KEY").expect("VOYAGE_API_KEY must be set");
let engine = OpenAIEmbeddingEngine::create(
"https://api.voyageai.com/v1",
"voyage-code-2",
Some(&api_key),
);
let embedding = engine.embed("Hello, world!").await.unwrap();
assert_eq!(embedding.len(), 1536);
}

Check warning on line 118 in crates/http-api-bindings/src/embedding/openai.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/embedding/openai.rs#L109-L118

Added lines #L109 - L118 were not covered by tests
}
98 changes: 0 additions & 98 deletions crates/http-api-bindings/src/embedding/voyage.rs

This file was deleted.

Loading