Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: validate model capability before download #3565

Merged
merged 1 commit into from
Dec 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions crates/tabby-download/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use std::{fs, io};

use aim_downloader::{bar::WrappedBar, error::DownloadError, hash::HashChecker, https};
use anyhow::{bail, Result};
use anyhow::{bail, Context, Result};
use tabby_common::registry::{parse_model_id, ModelInfo, ModelRegistry};
use tokio_retry::{
strategy::{jitter, ExponentialBackoff},
Expand Down Expand Up @@ -187,17 +187,58 @@ async fn download_file(
Ok(())
}

pub async fn download_model(model_id: &str, prefer_local_file: bool) {
pub enum ModelKind {
Embedding,
Completion,
Chat,
}

pub async fn download_model(model_id: &str, prefer_local_file: bool, kind: Option<ModelKind>) {
let (registry, name) = parse_model_id(model_id);

let registry = ModelRegistry::new(registry).await;

if let Some(kind) = kind {
let model_info = registry.get_model_info(name);
validate_model_kind(kind, model_info)
.context(
"Model validation has failed. For TabbyML models, please consult https://github.com/tabbyml/registry-tabby to locate the appropriate models.",
)
.unwrap();
}

let handler = |err| panic!("Failed to fetch model '{}' due to '{}'", model_id, err);
download_model_impl(&registry, name, prefer_local_file)
.await
.unwrap_or_else(handler)
}

fn validate_model_kind(kind: ModelKind, info: &ModelInfo) -> Result<()> {
match kind {
ModelKind::Embedding => Ok(()),
ModelKind::Completion => info
.prompt_template
.as_ref()
.ok_or_else(|| {
anyhow::anyhow!(
"Model '{}' is not a completion model; it does not have a prompt template.",
info.name
)
})
.map(|_| ()),
ModelKind::Chat => info
.chat_template
.as_ref()
.ok_or_else(|| {
anyhow::anyhow!(
"Model '{}' is not a chat model, it does not have a chat template",
info.name
)
})
.map(|_| ()),
}
}

#[cfg(test)]
mod tests {
// filter_download_address tests should be serial because they rely on environment variables
Expand Down
2 changes: 1 addition & 1 deletion crates/tabby/src/download.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ pub struct DownloadArgs {
}

pub async fn main(args: &DownloadArgs) {
download_model(&args.model, args.prefer_local_file).await;
download_model(&args.model, args.prefer_local_file, None).await;
info!("model '{}' is ready", args.model);
}
7 changes: 4 additions & 3 deletions crates/tabby/src/serve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use tabby_common::{
config::{Config, ModelConfig},
usage,
};
use tabby_download::ModelKind;
use tabby_inference::ChatCompletionStream;
use tokio::{sync::oneshot::Sender, time::sleep};
use tower_http::timeout::TimeoutLayer;
Expand Down Expand Up @@ -212,15 +213,15 @@ pub async fn main(config: &Config, args: &ServeArgs) {

async fn load_model(config: &Config) {
if let Some(ModelConfig::Local(ref model)) = config.model.completion {
download_model_if_needed(&model.model_id).await;
download_model_if_needed(&model.model_id, ModelKind::Completion).await;
}

if let Some(ModelConfig::Local(ref model)) = config.model.chat {
download_model_if_needed(&model.model_id).await;
download_model_if_needed(&model.model_id, ModelKind::Chat).await;
}

if let ModelConfig::Local(ref model) = config.model.embedding {
download_model_if_needed(&model.model_id).await;
download_model_if_needed(&model.model_id, ModelKind::Embedding).await;
}
}

Expand Down
6 changes: 3 additions & 3 deletions crates/tabby/src/services/model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{fs, sync::Arc};

pub use llama_cpp_server::PromptInfo;
use tabby_common::config::ModelConfig;
use tabby_download::download_model;
use tabby_download::{download_model, ModelKind};
use tabby_inference::{ChatCompletionStream, CodeGeneration, CompletionStream, Embedding};
use tracing::info;

Expand Down Expand Up @@ -80,10 +80,10 @@ async fn load_completion_and_chat(
(completion, prompt, chat)
}

pub async fn download_model_if_needed(model: &str) {
pub async fn download_model_if_needed(model: &str, kind: ModelKind) {
if fs::metadata(model).is_ok() {
info!("Loading model from local path {}", model);
} else {
download_model(model, true).await;
download_model(model, true, Some(kind)).await;
}
}
Loading