Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip] precision conversions #36

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion atoma-inference/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ async fn main() -> Result<(), ModelServiceError> {

let model_config = ModelsConfig::from_file_path("../inference.toml".parse().unwrap());
let private_key_bytes =
std::fs::read("../private_key").map_err(ModelServiceError::PrivateKeyError)?;
std::fs::read("./private_key").map_err(ModelServiceError::PrivateKeyError)?;
let private_key_bytes: [u8; 32] = private_key_bytes
.try_into()
.expect("Incorrect private key bytes length");
Expand Down
194 changes: 18 additions & 176 deletions atoma-inference/src/models/candle/falcon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@ use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
use tokenizers::Tokenizer;
use tracing::{debug, error, info};

use crate::models::{
candle::hub_load_safetensors,
config::ModelConfig,
types::{LlmLoadData, ModelType, TextModelInput, TextModelOutput},
ModelError, ModelTrait,
use crate::{
bail,
models::{
candle::hub_load_safetensors,
config::ModelConfig,
types::{LlmLoadData, ModelType, TextModelInput, TextModelOutput},
ModelError, ModelTrait,
},
};

use super::device;
Expand Down Expand Up @@ -111,17 +114,24 @@ impl ModelTrait for FalconModel {
config.validate()?;

if load_data.dtype != DType::BF16 && load_data.dtype != DType::F32 {
panic!("Invalid dtype, it must be either BF16 or F32 precision");
bail!("Invalid DType for Falcon model architecture");
}

let vb = unsafe {
VarBuilder::from_mmaped_safetensors(
&weights_filenames,
load_data.dtype,
&load_data.device,
)?
)
.map_err(|e| {
info!("Failed to load model weights: {e}");
e
})?
};
let model = Falcon::load(vb, config.clone())?;
let model = Falcon::load(vb, config.clone()).map_err(|e| {
info!("Failed to load model: {e}");
e
})?;
info!("Loaded Falcon model in {:?}", start.elapsed());

Ok(Self::new(
Expand Down Expand Up @@ -199,171 +209,3 @@ impl ModelTrait for FalconModel {
})
}
}

#[cfg(test)]
mod tests {
#[test]
#[cfg(feature = "metal")]
fn test_falcon_model_interface_with_metal() {
use super::*;

let api_key = "".to_string();
let cache_dir: PathBuf = "./test_falcon_cache_dir/".try_into().unwrap();
let model_id = "falcon_7b".to_string();
let dtype = "f32".to_string();
let revision = "refs/pr/43".to_string();
let device_id = 0;
let use_flash_attention = false;
let config = ModelConfig::new(
model_id,
dtype.clone(),
revision,
device_id,
use_flash_attention,
);
let load_data = FalconModel::fetch(api_key, cache_dir.clone(), config)
.expect("Failed to fetch falcon model");

println!("model device = {:?}", load_data.device);
let should_be_device = device(device_id).unwrap();
if should_be_device.is_cpu() {
assert!(load_data.device.is_cpu());
} else if should_be_device.is_cuda() {
assert!(load_data.device.is_cuda());
} else if should_be_device.is_metal() {
assert!(load_data.device.is_metal());
} else {
panic!("Invalid device")
}

assert_eq!(load_data.file_paths.len(), 4);
assert_eq!(load_data.use_flash_attention, use_flash_attention);
assert_eq!(load_data.model_type, ModelType::Falcon7b);

let should_be_dtype = DType::from_str(&dtype).unwrap();
assert_eq!(load_data.dtype, should_be_dtype);
let mut model = FalconModel::load(load_data).expect("Failed to load model");

if should_be_device.is_cpu() {
assert!(model.device.is_cpu());
} else if should_be_device.is_cuda() {
assert!(model.device.is_cuda());
} else if should_be_device.is_metal() {
assert!(model.device.is_metal());
} else {
panic!("Invalid device")
}

assert_eq!(model.dtype, should_be_dtype);
assert_eq!(model.model_type, ModelType::Falcon7b);

let prompt = "Write a hello world rust program: ".to_string();
let temperature = 0.6;
let random_seed = 42;
let repeat_penalty = 1.0;
let repeat_last_n = 20;
let max_tokens = 1;
let top_k = 10;
let top_p = 0.6;

let input = TextModelInput::new(
prompt.clone(),
temperature,
random_seed,
repeat_penalty,
repeat_last_n,
max_tokens,
top_k,
top_p,
);
let output = model.run(input).expect("Failed to run inference");

assert!(output.len() >= 1);
assert!(output.split(" ").collect::<Vec<_>>().len() <= max_tokens);

std::fs::remove_dir_all(cache_dir).unwrap();
}

#[test]
#[cfg(feature = "cuda")]
fn test_falcon_model_interface_with_cuda() {
use super::*;

let api_key = "".to_string();
let cache_dir: PathBuf = "./test_falcon_cache_dir/".try_into().unwrap();
let model_id = "falcon_7b".to_string();
let dtype = "f32".to_string();
let revision = "refs/pr/43".to_string();
let device_id = 0;
let use_flash_attention = false;
let config = ModelConfig::new(
model_id,
dtype.clone(),
revision,
device_id,
use_flash_attention,
);
let load_data = FalconModel::fetch(api_key, cache_dir.clone(), config)
.expect("Failed to fetch falcon model");

println!("model device = {:?}", load_data.device);
let should_be_device = device(device_id).unwrap();
if should_be_device.is_cpu() {
assert!(load_data.device.is_cpu());
} else if should_be_device.is_cuda() {
assert!(load_data.device.is_cuda());
} else if should_be_device.is_metal() {
assert!(load_data.device.is_metal());
} else {
panic!("Invalid device")
}

assert_eq!(load_data.file_paths.len(), 3);
assert_eq!(load_data.use_flash_attention, use_flash_attention);
assert_eq!(load_data.model_type, ModelType::Mamba130m);

let should_be_dtype = DType::from_str(&dtype).unwrap();
assert_eq!(load_data.dtype, should_be_dtype);
let mut model = FalconModel::load(load_data).expect("Failed to load model");

if should_be_device.is_cpu() {
assert!(model.device.is_cpu());
} else if should_be_device.is_cuda() {
assert!(model.device.is_cuda());
} else if should_be_device.is_metal() {
assert!(model.device.is_metal());
} else {
panic!("Invalid device")
}

assert_eq!(model.dtype, should_be_dtype);
assert_eq!(model.model_type, ModelType::Mamba130m);

let prompt = "Write a hello world rust program: ".to_string();
let temperature = 0.6;
let random_seed = 42;
let repeat_penalty = 1.0;
let repeat_last_n = 20;
let max_tokens = 1;
let top_k = 10;
let top_p = 0.6;

let input = TextModelInput::new(
prompt.clone(),
temperature,
random_seed,
repeat_penalty,
repeat_last_n,
max_tokens,
top_k,
top_p,
);
let output = model.run(input).expect("Failed to run inference");
println!("{output}");

assert!(output.len() >= 1);
assert!(output.split(" ").collect::<Vec<_>>().len() <= max_tokens);

std::fs::remove_dir_all(cache_dir).unwrap();
}
}
19 changes: 10 additions & 9 deletions atoma-inference/src/models/candle/mamba.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use crate::{
candle::device,
config::ModelConfig,
token_output_stream::TokenOutputStream,
types::{LlmLoadData, ModelType, TextModelInput, TextModelOutput},
types::{LlmLoadData, ModelType, TextModelInput},
ModelError, ModelTrait,
},
};
Expand Down Expand Up @@ -53,7 +53,7 @@ impl MambaModel {

impl ModelTrait for MambaModel {
type Input = TextModelInput;
type Output = TextModelOutput;
type Output = String;
type LoadData = LlmLoadData;

fn fetch(
Expand Down Expand Up @@ -119,6 +119,11 @@ impl ModelTrait for MambaModel {
&load_data.device,
)?
};

info!(
"Loaded model weights with precision: {:?}",
var_builder.dtype()
);
let model = Model::new(&config, var_builder.pp("backbone"))?;
info!("Loaded Mamba model in {:?}", start.elapsed());

Expand Down Expand Up @@ -222,11 +227,7 @@ impl ModelTrait for MambaModel {
generated_tokens as f64 / dt.as_secs_f64(),
);

Ok(TextModelOutput {
text: output,
time: dt.as_secs_f64(),
tokens_count: generated_tokens,
})
Ok(output)
}
}

Expand Down Expand Up @@ -308,8 +309,8 @@ mod tests {
let output = model.run(input).expect("Failed to run inference");
println!("{output}");

assert!(output.text.contains(&prompt));
assert!(output.text.len() > prompt.len());
assert!(output.contains(&prompt));
assert!(output.len() > prompt.len());

std::fs::remove_dir_all(cache_dir).unwrap();
}
Expand Down