diff --git a/atoma-inference/src/candle/llama.rs b/atoma-inference/src/candle/llama.rs index 3cb4d995..f40c5579 100644 --- a/atoma-inference/src/candle/llama.rs +++ b/atoma-inference/src/candle/llama.rs @@ -4,6 +4,8 @@ extern crate accelerate_src; #[cfg(feature = "mkl")] extern crate intel_mkl_src; +use std::path::PathBuf; + use anyhow::{bail, Error as E, Result}; use candle::{DType, Tensor}; @@ -17,6 +19,8 @@ use tokenizers::Tokenizer; use crate::candle::{device, hub_load_safetensors, token_output_stream::TokenOutputStream}; +use super::save_tensor_to_file; + const EOS_TOKEN: &str = ""; #[derive(Clone, Debug, Copy, PartialEq, Eq)] @@ -47,7 +51,7 @@ impl Default for Config { Self { temperature: None, top_p: None, - seed: 299792458, + seed: 0, sample_len: 10000, no_kv_cache: false, dtype: None, @@ -110,6 +114,7 @@ pub fn run(prompt: String, cfg: Config) -> Result { let mut logits_processor = LogitsProcessor::new(cfg.seed, cfg.temperature, cfg.top_p); let mut index_pos = 0; let mut res = String::new(); + let mut result = Vec::new(); for index in 0..cfg.sample_len { let (context_size, context_index) = if cache.use_kv_cache && index > 0 { (1, index_pos) @@ -131,8 +136,8 @@ pub fn run(prompt: String, cfg: Config) -> Result { )? }; index_pos += ctxt.len(); - let next_token = logits_processor.sample(&logits)?; + result.push(logits); tokens.push(next_token); if Some(next_token) == eos_token_id { @@ -145,5 +150,12 @@ pub fn run(prompt: String, cfg: Config) -> Result { if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? { res += &rest; } + for (i, tensor) in result.iter().enumerate() { + save_tensor_to_file( + tensor, + &PathBuf::from(format!("llama_logits_{}.pt", i.to_string())), + ) + .unwrap(); + } Ok(res) } diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index 2cdd69f8..46ae6c83 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -22,11 +22,7 @@ fn stable_diffusion() { fn llama1() { use crate::candle::llama::run; - let x = run( - "The most important thing is ".to_string(), - Default::default(), - ) - .unwrap(); + let x = run("If we find a solution to ".to_string(), Default::default()).unwrap(); println!("{}", x); }