diff --git a/atoma-inference/src/tests/mod.rs b/atoma-inference/src/tests/mod.rs index 4e2b130c..91720724 100644 --- a/atoma-inference/src/tests/mod.rs +++ b/atoma-inference/src/tests/mod.rs @@ -7,7 +7,6 @@ use prompts::PROMPTS; use std::{collections::HashMap, sync::mpsc}; -use axum::http::HeaderMap; use futures::{stream::FuturesUnordered, StreamExt}; use rand::rngs::OsRng; use reqwest::Client; @@ -22,7 +21,6 @@ use crate::{ service::ModelService, }; - struct TestModel { duration: Duration, } @@ -125,10 +123,7 @@ async fn test_mock_model_thread() { received_responses.sort(); should_be_received_responses.sort(); - assert_eq!( - received_responses, - should_be_received_responses - ); + assert_eq!(received_responses, should_be_received_responses); } #[tokio::test] @@ -141,21 +136,21 @@ async fn test_inference_service() { let model_configs = vec![ ModelConfig::new( "mamba_130m".to_string(), - "bf16".to_string(), + "f32".to_string(), "refs/pr/1".to_string(), 0, false, ), ModelConfig::new( "mamba_370m".to_string(), - "bf16".to_string(), + "f32".to_string(), "refs/pr/1".to_string(), 0, false, ), ModelConfig::new( "llama_tiny_llama_1_1b_chat".to_string(), - "bf16".to_string(), + "f32".to_string(), "main".to_string(), 0, false, @@ -175,42 +170,22 @@ async fn test_inference_service() { println!("Starting model service.."); let mut service = ModelService::start(config.clone(), private_key.clone(), req_receiver).unwrap(); - - let _service_join_handle = - tokio::spawn(async move { service.run().await.expect("Failed to run service") }); + + let _service_join_handle = tokio::spawn(async move { + service.run().await.expect("Failed to run service"); + }); let _jrpc_server_join_handle = tokio::spawn(async move { jrpc_server::run(req_sender.clone(), JRPC_PORT).await }); let client = Client::new(); - let mut headers = HeaderMap::new(); - headers.insert("Content-Type", "application/json".parse().unwrap()); - - std::thread::sleep(Duration::from_secs(50)); - loop { - match client - .post(format!("http://localhost:{}/healthz", JRPC_PORT)) - .headers(headers.clone()) - .send() - .await - { - Ok(response) => { - let response_json: Value = response.json().await.unwrap(); - println!("DEBUG: response_json = {}", response_json); - if response_json != Value::Null { - break; - } - std::thread::sleep(Duration::from_secs(1)); - } - Err(_) => { - std::thread::sleep(Duration::from_secs(1)); - } - } - } + + std::thread::sleep(Duration::from_secs(100)); let mut responses = vec![]; for (idx, prompt) in PROMPTS.iter().enumerate() { let model_id = model_ids[idx % 3]; - let params = json!({ + println!("model_id = {model_id}"); + let request = json!({ "request_id": idx, "prompt": prompt.to_string(), "model": model_id.to_string(), @@ -226,9 +201,8 @@ async fn test_inference_service() { let request = json!({ "jsonrpc": "2.0", - "method": "/", - "params": params, - "id": 1 // You can use a unique identifier for each request + "request": request, + "id": idx }); let response = client