diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index d2024e71..dc76d499 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -13,7 +13,7 @@ async fn main() -> Result<(), ModelServiceError> { let model_config = ModelsConfig::from_file_path("../inference.toml".parse().unwrap()); let private_key_bytes = - std::fs::read("../private_key").map_err(ModelServiceError::PrivateKeyError)?; + std::fs::read("./private_key").map_err(ModelServiceError::PrivateKeyError)?; let private_key_bytes: [u8; 32] = private_key_bytes .try_into() .expect("Incorrect private key bytes length"); diff --git a/atoma-inference/src/models/candle/falcon.rs b/atoma-inference/src/models/candle/falcon.rs index d0376351..5ecbd736 100644 --- a/atoma-inference/src/models/candle/falcon.rs +++ b/atoma-inference/src/models/candle/falcon.rs @@ -11,11 +11,14 @@ use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; use tokenizers::Tokenizer; use tracing::{debug, error, info}; -use crate::models::{ - candle::hub_load_safetensors, - config::ModelConfig, - types::{LlmLoadData, ModelType, TextModelInput, TextModelOutput}, - ModelError, ModelTrait, +use crate::{ + bail, + models::{ + candle::hub_load_safetensors, + config::ModelConfig, + types::{LlmLoadData, ModelType, TextModelInput, TextModelOutput}, + ModelError, ModelTrait, + }, }; use super::device; @@ -111,7 +114,7 @@ impl ModelTrait for FalconModel { config.validate()?; if load_data.dtype != DType::BF16 && load_data.dtype != DType::F32 { - panic!("Invalid dtype, it must be either BF16 or F32 precision"); + bail!("Invalid DType for Falcon model architecture"); } let vb = unsafe { @@ -119,9 +122,16 @@ impl ModelTrait for FalconModel { &weights_filenames, load_data.dtype, &load_data.device, - )? + ) + .map_err(|e| { + info!("Failed to load model weights: {e}"); + e + })? }; - let model = Falcon::load(vb, config.clone())?; + let model = Falcon::load(vb, config.clone()).map_err(|e| { + info!("Failed to load model: {e}"); + e + })?; info!("Loaded Falcon model in {:?}", start.elapsed()); Ok(Self::new( @@ -199,171 +209,3 @@ impl ModelTrait for FalconModel { }) } } - -#[cfg(test)] -mod tests { - #[test] - #[cfg(feature = "metal")] - fn test_falcon_model_interface_with_metal() { - use super::*; - - let api_key = "".to_string(); - let cache_dir: PathBuf = "./test_falcon_cache_dir/".try_into().unwrap(); - let model_id = "falcon_7b".to_string(); - let dtype = "f32".to_string(); - let revision = "refs/pr/43".to_string(); - let device_id = 0; - let use_flash_attention = false; - let config = ModelConfig::new( - model_id, - dtype.clone(), - revision, - device_id, - use_flash_attention, - ); - let load_data = FalconModel::fetch(api_key, cache_dir.clone(), config) - .expect("Failed to fetch falcon model"); - - println!("model device = {:?}", load_data.device); - let should_be_device = device(device_id).unwrap(); - if should_be_device.is_cpu() { - assert!(load_data.device.is_cpu()); - } else if should_be_device.is_cuda() { - assert!(load_data.device.is_cuda()); - } else if should_be_device.is_metal() { - assert!(load_data.device.is_metal()); - } else { - panic!("Invalid device") - } - - assert_eq!(load_data.file_paths.len(), 4); - assert_eq!(load_data.use_flash_attention, use_flash_attention); - assert_eq!(load_data.model_type, ModelType::Falcon7b); - - let should_be_dtype = DType::from_str(&dtype).unwrap(); - assert_eq!(load_data.dtype, should_be_dtype); - let mut model = FalconModel::load(load_data).expect("Failed to load model"); - - if should_be_device.is_cpu() { - assert!(model.device.is_cpu()); - } else if should_be_device.is_cuda() { - assert!(model.device.is_cuda()); - } else if should_be_device.is_metal() { - assert!(model.device.is_metal()); - } else { - panic!("Invalid device") - } - - assert_eq!(model.dtype, should_be_dtype); - assert_eq!(model.model_type, ModelType::Falcon7b); - - let prompt = "Write a hello world rust program: ".to_string(); - let temperature = 0.6; - let random_seed = 42; - let repeat_penalty = 1.0; - let repeat_last_n = 20; - let max_tokens = 1; - let top_k = 10; - let top_p = 0.6; - - let input = TextModelInput::new( - prompt.clone(), - temperature, - random_seed, - repeat_penalty, - repeat_last_n, - max_tokens, - top_k, - top_p, - ); - let output = model.run(input).expect("Failed to run inference"); - - assert!(output.len() >= 1); - assert!(output.split(" ").collect::>().len() <= max_tokens); - - std::fs::remove_dir_all(cache_dir).unwrap(); - } - - #[test] - #[cfg(feature = "cuda")] - fn test_falcon_model_interface_with_cuda() { - use super::*; - - let api_key = "".to_string(); - let cache_dir: PathBuf = "./test_falcon_cache_dir/".try_into().unwrap(); - let model_id = "falcon_7b".to_string(); - let dtype = "f32".to_string(); - let revision = "refs/pr/43".to_string(); - let device_id = 0; - let use_flash_attention = false; - let config = ModelConfig::new( - model_id, - dtype.clone(), - revision, - device_id, - use_flash_attention, - ); - let load_data = FalconModel::fetch(api_key, cache_dir.clone(), config) - .expect("Failed to fetch falcon model"); - - println!("model device = {:?}", load_data.device); - let should_be_device = device(device_id).unwrap(); - if should_be_device.is_cpu() { - assert!(load_data.device.is_cpu()); - } else if should_be_device.is_cuda() { - assert!(load_data.device.is_cuda()); - } else if should_be_device.is_metal() { - assert!(load_data.device.is_metal()); - } else { - panic!("Invalid device") - } - - assert_eq!(load_data.file_paths.len(), 3); - assert_eq!(load_data.use_flash_attention, use_flash_attention); - assert_eq!(load_data.model_type, ModelType::Mamba130m); - - let should_be_dtype = DType::from_str(&dtype).unwrap(); - assert_eq!(load_data.dtype, should_be_dtype); - let mut model = FalconModel::load(load_data).expect("Failed to load model"); - - if should_be_device.is_cpu() { - assert!(model.device.is_cpu()); - } else if should_be_device.is_cuda() { - assert!(model.device.is_cuda()); - } else if should_be_device.is_metal() { - assert!(model.device.is_metal()); - } else { - panic!("Invalid device") - } - - assert_eq!(model.dtype, should_be_dtype); - assert_eq!(model.model_type, ModelType::Mamba130m); - - let prompt = "Write a hello world rust program: ".to_string(); - let temperature = 0.6; - let random_seed = 42; - let repeat_penalty = 1.0; - let repeat_last_n = 20; - let max_tokens = 1; - let top_k = 10; - let top_p = 0.6; - - let input = TextModelInput::new( - prompt.clone(), - temperature, - random_seed, - repeat_penalty, - repeat_last_n, - max_tokens, - top_k, - top_p, - ); - let output = model.run(input).expect("Failed to run inference"); - println!("{output}"); - - assert!(output.len() >= 1); - assert!(output.split(" ").collect::>().len() <= max_tokens); - - std::fs::remove_dir_all(cache_dir).unwrap(); - } -} diff --git a/atoma-inference/src/models/candle/mamba.rs b/atoma-inference/src/models/candle/mamba.rs index a0188024..f5abe6dc 100644 --- a/atoma-inference/src/models/candle/mamba.rs +++ b/atoma-inference/src/models/candle/mamba.rs @@ -17,7 +17,7 @@ use crate::{ candle::device, config::ModelConfig, token_output_stream::TokenOutputStream, - types::{LlmLoadData, ModelType, TextModelInput, TextModelOutput}, + types::{LlmLoadData, ModelType, TextModelInput}, ModelError, ModelTrait, }, }; @@ -53,7 +53,7 @@ impl MambaModel { impl ModelTrait for MambaModel { type Input = TextModelInput; - type Output = TextModelOutput; + type Output = String; type LoadData = LlmLoadData; fn fetch( @@ -119,6 +119,11 @@ impl ModelTrait for MambaModel { &load_data.device, )? }; + + info!( + "Loaded model weights with precision: {:?}", + var_builder.dtype() + ); let model = Model::new(&config, var_builder.pp("backbone"))?; info!("Loaded Mamba model in {:?}", start.elapsed()); @@ -222,11 +227,7 @@ impl ModelTrait for MambaModel { generated_tokens as f64 / dt.as_secs_f64(), ); - Ok(TextModelOutput { - text: output, - time: dt.as_secs_f64(), - tokens_count: generated_tokens, - }) + Ok(output) } } @@ -308,8 +309,8 @@ mod tests { let output = model.run(input).expect("Failed to run inference"); println!("{output}"); - assert!(output.text.contains(&prompt)); - assert!(output.text.len() > prompt.len()); + assert!(output.contains(&prompt)); + assert!(output.len() > prompt.len()); std::fs::remove_dir_all(cache_dir).unwrap(); }