From 6f3db0a7ba23b39532fbddb7799a62ae4acb33cd Mon Sep 17 00:00:00 2001 From: Anton Kushakov <57725022+SpeedCrash100@users.noreply.github.com> Date: Sat, 16 Mar 2024 12:29:59 +0300 Subject: [PATCH] feat(tabby): Added --chat-device arg for serve subcommand (#1685) * feat(tabby): Added --chat-device arg for serve subcommand * chore(changelog): Added info about --chat-device flag --- CHANGELOG.md | 1 + crates/tabby/src/serve.rs | 12 +++++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ff1a8f260735..920d9b62fe37 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ # v0.10.0 [UNRELEASED] ## Features +* Added the --chat-device flag to override a device used to run chat model. ## Fixes and Improvements diff --git a/crates/tabby/src/serve.rs b/crates/tabby/src/serve.rs index 8f1f9c161eff..231d742687d1 100644 --- a/crates/tabby/src/serve.rs +++ b/crates/tabby/src/serve.rs @@ -97,6 +97,10 @@ pub struct ServeArgs { #[clap(long, default_value_t=Device::Cpu)] device: Device, + /// Device to run chat model [default equals --device arg] + #[clap(long, requires("chat_model"))] + chat_device: Option, + /// Parallelism for model serving - increasing this number will have a significant impact on the /// memory requirement e.g., GPU vRAM. #[clap(long, default_value_t = 1)] @@ -187,7 +191,13 @@ async fn api_router( let chat_state = if let Some(chat_model) = &args.chat_model { Some(Arc::new( - create_chat_service(logger.clone(), chat_model, &args.device, args.parallelism).await, + create_chat_service( + logger.clone(), + chat_model, + args.chat_device.as_ref().unwrap_or(&args.device), + args.parallelism, + ) + .await, )) } else { None