From 90d53de5ec3b17d29088fdda512aaf8c7c04aa3c Mon Sep 17 00:00:00 2001 From: Wei Zhang Date: Sat, 14 Dec 2024 12:21:24 +0800 Subject: [PATCH] feat: validate model capability before download Signed-off-by: Wei Zhang --- crates/tabby-download/src/lib.rs | 45 ++++++++++++++++++++++++-- crates/tabby/src/download.rs | 2 +- crates/tabby/src/serve.rs | 7 ++-- crates/tabby/src/services/model/mod.rs | 6 ++-- 4 files changed, 51 insertions(+), 9 deletions(-) diff --git a/crates/tabby-download/src/lib.rs b/crates/tabby-download/src/lib.rs index 0c1cc942fae7..18a0754f5dcc 100644 --- a/crates/tabby-download/src/lib.rs +++ b/crates/tabby-download/src/lib.rs @@ -2,7 +2,7 @@ use std::{fs, io}; use aim_downloader::{bar::WrappedBar, error::DownloadError, hash::HashChecker, https}; -use anyhow::{bail, Result}; +use anyhow::{bail, Context, Result}; use tabby_common::registry::{parse_model_id, ModelInfo, ModelRegistry}; use tokio_retry::{ strategy::{jitter, ExponentialBackoff}, @@ -187,17 +187,58 @@ async fn download_file( Ok(()) } -pub async fn download_model(model_id: &str, prefer_local_file: bool) { +pub enum ModelKind { + Embedding, + Completion, + Chat, +} + +pub async fn download_model(model_id: &str, prefer_local_file: bool, kind: Option) { let (registry, name) = parse_model_id(model_id); let registry = ModelRegistry::new(registry).await; + if let Some(kind) = kind { + let model_info = registry.get_model_info(name); + validate_model_kind(kind, model_info) + .context( + "Model validation has failed. For TabbyML models, please consult https://github.com/tabbyml/registry-tabby to locate the appropriate models.", + ) + .unwrap(); + } + let handler = |err| panic!("Failed to fetch model '{}' due to '{}'", model_id, err); download_model_impl(®istry, name, prefer_local_file) .await .unwrap_or_else(handler) } +fn validate_model_kind(kind: ModelKind, info: &ModelInfo) -> Result<()> { + match kind { + ModelKind::Embedding => Ok(()), + ModelKind::Completion => info + .prompt_template + .as_ref() + .ok_or_else(|| { + anyhow::anyhow!( + "Model '{}' is not a completion model; it does not have a prompt template.", + info.name + ) + }) + .map(|_| ()), + ModelKind::Chat => info + .chat_template + .as_ref() + .ok_or_else(|| { + anyhow::anyhow!( + "Model '{}' is not a chat model, it does not have a chat template", + info.name + ) + }) + .map(|_| ()), + } +} + #[cfg(test)] mod tests { // filter_download_address tests should be serial because they rely on environment variables diff --git a/crates/tabby/src/download.rs b/crates/tabby/src/download.rs index dcc36c1d116e..47ca391a21f8 100644 --- a/crates/tabby/src/download.rs +++ b/crates/tabby/src/download.rs @@ -14,6 +14,6 @@ pub struct DownloadArgs { } pub async fn main(args: &DownloadArgs) { - download_model(&args.model, args.prefer_local_file).await; + download_model(&args.model, args.prefer_local_file, None).await; info!("model '{}' is ready", args.model); } diff --git a/crates/tabby/src/serve.rs b/crates/tabby/src/serve.rs index 23a6cf5fe9b5..b5ca96109d8e 100644 --- a/crates/tabby/src/serve.rs +++ b/crates/tabby/src/serve.rs @@ -10,6 +10,7 @@ use tabby_common::{ config::{Config, ModelConfig}, usage, }; +use tabby_download::ModelKind; use tabby_inference::ChatCompletionStream; use tokio::{sync::oneshot::Sender, time::sleep}; use tower_http::timeout::TimeoutLayer; @@ -212,15 +213,15 @@ pub async fn main(config: &Config, args: &ServeArgs) { async fn load_model(config: &Config) { if let Some(ModelConfig::Local(ref model)) = config.model.completion { - download_model_if_needed(&model.model_id).await; + download_model_if_needed(&model.model_id, ModelKind::Completion).await; } if let Some(ModelConfig::Local(ref model)) = config.model.chat { - download_model_if_needed(&model.model_id).await; + download_model_if_needed(&model.model_id, ModelKind::Chat).await; } if let ModelConfig::Local(ref model) = config.model.embedding { - download_model_if_needed(&model.model_id).await; + download_model_if_needed(&model.model_id, ModelKind::Embedding).await; } } diff --git a/crates/tabby/src/services/model/mod.rs b/crates/tabby/src/services/model/mod.rs index d433a75b1e94..11d206f59425 100644 --- a/crates/tabby/src/services/model/mod.rs +++ b/crates/tabby/src/services/model/mod.rs @@ -2,7 +2,7 @@ use std::{fs, sync::Arc}; pub use llama_cpp_server::PromptInfo; use tabby_common::config::ModelConfig; -use tabby_download::download_model; +use tabby_download::{download_model, ModelKind}; use tabby_inference::{ChatCompletionStream, CodeGeneration, CompletionStream, Embedding}; use tracing::info; @@ -80,10 +80,10 @@ async fn load_completion_and_chat( (completion, prompt, chat) } -pub async fn download_model_if_needed(model: &str) { +pub async fn download_model_if_needed(model: &str, kind: ModelKind) { if fs::metadata(model).is_ok() { info!("Loading model from local path {}", model); } else { - download_model(model, true).await; + download_model(model, true, Some(kind)).await; } }