diff --git a/atoma-inference/src/model_thread.rs b/atoma-inference/src/model_thread.rs index b2a1fecd..388554ae 100644 --- a/atoma-inference/src/model_thread.rs +++ b/atoma-inference/src/model_thread.rs @@ -126,17 +126,22 @@ where let mut model_senders = HashMap::new(); for model_config in config.models() { - info!("Spawning new thread for model: {}", model_config.model_id); + info!("Spawning new thread for model: {}", model_config.model_id()); let api = api.clone(); let (model_sender, model_receiver) = mpsc::channel::>(); - let model_name = model_config.model_id.clone(); + let model_name = model_config.model_id().clone(); + model_senders.insert(model_name.clone(), model_sender.clone()); let join_handle = std::thread::spawn(move || { info!("Fetching files for model: {model_name}"); - let filenames = api.fetch(model_name, model_config.revision)?; + let filenames = api.fetch(model_name, model_config.revision())?; - let model = M::load(filenames, model_config.precision, model_config.device_id)?; + let model = M::load( + filenames, + model_config.precision(), + model_config.device_id(), + )?; let model_thread = ModelThread { model, receiver: model_receiver, @@ -155,7 +160,6 @@ where join_handle, sender: model_sender.clone(), }); - model_senders.insert(model_config.model_id, model_sender); } let model_dispatcher = ModelThreadDispatcher { diff --git a/atoma-inference/src/models/config.rs b/atoma-inference/src/models/config.rs index bf7e8892..8606ebcb 100644 --- a/atoma-inference/src/models/config.rs +++ b/atoma-inference/src/models/config.rs @@ -10,10 +10,42 @@ type Revision = String; #[derive(Clone, Debug, Deserialize, Serialize)] pub struct ModelConfig { - pub model_id: ModelId, - pub precision: PrecisionBits, - pub revision: Revision, - pub device_id: usize, + model_id: ModelId, + precision: PrecisionBits, + revision: Revision, + device_id: usize, +} + +impl ModelConfig { + pub fn new( + model_id: ModelId, + precision: PrecisionBits, + revision: Revision, + device_id: usize, + ) -> Self { + Self { + model_id, + precision, + revision, + device_id, + } + } + + pub fn model_id(&self) -> &ModelId { + &self.model_id + } + + pub fn precision(&self) -> PrecisionBits { + self.precision + } + + pub fn revision(&self) -> Revision { + self.revision.clone() + } + + pub fn device_id(&self) -> usize { + self.device_id + } } #[derive(Debug, Deserialize, Serialize)] @@ -114,7 +146,12 @@ pub mod tests { let config = ModelsConfig::new( String::from("my_key"), true, - vec![("Llama2_7b".to_string(), PrecisionBits::F16, "".to_string())], + vec![ModelConfig::new( + "Llama2_7b".to_string(), + PrecisionBits::F16, + "".to_string(), + 0, + )], "storage_path".parse().unwrap(), true, );