Skip to content

Commit

Permalink
fix llama and mamba
Browse files Browse the repository at this point in the history
  • Loading branch information
Cifko committed Apr 11, 2024
1 parent 4a9bc27 commit 0076048
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
20 changes: 10 additions & 10 deletions atoma-inference/src/models/candle/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::{
generation::LogitsProcessor,
models::llama::{Cache, LlamaConfig},
models::llama::{Config, LlamaConfig},
};
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};

Expand Down Expand Up @@ -32,14 +32,13 @@ enum Which {
TinyLlama1_1BChat,
}

pub struct Config {}

pub struct LlamaModel {
cache: Cache,
device: Device,
model: model::Llama,
model_type: ModelType,
tokenizer: Tokenizer,
config: Config,
dtype: DType,
}

impl ModelTrait for LlamaModel {
Expand Down Expand Up @@ -103,7 +102,7 @@ impl ModelTrait for LlamaModel {

let device = load_data.device;
let dtype = load_data.dtype;
let (model, tokenizer_filename, cache) = {
let (model, tokenizer_filename, config) = {
let config_filename = load_data.file_paths[0].clone();
let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;

Expand All @@ -113,18 +112,18 @@ impl ModelTrait for LlamaModel {
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&load_data.file_paths[2..], dtype, &device)?
};
let cache = model::Cache::new(true, dtype, &config, &device)?; // TODO: use from config
(model::Llama::load(vb, &config)?, tokenizer_filename, cache)
(model::Llama::load(vb, &config)?, tokenizer_filename, config)
};
let tokenizer = Tokenizer::from_file(tokenizer_filename)?;
info!("Loaded Llama model in {:?}", start.elapsed());

Ok(Self {
cache,
device,
model,
model_type: load_data.model_type,
tokenizer,
config,
dtype,
})
}

Expand All @@ -147,8 +146,9 @@ impl ModelTrait for LlamaModel {
let mut generated_tokens = 0;

let start_gen = Instant::now();
let mut cache = model::Cache::new(true, self.dtype, &self.config, &self.device)?;
for index in 0..input.max_tokens {
let (context_size, context_index) = if self.cache.use_kv_cache && index > 0 {
let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
(1, index_pos)
} else {
(tokens.len(), 0)
Expand All @@ -157,7 +157,7 @@ impl ModelTrait for LlamaModel {
let input_tensor = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self
.model
.forward(&input_tensor, context_index, &mut self.cache)?;
.forward(&input_tensor, context_index, &mut cache)?;
let logits = logits.squeeze(0)?;
let logits = if input.repeat_penalty == 1. {
logits
Expand Down
2 changes: 1 addition & 1 deletion atoma-inference/src/models/candle/mamba.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ impl ModelTrait for MambaModel {
None => bail!("Invalid eos token"),
};

let mut state = State::new(1, &self.config, &self.device)?; // TODO: handle larger batch sizes
let mut state = State::new(1, &self.config, self.dtype, &self.device)?; // TODO: handle larger batch sizes

let mut next_logits = None;
let mut output = String::new();
Expand Down

0 comments on commit 0076048

Please sign in to comment.