From d9c18bb5ffe65103b370ce8e866094ef1852eee0 Mon Sep 17 00:00:00 2001 From: Martin Stefcek Date: Thu, 4 Apr 2024 12:39:30 +0400 Subject: [PATCH] feat: add the option to set device_id for each model --- atoma-inference/src/main.rs | 4 +-- atoma-inference/src/model_thread.rs | 21 +++++++-------- atoma-inference/src/models/candle/mamba.rs | 27 +++++++++---------- atoma-inference/src/models/candle/mod.rs | 6 ++--- .../src/models/candle/stable_diffusion.rs | 9 ++++--- atoma-inference/src/models/config.rs | 20 +++++++++----- atoma-inference/src/models/mod.rs | 6 ++++- atoma-inference/src/service.rs | 6 ++--- 8 files changed, 56 insertions(+), 43 deletions(-) diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index bfb0181d..80d1db18 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -5,7 +5,7 @@ use hf_hub::api::sync::Api; use inference::{ models::{ candle::mamba::MambaModel, - config::ModelConfig, + config::ModelsConfig, types::{TextRequest, TextResponse}, }, service::{ModelService, ModelServiceError}, @@ -18,7 +18,7 @@ async fn main() -> Result<(), ModelServiceError> { let (req_sender, req_receiver) = tokio::sync::mpsc::channel::(32); let (resp_sender, mut resp_receiver) = tokio::sync::mpsc::channel::(32); - let model_config = ModelConfig::from_file_path("../inference.toml".parse().unwrap()); + let model_config = ModelsConfig::from_file_path("../inference.toml".parse().unwrap()); let private_key_bytes = std::fs::read("../private_key").map_err(ModelServiceError::PrivateKeyError)?; let private_key_bytes: [u8; 32] = private_key_bytes diff --git a/atoma-inference/src/model_thread.rs b/atoma-inference/src/model_thread.rs index 287190ec..b2a1fecd 100644 --- a/atoma-inference/src/model_thread.rs +++ b/atoma-inference/src/model_thread.rs @@ -11,7 +11,7 @@ use tracing::{debug, error, info, warn}; use crate::{ apis::{ApiError, ApiTrait}, - models::{config::ModelConfig, ModelError, ModelId, ModelTrait, Request, Response}, + models::{config::ModelsConfig, ModelError, ModelId, ModelTrait, Request, Response}, }; pub struct ModelThreadCommand(Req, oneshot::Sender) @@ -111,33 +111,32 @@ where Resp: Response, { pub(crate) fn start( - config: ModelConfig, + config: ModelsConfig, public_key: PublicKey, ) -> Result<(Self, Vec>), ModelThreadError> where F: ApiTrait + Send + Sync + 'static, M: ModelTrait + Send + 'static, { - let model_ids = config.model_ids(); let api_key = config.api_key(); let storage_path = config.storage_path(); let api = Arc::new(F::create(api_key, storage_path)?); - let mut handles = Vec::with_capacity(model_ids.len()); - let mut model_senders = HashMap::with_capacity(model_ids.len()); + let mut handles = Vec::new(); + let mut model_senders = HashMap::new(); - for (model_id, precision, revision) in model_ids { - info!("Spawning new thread for model: {model_id}"); + for model_config in config.models() { + info!("Spawning new thread for model: {}", model_config.model_id); let api = api.clone(); let (model_sender, model_receiver) = mpsc::channel::>(); - let model_name = model_id.clone(); + let model_name = model_config.model_id.clone(); let join_handle = std::thread::spawn(move || { info!("Fetching files for model: {model_name}"); - let filenames = api.fetch(model_name, revision)?; + let filenames = api.fetch(model_name, model_config.revision)?; - let model = M::load(filenames, precision)?; + let model = M::load(filenames, model_config.precision, model_config.device_id)?; let model_thread = ModelThread { model, receiver: model_receiver, @@ -156,7 +155,7 @@ where join_handle, sender: model_sender.clone(), }); - model_senders.insert(model_id, model_sender); + model_senders.insert(model_config.model_id, model_sender); } let model_dispatcher = ModelThreadDispatcher { diff --git a/atoma-inference/src/models/candle/mamba.rs b/atoma-inference/src/models/candle/mamba.rs index bdaa7e2d..3552cb6b 100644 --- a/atoma-inference/src/models/candle/mamba.rs +++ b/atoma-inference/src/models/candle/mamba.rs @@ -1,9 +1,6 @@ use std::{path::PathBuf, time::Instant}; -use candle::{ - utils::{cuda_is_available, metal_is_available}, - DType, Device, Tensor, -}; +use candle::{DType, Device, Tensor}; use candle_nn::VarBuilder; use candle_transformers::{ generation::LogitsProcessor, @@ -15,8 +12,12 @@ use tracing::info; use crate::{ bail, - models::types::{PrecisionBits, TextModelInput}, - models::{token_output_stream::TokenOutputStream, ModelError, ModelId, ModelTrait}, + models::{ + candle::device, + token_output_stream::TokenOutputStream, + types::{PrecisionBits, TextModelInput}, + ModelError, ModelId, ModelTrait, + }, }; pub struct MambaModel { @@ -53,7 +54,11 @@ impl ModelTrait for MambaModel { type Input = TextModelInput; type Output = String; - fn load(filenames: Vec, precision: PrecisionBits) -> Result + fn load( + filenames: Vec, + precision: PrecisionBits, + device_id: usize, + ) -> Result where Self: Sized, { @@ -70,13 +75,7 @@ impl ModelTrait for MambaModel { let config: Config = serde_json::from_slice(&std::fs::read(config_filename).map_err(ModelError::IoError)?) .map_err(ModelError::DeserializeError)?; - let device = if cuda_is_available() { - Device::new_cuda(0).map_err(ModelError::CandleError)? - } else if metal_is_available() { - Device::new_metal(0).map_err(ModelError::CandleError)? - } else { - Device::Cpu - }; + let device = device(device_id)?; let dtype = precision.into_dtype(); info!("Loading model weights.."); diff --git a/atoma-inference/src/models/candle/mod.rs b/atoma-inference/src/models/candle/mod.rs index a10f6f6f..008e929b 100644 --- a/atoma-inference/src/models/candle/mod.rs +++ b/atoma-inference/src/models/candle/mod.rs @@ -13,13 +13,13 @@ use super::ModelError; pub mod mamba; pub mod stable_diffusion; -pub fn device() -> Result { +pub fn device(device_id: usize) -> Result { if cuda_is_available() { info!("Using CUDA"); - Device::new_cuda(0) + Device::new_cuda(device_id) } else if metal_is_available() { info!("Using Metal"); - Device::new_metal(0) + Device::new_metal(device_id) } else { info!("Using Cpu"); Ok(Device::Cpu) diff --git a/atoma-inference/src/models/candle/stable_diffusion.rs b/atoma-inference/src/models/candle/stable_diffusion.rs index 3c3ab7b9..a04b8e50 100644 --- a/atoma-inference/src/models/candle/stable_diffusion.rs +++ b/atoma-inference/src/models/candle/stable_diffusion.rs @@ -97,7 +97,9 @@ impl From<&Input> for Fetch { } } } -pub struct StableDiffusion {} +pub struct StableDiffusion { + device_id: usize, +} pub struct Fetch { tokenizer: Option, @@ -116,11 +118,12 @@ impl ModelTrait for StableDiffusion { fn load( _filenames: Vec, _precision: PrecisionBits, + device_id: usize, ) -> Result where Self: Sized, { - Ok(Self {}) + Ok(Self { device_id }) } fn fetch(fetch: &Self::Fetch) -> Result<(), ModelError> { @@ -202,7 +205,7 @@ impl ModelTrait for StableDiffusion { }; let scheduler = sd_config.build_scheduler(n_steps)?; - let device = device()?; + let device = device(self.device_id)?; if let Some(seed) = input.seed { device.set_seed(seed)?; } diff --git a/atoma-inference/src/models/config.rs b/atoma-inference/src/models/config.rs index e5790163..bf7e8892 100644 --- a/atoma-inference/src/models/config.rs +++ b/atoma-inference/src/models/config.rs @@ -8,20 +8,28 @@ use crate::{models::types::PrecisionBits, models::ModelId}; type Revision = String; -#[derive(Debug, Deserialize, Serialize)] +#[derive(Clone, Debug, Deserialize, Serialize)] pub struct ModelConfig { + pub model_id: ModelId, + pub precision: PrecisionBits, + pub revision: Revision, + pub device_id: usize, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct ModelsConfig { api_key: String, flush_storage: bool, - models: Vec<(ModelId, PrecisionBits, Revision)>, + models: Vec, storage_path: PathBuf, tracing: bool, } -impl ModelConfig { +impl ModelsConfig { pub fn new( api_key: String, flush_storage: bool, - models: Vec<(ModelId, PrecisionBits, Revision)>, + models: Vec, storage_path: PathBuf, tracing: bool, ) -> Self { @@ -42,7 +50,7 @@ impl ModelConfig { self.flush_storage } - pub fn model_ids(&self) -> Vec<(ModelId, PrecisionBits, Revision)> { + pub fn models(&self) -> Vec { self.models.clone() } @@ -103,7 +111,7 @@ pub mod tests { #[test] fn test_config() { - let config = ModelConfig::new( + let config = ModelsConfig::new( String::from("my_key"), true, vec![("Llama2_7b".to_string(), PrecisionBits::F16, "".to_string())], diff --git a/atoma-inference/src/models/mod.rs b/atoma-inference/src/models/mod.rs index 467b2398..f12a7408 100644 --- a/atoma-inference/src/models/mod.rs +++ b/atoma-inference/src/models/mod.rs @@ -21,7 +21,11 @@ pub trait ModelTrait { fn fetch(_fetch: &Self::Fetch) -> Result<(), ModelError> { Ok(()) } - fn load(filenames: Vec, precision: PrecisionBits) -> Result + fn load( + filenames: Vec, + precision: PrecisionBits, + device_id: usize, + ) -> Result where Self: Sized; fn model_id(&self) -> ModelId; diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index 60cfdf9a..33a88b34 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -10,7 +10,7 @@ use thiserror::Error; use crate::{ apis::{ApiError, ApiTrait}, model_thread::{ModelThreadDispatcher, ModelThreadError, ModelThreadHandle}, - models::{config::ModelConfig, ModelTrait, Request, Response}, + models::{config::ModelsConfig, ModelTrait, Request, Response}, }; pub struct ModelService @@ -34,7 +34,7 @@ where Resp: std::fmt::Debug + Response, { pub fn start( - model_config: ModelConfig, + model_config: ModelsConfig, private_key: PrivateKey, request_receiver: Receiver, response_sender: Sender, @@ -250,7 +250,7 @@ mod tests { let (_, req_receiver) = tokio::sync::mpsc::channel::<()>(1); let (resp_sender, _) = tokio::sync::mpsc::channel::<()>(1); - let config = ModelConfig::from_file_path(CONFIG_FILE_PATH.parse().unwrap()); + let config = ModelsConfig::from_file_path(CONFIG_FILE_PATH.parse().unwrap()); let _ = ModelService::<(), ()>::start::( config,