Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(backend): adding Azure AI Foundry OpenAPI support #3683

Merged
merged 6 commits into from
Jan 15, 2025
Merged
67 changes: 42 additions & 25 deletions crates/http-api-bindings/src/chat/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,48 @@ pub async fn create(model: &HttpModelConfig) -> Arc<dyn ChatCompletionStream> {
.api_endpoint
.as_deref()
.expect("api_endpoint is required");
let config = OpenAIConfig::default()
.with_api_base(api_endpoint)
.with_api_key(model.api_key.clone().unwrap_or_default());

let mut builder = ExtendedOpenAIConfig::builder();

builder
.base(config)
.supported_models(model.supported_models.clone())
.model_name(model.model_name.as_deref().expect("Model name is required"));

if model.kind == "openai/chat" {
// Do nothing
} else if model.kind == "mistral/chat" {
builder.fields_to_remove(ExtendedOpenAIConfig::mistral_fields_to_remove());
} else {
panic!("Unsupported model kind: {}", model.kind);
}

let config = builder.build().expect("Failed to build config");

let engine = Box::new(
async_openai_alt::Client::with_config(config)
.with_http_client(create_reqwest_client(api_endpoint)),
);

let engine: Box<dyn ChatCompletionStream> = match model.kind.as_str() {
"azure/chat" => {
let config = async_openai_alt::config::AzureConfig::new()
.with_api_base(api_endpoint)
.with_api_key(model.api_key.clone().unwrap_or_default())
.with_api_version(
model
.api_version
.clone()
.unwrap_or("2024-05-01-preview".to_string()),
)
.with_deployment_id(model.model_name.as_deref().expect("Model name is required"));
Box::new(
async_openai_alt::Client::with_config(config)
.with_http_client(create_reqwest_client(api_endpoint)),
)
}
"openai/chat" | "mistral/chat" => {
let config = OpenAIConfig::default()
.with_api_base(api_endpoint)
.with_api_key(model.api_key.clone().unwrap_or_default());

let mut builder = ExtendedOpenAIConfig::builder();
builder
.base(config)
.supported_models(model.supported_models.clone())
.model_name(model.model_name.as_deref().expect("Model name is required"));

if model.kind == "mistral/chat" {
builder.fields_to_remove(ExtendedOpenAIConfig::mistral_fields_to_remove());
}

Box::new(
async_openai_alt::Client::with_config(
builder.build().expect("Failed to build config"),
)
.with_http_client(create_reqwest_client(api_endpoint)),
)
}
_ => panic!("Unsupported model kind: {}", model.kind),
};

Arc::new(rate_limit::new_chat(
engine,
Expand Down
5 changes: 5 additions & 0 deletions crates/tabby-common/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,11 @@ pub struct HttpModelConfig {

#[builder(default)]
pub additional_stop_words: Option<Vec<String>>,

/// Used For Azure API to specify the api version
#[builder(default)]
#[serde(default)]
pub api_version: Option<String>,
Sma1lboy marked this conversation as resolved.
Show resolved Hide resolved
}

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
Expand Down
17 changes: 17 additions & 0 deletions crates/tabby-inference/src/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,20 @@ impl ChatCompletionStream for async_openai_alt::Client<ExtendedOpenAIConfig> {
self.chat().create_stream(request).await
}
}

#[async_trait]
impl ChatCompletionStream for async_openai_alt::Client<async_openai_alt::config::AzureConfig> {
async fn chat(
&self,
request: CreateChatCompletionRequest,
) -> Result<CreateChatCompletionResponse, OpenAIError> {
self.chat().create(request).await
}

async fn chat_stream(
&self,
request: CreateChatCompletionRequest,
) -> Result<ChatCompletionResponseStream, OpenAIError> {
self.chat().create_stream(request).await
}
}
Loading