Skip to content

Commit

Permalink
merge main and resolve issues
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgeantonio21 committed Apr 15, 2024
2 parents 1e3ab00 + 66f5cf0 commit 1c9af48
Show file tree
Hide file tree
Showing 9 changed files with 297 additions and 8 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ edition = "2021"
[workspace.dependencies]
anyhow = "1.0.81"
async-trait = "0.1.78"
atoma-inference = { path = "../atoma-inference/" }
atoma-sui = { path = "../atoma-event-subscribe/sui/" }
axum = "0.7.5"
candle = { git = "https://github.com/huggingface/candle", package = "candle-core", branch = "main" }
candle-flash-attn = { git = "https://github.com/huggingface/candle", package = "candle-flash-attn", branch = "main" }
Expand Down
1 change: 1 addition & 0 deletions atoma-inference/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ tracing-subscriber.workspace = true
[dev-dependencies]
rand.workspace = true
toml.workspace = true
reqwest = { workspace = true, features = ["json"] }


[features]
Expand Down
2 changes: 2 additions & 0 deletions atoma-inference/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ pub mod service;
pub mod specs;

pub use ed25519_consensus::SigningKey as PrivateKey;
#[cfg(test)]
pub mod tests;
4 changes: 2 additions & 2 deletions atoma-inference/src/model_thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ impl ModelThreadDispatcher {
}
}

fn dispatch_model_thread(
pub(crate) fn dispatch_model_thread(
api_key: String,
cache_dir: PathBuf,
model_name: String,
Expand Down Expand Up @@ -233,7 +233,7 @@ fn dispatch_model_thread(
}
}

fn spawn_model_thread<M>(
pub(crate) fn spawn_model_thread<M>(
model_name: String,
api_key: String,
cache_dir: PathBuf,
Expand Down
4 changes: 2 additions & 2 deletions atoma-inference/src/models/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ impl ModelConfig {
}
}

#[derive(Debug, Deserialize, Serialize)]
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ModelsConfig {
api_key: String,
cache_dir: PathBuf,
Expand Down Expand Up @@ -182,7 +182,7 @@ pub mod tests {
);

let toml_str = toml::to_string(&config).unwrap();
let should_be_toml_str = "api_key = \"my_key\"\ncache_dir = \"/\"\nflush_storage = true\ntracing = true\n\n[[models]]\ndevice_id = 0\ndtype = \"Llama2_7b\"\nmodel_id = \"F16\"\nrevision = \"\"\nuse_flash_attention = true\n";
let should_be_toml_str = "api_key = \"my_key\"\ncache_dir = \"/\"\nflush_storage = true\ntracing = true\njrpc_port = 18001\n\n[[models]]\ndevice_id = 0\ndtype = \"Llama2_7b\"\nmodel_id = \"F16\"\nrevision = \"\"\nuse_flash_attention = true\n";
assert_eq!(toml_str, should_be_toml_str);
}
}
11 changes: 9 additions & 2 deletions atoma-inference/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,18 @@ mod tests {

let config_data = Value::Table(toml! {
api_key = "your_api_key"
models = [[0, "f32", "mamba_370m", "", false, 0]]
cache_dir = "./cache_dir/"
tokenizer_file_path = "./tokenizer_file_path/"
flush_storage = true
models = [
[
0,
"bf16",
"mamba_370m",
"",
false
]]
tracing = true
jrpc_port = 3000
});
let toml_string =
toml::to_string_pretty(&config_data).expect("Failed to serialize to TOML");
Expand Down
221 changes: 221 additions & 0 deletions atoma-inference/src/tests/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
use crate::models::{config::ModelConfig, types::ModelType, ModelError, ModelTrait};
use ed25519_consensus::SigningKey as PrivateKey;
use std::{path::PathBuf, time::Duration};

mod prompts;
use prompts::PROMPTS;

use std::{collections::HashMap, sync::mpsc};

use futures::{stream::FuturesUnordered, StreamExt};
use rand::rngs::OsRng;
use reqwest::Client;
use serde_json::json;
use serde_json::Value;
use tokio::sync::oneshot;

use crate::{
jrpc_server,
model_thread::{spawn_model_thread, ModelThreadCommand, ModelThreadDispatcher},
models::config::ModelsConfig,
service::ModelService,
};

struct TestModel {
duration: Duration,
}

impl ModelTrait for TestModel {
type Input = Value;
type Output = Value;
type LoadData = Duration;

fn fetch(
duration: String,
_cache_dir: PathBuf,
_config: ModelConfig,
) -> Result<Self::LoadData, ModelError> {
Ok(Duration::from_secs(duration.parse().unwrap()))
}

fn load(duration: Self::LoadData) -> Result<Self, ModelError>
where
Self: Sized,
{
Ok(Self { duration })
}

fn model_type(&self) -> ModelType {
todo!()
}

fn run(&mut self, input: Self::Input) -> Result<Self::Output, ModelError> {
std::thread::sleep(self.duration);
println!(
"Finished waiting time for {:?} and input = {}",
self.duration, input
);
Ok(input)
}
}

impl ModelThreadDispatcher {
fn test_start() -> Self {
let duration_in_secs = vec![1, 2, 5, 10];
let mut model_senders = HashMap::with_capacity(4);

for i in duration_in_secs {
let model_name = format!("test_model_{:?}", i);

let (model_sender, model_receiver) = mpsc::channel::<ModelThreadCommand>();
model_senders.insert(model_name.clone(), model_sender.clone());

let duration = format!("{i}");
let cache_dir = "./".parse().unwrap();
let model_config =
ModelConfig::new(model_name.clone(), "".to_string(), "".to_string(), 0, false);

let private_key = PrivateKey::new(OsRng);
let public_key = private_key.verification_key();

let _join_handle = spawn_model_thread::<TestModel>(
model_name,
duration,
cache_dir,
model_config,
public_key,
model_receiver,
);
}
Self { model_senders }
}
}

#[tokio::test]
async fn test_mock_model_thread() {
const NUM_REQUESTS: usize = 16;

let model_thread_dispatcher = ModelThreadDispatcher::test_start();
let mut responses = FuturesUnordered::new();

let mut should_be_received_responses = vec![];
for i in 0..NUM_REQUESTS {
for sender in model_thread_dispatcher.model_senders.values() {
let (response_sender, response_receiver) = oneshot::channel();
let request = json!(i);
let command = ModelThreadCommand {
request: request.clone(),
response_sender,
};
sender.send(command).expect("Failed to send command");
responses.push(response_receiver);
should_be_received_responses.push(request.as_u64().unwrap());
}
}

let mut received_responses = vec![];
while let Some(response) = responses.next().await {
if let Ok(value) = response {
received_responses.push(value.as_u64().unwrap());
}
}

received_responses.sort();
should_be_received_responses.sort();

assert_eq!(received_responses, should_be_received_responses);
}

#[tokio::test]
async fn test_inference_service() {
const CHANNEL_BUFFER: usize = 32;
const JRPC_PORT: u64 = 3000;

let private_key = PrivateKey::new(OsRng);
let model_ids = ["mamba_130m", "mamba_370m", "llama_tiny_llama_1_1b_chat"];
let model_configs = vec![
ModelConfig::new(
"mamba_130m".to_string(),
"f32".to_string(),
"refs/pr/1".to_string(),
0,
false,
),
ModelConfig::new(
"mamba_370m".to_string(),
"f32".to_string(),
"refs/pr/1".to_string(),
0,
false,
),
ModelConfig::new(
"llama_tiny_llama_1_1b_chat".to_string(),
"f32".to_string(),
"main".to_string(),
0,
false,
),
];
let config = ModelsConfig::new(
"".to_string(),
"./cache_dir".parse().unwrap(),
true,
model_configs,
true,
JRPC_PORT,
);

let (req_sender, req_receiver) = tokio::sync::mpsc::channel(CHANNEL_BUFFER);

println!("Starting model service..");
let mut service =
ModelService::start(config.clone(), private_key.clone(), req_receiver).unwrap();

let _service_join_handle = tokio::spawn(async move {
service.run().await.expect("Failed to run service");
});
let _jrpc_server_join_handle =
tokio::spawn(async move { jrpc_server::run(req_sender.clone(), JRPC_PORT).await });

let client = Client::new();

let mut responses = vec![];
for (idx, prompt) in PROMPTS.iter().enumerate() {
let model_id = model_ids[idx % 3];
println!("model_id = {model_id}");
let request = json!({
"request_id": idx,
"prompt": prompt.to_string(),
"model": model_id.to_string(),
"sampled_nodes": private_key.verification_key(),
"temperature": 0.5,
"random_seed": 42,
"repeat_penalty": 1.0,
"repeat_last_n": 64,
"max_tokens": 32,
"_top_k": 10,
"top_p": 1.0
});

let request = json!({
"jsonrpc": "2.0",
"request": request,
"id": idx
});

let response = client
.post(format!("http://localhost:{}/", JRPC_PORT))
.json(&request)
.send()
.await
.expect("Failed to receive response from JRPCs server");

let response_json: Value = response
.json()
.await
.expect("Failed to parse response to JSON");
println!("{}", response_json);
responses.push(response_json);
}
assert_eq!(responses.len(), PROMPTS.len());
}
56 changes: 56 additions & 0 deletions atoma-inference/src/tests/prompts.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
pub(crate) const PROMPTS: &[&str] = &[
"The sun set behind the mountains, painting the sky in shades of orange and purple.",
"She walked through the forest, listening to the rustling of leaves under her feet.",
"The old man sat on the park bench, feeding breadcrumbs to the pigeons.",
"As the train pulled into the station, he felt a sense of excitement for the journey ahead.",
"The city streets were bustling with activity, as people hurried to their destinations.",
"She looked out the window and watched the raindrops fall gently onto the pavement.",
"In the quiet of the night, he could hear the distant sound of crickets chirping.",
"The smell of freshly baked bread filled the air as she entered the bakery.",
"He sat by the fireplace, lost in thought as the flames danced before him.",
"The waves crashed against the shore, their rhythmic sound soothing her troubled mind.",
"She gazed up at the stars, feeling small yet connected to the universe.",
"The smell of coffee wafted through the air, inviting her to take a sip.",
"He laughed as he chased after the playful puppy, enjoying the simple pleasure of companionship.",
"She closed her eyes and let the music wash over her, transporting her to another world.",
"The smell of flowers filled the garden, attracting bees and butterflies alike.",
"He took a deep breath and plunged into the icy waters, feeling alive and invigorated.",
"As she reached the mountaintop, she was greeted by a breathtaking view of the valley below.",
"The sound of children's laughter echoed through the playground, bringing a smile to her face.",
"He savored the taste of the freshly picked strawberries, their sweetness exploding on his tongue.",
"She felt a sense of peace wash over her as she practiced yoga in the park.",
"The sound of church bells ringing in the distance signaled the start of a new day.",
"He watched in awe as the fireworks lit up the night sky, painting it in vibrant colors.",
"She felt a sense of accomplishment as she crossed the finish line, completing her first marathon.",
"The smell of rain on hot pavement filled the air, bringing relief from the summer heat.",
"He marveled at the intricate patterns of the snowflakes as they fell softly to the ground.",
"She sat on the swing, feeling the wind in her hair and the sun on her face.",
"The sound of thunder rumbled in the distance, signaling an approaching storm.",
"He listened to the sound of waves crashing against the rocks, feeling at peace with the world.",
"She watched as the leaves changed colors, ushering in the beauty of autumn.",
"The smell of barbecue filled the air as he fired up the grill for a summer cookout.",
"He felt a sense of nostalgia as he flipped through old photographs, reliving cherished memories.",
"She closed her eyes and listened to the sound of her own heartbeat, feeling alive and present in the moment.",
"The taste of hot chocolate warmed her from the inside out on a cold winter's day.",
"He watched as the clouds drifted lazily across the sky, their shapes morphing into fantastical creatures.",
"She felt a sense of wonder as she explored the hidden nooks and crannies of an old bookstore.",
"The smell of freshly cut grass reminded him of carefree childhood days spent playing in the park.",
"He sat by the window and watched the world go by, lost in his own thoughts.",
"She felt a surge of adrenaline as she jumped out of the plane, skydiving for the first time.",
"The sound of birds chirping in the morning signaled the start of a new day.",
"He watched in awe as the full moon cast its silver glow over the landscape.",
"She felt a sense of pride as she watched her garden flourish, blooming with colorful flowers.",
"The taste of ripe watermelon brought back memories of lazy summer afternoons spent with friends.",
"He listened to the sound of his own footsteps echoing through the empty streets, feeling a sense of solitude.",
"She felt a sense of belonging as she sat around the campfire, sharing stories with friends.",
"The smell of cinnamon and spices filled the kitchen as she baked a batch of cookies.",
"He felt a sense of accomplishment as he reached the summit of the mountain, conquering his fears.",
"She watched as the first snowflakes of winter fell gently to the ground, blanketing the world in white.",
"The sound of a crackling fire filled the cabin, warming her on a chilly winter's night.",
"He felt a sense of awe as he looked up at the towering skyscrapers, marveling at human ingenuity.",
"She closed her eyes and listened to the sound of the waves crashing against the shore, feeling at peace.",
"The taste of freshly squeezed lemonade cooled her down on a hot summer day.",
"He watched as the leaves danced in the wind, their colors swirling in a mesmerizing display.",
"She felt a sense of freedom as she rode her bike through the countryside, the wind in her hair.",
"The smell of pine trees filled the air as he hiked through the forest, reconnecting with nature."
];
4 changes: 2 additions & 2 deletions atoma-node/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ edition.workspace = true
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
atoma-sui = { path = "../atoma-event-subscribe/sui/" }
atoma-inference = { path = "../atoma-inference/" }
atoma-sui.workspace = true
atoma-inference.workspace = true
clap.workspace = true
serde_json.workspace = true
thiserror.workspace = true
Expand Down

0 comments on commit 1c9af48

Please sign in to comment.