Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgeantonio21 committed Apr 13, 2024
1 parent 8df33b7 commit 28cf4ff
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 5 deletions.
2 changes: 2 additions & 0 deletions atoma-inference/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@ pub mod model_thread;
pub mod models;
pub mod service;
pub mod specs;
#[cfg(test)]
pub mod tests;
10 changes: 5 additions & 5 deletions atoma-inference/src/model_thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ use crate::{
};

pub struct ModelThreadCommand {
request: serde_json::Value,
response_sender: oneshot::Sender<serde_json::Value>,
pub(crate) request: serde_json::Value,
pub(crate) response_sender: oneshot::Sender<serde_json::Value>,
}

#[derive(Debug, Error)]
Expand Down Expand Up @@ -94,7 +94,7 @@ where
}

pub struct ModelThreadDispatcher {
model_senders: HashMap<ModelId, mpsc::Sender<ModelThreadCommand>>,
pub(crate) model_senders: HashMap<ModelId, mpsc::Sender<ModelThreadCommand>>,
}

impl ModelThreadDispatcher {
Expand Down Expand Up @@ -172,7 +172,7 @@ impl ModelThreadDispatcher {
}
}

fn dispatch_model_thread(
pub(crate) fn dispatch_model_thread(
api_key: String,
cache_dir: PathBuf,
model_name: String,
Expand Down Expand Up @@ -231,7 +231,7 @@ fn dispatch_model_thread(
}
}

fn spawn_model_thread<M>(
pub(crate) fn spawn_model_thread<M>(
model_name: String,
api_key: String,
cache_dir: PathBuf,
Expand Down
108 changes: 108 additions & 0 deletions atoma-inference/src/tests/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
use crate::models::{config::ModelConfig, types::ModelType, ModelError, ModelTrait};
use ed25519_consensus::SigningKey as PrivateKey;
use std::{path::PathBuf, time::Duration};

#[cfg(test)]
mod tests {
use std::{collections::HashMap, sync::mpsc};

use rand::rngs::OsRng;
use tokio::sync::oneshot;

use crate::model_thread::{spawn_model_thread, ModelThreadCommand, ModelThreadDispatcher};

use super::*;

const DURATION_1_SECS: Duration = Duration::from_secs(1);
const DURATION_2_SECS: Duration = Duration::from_secs(2);
const DURATION_5_SECS: Duration = Duration::from_secs(5);
const DURATION_10_SECS: Duration = Duration::from_secs(10);

struct TestModel {
duration: Duration,
}

impl ModelTrait for TestModel {
type Input = ();
type Output = ();
type LoadData = ();

fn fetch(
api_key: String,
cache_dir: PathBuf,
config: ModelConfig,
) -> Result<Self::LoadData, ModelError> {
Ok(())
}

fn load(load_data: Self::LoadData) -> Result<Self, ModelError>
where
Self: Sized,
{
Ok(Self {
duration: DURATION_1_SECS,
})
}

fn model_type(&self) -> ModelType {
todo!()
}

fn run(&mut self, input: Self::Input) -> Result<Self::Output, ModelError> {
std::thread::sleep(self.duration);
Ok(())
}
}

impl ModelThreadDispatcher {
fn test_start() -> Self {
let mut model_senders = HashMap::with_capacity(4);

for duration in [
DURATION_1_SECS,
DURATION_2_SECS,
DURATION_5_SECS,
DURATION_10_SECS,
] {
let model_name = format!("test_model_{:?}", duration);

let (model_sender, model_receiver) = mpsc::channel::<ModelThreadCommand>();
model_senders.insert(model_name.clone(), model_sender.clone());

let api_key = "".to_string();
let cache_dir = "./".parse().unwrap();
let model_config =
ModelConfig::new(model_name.clone(), "".to_string(), "".to_string(), 0, false);

let private_key = PrivateKey::new(OsRng);
let public_key = private_key.verification_key();

let _join_handle = spawn_model_thread::<TestModel>(
model_name,
api_key,
cache_dir,
model_config,
public_key,
model_receiver,
);
}
Self { model_senders }
}
}

#[tokio::test]
async fn test_model_thread() {
let model_thread_dispatcher = ModelThreadDispatcher::test_start();

for _ in 0..10 {
for sender in model_thread_dispatcher.model_senders.values() {
let (response_sender, response_request) = oneshot::channel();
let command = ModelThreadCommand {
request: serde_json::Value::Null,
response_sender,
};
sender.send(command).expect("Failed to send command");
}
}
}
}

0 comments on commit 28cf4ff

Please sign in to comment.