diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index 36bfea0b..bfb0181d 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -1,9 +1,11 @@ use std::time::Duration; +use ed25519_consensus::SigningKey as PrivateKey; use hf_hub::api::sync::Api; use inference::{ models::{ candle::mamba::MambaModel, + config::ModelConfig, types::{TextRequest, TextResponse}, }, service::{ModelService, ModelServiceError}, @@ -16,9 +18,17 @@ async fn main() -> Result<(), ModelServiceError> { let (req_sender, req_receiver) = tokio::sync::mpsc::channel::(32); let (resp_sender, mut resp_receiver) = tokio::sync::mpsc::channel::(32); + let model_config = ModelConfig::from_file_path("../inference.toml".parse().unwrap()); + let private_key_bytes = + std::fs::read("../private_key").map_err(ModelServiceError::PrivateKeyError)?; + let private_key_bytes: [u8; 32] = private_key_bytes + .try_into() + .expect("Incorrect private key bytes length"); + + let private_key = PrivateKey::from(private_key_bytes); let mut service = ModelService::start::( - "../inference.toml".parse().unwrap(), - "../private_key".parse().unwrap(), + model_config, + private_key, req_receiver, resp_sender, ) diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index 8dce6923..1d13024f 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -34,8 +34,8 @@ where Resp: std::fmt::Debug + Response, { pub fn start( - config_file_path: PathBuf, - private_key_path: PathBuf, + model_config: ModelConfig, + private_key: PrivateKey, request_receiver: Receiver, response_sender: Sender, ) -> Result @@ -43,15 +43,7 @@ where M: ModelTrait + Send + 'static, F: ApiTrait + Send + Sync + 'static, { - let private_key_bytes = - std::fs::read(private_key_path).map_err(ModelServiceError::PrivateKeyError)?; - let private_key_bytes: [u8; 32] = private_key_bytes - .try_into() - .expect("Incorrect private key bytes length"); - - let private_key = PrivateKey::from(private_key_bytes); let public_key = private_key.verification_key(); - let model_config = ModelConfig::from_file_path(config_file_path); let flush_storage = model_config.flush_storage(); let storage_path = model_config.storage_path(); @@ -232,10 +224,8 @@ mod tests { #[tokio::test] async fn test_inference_service_initialization() { const CONFIG_FILE_PATH: &str = "./inference.toml"; - const PRIVATE_KEY_FILE_PATH: &str = "./private_key"; let private_key = PrivateKey::new(OsRng); - std::fs::write(PRIVATE_KEY_FILE_PATH, private_key.to_bytes()).unwrap(); let config_data = Value::Table(toml! { api_key = "your_api_key" @@ -255,15 +245,16 @@ mod tests { let (_, req_receiver) = tokio::sync::mpsc::channel::<()>(1); let (resp_sender, _) = tokio::sync::mpsc::channel::<()>(1); + let config = ModelConfig::from_file_path(CONFIG_FILE_PATH.parse().unwrap()); + let _ = ModelService::<(), ()>::start::( - PathBuf::from(CONFIG_FILE_PATH), - PathBuf::from(PRIVATE_KEY_FILE_PATH), + config, + private_key, req_receiver, resp_sender, ) .unwrap(); std::fs::remove_file(CONFIG_FILE_PATH).unwrap(); - std::fs::remove_file(PRIVATE_KEY_FILE_PATH).unwrap(); } }