Skip to content

Commit

Permalink
resolve bug in test
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgeantonio21 committed Apr 15, 2024
1 parent 331a6ee commit b712c57
Showing 1 changed file with 14 additions and 40 deletions.
54 changes: 14 additions & 40 deletions atoma-inference/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use prompts::PROMPTS;

use std::{collections::HashMap, sync::mpsc};

use axum::http::HeaderMap;
use futures::{stream::FuturesUnordered, StreamExt};
use rand::rngs::OsRng;
use reqwest::Client;
Expand All @@ -22,7 +21,6 @@ use crate::{
service::ModelService,
};


struct TestModel {
duration: Duration,
}
Expand Down Expand Up @@ -125,10 +123,7 @@ async fn test_mock_model_thread() {
received_responses.sort();
should_be_received_responses.sort();

assert_eq!(
received_responses,
should_be_received_responses
);
assert_eq!(received_responses, should_be_received_responses);
}

#[tokio::test]
Expand All @@ -141,21 +136,21 @@ async fn test_inference_service() {
let model_configs = vec![
ModelConfig::new(
"mamba_130m".to_string(),
"bf16".to_string(),
"f32".to_string(),
"refs/pr/1".to_string(),
0,
false,
),
ModelConfig::new(
"mamba_370m".to_string(),
"bf16".to_string(),
"f32".to_string(),
"refs/pr/1".to_string(),
0,
false,
),
ModelConfig::new(
"llama_tiny_llama_1_1b_chat".to_string(),
"bf16".to_string(),
"f32".to_string(),
"main".to_string(),
0,
false,
Expand All @@ -175,42 +170,22 @@ async fn test_inference_service() {
println!("Starting model service..");
let mut service =
ModelService::start(config.clone(), private_key.clone(), req_receiver).unwrap();

let _service_join_handle =
tokio::spawn(async move { service.run().await.expect("Failed to run service") });

let _service_join_handle = tokio::spawn(async move {
service.run().await.expect("Failed to run service");
});
let _jrpc_server_join_handle =
tokio::spawn(async move { jrpc_server::run(req_sender.clone(), JRPC_PORT).await });

let client = Client::new();
let mut headers = HeaderMap::new();
headers.insert("Content-Type", "application/json".parse().unwrap());

std::thread::sleep(Duration::from_secs(50));
loop {
match client
.post(format!("http://localhost:{}/healthz", JRPC_PORT))
.headers(headers.clone())
.send()
.await
{
Ok(response) => {
let response_json: Value = response.json().await.unwrap();
println!("DEBUG: response_json = {}", response_json);
if response_json != Value::Null {
break;
}
std::thread::sleep(Duration::from_secs(1));
}
Err(_) => {
std::thread::sleep(Duration::from_secs(1));
}
}
}

std::thread::sleep(Duration::from_secs(100));

let mut responses = vec![];
for (idx, prompt) in PROMPTS.iter().enumerate() {
let model_id = model_ids[idx % 3];
let params = json!({
println!("model_id = {model_id}");
let request = json!({
"request_id": idx,
"prompt": prompt.to_string(),
"model": model_id.to_string(),
Expand All @@ -226,9 +201,8 @@ async fn test_inference_service() {

let request = json!({
"jsonrpc": "2.0",
"method": "/",
"params": params,
"id": 1 // You can use a unique identifier for each request
"request": request,
"id": idx
});

let response = client
Expand Down

0 comments on commit b712c57

Please sign in to comment.