Skip to content

Commit

Permalink
add small modification to tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgeantonio21 committed Apr 11, 2024
1 parent 0076048 commit 0aa7604
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 9 deletions.
8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ edition = "2021"
[workspace.dependencies]
async-trait = "0.1.78"
axum = "0.7.5"
candle = { git = "https://github.com/huggingface/candle", package = "candle-core", version = "0.5.0" }
candle-flash-attn = { git = "https://github.com/huggingface/candle", package = "candle-flash-attn", version = "0.5.0" }
candle-nn = { git = "https://github.com/huggingface/candle", package = "candle-nn", version = "0.5.0" }
candle-transformers = { git = "https://github.com/huggingface/candle", package = "candle-transformers", version = "0.5.0" }
candle = { git = "https://github.com/huggingface/candle", package = "candle-core", branch = "main" }
candle-flash-attn = { git = "https://github.com/huggingface/candle", package = "candle-flash-attn", branch = "main" }
candle-nn = { git = "https://github.com/huggingface/candle", package = "candle-nn", branch = "main" }
candle-transformers = { git = "https://github.com/huggingface/candle", package = "candle-transformers", branch = "main" }
config = "0.14.0"
dotenv = "0.15.0"
ed25519-consensus = "2.1.0"
Expand Down
8 changes: 4 additions & 4 deletions atoma-inference/src/models/candle/falcon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,8 @@ mod tests {
);
let output = model.run(input).expect("Failed to run inference");

assert!(output.len() >= 1);
assert!(output.split(" ").collect::<Vec<_>>().len() <= max_tokens);
assert!(output.text.len() >= 1);
assert!(output.text.split(" ").collect::<Vec<_>>().len() <= max_tokens);

std::fs::remove_dir_all(cache_dir).unwrap();
}
Expand Down Expand Up @@ -361,8 +361,8 @@ mod tests {
let output = model.run(input).expect("Failed to run inference");
println!("{output}");

assert!(output.len() >= 1);
assert!(output.split(" ").collect::<Vec<_>>().len() <= max_tokens);
assert!(output.text.len() >= 1);
assert!(output.text.split(" ").collect::<Vec<_>>().len() <= max_tokens);

std::fs::remove_dir_all(cache_dir).unwrap();
}
Expand Down
1 change: 0 additions & 1 deletion atoma-inference/src/models/candle/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ mod tests {
panic!("Invalid device")
}

assert!(model.cache.use_kv_cache);
assert_eq!(model.model_type, ModelType::LlamaTinyLlama1_1BChat);

let prompt = "Write a hello world rust program: ".to_string();
Expand Down

0 comments on commit 0aa7604

Please sign in to comment.