Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(backend): adding Azure AI Foundry OpenAPI support #3683

Merged
merged 6 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 38 additions & 26 deletions crates/http-api-bindings/src/chat/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,38 +5,50 @@ use tabby_common::config::HttpModelConfig;
use tabby_inference::{ChatCompletionStream, ExtendedOpenAIConfig};

use super::rate_limit;
use crate::create_reqwest_client;
use crate::{create_reqwest_client, AZURE_API_VERSION};

pub async fn create(model: &HttpModelConfig) -> Arc<dyn ChatCompletionStream> {
let api_endpoint = model
.api_endpoint
.as_deref()
.expect("api_endpoint is required");
let config = OpenAIConfig::default()
.with_api_base(api_endpoint)
.with_api_key(model.api_key.clone().unwrap_or_default());

let mut builder = ExtendedOpenAIConfig::builder();

builder
.base(config)
.supported_models(model.supported_models.clone())
.model_name(model.model_name.as_deref().expect("Model name is required"));

if model.kind == "openai/chat" {
// Do nothing
} else if model.kind == "mistral/chat" {
builder.fields_to_remove(ExtendedOpenAIConfig::mistral_fields_to_remove());
} else {
panic!("Unsupported model kind: {}", model.kind);
}

let config = builder.build().expect("Failed to build config");

let engine = Box::new(
async_openai_alt::Client::with_config(config)
.with_http_client(create_reqwest_client(api_endpoint)),
);

let engine: Box<dyn ChatCompletionStream> = match model.kind.as_str() {
"azure/chat" => {
let config = async_openai_alt::config::AzureConfig::new()
.with_api_base(api_endpoint)
.with_api_key(model.api_key.clone().unwrap_or_default())
.with_api_version(AZURE_API_VERSION.to_string())
.with_deployment_id(model.model_name.as_deref().expect("Model name is required"));
Box::new(
async_openai_alt::Client::with_config(config)
.with_http_client(create_reqwest_client(api_endpoint)),
)
}
"openai/chat" | "mistral/chat" => {
let config = OpenAIConfig::default()
.with_api_base(api_endpoint)
.with_api_key(model.api_key.clone().unwrap_or_default());

let mut builder = ExtendedOpenAIConfig::builder();
builder
.base(config)
.supported_models(model.supported_models.clone())
.model_name(model.model_name.as_deref().expect("Model name is required"));

if model.kind == "mistral/chat" {
builder.fields_to_remove(ExtendedOpenAIConfig::mistral_fields_to_remove());
}

Box::new(
async_openai_alt::Client::with_config(
builder.build().expect("Failed to build config"),
)
.with_http_client(create_reqwest_client(api_endpoint)),
)
}
_ => panic!("Unsupported model kind: {}", model.kind),
};

Arc::new(rate_limit::new_chat(
engine,
Expand Down
123 changes: 123 additions & 0 deletions crates/http-api-bindings/src/embedding/azure.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
use std::sync::Arc;

use anyhow::Result;
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use tabby_inference::Embedding;

use crate::AZURE_API_VERSION;

/// `AzureEmbeddingEngine` is responsible for interacting with Azure's Embedding API.
///
/// **Note**: Currently, this implementation only supports the OpenAI API and specific API versions.
#[derive(Clone)]
pub struct AzureEmbeddingEngine {
client: Arc<Client>,
api_endpoint: String,
api_key: String,
}

/// Structure representing the request body for embedding.
#[derive(Debug, Serialize)]
struct EmbeddingRequest {
input: String,
}

/// Structure representing the response from the embedding API.
#[derive(Debug, Deserialize)]
struct EmbeddingResponse {
data: Vec<Data>,
}

/// Structure representing individual embedding data.
#[derive(Debug, Deserialize)]
struct Data {
embedding: Vec<f32>,
}

impl AzureEmbeddingEngine {
/// Creates a new instance of `AzureEmbeddingEngine`.
///
/// **Note**: Currently, this implementation only supports the OpenAI API and specific API versions.
///
/// # Parameters
///
/// - `api_endpoint`: The base URL of the Azure Embedding API.
/// - `model_name`: The name of the deployed model, used to construct the deployment ID.
/// - `api_key`: Optional API key for authentication.
/// - `api_version`: Optional API version, defaults to "2023-05-15".
///
/// # Returns
///
/// A boxed instance that implements the `Embedding` trait.
pub fn create(
api_endpoint: &str,
model_name: &str,
api_key: Option<&str>,
) -> Box<dyn Embedding> {
let client = Client::new();
let deployment_id = model_name;
// Construct the full endpoint URL for the Azure Embedding API
let azure_endpoint = format!(
"{}/openai/deployments/{}/embeddings",
api_endpoint.trim_end_matches('/'),
deployment_id
);

Box::new(Self {
client: Arc::new(client),
api_endpoint: azure_endpoint,
api_key: api_key.unwrap_or_default().to_owned(),
})
}
}

#[async_trait]
impl Embedding for AzureEmbeddingEngine {
/// Generates an embedding vector for the given prompt.
///
/// **Note**: Currently, this implementation only supports the OpenAI API and specific API versions.
///
/// # Parameters
///
/// - `prompt`: The input text to generate embeddings for.
///
/// # Returns
///
/// A `Result` containing the embedding vector or an error.
async fn embed(&self, prompt: &str) -> Result<Vec<f32>> {
// Clone all necessary fields to ensure thread safety across await points
let api_endpoint = self.api_endpoint.clone();
let api_key = self.api_key.clone();
let api_version = AZURE_API_VERSION.to_string();
let request = EmbeddingRequest {
input: prompt.to_owned(),
};

// Send a POST request to the Azure Embedding API
let response = self
.client
.post(&api_endpoint)
.query(&[("api-version", &api_version)])
.header("api-key", &api_key)
.header("Content-Type", "application/json")
.json(&request)
.send()
.await?;

// Check if the response status indicates success
if !response.status().is_success() {
let error_text = response.text().await?;
anyhow::bail!("Azure API error: {}", error_text);
}

// Deserialize the response body into `EmbeddingResponse`
let embedding_response: EmbeddingResponse = response.json().await?;
embedding_response
.data
.first()
.map(|data| data.embedding.clone())
.ok_or_else(|| anyhow::anyhow!("No embedding data received"))
}
}
10 changes: 10 additions & 0 deletions crates/http-api-bindings/src/embedding/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
mod azure;
mod llama;
mod openai;

use core::panic;
use std::sync::Arc;

use azure::AzureEmbeddingEngine;
use llama::LlamaCppEngine;
use openai::OpenAIEmbeddingEngine;
use tabby_common::config::HttpModelConfig;
Expand Down Expand Up @@ -40,6 +42,14 @@ pub async fn create(config: &HttpModelConfig) -> Arc<dyn Embedding> {
.expect("model_name must be set for voyage/embedding"),
config.api_key.as_deref(),
),
"azure/embedding" => AzureEmbeddingEngine::create(
config
.api_endpoint
.as_deref()
.expect("api_endpoint is required for azure/embedding"),
config.model_name.as_deref().unwrap_or_default(), // Provide a default if model_name is optional
config.api_key.as_deref(),
),
unsupported_kind => panic!(
"Unsupported kind for http embedding model: {}",
unsupported_kind
Expand Down
2 changes: 2 additions & 0 deletions crates/http-api-bindings/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,5 @@ fn create_reqwest_client(api_endpoint: &str) -> reqwest::Client {

builder.build().unwrap()
}

static AZURE_API_VERSION: &str = "2024-02-01";
17 changes: 17 additions & 0 deletions crates/tabby-inference/src/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,20 @@ impl ChatCompletionStream for async_openai_alt::Client<ExtendedOpenAIConfig> {
self.chat().create_stream(request).await
}
}

#[async_trait]
impl ChatCompletionStream for async_openai_alt::Client<async_openai_alt::config::AzureConfig> {
async fn chat(
&self,
request: CreateChatCompletionRequest,
) -> Result<CreateChatCompletionResponse, OpenAIError> {
self.chat().create(request).await
}

async fn chat_stream(
&self,
request: CreateChatCompletionRequest,
) -> Result<ChatCompletionResponseStream, OpenAIError> {
self.chat().create_stream(request).await
}
}
Loading