diff --git a/atoma-inference/src/models/candle/falcon.rs b/atoma-inference/src/models/candle/falcon.rs index 46f4ffb7..a7deb45c 100644 --- a/atoma-inference/src/models/candle/falcon.rs +++ b/atoma-inference/src/models/candle/falcon.rs @@ -195,3 +195,171 @@ impl ModelTrait for FalconModel { Ok(output) } } + +#[cfg(test)] +mod tests { + #[test] + #[cfg(feature = "metal")] + fn test_falcon_model_interface_with_metal() { + use super::*; + + let api_key = "".to_string(); + let cache_dir: PathBuf = "./test_falcon_cache_dir/".try_into().unwrap(); + let model_id = "falcon_7b".to_string(); + let dtype = "f32".to_string(); + let revision = "refs/pr/43".to_string(); + let device_id = 0; + let use_flash_attention = false; + let config = ModelConfig::new( + model_id, + dtype.clone(), + revision, + device_id, + use_flash_attention, + ); + let load_data = FalconModel::fetch(api_key, cache_dir.clone(), config) + .expect("Failed to fetch falcon model"); + + println!("model device = {:?}", load_data.device); + let should_be_device = device(device_id).unwrap(); + if should_be_device.is_cpu() { + assert!(load_data.device.is_cpu()); + } else if should_be_device.is_cuda() { + assert!(load_data.device.is_cuda()); + } else if should_be_device.is_metal() { + assert!(load_data.device.is_metal()); + } else { + panic!("Invalid device") + } + + assert_eq!(load_data.file_paths.len(), 4); + assert_eq!(load_data.use_flash_attention, use_flash_attention); + assert_eq!(load_data.model_type, ModelType::Falcon7b); + + let should_be_dtype = DType::from_str(&dtype).unwrap(); + assert_eq!(load_data.dtype, should_be_dtype); + let mut model = FalconModel::load(load_data).expect("Failed to load model"); + + if should_be_device.is_cpu() { + assert!(model.device.is_cpu()); + } else if should_be_device.is_cuda() { + assert!(model.device.is_cuda()); + } else if should_be_device.is_metal() { + assert!(model.device.is_metal()); + } else { + panic!("Invalid device") + } + + assert_eq!(model.dtype, should_be_dtype); + assert_eq!(model.model_type, ModelType::Falcon7b); + + let prompt = "Write a hello world rust program: ".to_string(); + let temperature = 0.6; + let random_seed = 42; + let repeat_penalty = 1.0; + let repeat_last_n = 20; + let max_tokens = 1; + let top_k = 10; + let top_p = 0.6; + + let input = TextModelInput::new( + prompt.clone(), + temperature, + random_seed, + repeat_penalty, + repeat_last_n, + max_tokens, + top_k, + top_p, + ); + let output = model.run(input).expect("Failed to run inference"); + + assert!(output.len() >= 1); + assert!(output.split(" ").collect::>().len() <= max_tokens); + + std::fs::remove_dir_all(cache_dir).unwrap(); + } + + #[test] + #[cfg(feature = "cuda")] + fn test_falcon_model_interface_with_cuda() { + use super::*; + + let api_key = "".to_string(); + let cache_dir: PathBuf = "./test_falcon_cache_dir/".try_into().unwrap(); + let model_id = "falcon_7b".to_string(); + let dtype = "f32".to_string(); + let revision = "refs/pr/43".to_string(); + let device_id = 0; + let use_flash_attention = false; + let config = ModelConfig::new( + model_id, + dtype.clone(), + revision, + device_id, + use_flash_attention, + ); + let load_data = FalconModel::fetch(api_key, cache_dir.clone(), config) + .expect("Failed to fetch falcon model"); + + println!("model device = {:?}", load_data.device); + let should_be_device = device(device_id).unwrap(); + if should_be_device.is_cpu() { + assert!(load_data.device.is_cpu()); + } else if should_be_device.is_cuda() { + assert!(load_data.device.is_cuda()); + } else if should_be_device.is_metal() { + assert!(load_data.device.is_metal()); + } else { + panic!("Invalid device") + } + + assert_eq!(load_data.file_paths.len(), 3); + assert_eq!(load_data.use_flash_attention, use_flash_attention); + assert_eq!(load_data.model_type, ModelType::Mamba130m); + + let should_be_dtype = DType::from_str(&dtype).unwrap(); + assert_eq!(load_data.dtype, should_be_dtype); + let mut model = FalconModel::load(load_data).expect("Failed to load model"); + + if should_be_device.is_cpu() { + assert!(model.device.is_cpu()); + } else if should_be_device.is_cuda() { + assert!(model.device.is_cuda()); + } else if should_be_device.is_metal() { + assert!(model.device.is_metal()); + } else { + panic!("Invalid device") + } + + assert_eq!(model.dtype, should_be_dtype); + assert_eq!(model.model_type, ModelType::Mamba130m); + + let prompt = "Write a hello world rust program: ".to_string(); + let temperature = 0.6; + let random_seed = 42; + let repeat_penalty = 1.0; + let repeat_last_n = 20; + let max_tokens = 1; + let top_k = 10; + let top_p = 0.6; + + let input = TextModelInput::new( + prompt.clone(), + temperature, + random_seed, + repeat_penalty, + repeat_last_n, + max_tokens, + top_k, + top_p, + ); + let output = model.run(input).expect("Failed to run inference"); + println!("{output}"); + + assert!(output.len() >= 1); + assert!(output.split(" ").collect::>().len() <= max_tokens); + + std::fs::remove_dir_all(cache_dir).unwrap(); + } +} diff --git a/atoma-inference/src/models/candle/llama.rs b/atoma-inference/src/models/candle/llama.rs index eaed4304..02005805 100644 --- a/atoma-inference/src/models/candle/llama.rs +++ b/atoma-inference/src/models/candle/llama.rs @@ -195,3 +195,88 @@ impl ModelTrait for LlamaModel { Ok(res) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_llama_model_interface() { + let api_key = "".to_string(); + let cache_dir: PathBuf = "./test_llama_cache_dir/".try_into().unwrap(); + let model_id = "llama_tiny_llama_1_1b_chat".to_string(); + let dtype = "f32".to_string(); + let revision = "main".to_string(); + let device_id = 0; + let use_flash_attention = false; + let config = ModelConfig::new( + model_id, + dtype.clone(), + revision, + device_id, + use_flash_attention, + ); + let load_data = LlamaModel::fetch(api_key, cache_dir.clone(), config) + .expect("Failed to fetch llama model"); + + println!("model device = {:?}", load_data.device); + let should_be_device = device(device_id).unwrap(); + if should_be_device.is_cpu() { + assert!(load_data.device.is_cpu()); + } else if should_be_device.is_cuda() { + assert!(load_data.device.is_cuda()); + } else if should_be_device.is_metal() { + assert!(load_data.device.is_metal()); + } else { + panic!("Invalid device") + } + + assert_eq!(load_data.file_paths.len(), 3); + assert_eq!(load_data.use_flash_attention, use_flash_attention); + assert_eq!(load_data.model_type, ModelType::LlamaTinyLlama1_1BChat); + + let should_be_dtype = DType::from_str(&dtype).unwrap(); + assert_eq!(load_data.dtype, should_be_dtype); + let mut model = LlamaModel::load(load_data).expect("Failed to load model"); + + if should_be_device.is_cpu() { + assert!(model.device.is_cpu()); + } else if should_be_device.is_cuda() { + assert!(model.device.is_cuda()); + } else if should_be_device.is_metal() { + assert!(model.device.is_metal()); + } else { + panic!("Invalid device") + } + + assert_eq!(model.cache.use_kv_cache, true); + assert_eq!(model.model_type, ModelType::LlamaTinyLlama1_1BChat); + + let prompt = "Write a hello world rust program: ".to_string(); + let temperature = 0.6; + let random_seed = 42; + let repeat_penalty = 1.0; + let repeat_last_n = 20; + let max_tokens = 128; + let top_k = 10; + let top_p = 0.6; + + let input = TextModelInput::new( + prompt.clone(), + temperature, + random_seed, + repeat_penalty, + repeat_last_n, + max_tokens, + top_k, + top_p, + ); + let output = model.run(input).expect("Failed to run inference"); + println!("output = {output}"); + + assert!(output.len() > 1); + assert!(output.split(" ").collect::>().len() <= max_tokens); + + std::fs::remove_dir_all(cache_dir).unwrap(); + } +} diff --git a/atoma-inference/src/models/candle/mamba.rs b/atoma-inference/src/models/candle/mamba.rs index 0b1213ff..46cd5e4f 100644 --- a/atoma-inference/src/models/candle/mamba.rs +++ b/atoma-inference/src/models/candle/mamba.rs @@ -225,3 +225,88 @@ impl ModelTrait for MambaModel { Ok(output) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_mamba_model_interface() { + let api_key = "".to_string(); + let cache_dir: PathBuf = "./test_mamba_cache_dir/".try_into().unwrap(); + let model_id = "mamba_130m".to_string(); + let dtype = "f32".to_string(); + let revision = "refs/pr/1".to_string(); + let device_id = 0; + let use_flash_attention = false; + let config = ModelConfig::new( + model_id, + dtype.clone(), + revision, + device_id, + use_flash_attention, + ); + let load_data = MambaModel::fetch(api_key, cache_dir.clone(), config) + .expect("Failed to fetch mamba model"); + + println!("model device = {:?}", load_data.device); + let should_be_device = device(device_id).unwrap(); + if should_be_device.is_cpu() { + assert!(load_data.device.is_cpu()); + } else if should_be_device.is_cuda() { + assert!(load_data.device.is_cuda()); + } else if should_be_device.is_metal() { + assert!(load_data.device.is_metal()); + } else { + panic!("Invalid device") + } + + assert_eq!(load_data.file_paths.len(), 3); + assert_eq!(load_data.use_flash_attention, use_flash_attention); + assert_eq!(load_data.model_type, ModelType::Mamba130m); + + let should_be_dtype = DType::from_str(&dtype).unwrap(); + assert_eq!(load_data.dtype, should_be_dtype); + let mut model = MambaModel::load(load_data).expect("Failed to load model"); + + if should_be_device.is_cpu() { + assert!(model.device.is_cpu()); + } else if should_be_device.is_cuda() { + assert!(model.device.is_cuda()); + } else if should_be_device.is_metal() { + assert!(model.device.is_metal()); + } else { + panic!("Invalid device") + } + + assert_eq!(model.dtype, should_be_dtype); + assert_eq!(model.model_type, ModelType::Mamba130m); + + let prompt = "Write a hello world rust program: ".to_string(); + let temperature = 0.6; + let random_seed = 42; + let repeat_penalty = 1.0; + let repeat_last_n = 20; + let max_tokens = 128; + let top_k = 10; + let top_p = 0.6; + + let input = TextModelInput::new( + prompt.clone(), + temperature, + random_seed, + repeat_penalty, + repeat_last_n, + max_tokens, + top_k, + top_p, + ); + let output = model.run(input).expect("Failed to run inference"); + println!("{output}"); + + assert!(output.contains(&prompt)); + assert!(output.len() > prompt.len()); + + std::fs::remove_dir_all(cache_dir).unwrap(); + } +} diff --git a/atoma-inference/src/models/candle/stable_diffusion.rs b/atoma-inference/src/models/candle/stable_diffusion.rs index b98f658d..906b68f6 100644 --- a/atoma-inference/src/models/candle/stable_diffusion.rs +++ b/atoma-inference/src/models/candle/stable_diffusion.rs @@ -20,7 +20,7 @@ use crate::{ use super::{convert_to_image, device, save_tensor_to_file}; -#[derive(Deserialize)] +#[derive(Debug, Deserialize)] pub struct StableDiffusionInput { pub prompt: String, pub uncond_prompt: String, @@ -615,3 +615,96 @@ impl StableDiffusion { Ok(img) } } + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use super::*; + + #[tokio::test] + async fn test_stable_diffusion_model_interface() { + let api_key = "".to_string(); + let cache_dir: PathBuf = "./test_sd_cache_dir/".try_into().unwrap(); + let model_id = "stable_diffusion_v1-5".to_string(); + let dtype = "f32".to_string(); + let revision = "".to_string(); + let device_id = 0; + let use_flash_attention = false; + let config = ModelConfig::new( + model_id, + dtype.clone(), + revision, + device_id, + use_flash_attention, + ); + let load_data = StableDiffusion::fetch(api_key, cache_dir.clone(), config) + .expect("Failed to fetch mamba model"); + + println!("model device = {:?}", load_data.device); + let should_be_device = device(device_id).unwrap(); + if should_be_device.is_cpu() { + assert!(load_data.device.is_cpu()); + } else if should_be_device.is_cuda() { + assert!(load_data.device.is_cuda()); + } else if should_be_device.is_metal() { + assert!(load_data.device.is_metal()); + } else { + panic!("Invalid device") + } + + assert_eq!(load_data.use_flash_attention, use_flash_attention); + assert_eq!(load_data.model_type, ModelType::StableDiffusionV1_5); + + let should_be_dtype = DType::from_str(&dtype).unwrap(); + assert_eq!(load_data.dtype, should_be_dtype); + let mut model = StableDiffusion::load(load_data).expect("Failed to load model"); + + if should_be_device.is_cpu() { + assert!(model.device.is_cpu()); + } else if should_be_device.is_cuda() { + assert!(model.device.is_cuda()); + } else if should_be_device.is_metal() { + assert!(model.device.is_metal()); + } else { + panic!("Invalid device") + } + + assert_eq!(model.dtype, should_be_dtype); + assert_eq!(model.model_type, ModelType::StableDiffusionV1_5); + + let prompt = "A portrait of a flying cat: ".to_string(); + let uncond_prompt = "".to_string(); + let random_seed = 42; + + let input = StableDiffusionInput { + prompt: prompt.clone(), + uncond_prompt, + height: None, + width: None, + random_seed: Some(random_seed), + n_steps: None, + num_samples: 1, + model: ModelType::StableDiffusionV1_5.to_string(), + guidance_scale: None, + img2img: None, + img2img_strength: 1.0, + }; + println!("Running inference on input: {:?}", input); + let output = model.run(input).expect("Failed to run inference"); + println!("{:?}", output); + + assert_eq!(output[0].1, 512); + assert_eq!(output[0].2, 512); + + std::fs::remove_dir_all(cache_dir).unwrap(); + std::fs::remove_file("tensor1").unwrap(); + std::fs::remove_file("tensor2").unwrap(); + std::fs::remove_file("tensor3").unwrap(); + std::fs::remove_file("tensor4").unwrap(); + + tokio::time::sleep(Duration::from_secs(5)).await; // give 5 seconds to look at the generated image + + std::fs::remove_file("./image.png").unwrap(); + } +} diff --git a/atoma-inference/src/models/types.rs b/atoma-inference/src/models/types.rs index e28486c7..7191aa27 100644 --- a/atoma-inference/src/models/types.rs +++ b/atoma-inference/src/models/types.rs @@ -19,7 +19,7 @@ pub struct LlmLoadData { pub use_flash_attention: bool, } -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] pub enum ModelType { Falcon7b, Falcon40b, @@ -65,7 +65,7 @@ impl FromStr for ModelType { "stable_diffusion_xl" => Ok(Self::StableDiffusionXl), "stable_diffusion_turbo" => Ok(Self::StableDiffusionTurbo), _ => Err(ModelError::InvalidModelType( - "Invalid string model type descryption".to_string(), + "Invalid string model type description".to_string(), )), } }