diff --git a/atoma-inference/src/model_thread.rs b/atoma-inference/src/model_thread.rs index 8e1ee6bb..5b623aaa 100644 --- a/atoma-inference/src/model_thread.rs +++ b/atoma-inference/src/model_thread.rs @@ -108,7 +108,6 @@ where U: Response, { pub(crate) fn start( - api: F, config: ModelConfig, public_key: PublicKey, ) -> Result<(Self, Vec>), ModelThreadError> @@ -119,6 +118,10 @@ where + 'static, { let model_ids = config.model_ids(); + let api_key = config.api_key(); + let storage_path = config.storage_path(); + let api = F::create(api_key, storage_path)?; + let mut handles = Vec::with_capacity(model_ids.len()); let mut model_senders = HashMap::with_capacity(model_ids.len()); diff --git a/atoma-inference/src/models/config.rs b/atoma-inference/src/models/config.rs index 08217d91..ff5bf9a6 100644 --- a/atoma-inference/src/models/config.rs +++ b/atoma-inference/src/models/config.rs @@ -9,7 +9,7 @@ use crate::models::ModelId; pub struct ModelConfig { api_key: String, models: Vec, - storage_folder: PathBuf, + storage_path: PathBuf, tracing: bool, } @@ -17,13 +17,13 @@ impl ModelConfig { pub fn new( api_key: String, models: Vec, - storage_folder: PathBuf, + storage_path: PathBuf, tracing: bool, ) -> Self { Self { api_key, models, - storage_folder, + storage_path, tracing, } } @@ -36,8 +36,8 @@ impl ModelConfig { self.models.clone() } - pub fn storage_folder(&self) -> PathBuf { - self.storage_folder.clone() + pub fn storage_path(&self) -> PathBuf { + self.storage_path.clone() } pub fn tracing(&self) -> bool { @@ -66,12 +66,12 @@ pub mod tests { let config = ModelConfig::new( String::from("my_key"), vec!["Llama2_7b".to_string()], - "storage_folder".parse().unwrap(), + "storage_path".parse().unwrap(), true, ); let toml_str = toml::to_string(&config).unwrap(); - let should_be_toml_str = "api_key = \"my_key\"\nmodels = [\"Llama2_7b\"]\nstorage_folder = \"storage_folder\"\ntracing = true\n"; + let should_be_toml_str = "api_key = \"my_key\"\nmodels = [\"Llama2_7b\"]\nstorage_path = \"storage_path\"\ntracing = true\n"; assert_eq!(toml_str, should_be_toml_str); } } diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index ef8b351d..076966aa 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -49,13 +49,9 @@ where let private_key = PrivateKey::from(private_key_bytes); let public_key = private_key.verification_key(); let model_config = ModelConfig::from_file_path(config_file_path); - let api_key = model_config.api_key(); - let storage_folder = model_config.storage_folder(); - - let api = F::create(api_key, storage_folder)?; let (dispatcher, model_thread_handle) = - ModelThreadDispatcher::start::(api, model_config, public_key) + ModelThreadDispatcher::start::(model_config, public_key) .map_err(ModelServiceError::ModelThreadError)?; let start_time = Instant::now(); @@ -236,7 +232,7 @@ mod tests { let config_data = Value::Table(toml! { api_key = "your_api_key" models = ["Mamba3b"] - storage_folder = "./storage_folder/" + storage_path = "./storage_path/" tokenizer_file_path = "./tokenizer_file_path/" tracing = true });