From 28cf4ffcb01c074a3170182f4b0a6e76f4dd97ef Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Sat, 13 Apr 2024 18:21:02 +0100 Subject: [PATCH] first commit --- atoma-inference/src/lib.rs | 2 + atoma-inference/src/model_thread.rs | 10 +-- atoma-inference/src/tests/mod.rs | 108 ++++++++++++++++++++++++++++ 3 files changed, 115 insertions(+), 5 deletions(-) create mode 100644 atoma-inference/src/tests/mod.rs diff --git a/atoma-inference/src/lib.rs b/atoma-inference/src/lib.rs index 39a31749..f4bfd3c3 100644 --- a/atoma-inference/src/lib.rs +++ b/atoma-inference/src/lib.rs @@ -4,3 +4,5 @@ pub mod model_thread; pub mod models; pub mod service; pub mod specs; +#[cfg(test)] +pub mod tests; diff --git a/atoma-inference/src/model_thread.rs b/atoma-inference/src/model_thread.rs index 3339c9a2..ab5c3aca 100644 --- a/atoma-inference/src/model_thread.rs +++ b/atoma-inference/src/model_thread.rs @@ -21,8 +21,8 @@ use crate::{ }; pub struct ModelThreadCommand { - request: serde_json::Value, - response_sender: oneshot::Sender, + pub(crate) request: serde_json::Value, + pub(crate) response_sender: oneshot::Sender, } #[derive(Debug, Error)] @@ -94,7 +94,7 @@ where } pub struct ModelThreadDispatcher { - model_senders: HashMap>, + pub(crate) model_senders: HashMap>, } impl ModelThreadDispatcher { @@ -172,7 +172,7 @@ impl ModelThreadDispatcher { } } -fn dispatch_model_thread( +pub(crate) fn dispatch_model_thread( api_key: String, cache_dir: PathBuf, model_name: String, @@ -231,7 +231,7 @@ fn dispatch_model_thread( } } -fn spawn_model_thread( +pub(crate) fn spawn_model_thread( model_name: String, api_key: String, cache_dir: PathBuf, diff --git a/atoma-inference/src/tests/mod.rs b/atoma-inference/src/tests/mod.rs new file mode 100644 index 00000000..78145b90 --- /dev/null +++ b/atoma-inference/src/tests/mod.rs @@ -0,0 +1,108 @@ +use crate::models::{config::ModelConfig, types::ModelType, ModelError, ModelTrait}; +use ed25519_consensus::SigningKey as PrivateKey; +use std::{path::PathBuf, time::Duration}; + +#[cfg(test)] +mod tests { + use std::{collections::HashMap, sync::mpsc}; + + use rand::rngs::OsRng; + use tokio::sync::oneshot; + + use crate::model_thread::{spawn_model_thread, ModelThreadCommand, ModelThreadDispatcher}; + + use super::*; + + const DURATION_1_SECS: Duration = Duration::from_secs(1); + const DURATION_2_SECS: Duration = Duration::from_secs(2); + const DURATION_5_SECS: Duration = Duration::from_secs(5); + const DURATION_10_SECS: Duration = Duration::from_secs(10); + + struct TestModel { + duration: Duration, + } + + impl ModelTrait for TestModel { + type Input = (); + type Output = (); + type LoadData = (); + + fn fetch( + api_key: String, + cache_dir: PathBuf, + config: ModelConfig, + ) -> Result { + Ok(()) + } + + fn load(load_data: Self::LoadData) -> Result + where + Self: Sized, + { + Ok(Self { + duration: DURATION_1_SECS, + }) + } + + fn model_type(&self) -> ModelType { + todo!() + } + + fn run(&mut self, input: Self::Input) -> Result { + std::thread::sleep(self.duration); + Ok(()) + } + } + + impl ModelThreadDispatcher { + fn test_start() -> Self { + let mut model_senders = HashMap::with_capacity(4); + + for duration in [ + DURATION_1_SECS, + DURATION_2_SECS, + DURATION_5_SECS, + DURATION_10_SECS, + ] { + let model_name = format!("test_model_{:?}", duration); + + let (model_sender, model_receiver) = mpsc::channel::(); + model_senders.insert(model_name.clone(), model_sender.clone()); + + let api_key = "".to_string(); + let cache_dir = "./".parse().unwrap(); + let model_config = + ModelConfig::new(model_name.clone(), "".to_string(), "".to_string(), 0, false); + + let private_key = PrivateKey::new(OsRng); + let public_key = private_key.verification_key(); + + let _join_handle = spawn_model_thread::( + model_name, + api_key, + cache_dir, + model_config, + public_key, + model_receiver, + ); + } + Self { model_senders } + } + } + + #[tokio::test] + async fn test_model_thread() { + let model_thread_dispatcher = ModelThreadDispatcher::test_start(); + + for _ in 0..10 { + for sender in model_thread_dispatcher.model_senders.values() { + let (response_sender, response_request) = oneshot::channel(); + let command = ModelThreadCommand { + request: serde_json::Value::Null, + response_sender, + }; + sender.send(command).expect("Failed to send command"); + } + } + } +}