From cda82dbc7ac63f62fd093bae5ca5c22d3b4a0e2b Mon Sep 17 00:00:00 2001 From: Martin Stefcek Date: Wed, 3 Apr 2024 18:48:37 +0400 Subject: [PATCH] add candle::llama --- atoma-inference/src/candle/llama.rs | 234 ++++++++++++++++++ atoma-inference/src/candle/mod.rs | 1 + .../src/candle/stable_diffusion.rs | 6 +- atoma-inference/src/main.rs | 3 + atoma-inference/src/models/candle/mamba.rs | 3 +- atoma-inference/src/models/mod.rs | 3 +- atoma-inference/src/service.rs | 8 +- 7 files changed, 251 insertions(+), 7 deletions(-) create mode 100644 atoma-inference/src/candle/llama.rs diff --git a/atoma-inference/src/candle/llama.rs b/atoma-inference/src/candle/llama.rs new file mode 100644 index 00000000..feff72d6 --- /dev/null +++ b/atoma-inference/src/candle/llama.rs @@ -0,0 +1,234 @@ +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +use std::path::PathBuf; + +use candle::{DType, Device, Tensor}; +use candle_nn::VarBuilder; +use candle_transformers::{ + generation::LogitsProcessor, + models::llama::{Cache, LlamaConfig}, +}; +use hf_hub::{api::sync::Api, Repo, RepoType}; + +use candle_transformers::models::llama as model; +use tokenizers::Tokenizer; + +use crate::{ + candle::{device, hub_load_safetensors, token_output_stream::TokenOutputStream}, + models::{types::PrecisionBits, ModelError, ModelTrait}, +}; + +use super::CandleModel; + +const EOS_TOKEN: &str = ""; + +#[allow(dead_code)] +#[derive(Clone, Debug, Copy, PartialEq, Eq)] +enum Which { + V1, + V2, + Solar10_7B, + TinyLlama1_1BChat, +} + +pub struct Config {} + +pub struct Llama { + device: Device, + tokenizer: Tokenizer, + dtype: DType, + llama: model::Llama, + cache: Cache, +} + +pub struct Input { + prompt: String, + temperature: Option, + top_p: Option, + seed: u64, + sample_len: usize, + no_kv_cache: bool, + dtype: Option, + model_id: Option, + revision: Option, + which: Which, + use_flash_attn: bool, + repeat_penalty: f32, + repeat_last_n: usize, +} + +impl Input { + pub fn default_prompt(prompt: String) -> Self { + Self { + prompt, + temperature: None, + top_p: None, + seed: 0, + sample_len: 10000, + no_kv_cache: false, + dtype: None, + model_id: None, + revision: None, + which: Which::TinyLlama1_1BChat, + use_flash_attn: false, + repeat_penalty: 1., + repeat_last_n: 64, + } + } +} + +pub struct Fetch { + model_id: Option, + revision: Option, + which: Which, + use_flash_attn: bool, + no_kv_cache: bool, + dtype: Option, +} + +impl Default for Fetch { + fn default() -> Self { + Self { + model_id: None, + revision: None, + which: Which::TinyLlama1_1BChat, + use_flash_attn: false, + no_kv_cache: false, + dtype: None, + } + } +} + +pub struct Load { + precision: PrecisionBits, + use_flash_attn: bool, + no_kv_cache: bool, +} + +impl ModelTrait for Llama { + type Input = Input; + type Load = Load; + type Fetch = Fetch; + type Output = Vec; + + fn fetch(fetch: &Self::Fetch) -> Result<(), ModelError> { + let device = device()?; + let dtype = match fetch.dtype.as_deref() { + Some("f16") => DType::F16, + Some("bf16") => DType::BF16, + Some("f32") => DType::F32, + Some(dtype) => Err(ModelError::Config(format!("Invalid dtype : {dtype}")))?, + None => DType::F16, + }; + let api = Api::new()?; + let model_id = fetch.model_id.clone().unwrap_or_else(|| match fetch.which { + Which::V1 => "Narsil/amall-7b".to_string(), + Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(), + Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(), + Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(), + }); + let revision = fetch.revision.clone().unwrap_or("main".to_string()); + let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision)); + api.get("tokenizer.json")?; + let config_filename = api.get("config.json")?; + let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?; + let config = config.into_config(fetch.use_flash_attn); + let filenames = match fetch.which { + Which::V1 | Which::V2 | Which::Solar10_7B => { + hub_load_safetensors(&api, "model.safetensors.index.json")? + } + Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?], + }; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; + model::Llama::load(vb, &config)?; + Ok(()) + } + + fn model_id(&self) -> crate::models::ModelId { + "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string() + } + + fn load(filenames: Vec, cfg: Self::Load) -> Result { + let device = device()?; + let dtype = cfg.precision.into_dtype(); + let (llama, tokenizer_filename, mut cache) = { + let api = Api::new()?; + + let tokenizer_filename = filenames[0]; + let config_filename = filenames[1]; + let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?; + let config = config.into_config(cfg.use_flash_attn); + + let vb = + unsafe { VarBuilder::from_mmaped_safetensors(&filenames[1..], dtype, &device)? }; + let cache = model::Cache::new(!cfg.no_kv_cache, dtype, &config, &device)?; + (model::Llama::load(vb, &config)?, tokenizer_filename, cache) + }; + let tokenizer = Tokenizer::from_file(tokenizer_filename)?; + Ok(Llama { + device, + tokenizer, + dtype, + llama, + cache, + }) + } + + fn run(&mut self, input: Self::Input) -> Result { + let eos_token_id = self.tokenizer.token_to_id(EOS_TOKEN); + let mut tokens = self + .tokenizer + .encode(input.prompt.clone(), true)? + .get_ids() + .to_vec(); + + let mut tokenizer = TokenOutputStream::new(self.tokenizer); + let mut logits_processor = LogitsProcessor::new(input.seed, input.temperature, input.top_p); + let mut index_pos = 0; + let mut res = String::new(); + let mut result = Vec::new(); + for index in 0..input.sample_len { + let (context_size, context_index) = if self.cache.use_kv_cache && index > 0 { + (1, index_pos) + } else { + (tokens.len(), 0) + }; + let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; + let input_tensor = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; + let logits = self + .llama + .forward(&input_tensor, context_index, &mut self.cache)?; + let logits = logits.squeeze(0)?; + let logits = if input.repeat_penalty == 1. { + logits + } else { + let start_at = tokens.len().saturating_sub(input.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + input.repeat_penalty, + &tokens[start_at..], + )? + }; + index_pos += ctxt.len(); + let next_token = logits_processor.sample(&logits)?; + result.push(logits); + tokens.push(next_token); + + if Some(next_token) == eos_token_id { + break; + } + if let Some(t) = tokenizer.next_token(next_token)? { + res += &t; + } + } + if let Some(rest) = tokenizer.decode_rest()? { + res += &rest; + } + println!("Result {}", res); + Ok(result) + } +} diff --git a/atoma-inference/src/candle/mod.rs b/atoma-inference/src/candle/mod.rs index 8fe9d9b6..7b5da9fe 100644 --- a/atoma-inference/src/candle/mod.rs +++ b/atoma-inference/src/candle/mod.rs @@ -1,3 +1,4 @@ +pub mod llama; pub mod stable_diffusion; pub mod token_output_stream; diff --git a/atoma-inference/src/candle/stable_diffusion.rs b/atoma-inference/src/candle/stable_diffusion.rs index 013dfe36..1ab13c46 100644 --- a/atoma-inference/src/candle/stable_diffusion.rs +++ b/atoma-inference/src/candle/stable_diffusion.rs @@ -114,12 +114,10 @@ pub struct Fetch { impl ModelTrait for StableDiffusion { type Input = Input; type Fetch = Fetch; + type Load = (); type Output = Vec; - fn load( - _filenames: Vec, - _precision: PrecisionBits, - ) -> Result + fn load(_filenames: Vec, _config: Self::Load) -> Result where Self: Sized, { diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index bfb0181d..c623425e 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -14,6 +14,9 @@ use inference::{ #[tokio::main] async fn main() -> Result<(), ModelServiceError> { tracing_subscriber::fmt::init(); + let input = inference::candle::llama::Input::default_prompt("Painting is like".to_string()); + inference::candle::llama::Llama::fetch(&Default::default()).unwrap(); + let x = inference::candle::llama::Llama::inference(input).unwrap(); let (req_sender, req_receiver) = tokio::sync::mpsc::channel::(32); let (resp_sender, mut resp_receiver) = tokio::sync::mpsc::channel::(32); diff --git a/atoma-inference/src/models/candle/mamba.rs b/atoma-inference/src/models/candle/mamba.rs index bdaa7e2d..a1833e17 100644 --- a/atoma-inference/src/models/candle/mamba.rs +++ b/atoma-inference/src/models/candle/mamba.rs @@ -50,10 +50,11 @@ impl MambaModel { impl ModelTrait for MambaModel { type Fetch = (); + type Load = PrecisionBits; type Input = TextModelInput; type Output = String; - fn load(filenames: Vec, precision: PrecisionBits) -> Result + fn load(filenames: Vec, precision: Self::Load) -> Result where Self: Sized, { diff --git a/atoma-inference/src/models/mod.rs b/atoma-inference/src/models/mod.rs index 467b2398..e25c1c1e 100644 --- a/atoma-inference/src/models/mod.rs +++ b/atoma-inference/src/models/mod.rs @@ -15,13 +15,14 @@ pub type ModelId = String; pub trait ModelTrait { type Fetch; + type Load; type Input; type Output; fn fetch(_fetch: &Self::Fetch) -> Result<(), ModelError> { Ok(()) } - fn load(filenames: Vec, precision: PrecisionBits) -> Result + fn load(filenames: Vec, config: Self::Load) -> Result where Self: Sized; fn model_id(&self) -> ModelId; diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index 1d13024f..17214f56 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -207,8 +207,14 @@ mod tests { impl ModelTrait for TestModelInstance { type Input = (); type Output = (); + type Load = (); + type Fetch = (); - fn load(_: Vec, _: PrecisionBits) -> Result { + fn fetch(_fetch: &Self::Fetch) -> Result<(), crate::models::ModelError> { + Ok(()) + } + + fn load(_: Vec, _: Self::Load) -> Result { Ok(Self {}) }