diff --git a/atoma-inference/src/models/candle/llama.rs b/atoma-inference/src/models/candle/llama.rs index 1b4090b1..d7020e1c 100644 --- a/atoma-inference/src/models/candle/llama.rs +++ b/atoma-inference/src/models/candle/llama.rs @@ -4,7 +4,7 @@ use candle::{DType, Device, Tensor}; use candle_nn::VarBuilder; use candle_transformers::{ generation::LogitsProcessor, - models::llama::{Cache, LlamaConfig}, + models::llama::{Config, LlamaConfig}, }; use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; @@ -32,14 +32,13 @@ enum Which { TinyLlama1_1BChat, } -pub struct Config {} - pub struct LlamaModel { - cache: Cache, device: Device, model: model::Llama, model_type: ModelType, tokenizer: Tokenizer, + config: Config, + dtype: DType, } impl ModelTrait for LlamaModel { @@ -103,7 +102,7 @@ impl ModelTrait for LlamaModel { let device = load_data.device; let dtype = load_data.dtype; - let (model, tokenizer_filename, cache) = { + let (model, tokenizer_filename, config) = { let config_filename = load_data.file_paths[0].clone(); let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?; @@ -113,18 +112,18 @@ impl ModelTrait for LlamaModel { let vb = unsafe { VarBuilder::from_mmaped_safetensors(&load_data.file_paths[2..], dtype, &device)? }; - let cache = model::Cache::new(true, dtype, &config, &device)?; // TODO: use from config - (model::Llama::load(vb, &config)?, tokenizer_filename, cache) + (model::Llama::load(vb, &config)?, tokenizer_filename, config) }; let tokenizer = Tokenizer::from_file(tokenizer_filename)?; info!("Loaded Llama model in {:?}", start.elapsed()); Ok(Self { - cache, device, model, model_type: load_data.model_type, tokenizer, + config, + dtype, }) } @@ -147,8 +146,9 @@ impl ModelTrait for LlamaModel { let mut generated_tokens = 0; let start_gen = Instant::now(); + let mut cache = model::Cache::new(true, self.dtype, &self.config, &self.device)?; for index in 0..input.max_tokens { - let (context_size, context_index) = if self.cache.use_kv_cache && index > 0 { + let (context_size, context_index) = if cache.use_kv_cache && index > 0 { (1, index_pos) } else { (tokens.len(), 0) @@ -157,7 +157,7 @@ impl ModelTrait for LlamaModel { let input_tensor = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; let logits = self .model - .forward(&input_tensor, context_index, &mut self.cache)?; + .forward(&input_tensor, context_index, &mut cache)?; let logits = logits.squeeze(0)?; let logits = if input.repeat_penalty == 1. { logits diff --git a/atoma-inference/src/models/candle/mamba.rs b/atoma-inference/src/models/candle/mamba.rs index a0188024..3a212cb2 100644 --- a/atoma-inference/src/models/candle/mamba.rs +++ b/atoma-inference/src/models/candle/mamba.rs @@ -167,7 +167,7 @@ impl ModelTrait for MambaModel { None => bail!("Invalid eos token"), }; - let mut state = State::new(1, &self.config, &self.device)?; // TODO: handle larger batch sizes + let mut state = State::new(1, &self.config, self.dtype, &self.device)?; // TODO: handle larger batch sizes let mut next_logits = None; let mut output = String::new();