Skip to content

Commit

Permalink
Merge pull request #27 from atoma-network/20240407-testing
Browse files Browse the repository at this point in the history
add testing for model interfaces
  • Loading branch information
Cifko authored Apr 10, 2024
2 parents 4ba4aa9 + 8b0ee5a commit 0d5c8d8
Show file tree
Hide file tree
Showing 5 changed files with 434 additions and 3 deletions.
168 changes: 168 additions & 0 deletions atoma-inference/src/models/candle/falcon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,171 @@ impl ModelTrait for FalconModel {
Ok(output)
}
}

#[cfg(test)]
mod tests {
#[test]
#[cfg(feature = "metal")]
fn test_falcon_model_interface_with_metal() {
use super::*;

let api_key = "".to_string();
let cache_dir: PathBuf = "./test_falcon_cache_dir/".try_into().unwrap();
let model_id = "falcon_7b".to_string();
let dtype = "f32".to_string();
let revision = "refs/pr/43".to_string();
let device_id = 0;
let use_flash_attention = false;
let config = ModelConfig::new(
model_id,
dtype.clone(),
revision,
device_id,
use_flash_attention,
);
let load_data = FalconModel::fetch(api_key, cache_dir.clone(), config)
.expect("Failed to fetch falcon model");

println!("model device = {:?}", load_data.device);
let should_be_device = device(device_id).unwrap();
if should_be_device.is_cpu() {
assert!(load_data.device.is_cpu());
} else if should_be_device.is_cuda() {
assert!(load_data.device.is_cuda());
} else if should_be_device.is_metal() {
assert!(load_data.device.is_metal());
} else {
panic!("Invalid device")
}

assert_eq!(load_data.file_paths.len(), 4);
assert_eq!(load_data.use_flash_attention, use_flash_attention);
assert_eq!(load_data.model_type, ModelType::Falcon7b);

let should_be_dtype = DType::from_str(&dtype).unwrap();
assert_eq!(load_data.dtype, should_be_dtype);
let mut model = FalconModel::load(load_data).expect("Failed to load model");

if should_be_device.is_cpu() {
assert!(model.device.is_cpu());
} else if should_be_device.is_cuda() {
assert!(model.device.is_cuda());
} else if should_be_device.is_metal() {
assert!(model.device.is_metal());
} else {
panic!("Invalid device")
}

assert_eq!(model.dtype, should_be_dtype);
assert_eq!(model.model_type, ModelType::Falcon7b);

let prompt = "Write a hello world rust program: ".to_string();
let temperature = 0.6;
let random_seed = 42;
let repeat_penalty = 1.0;
let repeat_last_n = 20;
let max_tokens = 1;
let top_k = 10;
let top_p = 0.6;

let input = TextModelInput::new(
prompt.clone(),
temperature,
random_seed,
repeat_penalty,
repeat_last_n,
max_tokens,
top_k,
top_p,
);
let output = model.run(input).expect("Failed to run inference");

assert!(output.len() >= 1);
assert!(output.split(" ").collect::<Vec<_>>().len() <= max_tokens);

std::fs::remove_dir_all(cache_dir).unwrap();
}

#[test]
#[cfg(feature = "cuda")]
fn test_falcon_model_interface_with_cuda() {
use super::*;

let api_key = "".to_string();
let cache_dir: PathBuf = "./test_falcon_cache_dir/".try_into().unwrap();
let model_id = "falcon_7b".to_string();
let dtype = "f32".to_string();
let revision = "refs/pr/43".to_string();
let device_id = 0;
let use_flash_attention = false;
let config = ModelConfig::new(
model_id,
dtype.clone(),
revision,
device_id,
use_flash_attention,
);
let load_data = FalconModel::fetch(api_key, cache_dir.clone(), config)
.expect("Failed to fetch falcon model");

println!("model device = {:?}", load_data.device);
let should_be_device = device(device_id).unwrap();
if should_be_device.is_cpu() {
assert!(load_data.device.is_cpu());
} else if should_be_device.is_cuda() {
assert!(load_data.device.is_cuda());
} else if should_be_device.is_metal() {
assert!(load_data.device.is_metal());
} else {
panic!("Invalid device")
}

assert_eq!(load_data.file_paths.len(), 3);
assert_eq!(load_data.use_flash_attention, use_flash_attention);
assert_eq!(load_data.model_type, ModelType::Mamba130m);

let should_be_dtype = DType::from_str(&dtype).unwrap();
assert_eq!(load_data.dtype, should_be_dtype);
let mut model = FalconModel::load(load_data).expect("Failed to load model");

if should_be_device.is_cpu() {
assert!(model.device.is_cpu());
} else if should_be_device.is_cuda() {
assert!(model.device.is_cuda());
} else if should_be_device.is_metal() {
assert!(model.device.is_metal());
} else {
panic!("Invalid device")
}

assert_eq!(model.dtype, should_be_dtype);
assert_eq!(model.model_type, ModelType::Mamba130m);

let prompt = "Write a hello world rust program: ".to_string();
let temperature = 0.6;
let random_seed = 42;
let repeat_penalty = 1.0;
let repeat_last_n = 20;
let max_tokens = 1;
let top_k = 10;
let top_p = 0.6;

let input = TextModelInput::new(
prompt.clone(),
temperature,
random_seed,
repeat_penalty,
repeat_last_n,
max_tokens,
top_k,
top_p,
);
let output = model.run(input).expect("Failed to run inference");
println!("{output}");

assert!(output.len() >= 1);
assert!(output.split(" ").collect::<Vec<_>>().len() <= max_tokens);

std::fs::remove_dir_all(cache_dir).unwrap();
}
}
85 changes: 85 additions & 0 deletions atoma-inference/src/models/candle/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,88 @@ impl ModelTrait for LlamaModel {
Ok(res)
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_llama_model_interface() {
let api_key = "".to_string();
let cache_dir: PathBuf = "./test_llama_cache_dir/".try_into().unwrap();
let model_id = "llama_tiny_llama_1_1b_chat".to_string();
let dtype = "f32".to_string();
let revision = "main".to_string();
let device_id = 0;
let use_flash_attention = false;
let config = ModelConfig::new(
model_id,
dtype.clone(),
revision,
device_id,
use_flash_attention,
);
let load_data = LlamaModel::fetch(api_key, cache_dir.clone(), config)
.expect("Failed to fetch llama model");

println!("model device = {:?}", load_data.device);
let should_be_device = device(device_id).unwrap();
if should_be_device.is_cpu() {
assert!(load_data.device.is_cpu());
} else if should_be_device.is_cuda() {
assert!(load_data.device.is_cuda());
} else if should_be_device.is_metal() {
assert!(load_data.device.is_metal());
} else {
panic!("Invalid device")
}

assert_eq!(load_data.file_paths.len(), 3);
assert_eq!(load_data.use_flash_attention, use_flash_attention);
assert_eq!(load_data.model_type, ModelType::LlamaTinyLlama1_1BChat);

let should_be_dtype = DType::from_str(&dtype).unwrap();
assert_eq!(load_data.dtype, should_be_dtype);
let mut model = LlamaModel::load(load_data).expect("Failed to load model");

if should_be_device.is_cpu() {
assert!(model.device.is_cpu());
} else if should_be_device.is_cuda() {
assert!(model.device.is_cuda());
} else if should_be_device.is_metal() {
assert!(model.device.is_metal());
} else {
panic!("Invalid device")
}

assert_eq!(model.cache.use_kv_cache, true);
assert_eq!(model.model_type, ModelType::LlamaTinyLlama1_1BChat);

let prompt = "Write a hello world rust program: ".to_string();
let temperature = 0.6;
let random_seed = 42;
let repeat_penalty = 1.0;
let repeat_last_n = 20;
let max_tokens = 128;
let top_k = 10;
let top_p = 0.6;

let input = TextModelInput::new(
prompt.clone(),
temperature,
random_seed,
repeat_penalty,
repeat_last_n,
max_tokens,
top_k,
top_p,
);
let output = model.run(input).expect("Failed to run inference");
println!("output = {output}");

assert!(output.len() > 1);
assert!(output.split(" ").collect::<Vec<_>>().len() <= max_tokens);

std::fs::remove_dir_all(cache_dir).unwrap();
}
}
85 changes: 85 additions & 0 deletions atoma-inference/src/models/candle/mamba.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,88 @@ impl ModelTrait for MambaModel {
Ok(output)
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_mamba_model_interface() {
let api_key = "".to_string();
let cache_dir: PathBuf = "./test_mamba_cache_dir/".try_into().unwrap();
let model_id = "mamba_130m".to_string();
let dtype = "f32".to_string();
let revision = "refs/pr/1".to_string();
let device_id = 0;
let use_flash_attention = false;
let config = ModelConfig::new(
model_id,
dtype.clone(),
revision,
device_id,
use_flash_attention,
);
let load_data = MambaModel::fetch(api_key, cache_dir.clone(), config)
.expect("Failed to fetch mamba model");

println!("model device = {:?}", load_data.device);
let should_be_device = device(device_id).unwrap();
if should_be_device.is_cpu() {
assert!(load_data.device.is_cpu());
} else if should_be_device.is_cuda() {
assert!(load_data.device.is_cuda());
} else if should_be_device.is_metal() {
assert!(load_data.device.is_metal());
} else {
panic!("Invalid device")
}

assert_eq!(load_data.file_paths.len(), 3);
assert_eq!(load_data.use_flash_attention, use_flash_attention);
assert_eq!(load_data.model_type, ModelType::Mamba130m);

let should_be_dtype = DType::from_str(&dtype).unwrap();
assert_eq!(load_data.dtype, should_be_dtype);
let mut model = MambaModel::load(load_data).expect("Failed to load model");

if should_be_device.is_cpu() {
assert!(model.device.is_cpu());
} else if should_be_device.is_cuda() {
assert!(model.device.is_cuda());
} else if should_be_device.is_metal() {
assert!(model.device.is_metal());
} else {
panic!("Invalid device")
}

assert_eq!(model.dtype, should_be_dtype);
assert_eq!(model.model_type, ModelType::Mamba130m);

let prompt = "Write a hello world rust program: ".to_string();
let temperature = 0.6;
let random_seed = 42;
let repeat_penalty = 1.0;
let repeat_last_n = 20;
let max_tokens = 128;
let top_k = 10;
let top_p = 0.6;

let input = TextModelInput::new(
prompt.clone(),
temperature,
random_seed,
repeat_penalty,
repeat_last_n,
max_tokens,
top_k,
top_p,
);
let output = model.run(input).expect("Failed to run inference");
println!("{output}");

assert!(output.contains(&prompt));
assert!(output.len() > prompt.len());

std::fs::remove_dir_all(cache_dir).unwrap();
}
}
Loading

0 comments on commit 0d5c8d8

Please sign in to comment.