Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add testing for model interfaces #27

Merged
merged 7 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 168 additions & 0 deletions atoma-inference/src/models/candle/falcon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,171 @@ impl ModelTrait for FalconModel {
Ok(output)
}
}

#[cfg(test)]
mod tests {
#[test]
#[cfg(feature = "metal")]
fn test_falcon_model_interface_with_metal() {
use super::*;

let api_key = "".to_string();
let cache_dir: PathBuf = "./test_falcon_cache_dir/".try_into().unwrap();
let model_id = "falcon_7b".to_string();
let dtype = "f32".to_string();
let revision = "refs/pr/43".to_string();
let device_id = 0;
let use_flash_attention = false;
let config = ModelConfig::new(
model_id,
dtype.clone(),
revision,
device_id,
use_flash_attention,
);
let load_data = FalconModel::fetch(api_key, cache_dir.clone(), config)
.expect("Failed to fetch falcon model");

println!("model device = {:?}", load_data.device);
let should_be_device = device(device_id).unwrap();
if should_be_device.is_cpu() {
assert!(load_data.device.is_cpu());
} else if should_be_device.is_cuda() {
assert!(load_data.device.is_cuda());
} else if should_be_device.is_metal() {
assert!(load_data.device.is_metal());
} else {
panic!("Invalid device")
}

assert_eq!(load_data.file_paths.len(), 4);
assert_eq!(load_data.use_flash_attention, use_flash_attention);
assert_eq!(load_data.model_type, ModelType::Falcon7b);

let should_be_dtype = DType::from_str(&dtype).unwrap();
assert_eq!(load_data.dtype, should_be_dtype);
let mut model = FalconModel::load(load_data).expect("Failed to load model");

if should_be_device.is_cpu() {
assert!(model.device.is_cpu());
} else if should_be_device.is_cuda() {
assert!(model.device.is_cuda());
} else if should_be_device.is_metal() {
assert!(model.device.is_metal());
} else {
panic!("Invalid device")
}

assert_eq!(model.dtype, should_be_dtype);
assert_eq!(model.model_type, ModelType::Falcon7b);

let prompt = "Write a hello world rust program: ".to_string();
let temperature = 0.6;
let random_seed = 42;
let repeat_penalty = 1.0;
let repeat_last_n = 20;
let max_tokens = 1;
let top_k = 10;
let top_p = 0.6;

let input = TextModelInput::new(
prompt.clone(),
temperature,
random_seed,
repeat_penalty,
repeat_last_n,
max_tokens,
top_k,
top_p,
);
let output = model.run(input).expect("Failed to run inference");

assert!(output.len() >= 1);
assert!(output.split(" ").collect::<Vec<_>>().len() <= max_tokens);

std::fs::remove_dir_all(cache_dir).unwrap();
}

#[test]
#[cfg(feature = "cuda")]
fn test_falcon_model_interface_with_cuda() {
use super::*;

let api_key = "".to_string();
let cache_dir: PathBuf = "./test_falcon_cache_dir/".try_into().unwrap();
let model_id = "falcon_7b".to_string();
let dtype = "f32".to_string();
let revision = "refs/pr/43".to_string();
let device_id = 0;
let use_flash_attention = false;
let config = ModelConfig::new(
model_id,
dtype.clone(),
revision,
device_id,
use_flash_attention,
);
let load_data = FalconModel::fetch(api_key, cache_dir.clone(), config)
.expect("Failed to fetch falcon model");

println!("model device = {:?}", load_data.device);
let should_be_device = device(device_id).unwrap();
if should_be_device.is_cpu() {
assert!(load_data.device.is_cpu());
} else if should_be_device.is_cuda() {
assert!(load_data.device.is_cuda());
} else if should_be_device.is_metal() {
assert!(load_data.device.is_metal());
} else {
panic!("Invalid device")
}

assert_eq!(load_data.file_paths.len(), 3);
assert_eq!(load_data.use_flash_attention, use_flash_attention);
assert_eq!(load_data.model_type, ModelType::Mamba130m);

let should_be_dtype = DType::from_str(&dtype).unwrap();
assert_eq!(load_data.dtype, should_be_dtype);
let mut model = FalconModel::load(load_data).expect("Failed to load model");

if should_be_device.is_cpu() {
assert!(model.device.is_cpu());
} else if should_be_device.is_cuda() {
assert!(model.device.is_cuda());
} else if should_be_device.is_metal() {
assert!(model.device.is_metal());
} else {
panic!("Invalid device")
}

assert_eq!(model.dtype, should_be_dtype);
assert_eq!(model.model_type, ModelType::Mamba130m);

let prompt = "Write a hello world rust program: ".to_string();
let temperature = 0.6;
let random_seed = 42;
let repeat_penalty = 1.0;
let repeat_last_n = 20;
let max_tokens = 1;
let top_k = 10;
let top_p = 0.6;

let input = TextModelInput::new(
prompt.clone(),
temperature,
random_seed,
repeat_penalty,
repeat_last_n,
max_tokens,
top_k,
top_p,
);
let output = model.run(input).expect("Failed to run inference");
println!("{output}");

assert!(output.len() >= 1);
assert!(output.split(" ").collect::<Vec<_>>().len() <= max_tokens);

std::fs::remove_dir_all(cache_dir).unwrap();
}
}
85 changes: 85 additions & 0 deletions atoma-inference/src/models/candle/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,88 @@ impl ModelTrait for LlamaModel {
Ok(res)
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_llama_model_interface() {
let api_key = "".to_string();
let cache_dir: PathBuf = "./test_llama_cache_dir/".try_into().unwrap();
let model_id = "llama_tiny_llama_1_1b_chat".to_string();
let dtype = "f32".to_string();
let revision = "main".to_string();
let device_id = 0;
let use_flash_attention = false;
let config = ModelConfig::new(
model_id,
dtype.clone(),
revision,
device_id,
use_flash_attention,
);
let load_data = LlamaModel::fetch(api_key, cache_dir.clone(), config)
.expect("Failed to fetch llama model");

println!("model device = {:?}", load_data.device);
let should_be_device = device(device_id).unwrap();
if should_be_device.is_cpu() {
assert!(load_data.device.is_cpu());
} else if should_be_device.is_cuda() {
assert!(load_data.device.is_cuda());
} else if should_be_device.is_metal() {
assert!(load_data.device.is_metal());
} else {
panic!("Invalid device")
}

assert_eq!(load_data.file_paths.len(), 3);
assert_eq!(load_data.use_flash_attention, use_flash_attention);
assert_eq!(load_data.model_type, ModelType::LlamaTinyLlama1_1BChat);

let should_be_dtype = DType::from_str(&dtype).unwrap();
assert_eq!(load_data.dtype, should_be_dtype);
let mut model = LlamaModel::load(load_data).expect("Failed to load model");

if should_be_device.is_cpu() {
assert!(model.device.is_cpu());
} else if should_be_device.is_cuda() {
assert!(model.device.is_cuda());
} else if should_be_device.is_metal() {
assert!(model.device.is_metal());
} else {
panic!("Invalid device")
}

assert_eq!(model.cache.use_kv_cache, true);
assert_eq!(model.model_type, ModelType::LlamaTinyLlama1_1BChat);

let prompt = "Write a hello world rust program: ".to_string();
let temperature = 0.6;
let random_seed = 42;
let repeat_penalty = 1.0;
let repeat_last_n = 20;
let max_tokens = 128;
let top_k = 10;
let top_p = 0.6;

let input = TextModelInput::new(
prompt.clone(),
temperature,
random_seed,
repeat_penalty,
repeat_last_n,
max_tokens,
top_k,
top_p,
);
let output = model.run(input).expect("Failed to run inference");
println!("output = {output}");

assert!(output.len() > 1);
assert!(output.split(" ").collect::<Vec<_>>().len() <= max_tokens);

std::fs::remove_dir_all(cache_dir).unwrap();
}
}
85 changes: 85 additions & 0 deletions atoma-inference/src/models/candle/mamba.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,88 @@ impl ModelTrait for MambaModel {
Ok(output)
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_mamba_model_interface() {
let api_key = "".to_string();
let cache_dir: PathBuf = "./test_mamba_cache_dir/".try_into().unwrap();
let model_id = "mamba_130m".to_string();
let dtype = "f32".to_string();
let revision = "refs/pr/1".to_string();
let device_id = 0;
let use_flash_attention = false;
let config = ModelConfig::new(
model_id,
dtype.clone(),
revision,
device_id,
use_flash_attention,
);
let load_data = MambaModel::fetch(api_key, cache_dir.clone(), config)
.expect("Failed to fetch mamba model");

println!("model device = {:?}", load_data.device);
let should_be_device = device(device_id).unwrap();
if should_be_device.is_cpu() {
assert!(load_data.device.is_cpu());
} else if should_be_device.is_cuda() {
assert!(load_data.device.is_cuda());
} else if should_be_device.is_metal() {
assert!(load_data.device.is_metal());
} else {
panic!("Invalid device")
}

assert_eq!(load_data.file_paths.len(), 3);
assert_eq!(load_data.use_flash_attention, use_flash_attention);
assert_eq!(load_data.model_type, ModelType::Mamba130m);

let should_be_dtype = DType::from_str(&dtype).unwrap();
assert_eq!(load_data.dtype, should_be_dtype);
let mut model = MambaModel::load(load_data).expect("Failed to load model");

if should_be_device.is_cpu() {
assert!(model.device.is_cpu());
} else if should_be_device.is_cuda() {
assert!(model.device.is_cuda());
} else if should_be_device.is_metal() {
assert!(model.device.is_metal());
} else {
panic!("Invalid device")
}

assert_eq!(model.dtype, should_be_dtype);
assert_eq!(model.model_type, ModelType::Mamba130m);

let prompt = "Write a hello world rust program: ".to_string();
let temperature = 0.6;
let random_seed = 42;
let repeat_penalty = 1.0;
let repeat_last_n = 20;
let max_tokens = 128;
let top_k = 10;
let top_p = 0.6;

let input = TextModelInput::new(
prompt.clone(),
temperature,
random_seed,
repeat_penalty,
repeat_last_n,
max_tokens,
top_k,
top_p,
);
let output = model.run(input).expect("Failed to run inference");
println!("{output}");

assert!(output.contains(&prompt));
assert!(output.len() > prompt.len());

std::fs::remove_dir_all(cache_dir).unwrap();
}
}
Loading