From 40af812c076615d6f40d98cb7d2e70f67e10484e Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Sun, 7 Apr 2024 13:31:19 +0100 Subject: [PATCH 1/5] add both falcon and mamba test suite --- atoma-inference/src/models/candle/falcon.rs | 79 +++++++++++++++++++++ atoma-inference/src/models/candle/mamba.rs | 79 +++++++++++++++++++++ atoma-inference/src/models/types.rs | 2 +- 3 files changed, 159 insertions(+), 1 deletion(-) diff --git a/atoma-inference/src/models/candle/falcon.rs b/atoma-inference/src/models/candle/falcon.rs index 12c5164d..8a4d8804 100644 --- a/atoma-inference/src/models/candle/falcon.rs +++ b/atoma-inference/src/models/candle/falcon.rs @@ -193,3 +193,82 @@ impl ModelTrait for FalconModel { Ok(output) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_falcon_model_interface() { + let api_key = "".to_string(); + let cache_dir: PathBuf = "./test_cache_dir/".try_into().unwrap(); + let model_id = "falcon_7b".to_string(); + let dtype = "f32".to_string(); + let revision = "refs/pr/43".to_string(); + let device_id = 0; + let use_flash_attention = false; + let config = ModelConfig::new(model_id, dtype.clone(), revision, device_id, use_flash_attention); + let load_data = + FalconModel::fetch(api_key, cache_dir.clone(), config).expect("Failed to fetch falcon model"); + + println!("model device = {:?}", load_data.device); + let should_be_device = device(device_id).unwrap(); + if should_be_device.is_cpu() { + assert!(load_data.device.is_cpu()); + } else if should_be_device.is_cuda() { + assert!(load_data.device.is_cuda()); + } else if should_be_device.is_metal() { + assert!(load_data.device.is_metal()); + } else { + panic!("Invalid device") + } + + assert_eq!(load_data.file_paths.len(), 3); + assert_eq!(load_data.use_flash_attention, use_flash_attention); + assert_eq!(load_data.model_type, ModelType::Mamba130m); + + let should_be_dtype = DType::from_str(&dtype).unwrap(); + assert_eq!(load_data.dtype, should_be_dtype); + let mut model = FalconModel::load(load_data).expect("Failed to load model"); + + if should_be_device.is_cpu() { + assert!(model.device.is_cpu()); + } else if should_be_device.is_cuda() { + assert!(model.device.is_cuda()); + } else if should_be_device.is_metal() { + assert!(model.device.is_metal()); + } else { + panic!("Invalid device") + } + + assert_eq!(model.dtype, should_be_dtype); + assert_eq!(model.model_type, ModelType::Mamba130m); + + let prompt = "Write a hello world rust program: ".to_string(); + let temperature = 0.6; + let random_seed = 42; + let repeat_penalty = 1.0; + let repeat_last_n = 20; + let max_tokens = 1; + let top_k = 10; + let top_p = 0.6; + + let input = TextModelInput::new( + prompt.clone(), + temperature, + random_seed, + repeat_penalty, + repeat_last_n, + max_tokens, + top_k, + top_p, + ); + let output = model.run(input).expect("Failed to run inference"); + println!("{output}"); + + assert!(output.contains(&prompt)); + assert!(output.len() > prompt.len()); + + std::fs::remove_dir_all(cache_dir).unwrap(); + } +} diff --git a/atoma-inference/src/models/candle/mamba.rs b/atoma-inference/src/models/candle/mamba.rs index c77c68f0..60c37c57 100644 --- a/atoma-inference/src/models/candle/mamba.rs +++ b/atoma-inference/src/models/candle/mamba.rs @@ -223,3 +223,82 @@ impl ModelTrait for MambaModel { Ok(output) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_mamba_model_interface() { + let api_key = "".to_string(); + let cache_dir: PathBuf = "./test_cache_dir/".try_into().unwrap(); + let model_id = "mamba_130m".to_string(); + let dtype = "f32".to_string(); + let revision = "refs/pr/1".to_string(); + let device_id = 0; + let use_flash_attention = false; + let config = ModelConfig::new(model_id, dtype.clone(), revision, device_id, use_flash_attention); + let load_data = + MambaModel::fetch(api_key, cache_dir.clone(), config).expect("Failed to fetch mamba model"); + + println!("model device = {:?}", load_data.device); + let should_be_device = device(device_id).unwrap(); + if should_be_device.is_cpu() { + assert!(load_data.device.is_cpu()); + } else if should_be_device.is_cuda() { + assert!(load_data.device.is_cuda()); + } else if should_be_device.is_metal() { + assert!(load_data.device.is_metal()); + } else { + panic!("Invalid device") + } + + assert_eq!(load_data.file_paths.len(), 3); + assert_eq!(load_data.use_flash_attention, use_flash_attention); + assert_eq!(load_data.model_type, ModelType::Mamba130m); + + let should_be_dtype = DType::from_str(&dtype).unwrap(); + assert_eq!(load_data.dtype, should_be_dtype); + let mut model = MambaModel::load(load_data).expect("Failed to load model"); + + if should_be_device.is_cpu() { + assert!(model.device.is_cpu()); + } else if should_be_device.is_cuda() { + assert!(model.device.is_cuda()); + } else if should_be_device.is_metal() { + assert!(model.device.is_metal()); + } else { + panic!("Invalid device") + } + + assert_eq!(model.dtype, should_be_dtype); + assert_eq!(model.model_type, ModelType::Mamba130m); + + let prompt = "Write a hello world rust program: ".to_string(); + let temperature = 0.6; + let random_seed = 42; + let repeat_penalty = 1.0; + let repeat_last_n = 20; + let max_tokens = 128; + let top_k = 10; + let top_p = 0.6; + + let input = TextModelInput::new( + prompt.clone(), + temperature, + random_seed, + repeat_penalty, + repeat_last_n, + max_tokens, + top_k, + top_p, + ); + let output = model.run(input).expect("Failed to run inference"); + println!("{output}"); + + assert!(output.contains(&prompt)); + assert!(output.len() > prompt.len()); + + std::fs::remove_dir_all(cache_dir).unwrap(); + } +} diff --git a/atoma-inference/src/models/types.rs b/atoma-inference/src/models/types.rs index dfb2001f..743323f5 100644 --- a/atoma-inference/src/models/types.rs +++ b/atoma-inference/src/models/types.rs @@ -19,7 +19,7 @@ pub struct LlmLoadData { pub use_flash_attention: bool, } -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] pub enum ModelType { Falcon7b, Falcon40b, From 1b384bf5fda742e9c79b4d7a0505d9e656c39db9 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Sun, 7 Apr 2024 13:41:17 +0100 Subject: [PATCH 2/5] add cuda and metal features for falcon testing, add llama testing suite --- atoma-inference/src/models/candle/falcon.rs | 82 ++++++++++++++++++++- atoma-inference/src/models/candle/llama.rs | 80 ++++++++++++++++++++ 2 files changed, 160 insertions(+), 2 deletions(-) diff --git a/atoma-inference/src/models/candle/falcon.rs b/atoma-inference/src/models/candle/falcon.rs index 8a4d8804..77324971 100644 --- a/atoma-inference/src/models/candle/falcon.rs +++ b/atoma-inference/src/models/candle/falcon.rs @@ -196,10 +196,88 @@ impl ModelTrait for FalconModel { #[cfg(test)] mod tests { - use super::*; + #[test] + #[cfg(feature = "metal")] + fn test_falcon_model_interface_with_metal() { + use super::*; + + let api_key = "".to_string(); + let cache_dir: PathBuf = "./test_cache_dir/".try_into().unwrap(); + let model_id = "falcon_7b".to_string(); + let dtype = "f32".to_string(); + let revision = "refs/pr/43".to_string(); + let device_id = 0; + let use_flash_attention = false; + let config = ModelConfig::new(model_id, dtype.clone(), revision, device_id, use_flash_attention); + let load_data = + FalconModel::fetch(api_key, cache_dir.clone(), config).expect("Failed to fetch falcon model"); + + println!("model device = {:?}", load_data.device); + let should_be_device = device(device_id).unwrap(); + if should_be_device.is_cpu() { + assert!(load_data.device.is_cpu()); + } else if should_be_device.is_cuda() { + assert!(load_data.device.is_cuda()); + } else if should_be_device.is_metal() { + assert!(load_data.device.is_metal()); + } else { + panic!("Invalid device") + } + + assert_eq!(load_data.file_paths.len(), 3); + assert_eq!(load_data.use_flash_attention, use_flash_attention); + assert_eq!(load_data.model_type, ModelType::Mamba130m); + + let should_be_dtype = DType::from_str(&dtype).unwrap(); + assert_eq!(load_data.dtype, should_be_dtype); + let mut model = FalconModel::load(load_data).expect("Failed to load model"); + + if should_be_device.is_cpu() { + assert!(model.device.is_cpu()); + } else if should_be_device.is_cuda() { + assert!(model.device.is_cuda()); + } else if should_be_device.is_metal() { + assert!(model.device.is_metal()); + } else { + panic!("Invalid device") + } + + assert_eq!(model.dtype, should_be_dtype); + assert_eq!(model.model_type, ModelType::Mamba130m); + + let prompt = "Write a hello world rust program: ".to_string(); + let temperature = 0.6; + let random_seed = 42; + let repeat_penalty = 1.0; + let repeat_last_n = 20; + let max_tokens = 1; + let top_k = 10; + let top_p = 0.6; + + let input = TextModelInput::new( + prompt.clone(), + temperature, + random_seed, + repeat_penalty, + repeat_last_n, + max_tokens, + top_k, + top_p, + ); + let output = model.run(input).expect("Failed to run inference"); + println!("{output}"); + + assert!(output.contains(&prompt)); + assert!(output.len() > prompt.len()); + + std::fs::remove_dir_all(cache_dir).unwrap(); + } #[test] - fn test_falcon_model_interface() { + #[cfg(feature = "cuda")] + fn test_falcon_model_interface_with_cuda() { + use super::*; + let api_key = "".to_string(); let cache_dir: PathBuf = "./test_cache_dir/".try_into().unwrap(); let model_id = "falcon_7b".to_string(); diff --git a/atoma-inference/src/models/candle/llama.rs b/atoma-inference/src/models/candle/llama.rs index 7e4a9026..375c5dcc 100644 --- a/atoma-inference/src/models/candle/llama.rs +++ b/atoma-inference/src/models/candle/llama.rs @@ -183,3 +183,83 @@ impl ModelTrait for LlamaModel { Ok(res) } } + + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_llama_model_interface() { + let api_key = "".to_string(); + let cache_dir: PathBuf = "./test_cache_dir/".try_into().unwrap(); + let model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(); + let dtype = "f32".to_string(); + let revision = "main".to_string(); + let device_id = 0; + let use_flash_attention = false; + let config = ModelConfig::new(model_id, dtype.clone(), revision, device_id, use_flash_attention); + let load_data = + LlamaModel::fetch(api_key, cache_dir.clone(), config).expect("Failed to fetch llama model"); + + println!("model device = {:?}", load_data.device); + let should_be_device = device(device_id).unwrap(); + if should_be_device.is_cpu() { + assert!(load_data.device.is_cpu()); + } else if should_be_device.is_cuda() { + assert!(load_data.device.is_cuda()); + } else if should_be_device.is_metal() { + assert!(load_data.device.is_metal()); + } else { + panic!("Invalid device") + } + + assert_eq!(load_data.file_paths.len(), 3); + assert_eq!(load_data.use_flash_attention, use_flash_attention); + assert_eq!(load_data.model_type, ModelType::Mamba130m); + + let should_be_dtype = DType::from_str(&dtype).unwrap(); + assert_eq!(load_data.dtype, should_be_dtype); + let mut model = LlamaModel::load(load_data).expect("Failed to load model"); + + if should_be_device.is_cpu() { + assert!(model.device.is_cpu()); + } else if should_be_device.is_cuda() { + assert!(model.device.is_cuda()); + } else if should_be_device.is_metal() { + assert!(model.device.is_metal()); + } else { + panic!("Invalid device") + } + + assert_eq!(model.cache.use_kv_cache, true); + assert_eq!(model.model_type, ModelType::Mamba130m); + + let prompt = "Write a hello world rust program: ".to_string(); + let temperature = 0.6; + let random_seed = 42; + let repeat_penalty = 1.0; + let repeat_last_n = 20; + let max_tokens = 128; + let top_k = 10; + let top_p = 0.6; + + let input = TextModelInput::new( + prompt.clone(), + temperature, + random_seed, + repeat_penalty, + repeat_last_n, + max_tokens, + top_k, + top_p, + ); + let output = model.run(input).expect("Failed to run inference"); + println!("{output}"); + + assert!(output.contains(&prompt)); + assert!(output.len() > prompt.len()); + + std::fs::remove_dir_all(cache_dir).unwrap(); + } +} From 351879fc0b7e3d4f0bcbb482534e7c381f860ca0 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Sun, 7 Apr 2024 14:58:50 +0100 Subject: [PATCH 3/5] add test for stable diffusion, refactor previous tests --- atoma-inference/src/models/candle/falcon.rs | 65 +++++++++------ atoma-inference/src/models/candle/llama.rs | 37 +++++---- atoma-inference/src/models/candle/mamba.rs | 24 ++++-- .../src/models/candle/stable_diffusion.rs | 82 +++++++++++++++++++ atoma-inference/src/models/types.rs | 2 +- 5 files changed, 157 insertions(+), 53 deletions(-) diff --git a/atoma-inference/src/models/candle/falcon.rs b/atoma-inference/src/models/candle/falcon.rs index 77324971..c2db83c0 100644 --- a/atoma-inference/src/models/candle/falcon.rs +++ b/atoma-inference/src/models/candle/falcon.rs @@ -208,9 +208,15 @@ mod tests { let revision = "refs/pr/43".to_string(); let device_id = 0; let use_flash_attention = false; - let config = ModelConfig::new(model_id, dtype.clone(), revision, device_id, use_flash_attention); - let load_data = - FalconModel::fetch(api_key, cache_dir.clone(), config).expect("Failed to fetch falcon model"); + let config = ModelConfig::new( + model_id, + dtype.clone(), + revision, + device_id, + use_flash_attention, + ); + let load_data = FalconModel::fetch(api_key, cache_dir.clone(), config) + .expect("Failed to fetch falcon model"); println!("model device = {:?}", load_data.device); let should_be_device = device(device_id).unwrap(); @@ -218,16 +224,16 @@ mod tests { assert!(load_data.device.is_cpu()); } else if should_be_device.is_cuda() { assert!(load_data.device.is_cuda()); - } else if should_be_device.is_metal() { + } else if should_be_device.is_metal() { assert!(load_data.device.is_metal()); - } else { + } else { panic!("Invalid device") } - assert_eq!(load_data.file_paths.len(), 3); + assert_eq!(load_data.file_paths.len(), 4); assert_eq!(load_data.use_flash_attention, use_flash_attention); - assert_eq!(load_data.model_type, ModelType::Mamba130m); - + assert_eq!(load_data.model_type, ModelType::Falcon7b); + let should_be_dtype = DType::from_str(&dtype).unwrap(); assert_eq!(load_data.dtype, should_be_dtype); let mut model = FalconModel::load(load_data).expect("Failed to load model"); @@ -236,21 +242,21 @@ mod tests { assert!(model.device.is_cpu()); } else if should_be_device.is_cuda() { assert!(model.device.is_cuda()); - } else if should_be_device.is_metal() { + } else if should_be_device.is_metal() { assert!(model.device.is_metal()); - } else { + } else { panic!("Invalid device") } assert_eq!(model.dtype, should_be_dtype); - assert_eq!(model.model_type, ModelType::Mamba130m); + assert_eq!(model.model_type, ModelType::Falcon7b); let prompt = "Write a hello world rust program: ".to_string(); let temperature = 0.6; let random_seed = 42; let repeat_penalty = 1.0; let repeat_last_n = 20; - let max_tokens = 1; + let max_tokens = 128; let top_k = 10; let top_p = 0.6; @@ -265,11 +271,10 @@ mod tests { top_p, ); let output = model.run(input).expect("Failed to run inference"); - println!("{output}"); - assert!(output.contains(&prompt)); - assert!(output.len() > prompt.len()); - + assert!(output.len() > 1); + assert!(output.split(" ").collect::>().len() <= max_tokens); + std::fs::remove_dir_all(cache_dir).unwrap(); } @@ -285,9 +290,15 @@ mod tests { let revision = "refs/pr/43".to_string(); let device_id = 0; let use_flash_attention = false; - let config = ModelConfig::new(model_id, dtype.clone(), revision, device_id, use_flash_attention); - let load_data = - FalconModel::fetch(api_key, cache_dir.clone(), config).expect("Failed to fetch falcon model"); + let config = ModelConfig::new( + model_id, + dtype.clone(), + revision, + device_id, + use_flash_attention, + ); + let load_data = FalconModel::fetch(api_key, cache_dir.clone(), config) + .expect("Failed to fetch falcon model"); println!("model device = {:?}", load_data.device); let should_be_device = device(device_id).unwrap(); @@ -295,16 +306,16 @@ mod tests { assert!(load_data.device.is_cpu()); } else if should_be_device.is_cuda() { assert!(load_data.device.is_cuda()); - } else if should_be_device.is_metal() { + } else if should_be_device.is_metal() { assert!(load_data.device.is_metal()); - } else { + } else { panic!("Invalid device") } assert_eq!(load_data.file_paths.len(), 3); assert_eq!(load_data.use_flash_attention, use_flash_attention); assert_eq!(load_data.model_type, ModelType::Mamba130m); - + let should_be_dtype = DType::from_str(&dtype).unwrap(); assert_eq!(load_data.dtype, should_be_dtype); let mut model = FalconModel::load(load_data).expect("Failed to load model"); @@ -313,9 +324,9 @@ mod tests { assert!(model.device.is_cpu()); } else if should_be_device.is_cuda() { assert!(model.device.is_cuda()); - } else if should_be_device.is_metal() { + } else if should_be_device.is_metal() { assert!(model.device.is_metal()); - } else { + } else { panic!("Invalid device") } @@ -344,9 +355,9 @@ mod tests { let output = model.run(input).expect("Failed to run inference"); println!("{output}"); - assert!(output.contains(&prompt)); - assert!(output.len() > prompt.len()); - + assert!(output.len() > 1); + assert!(output.split(" ").collect::>().len() <= max_tokens); + std::fs::remove_dir_all(cache_dir).unwrap(); } } diff --git a/atoma-inference/src/models/candle/llama.rs b/atoma-inference/src/models/candle/llama.rs index 375c5dcc..0e450125 100644 --- a/atoma-inference/src/models/candle/llama.rs +++ b/atoma-inference/src/models/candle/llama.rs @@ -184,7 +184,6 @@ impl ModelTrait for LlamaModel { } } - #[cfg(test)] mod tests { use super::*; @@ -193,14 +192,20 @@ mod tests { fn test_llama_model_interface() { let api_key = "".to_string(); let cache_dir: PathBuf = "./test_cache_dir/".try_into().unwrap(); - let model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(); + let model_id = "llama_tiny_llama_1_1b_chat".to_string(); let dtype = "f32".to_string(); let revision = "main".to_string(); let device_id = 0; let use_flash_attention = false; - let config = ModelConfig::new(model_id, dtype.clone(), revision, device_id, use_flash_attention); - let load_data = - LlamaModel::fetch(api_key, cache_dir.clone(), config).expect("Failed to fetch llama model"); + let config = ModelConfig::new( + model_id, + dtype.clone(), + revision, + device_id, + use_flash_attention, + ); + let load_data = LlamaModel::fetch(api_key, cache_dir.clone(), config) + .expect("Failed to fetch llama model"); println!("model device = {:?}", load_data.device); let should_be_device = device(device_id).unwrap(); @@ -208,16 +213,16 @@ mod tests { assert!(load_data.device.is_cpu()); } else if should_be_device.is_cuda() { assert!(load_data.device.is_cuda()); - } else if should_be_device.is_metal() { + } else if should_be_device.is_metal() { assert!(load_data.device.is_metal()); - } else { + } else { panic!("Invalid device") } assert_eq!(load_data.file_paths.len(), 3); assert_eq!(load_data.use_flash_attention, use_flash_attention); - assert_eq!(load_data.model_type, ModelType::Mamba130m); - + assert_eq!(load_data.model_type, ModelType::LlamaTinyLlama1_1BChat); + let should_be_dtype = DType::from_str(&dtype).unwrap(); assert_eq!(load_data.dtype, should_be_dtype); let mut model = LlamaModel::load(load_data).expect("Failed to load model"); @@ -226,14 +231,14 @@ mod tests { assert!(model.device.is_cpu()); } else if should_be_device.is_cuda() { assert!(model.device.is_cuda()); - } else if should_be_device.is_metal() { + } else if should_be_device.is_metal() { assert!(model.device.is_metal()); - } else { + } else { panic!("Invalid device") } assert_eq!(model.cache.use_kv_cache, true); - assert_eq!(model.model_type, ModelType::Mamba130m); + assert_eq!(model.model_type, ModelType::LlamaTinyLlama1_1BChat); let prompt = "Write a hello world rust program: ".to_string(); let temperature = 0.6; @@ -255,11 +260,11 @@ mod tests { top_p, ); let output = model.run(input).expect("Failed to run inference"); - println!("{output}"); + println!("output = {output}"); + + assert!(output.len() > 1); + assert!(output.split(" ").collect::>().len() <= max_tokens); - assert!(output.contains(&prompt)); - assert!(output.len() > prompt.len()); - std::fs::remove_dir_all(cache_dir).unwrap(); } } diff --git a/atoma-inference/src/models/candle/mamba.rs b/atoma-inference/src/models/candle/mamba.rs index 60c37c57..eaf34f4b 100644 --- a/atoma-inference/src/models/candle/mamba.rs +++ b/atoma-inference/src/models/candle/mamba.rs @@ -237,9 +237,15 @@ mod tests { let revision = "refs/pr/1".to_string(); let device_id = 0; let use_flash_attention = false; - let config = ModelConfig::new(model_id, dtype.clone(), revision, device_id, use_flash_attention); - let load_data = - MambaModel::fetch(api_key, cache_dir.clone(), config).expect("Failed to fetch mamba model"); + let config = ModelConfig::new( + model_id, + dtype.clone(), + revision, + device_id, + use_flash_attention, + ); + let load_data = MambaModel::fetch(api_key, cache_dir.clone(), config) + .expect("Failed to fetch mamba model"); println!("model device = {:?}", load_data.device); let should_be_device = device(device_id).unwrap(); @@ -247,16 +253,16 @@ mod tests { assert!(load_data.device.is_cpu()); } else if should_be_device.is_cuda() { assert!(load_data.device.is_cuda()); - } else if should_be_device.is_metal() { + } else if should_be_device.is_metal() { assert!(load_data.device.is_metal()); - } else { + } else { panic!("Invalid device") } assert_eq!(load_data.file_paths.len(), 3); assert_eq!(load_data.use_flash_attention, use_flash_attention); assert_eq!(load_data.model_type, ModelType::Mamba130m); - + let should_be_dtype = DType::from_str(&dtype).unwrap(); assert_eq!(load_data.dtype, should_be_dtype); let mut model = MambaModel::load(load_data).expect("Failed to load model"); @@ -265,9 +271,9 @@ mod tests { assert!(model.device.is_cpu()); } else if should_be_device.is_cuda() { assert!(model.device.is_cuda()); - } else if should_be_device.is_metal() { + } else if should_be_device.is_metal() { assert!(model.device.is_metal()); - } else { + } else { panic!("Invalid device") } @@ -298,7 +304,7 @@ mod tests { assert!(output.contains(&prompt)); assert!(output.len() > prompt.len()); - + std::fs::remove_dir_all(cache_dir).unwrap(); } } diff --git a/atoma-inference/src/models/candle/stable_diffusion.rs b/atoma-inference/src/models/candle/stable_diffusion.rs index ea42f8be..a0a43831 100644 --- a/atoma-inference/src/models/candle/stable_diffusion.rs +++ b/atoma-inference/src/models/candle/stable_diffusion.rs @@ -610,3 +610,85 @@ impl StableDiffusion { Ok(img) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_stable_diffusion_model_interface() { + let api_key = "".to_string(); + let cache_dir: PathBuf = "./test_cache_dir/".try_into().unwrap(); + let model_id = "runwayml/stable-diffusion-v1-5".to_string(); + let dtype = "f32".to_string(); + let revision = "".to_string(); + let device_id = 0; + let use_flash_attention = false; + let config = ModelConfig::new( + model_id, + dtype.clone(), + revision, + device_id, + use_flash_attention, + ); + let load_data = StableDiffusion::fetch(api_key, cache_dir.clone(), config) + .expect("Failed to fetch mamba model"); + + println!("model device = {:?}", load_data.device); + let should_be_device = device(device_id).unwrap(); + if should_be_device.is_cpu() { + assert!(load_data.device.is_cpu()); + } else if should_be_device.is_cuda() { + assert!(load_data.device.is_cuda()); + } else if should_be_device.is_metal() { + assert!(load_data.device.is_metal()); + } else { + panic!("Invalid device") + } + + assert_eq!(load_data.use_flash_attention, use_flash_attention); + assert_eq!(load_data.model_type, ModelType::StableDiffusionV1_5); + + let should_be_dtype = DType::from_str(&dtype).unwrap(); + assert_eq!(load_data.dtype, should_be_dtype); + let mut model = StableDiffusion::load(load_data).expect("Failed to load model"); + + if should_be_device.is_cpu() { + assert!(model.device.is_cpu()); + } else if should_be_device.is_cuda() { + assert!(model.device.is_cuda()); + } else if should_be_device.is_metal() { + assert!(model.device.is_metal()); + } else { + panic!("Invalid device") + } + + assert_eq!(model.dtype, should_be_dtype); + assert_eq!(model.model_type, ModelType::StableDiffusionV1_5); + + let prompt = "A portrait of a flying cat: ".to_string(); + let uncond_prompt = "".to_string(); + let random_seed = 42; + + let input = StableDiffusionInput { + prompt: prompt.clone(), + uncond_prompt, + height: None, + width: None, + random_seed: Some(random_seed), + n_steps: None, + num_samples: 1, + model: ModelType::StableDiffusionV1_5.to_string(), + guidance_scale: None, + img2img: None, + img2img_strength: 1.0, + }; + let output = model.run(input).expect("Failed to run inference"); + println!("{:?}", output); + + assert_eq!(output[0].1, 512); + assert_eq!(output[0].2, 512); + + std::fs::remove_dir_all(cache_dir).unwrap(); + } +} diff --git a/atoma-inference/src/models/types.rs b/atoma-inference/src/models/types.rs index 743323f5..20d0ac14 100644 --- a/atoma-inference/src/models/types.rs +++ b/atoma-inference/src/models/types.rs @@ -65,7 +65,7 @@ impl FromStr for ModelType { "stable_diffusion_xl" => Ok(Self::StableDiffusionXl), "stable_diffusion_turbo" => Ok(Self::StableDiffusionTurbo), _ => Err(ModelError::InvalidModelType( - "Invalid string model type descryption".to_string(), + "Invalid string model type description".to_string(), )), } } From 8db0fb6f0855d2de844f230ffbf703c88d66fb82 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Sun, 7 Apr 2024 16:35:27 +0100 Subject: [PATCH 4/5] refactor small pieces and bits of tests --- atoma-inference/src/models/candle/falcon.rs | 4 ++-- atoma-inference/src/models/candle/llama.rs | 2 +- atoma-inference/src/models/candle/mamba.rs | 2 +- .../src/models/candle/stable_diffusion.rs | 20 ++++++++++++++----- 4 files changed, 19 insertions(+), 9 deletions(-) diff --git a/atoma-inference/src/models/candle/falcon.rs b/atoma-inference/src/models/candle/falcon.rs index c2db83c0..5c7e8094 100644 --- a/atoma-inference/src/models/candle/falcon.rs +++ b/atoma-inference/src/models/candle/falcon.rs @@ -202,7 +202,7 @@ mod tests { use super::*; let api_key = "".to_string(); - let cache_dir: PathBuf = "./test_cache_dir/".try_into().unwrap(); + let cache_dir: PathBuf = "./test_falcon_cache_dir/".try_into().unwrap(); let model_id = "falcon_7b".to_string(); let dtype = "f32".to_string(); let revision = "refs/pr/43".to_string(); @@ -284,7 +284,7 @@ mod tests { use super::*; let api_key = "".to_string(); - let cache_dir: PathBuf = "./test_cache_dir/".try_into().unwrap(); + let cache_dir: PathBuf = "./test_falcon_cache_dir/".try_into().unwrap(); let model_id = "falcon_7b".to_string(); let dtype = "f32".to_string(); let revision = "refs/pr/43".to_string(); diff --git a/atoma-inference/src/models/candle/llama.rs b/atoma-inference/src/models/candle/llama.rs index 0e450125..a1699949 100644 --- a/atoma-inference/src/models/candle/llama.rs +++ b/atoma-inference/src/models/candle/llama.rs @@ -191,7 +191,7 @@ mod tests { #[test] fn test_llama_model_interface() { let api_key = "".to_string(); - let cache_dir: PathBuf = "./test_cache_dir/".try_into().unwrap(); + let cache_dir: PathBuf = "./test_llama_cache_dir/".try_into().unwrap(); let model_id = "llama_tiny_llama_1_1b_chat".to_string(); let dtype = "f32".to_string(); let revision = "main".to_string(); diff --git a/atoma-inference/src/models/candle/mamba.rs b/atoma-inference/src/models/candle/mamba.rs index eaf34f4b..6238ed96 100644 --- a/atoma-inference/src/models/candle/mamba.rs +++ b/atoma-inference/src/models/candle/mamba.rs @@ -231,7 +231,7 @@ mod tests { #[test] fn test_mamba_model_interface() { let api_key = "".to_string(); - let cache_dir: PathBuf = "./test_cache_dir/".try_into().unwrap(); + let cache_dir: PathBuf = "./test_mamba_cache_dir/".try_into().unwrap(); let model_id = "mamba_130m".to_string(); let dtype = "f32".to_string(); let revision = "refs/pr/1".to_string(); diff --git a/atoma-inference/src/models/candle/stable_diffusion.rs b/atoma-inference/src/models/candle/stable_diffusion.rs index a0a43831..52c0ca1a 100644 --- a/atoma-inference/src/models/candle/stable_diffusion.rs +++ b/atoma-inference/src/models/candle/stable_diffusion.rs @@ -612,14 +612,16 @@ impl StableDiffusion { } #[cfg(test)] -mod tests { +mod tests { + use std::time::Duration; + use super::*; - #[test] - fn test_stable_diffusion_model_interface() { + #[tokio::test] + async fn test_stable_diffusion_model_interface() { let api_key = "".to_string(); - let cache_dir: PathBuf = "./test_cache_dir/".try_into().unwrap(); - let model_id = "runwayml/stable-diffusion-v1-5".to_string(); + let cache_dir: PathBuf = "./test_sd_cache_dir/".try_into().unwrap(); + let model_id = "stable_diffusion_v1-5".to_string(); let dtype = "f32".to_string(); let revision = "".to_string(); let device_id = 0; @@ -690,5 +692,13 @@ mod tests { assert_eq!(output[0].2, 512); std::fs::remove_dir_all(cache_dir).unwrap(); + std::fs::remove_file("tensor1").unwrap(); + std::fs::remove_file("tensor2").unwrap(); + std::fs::remove_file("tensor3").unwrap(); + std::fs::remove_file("tensor4").unwrap(); + + tokio::time::sleep(Duration::from_secs(5)).await; // give 5 seconds to look at the generated image + + std::fs::remove_file("./image.png").unwrap(); } } From 9b6486d2c003d96854c051a4fcd2d9e838f9f0e2 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Sun, 7 Apr 2024 17:03:15 +0100 Subject: [PATCH 5/5] refactor tests --- atoma-inference/src/models/candle/falcon.rs | 6 +++--- atoma-inference/src/models/candle/stable_diffusion.rs | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/atoma-inference/src/models/candle/falcon.rs b/atoma-inference/src/models/candle/falcon.rs index 5c7e8094..fe6ec082 100644 --- a/atoma-inference/src/models/candle/falcon.rs +++ b/atoma-inference/src/models/candle/falcon.rs @@ -256,7 +256,7 @@ mod tests { let random_seed = 42; let repeat_penalty = 1.0; let repeat_last_n = 20; - let max_tokens = 128; + let max_tokens = 1; let top_k = 10; let top_p = 0.6; @@ -272,7 +272,7 @@ mod tests { ); let output = model.run(input).expect("Failed to run inference"); - assert!(output.len() > 1); + assert!(output.len() >= 1); assert!(output.split(" ").collect::>().len() <= max_tokens); std::fs::remove_dir_all(cache_dir).unwrap(); @@ -355,7 +355,7 @@ mod tests { let output = model.run(input).expect("Failed to run inference"); println!("{output}"); - assert!(output.len() > 1); + assert!(output.len() >= 1); assert!(output.split(" ").collect::>().len() <= max_tokens); std::fs::remove_dir_all(cache_dir).unwrap(); diff --git a/atoma-inference/src/models/candle/stable_diffusion.rs b/atoma-inference/src/models/candle/stable_diffusion.rs index 52c0ca1a..da008e5c 100644 --- a/atoma-inference/src/models/candle/stable_diffusion.rs +++ b/atoma-inference/src/models/candle/stable_diffusion.rs @@ -20,7 +20,7 @@ use crate::{ use super::{convert_to_image, device, save_tensor_to_file}; -#[derive(Deserialize)] +#[derive(Debug, Deserialize)] pub struct StableDiffusionInput { pub prompt: String, pub uncond_prompt: String, @@ -685,6 +685,7 @@ mod tests { img2img: None, img2img_strength: 1.0, }; + println!("Running inference on input: {:?}", input); let output = model.run(input).expect("Failed to run inference"); println!("{:?}", output);