Skip to content

Commit

Permalink
rename InferenceService to ModelService
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgeantonio21 committed Mar 31, 2024
1 parent a19817e commit 4a12b71
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
4 changes: 2 additions & 2 deletions atoma-inference/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
// use hf_hub::api::sync::Api;
// use inference::service::InferenceService;
// use inference::service::ModelService;

#[tokio::main]
async fn main() {
tracing_subscriber::fmt::init();

// let (_, receiver) = tokio::sync::mpsc::channel(32);

// let _ = InferenceService::start::<Model, Api>(
// let _ = ModelService::start::<Model, Api>(
// "../inference.toml".parse().unwrap(),
// "../private_key".parse().unwrap(),
// receiver,
Expand Down
22 changes: 11 additions & 11 deletions atoma-inference/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::{
models::{config::ModelConfig, ModelTrait, Request, Response},
};

pub struct InferenceService<T, U>
pub struct ModelService<T, U>
where
T: Request,
U: Response,
Expand All @@ -24,7 +24,7 @@ where
request_receiver: Receiver<T>,
}

impl<T, U> InferenceService<T, U>
impl<T, U> ModelService<T, U>
where
T: Clone + Request,
U: std::fmt::Debug + Response,
Expand All @@ -33,15 +33,15 @@ where
config_file_path: PathBuf,
private_key_path: PathBuf,
request_receiver: Receiver<T>,
) -> Result<Self, InferenceServiceError>
) -> Result<Self, ModelServiceError>
where
M: ModelTrait<FetchApi = F, Input = T::ModelInput, Output = U::ModelOutput>
+ Send
+ 'static,
F: ApiTrait,
{
let private_key_bytes =
std::fs::read(private_key_path).map_err(InferenceServiceError::PrivateKeyError)?;
std::fs::read(private_key_path).map_err(ModelServiceError::PrivateKeyError)?;
let private_key_bytes: [u8; 32] = private_key_bytes
.try_into()
.expect("Incorrect private key bytes length");
Expand All @@ -56,7 +56,7 @@ where

let (dispatcher, model_thread_handle) =
ModelThreadDispatcher::start::<M, F>(api, model_config, public_key)
.map_err(InferenceServiceError::ModelThreadError)?;
.map_err(ModelServiceError::ModelThreadError)?;
let start_time = Instant::now();

Ok(Self {
Expand All @@ -67,7 +67,7 @@ where
})
}

pub async fn run(&mut self) -> Result<U, InferenceServiceError> {
pub async fn run(&mut self) -> Result<U, ModelServiceError> {
loop {
tokio::select! {
message = self.request_receiver.recv() => {
Expand All @@ -92,7 +92,7 @@ where
}
}

impl<T, U> InferenceService<T, U>
impl<T, U> ModelService<T, U>
where
T: Request,
U: Response,
Expand All @@ -112,7 +112,7 @@ where
}

#[derive(Debug, Error)]
pub enum InferenceServiceError {
pub enum ModelServiceError {
#[error("Failed to connect to API: `{0}`")]
FailedApiConnection(ApiError),
#[error("Failed to run inference: `{0}`")]
Expand All @@ -133,13 +133,13 @@ pub enum InferenceServiceError {
CandleError(CandleError),
}

impl From<ApiError> for InferenceServiceError {
impl From<ApiError> for ModelServiceError {
fn from(error: ApiError) -> Self {
Self::ApiError(error)
}
}

impl From<CandleError> for InferenceServiceError {
impl From<CandleError> for ModelServiceError {
fn from(error: CandleError) -> Self {
Self::CandleError(error)
}
Expand Down Expand Up @@ -249,7 +249,7 @@ mod tests {

let (_, receiver) = tokio::sync::mpsc::channel::<()>(1);

let _ = InferenceService::<(), ()>::start::<TestModelInstance, MockApi>(
let _ = ModelService::<(), ()>::start::<TestModelInstance, MockApi>(
PathBuf::try_from(CONFIG_FILE_PATH).unwrap(),
PathBuf::try_from(PRIVATE_KEY_FILE_PATH).unwrap(),
receiver,
Expand Down

0 comments on commit 4a12b71

Please sign in to comment.