diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index 6feadac8..1a614417 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -1,5 +1,5 @@ // use hf_hub::api::sync::Api; -// use inference::service::InferenceService; +// use inference::service::ModelService; #[tokio::main] async fn main() { @@ -7,7 +7,7 @@ async fn main() { // let (_, receiver) = tokio::sync::mpsc::channel(32); - // let _ = InferenceService::start::( + // let _ = ModelService::start::( // "../inference.toml".parse().unwrap(), // "../private_key".parse().unwrap(), // receiver, diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index f572a33d..ef8b351d 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -13,7 +13,7 @@ use crate::{ models::{config::ModelConfig, ModelTrait, Request, Response}, }; -pub struct InferenceService +pub struct ModelService where T: Request, U: Response, @@ -24,7 +24,7 @@ where request_receiver: Receiver, } -impl InferenceService +impl ModelService where T: Clone + Request, U: std::fmt::Debug + Response, @@ -33,7 +33,7 @@ where config_file_path: PathBuf, private_key_path: PathBuf, request_receiver: Receiver, - ) -> Result + ) -> Result where M: ModelTrait + Send @@ -41,7 +41,7 @@ where F: ApiTrait, { let private_key_bytes = - std::fs::read(private_key_path).map_err(InferenceServiceError::PrivateKeyError)?; + std::fs::read(private_key_path).map_err(ModelServiceError::PrivateKeyError)?; let private_key_bytes: [u8; 32] = private_key_bytes .try_into() .expect("Incorrect private key bytes length"); @@ -56,7 +56,7 @@ where let (dispatcher, model_thread_handle) = ModelThreadDispatcher::start::(api, model_config, public_key) - .map_err(InferenceServiceError::ModelThreadError)?; + .map_err(ModelServiceError::ModelThreadError)?; let start_time = Instant::now(); Ok(Self { @@ -67,7 +67,7 @@ where }) } - pub async fn run(&mut self) -> Result { + pub async fn run(&mut self) -> Result { loop { tokio::select! { message = self.request_receiver.recv() => { @@ -92,7 +92,7 @@ where } } -impl InferenceService +impl ModelService where T: Request, U: Response, @@ -112,7 +112,7 @@ where } #[derive(Debug, Error)] -pub enum InferenceServiceError { +pub enum ModelServiceError { #[error("Failed to connect to API: `{0}`")] FailedApiConnection(ApiError), #[error("Failed to run inference: `{0}`")] @@ -133,13 +133,13 @@ pub enum InferenceServiceError { CandleError(CandleError), } -impl From for InferenceServiceError { +impl From for ModelServiceError { fn from(error: ApiError) -> Self { Self::ApiError(error) } } -impl From for InferenceServiceError { +impl From for ModelServiceError { fn from(error: CandleError) -> Self { Self::CandleError(error) } @@ -249,7 +249,7 @@ mod tests { let (_, receiver) = tokio::sync::mpsc::channel::<()>(1); - let _ = InferenceService::<(), ()>::start::( + let _ = ModelService::<(), ()>::start::( PathBuf::try_from(CONFIG_FILE_PATH).unwrap(), PathBuf::try_from(PRIVATE_KEY_FILE_PATH).unwrap(), receiver,