From 28cf4ffcb01c074a3170182f4b0a6e76f4dd97ef Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Sat, 13 Apr 2024 18:21:02 +0100 Subject: [PATCH 1/8] first commit --- atoma-inference/src/lib.rs | 2 + atoma-inference/src/model_thread.rs | 10 +-- atoma-inference/src/tests/mod.rs | 108 ++++++++++++++++++++++++++++ 3 files changed, 115 insertions(+), 5 deletions(-) create mode 100644 atoma-inference/src/tests/mod.rs diff --git a/atoma-inference/src/lib.rs b/atoma-inference/src/lib.rs index 39a31749..f4bfd3c3 100644 --- a/atoma-inference/src/lib.rs +++ b/atoma-inference/src/lib.rs @@ -4,3 +4,5 @@ pub mod model_thread; pub mod models; pub mod service; pub mod specs; +#[cfg(test)] +pub mod tests; diff --git a/atoma-inference/src/model_thread.rs b/atoma-inference/src/model_thread.rs index 3339c9a2..ab5c3aca 100644 --- a/atoma-inference/src/model_thread.rs +++ b/atoma-inference/src/model_thread.rs @@ -21,8 +21,8 @@ use crate::{ }; pub struct ModelThreadCommand { - request: serde_json::Value, - response_sender: oneshot::Sender, + pub(crate) request: serde_json::Value, + pub(crate) response_sender: oneshot::Sender, } #[derive(Debug, Error)] @@ -94,7 +94,7 @@ where } pub struct ModelThreadDispatcher { - model_senders: HashMap>, + pub(crate) model_senders: HashMap>, } impl ModelThreadDispatcher { @@ -172,7 +172,7 @@ impl ModelThreadDispatcher { } } -fn dispatch_model_thread( +pub(crate) fn dispatch_model_thread( api_key: String, cache_dir: PathBuf, model_name: String, @@ -231,7 +231,7 @@ fn dispatch_model_thread( } } -fn spawn_model_thread( +pub(crate) fn spawn_model_thread( model_name: String, api_key: String, cache_dir: PathBuf, diff --git a/atoma-inference/src/tests/mod.rs b/atoma-inference/src/tests/mod.rs new file mode 100644 index 00000000..78145b90 --- /dev/null +++ b/atoma-inference/src/tests/mod.rs @@ -0,0 +1,108 @@ +use crate::models::{config::ModelConfig, types::ModelType, ModelError, ModelTrait}; +use ed25519_consensus::SigningKey as PrivateKey; +use std::{path::PathBuf, time::Duration}; + +#[cfg(test)] +mod tests { + use std::{collections::HashMap, sync::mpsc}; + + use rand::rngs::OsRng; + use tokio::sync::oneshot; + + use crate::model_thread::{spawn_model_thread, ModelThreadCommand, ModelThreadDispatcher}; + + use super::*; + + const DURATION_1_SECS: Duration = Duration::from_secs(1); + const DURATION_2_SECS: Duration = Duration::from_secs(2); + const DURATION_5_SECS: Duration = Duration::from_secs(5); + const DURATION_10_SECS: Duration = Duration::from_secs(10); + + struct TestModel { + duration: Duration, + } + + impl ModelTrait for TestModel { + type Input = (); + type Output = (); + type LoadData = (); + + fn fetch( + api_key: String, + cache_dir: PathBuf, + config: ModelConfig, + ) -> Result { + Ok(()) + } + + fn load(load_data: Self::LoadData) -> Result + where + Self: Sized, + { + Ok(Self { + duration: DURATION_1_SECS, + }) + } + + fn model_type(&self) -> ModelType { + todo!() + } + + fn run(&mut self, input: Self::Input) -> Result { + std::thread::sleep(self.duration); + Ok(()) + } + } + + impl ModelThreadDispatcher { + fn test_start() -> Self { + let mut model_senders = HashMap::with_capacity(4); + + for duration in [ + DURATION_1_SECS, + DURATION_2_SECS, + DURATION_5_SECS, + DURATION_10_SECS, + ] { + let model_name = format!("test_model_{:?}", duration); + + let (model_sender, model_receiver) = mpsc::channel::(); + model_senders.insert(model_name.clone(), model_sender.clone()); + + let api_key = "".to_string(); + let cache_dir = "./".parse().unwrap(); + let model_config = + ModelConfig::new(model_name.clone(), "".to_string(), "".to_string(), 0, false); + + let private_key = PrivateKey::new(OsRng); + let public_key = private_key.verification_key(); + + let _join_handle = spawn_model_thread::( + model_name, + api_key, + cache_dir, + model_config, + public_key, + model_receiver, + ); + } + Self { model_senders } + } + } + + #[tokio::test] + async fn test_model_thread() { + let model_thread_dispatcher = ModelThreadDispatcher::test_start(); + + for _ in 0..10 { + for sender in model_thread_dispatcher.model_senders.values() { + let (response_sender, response_request) = oneshot::channel(); + let command = ModelThreadCommand { + request: serde_json::Value::Null, + response_sender, + }; + sender.send(command).expect("Failed to send command"); + } + } + } +} From a96c0e7eb9572fa1089dc68276979d6394b841cc Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Sun, 14 Apr 2024 22:21:26 +0100 Subject: [PATCH 2/8] add more logic to test_inference_service --- atoma-inference/Cargo.toml | 1 + atoma-inference/src/models/config.rs | 4 +- atoma-inference/src/service.rs | 11 +- atoma-inference/src/tests/mod.rs | 174 ++++++++++++++++++++++----- atoma-inference/src/tests/prompts.rs | 56 +++++++++ 5 files changed, 211 insertions(+), 35 deletions(-) create mode 100644 atoma-inference/src/tests/prompts.rs diff --git a/atoma-inference/Cargo.toml b/atoma-inference/Cargo.toml index 3b7de154..bbfba1b0 100644 --- a/atoma-inference/Cargo.toml +++ b/atoma-inference/Cargo.toml @@ -27,6 +27,7 @@ tracing-subscriber.workspace = true [dev-dependencies] rand.workspace = true toml.workspace = true +reqwest = { workspace = true, features = ["json"] } [features] diff --git a/atoma-inference/src/models/config.rs b/atoma-inference/src/models/config.rs index 76477f4b..82c93017 100644 --- a/atoma-inference/src/models/config.rs +++ b/atoma-inference/src/models/config.rs @@ -55,7 +55,7 @@ impl ModelConfig { } } -#[derive(Debug, Deserialize, Serialize)] +#[derive(Clone, Debug, Deserialize, Serialize)] pub struct ModelsConfig { api_key: String, cache_dir: PathBuf, @@ -182,7 +182,7 @@ pub mod tests { ); let toml_str = toml::to_string(&config).unwrap(); - let should_be_toml_str = "api_key = \"my_key\"\ncache_dir = \"/\"\nflush_storage = true\ntracing = true\n\n[[models]]\ndevice_id = 0\ndtype = \"Llama2_7b\"\nmodel_id = \"F16\"\nrevision = \"\"\nuse_flash_attention = true\n"; + let should_be_toml_str = "api_key = \"my_key\"\ncache_dir = \"/\"\nflush_storage = true\ntracing = true\njrpc_port = 18001\n\n[[models]]\ndevice_id = 0\ndtype = \"Llama2_7b\"\nmodel_id = \"F16\"\nrevision = \"\"\nuse_flash_attention = true\n"; assert_eq!(toml_str, should_be_toml_str); } } diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index acbe4eac..09a68182 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -188,11 +188,18 @@ mod tests { let config_data = Value::Table(toml! { api_key = "your_api_key" - models = [[0, "f32", "mamba_370m", "", false, 0]] cache_dir = "./cache_dir/" - tokenizer_file_path = "./tokenizer_file_path/" flush_storage = true + models = [ + [ + 0, + "bf16", + "mamba_370m", + "", + false + ]] tracing = true + jrpc_port = 3000 }); let toml_string = toml::to_string_pretty(&config_data).expect("Failed to serialize to TOML"); diff --git a/atoma-inference/src/tests/mod.rs b/atoma-inference/src/tests/mod.rs index 78145b90..23406a82 100644 --- a/atoma-inference/src/tests/mod.rs +++ b/atoma-inference/src/tests/mod.rs @@ -2,46 +2,51 @@ use crate::models::{config::ModelConfig, types::ModelType, ModelError, ModelTrai use ed25519_consensus::SigningKey as PrivateKey; use std::{path::PathBuf, time::Duration}; +mod prompts; +use prompts::PROMPTS; + #[cfg(test)] mod tests { use std::{collections::HashMap, sync::mpsc}; + use futures::{stream::FuturesUnordered, StreamExt}; use rand::rngs::OsRng; + use reqwest::Client; + use serde_json::json; use tokio::sync::oneshot; + use serde_json::Value; - use crate::model_thread::{spawn_model_thread, ModelThreadCommand, ModelThreadDispatcher}; + use crate::{ + jrpc_server, + model_thread::{spawn_model_thread, ModelThreadCommand, ModelThreadDispatcher}, + models::config::ModelsConfig, + service::ModelService, + }; use super::*; - const DURATION_1_SECS: Duration = Duration::from_secs(1); - const DURATION_2_SECS: Duration = Duration::from_secs(2); - const DURATION_5_SECS: Duration = Duration::from_secs(5); - const DURATION_10_SECS: Duration = Duration::from_secs(10); - struct TestModel { duration: Duration, } impl ModelTrait for TestModel { - type Input = (); - type Output = (); - type LoadData = (); + type Input = Value; + type Output = Value; + type LoadData = Duration; fn fetch( - api_key: String, - cache_dir: PathBuf, - config: ModelConfig, + duration: String, + _cache_dir: PathBuf, + _config: ModelConfig, ) -> Result { - Ok(()) + Ok(Duration::from_secs(duration.parse().unwrap())) } - fn load(load_data: Self::LoadData) -> Result + fn load(duration: Self::LoadData) -> Result where Self: Sized, { - Ok(Self { - duration: DURATION_1_SECS, - }) + Ok(Self { duration }) } fn model_type(&self) -> ModelType { @@ -50,26 +55,26 @@ mod tests { fn run(&mut self, input: Self::Input) -> Result { std::thread::sleep(self.duration); - Ok(()) + println!( + "Finished waiting time for {:?} and input = {}", + self.duration, input + ); + Ok(input) } } impl ModelThreadDispatcher { fn test_start() -> Self { + let duration_in_secs = vec![1, 2, 5, 10]; let mut model_senders = HashMap::with_capacity(4); - for duration in [ - DURATION_1_SECS, - DURATION_2_SECS, - DURATION_5_SECS, - DURATION_10_SECS, - ] { - let model_name = format!("test_model_{:?}", duration); + for i in duration_in_secs { + let model_name = format!("test_model_{:?}", i); let (model_sender, model_receiver) = mpsc::channel::(); model_senders.insert(model_name.clone(), model_sender.clone()); - let api_key = "".to_string(); + let duration = format!("{i}"); let cache_dir = "./".parse().unwrap(); let model_config = ModelConfig::new(model_name.clone(), "".to_string(), "".to_string(), 0, false); @@ -79,7 +84,7 @@ mod tests { let _join_handle = spawn_model_thread::( model_name, - api_key, + duration, cache_dir, model_config, public_key, @@ -91,18 +96,125 @@ mod tests { } #[tokio::test] - async fn test_model_thread() { + async fn test_mock_model_thread() { + const NUM_REQUESTS: usize = 16; + let model_thread_dispatcher = ModelThreadDispatcher::test_start(); + let mut responses = FuturesUnordered::new(); - for _ in 0..10 { + let mut should_be_received_responses = vec![]; + for i in 0..NUM_REQUESTS { for sender in model_thread_dispatcher.model_senders.values() { - let (response_sender, response_request) = oneshot::channel(); + let (response_sender, response_receiver) = oneshot::channel(); + let request = json!(i); let command = ModelThreadCommand { - request: serde_json::Value::Null, + request: request.clone(), response_sender, }; sender.send(command).expect("Failed to send command"); + responses.push(response_receiver); + should_be_received_responses.push(request.as_u64().unwrap()); + } + } + + let mut received_responses = vec![]; + while let Some(response) = responses.next().await { + if let Ok(value) = response { + received_responses.push(value.as_u64().unwrap()); } } + + assert_eq!( + received_responses.sort(), + should_be_received_responses.sort() + ); + } + + #[tokio::test] + async fn test_inference_service() { + const CHANNEL_BUFFER: usize = 32; + const JRPC_PORT: u64 = 3000; + + let private_key = PrivateKey::new(OsRng); + let model_configs = vec![ + ModelConfig::new( + "mamba_130m".to_string(), + "bf16".to_string(), + "refs/pr/1".to_string(), + 0, + false, + ), + ModelConfig::new( + "mamba_370m".to_string(), + "bf16".to_string(), + "refs/pr/1".to_string(), + 0, + false, + ), + ModelConfig::new( + "llama_tiny_llama_1_1b_chat".to_string(), + "bf16".to_string(), + "main".to_string(), + 0, + false, + ), + ]; + let config = ModelsConfig::new( + "".to_string(), + "./cache_dir".parse().unwrap(), + true, + model_configs, + true, + JRPC_PORT, + ); + + let (req_sender, req_receiver) = tokio::sync::mpsc::channel(CHANNEL_BUFFER); + + println!("Starting model service.."); + let mut service = + ModelService::start(config.clone(), private_key.clone(), req_receiver).unwrap(); + + let _service_join_handle = + tokio::spawn(async move { service.run().await.expect("Failed to run service") }); + let _jrpc_server_join_handle = + tokio::spawn(async move { jrpc_server::run(req_sender.clone(), JRPC_PORT).await }); + + let client = Client::new(); + + for prompt in PROMPTS { + println!("FLAG: {prompt}"); + tokio::time::sleep(Duration::from_secs(1)).await; + + let params = json!({ + "prompt": prompt, + "temperature": 0.5, + "random_seed": 42, + "repeat_penalty": 1.0, + "repeat_last_n": 64, + "max_tokens": 32, + "_top_k": 10, + "top_p": 1.0 + }); + + let request = json!({ + "jsonrpc": "2.0", + "method": "/", + "params": params, + "id": 1 // You can use a unique identifier for each request + }); + + let response = client + .post(format!("http://localhost:{}/", JRPC_PORT)) + .json(&request) + .send() + .await + .expect("Failed to receive response from JRPCs server"); + + let response_json: Value = response + .json() + .await + .expect("Failed to parse response to JSON"); + println!("{}", response_json); + } } } diff --git a/atoma-inference/src/tests/prompts.rs b/atoma-inference/src/tests/prompts.rs new file mode 100644 index 00000000..5aca6a55 --- /dev/null +++ b/atoma-inference/src/tests/prompts.rs @@ -0,0 +1,56 @@ +pub(crate) const PROMPTS: &[&str] = &[ + "The sun set behind the mountains, painting the sky in shades of orange and purple.", + "She walked through the forest, listening to the rustling of leaves under her feet.", + "The old man sat on the park bench, feeding breadcrumbs to the pigeons.", + "As the train pulled into the station, he felt a sense of excitement for the journey ahead.", + "The city streets were bustling with activity, as people hurried to their destinations.", + "She looked out the window and watched the raindrops fall gently onto the pavement.", + "In the quiet of the night, he could hear the distant sound of crickets chirping.", + "The smell of freshly baked bread filled the air as she entered the bakery.", + "He sat by the fireplace, lost in thought as the flames danced before him.", + "The waves crashed against the shore, their rhythmic sound soothing her troubled mind.", + "She gazed up at the stars, feeling small yet connected to the universe.", + "The smell of coffee wafted through the air, inviting her to take a sip.", + "He laughed as he chased after the playful puppy, enjoying the simple pleasure of companionship.", + "She closed her eyes and let the music wash over her, transporting her to another world.", + "The smell of flowers filled the garden, attracting bees and butterflies alike.", + "He took a deep breath and plunged into the icy waters, feeling alive and invigorated.", + "As she reached the mountaintop, she was greeted by a breathtaking view of the valley below.", + "The sound of children's laughter echoed through the playground, bringing a smile to her face.", + "He savored the taste of the freshly picked strawberries, their sweetness exploding on his tongue.", + "She felt a sense of peace wash over her as she practiced yoga in the park.", + "The sound of church bells ringing in the distance signaled the start of a new day.", + "He watched in awe as the fireworks lit up the night sky, painting it in vibrant colors.", + "She felt a sense of accomplishment as she crossed the finish line, completing her first marathon.", + "The smell of rain on hot pavement filled the air, bringing relief from the summer heat.", + "He marveled at the intricate patterns of the snowflakes as they fell softly to the ground.", + "She sat on the swing, feeling the wind in her hair and the sun on her face.", + "The sound of thunder rumbled in the distance, signaling an approaching storm.", + "He listened to the sound of waves crashing against the rocks, feeling at peace with the world.", + "She watched as the leaves changed colors, ushering in the beauty of autumn.", + "The smell of barbecue filled the air as he fired up the grill for a summer cookout.", + "He felt a sense of nostalgia as he flipped through old photographs, reliving cherished memories.", + "She closed her eyes and listened to the sound of her own heartbeat, feeling alive and present in the moment.", + "The taste of hot chocolate warmed her from the inside out on a cold winter's day.", + "He watched as the clouds drifted lazily across the sky, their shapes morphing into fantastical creatures.", + "She felt a sense of wonder as she explored the hidden nooks and crannies of an old bookstore.", + "The smell of freshly cut grass reminded him of carefree childhood days spent playing in the park.", + "He sat by the window and watched the world go by, lost in his own thoughts.", + "She felt a surge of adrenaline as she jumped out of the plane, skydiving for the first time.", + "The sound of birds chirping in the morning signaled the start of a new day.", + "He watched in awe as the full moon cast its silver glow over the landscape.", + "She felt a sense of pride as she watched her garden flourish, blooming with colorful flowers.", + "The taste of ripe watermelon brought back memories of lazy summer afternoons spent with friends.", + "He listened to the sound of his own footsteps echoing through the empty streets, feeling a sense of solitude.", + "She felt a sense of belonging as she sat around the campfire, sharing stories with friends.", + "The smell of cinnamon and spices filled the kitchen as she baked a batch of cookies.", + "He felt a sense of accomplishment as he reached the summit of the mountain, conquering his fears.", + "She watched as the first snowflakes of winter fell gently to the ground, blanketing the world in white.", + "The sound of a crackling fire filled the cabin, warming her on a chilly winter's night.", + "He felt a sense of awe as he looked up at the towering skyscrapers, marveling at human ingenuity.", + "She closed her eyes and listened to the sound of the waves crashing against the shore, feeling at peace.", + "The taste of freshly squeezed lemonade cooled her down on a hot summer day.", + "He watched as the leaves danced in the wind, their colors swirling in a mesmerizing display.", + "She felt a sense of freedom as she rode her bike through the countryside, the wind in her hair.", + "The smell of pine trees filled the air as he hiked through the forest, reconnecting with nature." +]; From 2d52cbd89b3fce745aacc9ea13a6b1d54e8b7e67 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Sun, 14 Apr 2024 22:23:33 +0100 Subject: [PATCH 3/8] add more logic to test_inference_service --- atoma-inference/src/tests/mod.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/atoma-inference/src/tests/mod.rs b/atoma-inference/src/tests/mod.rs index 23406a82..b18ca788 100644 --- a/atoma-inference/src/tests/mod.rs +++ b/atoma-inference/src/tests/mod.rs @@ -180,11 +180,9 @@ mod tests { tokio::spawn(async move { jrpc_server::run(req_sender.clone(), JRPC_PORT).await }); let client = Client::new(); + tokio::time::sleep(Duration::from_secs(60)).await; for prompt in PROMPTS { - println!("FLAG: {prompt}"); - tokio::time::sleep(Duration::from_secs(1)).await; - let params = json!({ "prompt": prompt, "temperature": 0.5, From 289291cbbd906dc60742f326aa14ddd443412354 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Mon, 15 Apr 2024 08:21:46 +0100 Subject: [PATCH 4/8] add more logic to test --- atoma-inference/src/tests/mod.rs | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/atoma-inference/src/tests/mod.rs b/atoma-inference/src/tests/mod.rs index b18ca788..ae7b33b9 100644 --- a/atoma-inference/src/tests/mod.rs +++ b/atoma-inference/src/tests/mod.rs @@ -180,11 +180,15 @@ mod tests { tokio::spawn(async move { jrpc_server::run(req_sender.clone(), JRPC_PORT).await }); let client = Client::new(); - tokio::time::sleep(Duration::from_secs(60)).await; + tokio::time::sleep(Duration::from_secs(4 * 60)).await; - for prompt in PROMPTS { + let mut responses = vec![]; + for (idx, prompt) in PROMPTS.iter().enumerate() { let params = json!({ - "prompt": prompt, + "request_id": idx, + "prompt": prompt.to_string(), + "model":, + "sampled_nodes": vec![], "temperature": 0.5, "random_seed": 42, "repeat_penalty": 1.0, @@ -213,6 +217,8 @@ mod tests { .await .expect("Failed to parse response to JSON"); println!("{}", response_json); + responses.push(response_json); } + assert_eq!(responses.len(), PROMPTS.len()); } } From beaa15daa389fc40fbacc7faea19473140170154 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Mon, 15 Apr 2024 09:02:51 +0100 Subject: [PATCH 5/8] add json healthz --- atoma-inference/src/tests/mod.rs | 419 ++++++++++++++++--------------- 1 file changed, 222 insertions(+), 197 deletions(-) diff --git a/atoma-inference/src/tests/mod.rs b/atoma-inference/src/tests/mod.rs index ae7b33b9..4e2b130c 100644 --- a/atoma-inference/src/tests/mod.rs +++ b/atoma-inference/src/tests/mod.rs @@ -5,220 +5,245 @@ use std::{path::PathBuf, time::Duration}; mod prompts; use prompts::PROMPTS; -#[cfg(test)] -mod tests { - use std::{collections::HashMap, sync::mpsc}; - - use futures::{stream::FuturesUnordered, StreamExt}; - use rand::rngs::OsRng; - use reqwest::Client; - use serde_json::json; - use tokio::sync::oneshot; - use serde_json::Value; - - use crate::{ - jrpc_server, - model_thread::{spawn_model_thread, ModelThreadCommand, ModelThreadDispatcher}, - models::config::ModelsConfig, - service::ModelService, - }; - - use super::*; - - struct TestModel { - duration: Duration, +use std::{collections::HashMap, sync::mpsc}; + +use axum::http::HeaderMap; +use futures::{stream::FuturesUnordered, StreamExt}; +use rand::rngs::OsRng; +use reqwest::Client; +use serde_json::json; +use serde_json::Value; +use tokio::sync::oneshot; + +use crate::{ + jrpc_server, + model_thread::{spawn_model_thread, ModelThreadCommand, ModelThreadDispatcher}, + models::config::ModelsConfig, + service::ModelService, +}; + + +struct TestModel { + duration: Duration, +} + +impl ModelTrait for TestModel { + type Input = Value; + type Output = Value; + type LoadData = Duration; + + fn fetch( + duration: String, + _cache_dir: PathBuf, + _config: ModelConfig, + ) -> Result { + Ok(Duration::from_secs(duration.parse().unwrap())) } - impl ModelTrait for TestModel { - type Input = Value; - type Output = Value; - type LoadData = Duration; - - fn fetch( - duration: String, - _cache_dir: PathBuf, - _config: ModelConfig, - ) -> Result { - Ok(Duration::from_secs(duration.parse().unwrap())) - } + fn load(duration: Self::LoadData) -> Result + where + Self: Sized, + { + Ok(Self { duration }) + } - fn load(duration: Self::LoadData) -> Result - where - Self: Sized, - { - Ok(Self { duration }) - } + fn model_type(&self) -> ModelType { + todo!() + } - fn model_type(&self) -> ModelType { - todo!() - } + fn run(&mut self, input: Self::Input) -> Result { + std::thread::sleep(self.duration); + println!( + "Finished waiting time for {:?} and input = {}", + self.duration, input + ); + Ok(input) + } +} + +impl ModelThreadDispatcher { + fn test_start() -> Self { + let duration_in_secs = vec![1, 2, 5, 10]; + let mut model_senders = HashMap::with_capacity(4); + + for i in duration_in_secs { + let model_name = format!("test_model_{:?}", i); + + let (model_sender, model_receiver) = mpsc::channel::(); + model_senders.insert(model_name.clone(), model_sender.clone()); + + let duration = format!("{i}"); + let cache_dir = "./".parse().unwrap(); + let model_config = + ModelConfig::new(model_name.clone(), "".to_string(), "".to_string(), 0, false); - fn run(&mut self, input: Self::Input) -> Result { - std::thread::sleep(self.duration); - println!( - "Finished waiting time for {:?} and input = {}", - self.duration, input + let private_key = PrivateKey::new(OsRng); + let public_key = private_key.verification_key(); + + let _join_handle = spawn_model_thread::( + model_name, + duration, + cache_dir, + model_config, + public_key, + model_receiver, ); - Ok(input) } + Self { model_senders } } +} - impl ModelThreadDispatcher { - fn test_start() -> Self { - let duration_in_secs = vec![1, 2, 5, 10]; - let mut model_senders = HashMap::with_capacity(4); - - for i in duration_in_secs { - let model_name = format!("test_model_{:?}", i); - - let (model_sender, model_receiver) = mpsc::channel::(); - model_senders.insert(model_name.clone(), model_sender.clone()); - - let duration = format!("{i}"); - let cache_dir = "./".parse().unwrap(); - let model_config = - ModelConfig::new(model_name.clone(), "".to_string(), "".to_string(), 0, false); - - let private_key = PrivateKey::new(OsRng); - let public_key = private_key.verification_key(); - - let _join_handle = spawn_model_thread::( - model_name, - duration, - cache_dir, - model_config, - public_key, - model_receiver, - ); - } - Self { model_senders } +#[tokio::test] +async fn test_mock_model_thread() { + const NUM_REQUESTS: usize = 16; + + let model_thread_dispatcher = ModelThreadDispatcher::test_start(); + let mut responses = FuturesUnordered::new(); + + let mut should_be_received_responses = vec![]; + for i in 0..NUM_REQUESTS { + for sender in model_thread_dispatcher.model_senders.values() { + let (response_sender, response_receiver) = oneshot::channel(); + let request = json!(i); + let command = ModelThreadCommand { + request: request.clone(), + response_sender, + }; + sender.send(command).expect("Failed to send command"); + responses.push(response_receiver); + should_be_received_responses.push(request.as_u64().unwrap()); } } - #[tokio::test] - async fn test_mock_model_thread() { - const NUM_REQUESTS: usize = 16; - - let model_thread_dispatcher = ModelThreadDispatcher::test_start(); - let mut responses = FuturesUnordered::new(); - - let mut should_be_received_responses = vec![]; - for i in 0..NUM_REQUESTS { - for sender in model_thread_dispatcher.model_senders.values() { - let (response_sender, response_receiver) = oneshot::channel(); - let request = json!(i); - let command = ModelThreadCommand { - request: request.clone(), - response_sender, - }; - sender.send(command).expect("Failed to send command"); - responses.push(response_receiver); - should_be_received_responses.push(request.as_u64().unwrap()); - } + let mut received_responses = vec![]; + while let Some(response) = responses.next().await { + if let Ok(value) = response { + received_responses.push(value.as_u64().unwrap()); } + } + + received_responses.sort(); + should_be_received_responses.sort(); + + assert_eq!( + received_responses, + should_be_received_responses + ); +} - let mut received_responses = vec![]; - while let Some(response) = responses.next().await { - if let Ok(value) = response { - received_responses.push(value.as_u64().unwrap()); +#[tokio::test] +async fn test_inference_service() { + const CHANNEL_BUFFER: usize = 32; + const JRPC_PORT: u64 = 3000; + + let private_key = PrivateKey::new(OsRng); + let model_ids = ["mamba_130m", "mamba_370m", "llama_tiny_llama_1_1b_chat"]; + let model_configs = vec![ + ModelConfig::new( + "mamba_130m".to_string(), + "bf16".to_string(), + "refs/pr/1".to_string(), + 0, + false, + ), + ModelConfig::new( + "mamba_370m".to_string(), + "bf16".to_string(), + "refs/pr/1".to_string(), + 0, + false, + ), + ModelConfig::new( + "llama_tiny_llama_1_1b_chat".to_string(), + "bf16".to_string(), + "main".to_string(), + 0, + false, + ), + ]; + let config = ModelsConfig::new( + "".to_string(), + "./cache_dir".parse().unwrap(), + true, + model_configs, + true, + JRPC_PORT, + ); + + let (req_sender, req_receiver) = tokio::sync::mpsc::channel(CHANNEL_BUFFER); + + println!("Starting model service.."); + let mut service = + ModelService::start(config.clone(), private_key.clone(), req_receiver).unwrap(); + + let _service_join_handle = + tokio::spawn(async move { service.run().await.expect("Failed to run service") }); + let _jrpc_server_join_handle = + tokio::spawn(async move { jrpc_server::run(req_sender.clone(), JRPC_PORT).await }); + + let client = Client::new(); + let mut headers = HeaderMap::new(); + headers.insert("Content-Type", "application/json".parse().unwrap()); + + std::thread::sleep(Duration::from_secs(50)); + loop { + match client + .post(format!("http://localhost:{}/healthz", JRPC_PORT)) + .headers(headers.clone()) + .send() + .await + { + Ok(response) => { + let response_json: Value = response.json().await.unwrap(); + println!("DEBUG: response_json = {}", response_json); + if response_json != Value::Null { + break; + } + std::thread::sleep(Duration::from_secs(1)); + } + Err(_) => { + std::thread::sleep(Duration::from_secs(1)); } } - - assert_eq!( - received_responses.sort(), - should_be_received_responses.sort() - ); } - #[tokio::test] - async fn test_inference_service() { - const CHANNEL_BUFFER: usize = 32; - const JRPC_PORT: u64 = 3000; - - let private_key = PrivateKey::new(OsRng); - let model_configs = vec![ - ModelConfig::new( - "mamba_130m".to_string(), - "bf16".to_string(), - "refs/pr/1".to_string(), - 0, - false, - ), - ModelConfig::new( - "mamba_370m".to_string(), - "bf16".to_string(), - "refs/pr/1".to_string(), - 0, - false, - ), - ModelConfig::new( - "llama_tiny_llama_1_1b_chat".to_string(), - "bf16".to_string(), - "main".to_string(), - 0, - false, - ), - ]; - let config = ModelsConfig::new( - "".to_string(), - "./cache_dir".parse().unwrap(), - true, - model_configs, - true, - JRPC_PORT, - ); - - let (req_sender, req_receiver) = tokio::sync::mpsc::channel(CHANNEL_BUFFER); - - println!("Starting model service.."); - let mut service = - ModelService::start(config.clone(), private_key.clone(), req_receiver).unwrap(); - - let _service_join_handle = - tokio::spawn(async move { service.run().await.expect("Failed to run service") }); - let _jrpc_server_join_handle = - tokio::spawn(async move { jrpc_server::run(req_sender.clone(), JRPC_PORT).await }); - - let client = Client::new(); - tokio::time::sleep(Duration::from_secs(4 * 60)).await; - - let mut responses = vec![]; - for (idx, prompt) in PROMPTS.iter().enumerate() { - let params = json!({ - "request_id": idx, - "prompt": prompt.to_string(), - "model":, - "sampled_nodes": vec![], - "temperature": 0.5, - "random_seed": 42, - "repeat_penalty": 1.0, - "repeat_last_n": 64, - "max_tokens": 32, - "_top_k": 10, - "top_p": 1.0 - }); - - let request = json!({ - "jsonrpc": "2.0", - "method": "/", - "params": params, - "id": 1 // You can use a unique identifier for each request - }); - - let response = client - .post(format!("http://localhost:{}/", JRPC_PORT)) - .json(&request) - .send() - .await - .expect("Failed to receive response from JRPCs server"); - - let response_json: Value = response - .json() - .await - .expect("Failed to parse response to JSON"); - println!("{}", response_json); - responses.push(response_json); - } - assert_eq!(responses.len(), PROMPTS.len()); + let mut responses = vec![]; + for (idx, prompt) in PROMPTS.iter().enumerate() { + let model_id = model_ids[idx % 3]; + let params = json!({ + "request_id": idx, + "prompt": prompt.to_string(), + "model": model_id.to_string(), + "sampled_nodes": private_key.verification_key(), + "temperature": 0.5, + "random_seed": 42, + "repeat_penalty": 1.0, + "repeat_last_n": 64, + "max_tokens": 32, + "_top_k": 10, + "top_p": 1.0 + }); + + let request = json!({ + "jsonrpc": "2.0", + "method": "/", + "params": params, + "id": 1 // You can use a unique identifier for each request + }); + + let response = client + .post(format!("http://localhost:{}/", JRPC_PORT)) + .json(&request) + .send() + .await + .expect("Failed to receive response from JRPCs server"); + + let response_json: Value = response + .json() + .await + .expect("Failed to parse response to JSON"); + println!("{}", response_json); + responses.push(response_json); } + assert_eq!(responses.len(), PROMPTS.len()); } From b712c57287fb630e29fa3d7cb8728e6b22e0216a Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Mon, 15 Apr 2024 09:54:25 +0100 Subject: [PATCH 6/8] resolve bug in test --- atoma-inference/src/tests/mod.rs | 54 +++++++++----------------------- 1 file changed, 14 insertions(+), 40 deletions(-) diff --git a/atoma-inference/src/tests/mod.rs b/atoma-inference/src/tests/mod.rs index 4e2b130c..91720724 100644 --- a/atoma-inference/src/tests/mod.rs +++ b/atoma-inference/src/tests/mod.rs @@ -7,7 +7,6 @@ use prompts::PROMPTS; use std::{collections::HashMap, sync::mpsc}; -use axum::http::HeaderMap; use futures::{stream::FuturesUnordered, StreamExt}; use rand::rngs::OsRng; use reqwest::Client; @@ -22,7 +21,6 @@ use crate::{ service::ModelService, }; - struct TestModel { duration: Duration, } @@ -125,10 +123,7 @@ async fn test_mock_model_thread() { received_responses.sort(); should_be_received_responses.sort(); - assert_eq!( - received_responses, - should_be_received_responses - ); + assert_eq!(received_responses, should_be_received_responses); } #[tokio::test] @@ -141,21 +136,21 @@ async fn test_inference_service() { let model_configs = vec![ ModelConfig::new( "mamba_130m".to_string(), - "bf16".to_string(), + "f32".to_string(), "refs/pr/1".to_string(), 0, false, ), ModelConfig::new( "mamba_370m".to_string(), - "bf16".to_string(), + "f32".to_string(), "refs/pr/1".to_string(), 0, false, ), ModelConfig::new( "llama_tiny_llama_1_1b_chat".to_string(), - "bf16".to_string(), + "f32".to_string(), "main".to_string(), 0, false, @@ -175,42 +170,22 @@ async fn test_inference_service() { println!("Starting model service.."); let mut service = ModelService::start(config.clone(), private_key.clone(), req_receiver).unwrap(); - - let _service_join_handle = - tokio::spawn(async move { service.run().await.expect("Failed to run service") }); + + let _service_join_handle = tokio::spawn(async move { + service.run().await.expect("Failed to run service"); + }); let _jrpc_server_join_handle = tokio::spawn(async move { jrpc_server::run(req_sender.clone(), JRPC_PORT).await }); let client = Client::new(); - let mut headers = HeaderMap::new(); - headers.insert("Content-Type", "application/json".parse().unwrap()); - - std::thread::sleep(Duration::from_secs(50)); - loop { - match client - .post(format!("http://localhost:{}/healthz", JRPC_PORT)) - .headers(headers.clone()) - .send() - .await - { - Ok(response) => { - let response_json: Value = response.json().await.unwrap(); - println!("DEBUG: response_json = {}", response_json); - if response_json != Value::Null { - break; - } - std::thread::sleep(Duration::from_secs(1)); - } - Err(_) => { - std::thread::sleep(Duration::from_secs(1)); - } - } - } + + std::thread::sleep(Duration::from_secs(100)); let mut responses = vec![]; for (idx, prompt) in PROMPTS.iter().enumerate() { let model_id = model_ids[idx % 3]; - let params = json!({ + println!("model_id = {model_id}"); + let request = json!({ "request_id": idx, "prompt": prompt.to_string(), "model": model_id.to_string(), @@ -226,9 +201,8 @@ async fn test_inference_service() { let request = json!({ "jsonrpc": "2.0", - "method": "/", - "params": params, - "id": 1 // You can use a unique identifier for each request + "request": request, + "id": idx }); let response = client From 107c41719471970851ca8a25a254334be47950a1 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Mon, 15 Apr 2024 10:01:23 +0100 Subject: [PATCH 7/8] remove waiting --- atoma-inference/src/tests/mod.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/atoma-inference/src/tests/mod.rs b/atoma-inference/src/tests/mod.rs index 91720724..3a103baa 100644 --- a/atoma-inference/src/tests/mod.rs +++ b/atoma-inference/src/tests/mod.rs @@ -179,8 +179,6 @@ async fn test_inference_service() { let client = Client::new(); - std::thread::sleep(Duration::from_secs(100)); - let mut responses = vec![]; for (idx, prompt) in PROMPTS.iter().enumerate() { let model_id = model_ids[idx % 3]; From 965af99c8478106710f94e8237fb1186226e2540 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Mon, 15 Apr 2024 10:02:28 +0100 Subject: [PATCH 8/8] fmt --- atoma-inference/src/tests/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/atoma-inference/src/tests/mod.rs b/atoma-inference/src/tests/mod.rs index 3a103baa..07ebe761 100644 --- a/atoma-inference/src/tests/mod.rs +++ b/atoma-inference/src/tests/mod.rs @@ -170,7 +170,7 @@ async fn test_inference_service() { println!("Starting model service.."); let mut service = ModelService::start(config.clone(), private_key.clone(), req_receiver).unwrap(); - + let _service_join_handle = tokio::spawn(async move { service.run().await.expect("Failed to run service"); });