From 5cb119d9a5671197a643dcce390e16337eea6004 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Mon, 1 Apr 2024 10:29:40 +0100 Subject: [PATCH 01/19] first commit --- atoma-inference/src/models/candle/mamba.rs | 26 ++++++++++++++++++++++ atoma-inference/src/models/candle/mod.rs | 1 + atoma-inference/src/models/mod.rs | 2 +- atoma-inference/src/service.rs | 1 - 4 files changed, 28 insertions(+), 2 deletions(-) create mode 100644 atoma-inference/src/models/candle/mamba.rs create mode 100644 atoma-inference/src/models/candle/mod.rs diff --git a/atoma-inference/src/models/candle/mamba.rs b/atoma-inference/src/models/candle/mamba.rs new file mode 100644 index 00000000..a6a730d6 --- /dev/null +++ b/atoma-inference/src/models/candle/mamba.rs @@ -0,0 +1,26 @@ +use std::path::PathBuf; + +use candle_nn::VarBuilder; +use candle_transformers::models::mamba::Model as MambaModel; + +use crate::models::{ModelError, ModelId, ModelTrait}; + +impl ModelTrait for MambaModel { + type Input = String; + type Output = String; + + fn load(filenames: Vec) -> Result + where + Self: Sized { + todo!() + } + + fn model_id(&self) -> ModelId { + todo!() + } + + fn run(&self, input: Self::Input) -> Result { + todo!() + } +} + diff --git a/atoma-inference/src/models/candle/mod.rs b/atoma-inference/src/models/candle/mod.rs new file mode 100644 index 00000000..6857668a --- /dev/null +++ b/atoma-inference/src/models/candle/mod.rs @@ -0,0 +1 @@ +pub mod mamba; \ No newline at end of file diff --git a/atoma-inference/src/models/mod.rs b/atoma-inference/src/models/mod.rs index efc3978e..5168cd1b 100644 --- a/atoma-inference/src/models/mod.rs +++ b/atoma-inference/src/models/mod.rs @@ -4,6 +4,7 @@ use ed25519_consensus::VerificationKey as PublicKey; use thiserror::Error; pub mod config; +pub mod candle; pub type ModelId = String; @@ -14,7 +15,6 @@ pub trait ModelBuilder { } pub trait ModelTrait { - type Builder: Send + Sync + 'static; type Input; type Output; diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index 5425bdfa..a01bd5f4 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -191,7 +191,6 @@ mod tests { struct TestModelInstance {} impl ModelTrait for TestModelInstance { - type Builder = (); type Input = (); type Output = (); From 05675f73b423ece42cfcab5627f912913c5bdaef Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Mon, 1 Apr 2024 12:21:44 +0100 Subject: [PATCH 02/19] add load logic for Mamba model --- atoma-inference/src/model_thread.rs | 18 ++- atoma-inference/src/models/candle/mamba.rs | 154 +++++++++++++++++++-- atoma-inference/src/models/candle/mod.rs | 2 +- atoma-inference/src/models/config.rs | 10 +- atoma-inference/src/models/mod.rs | 18 ++- atoma-inference/src/service.rs | 6 +- 6 files changed, 180 insertions(+), 28 deletions(-) diff --git a/atoma-inference/src/model_thread.rs b/atoma-inference/src/model_thread.rs index 5e5db91c..f3b58e2e 100644 --- a/atoma-inference/src/model_thread.rs +++ b/atoma-inference/src/model_thread.rs @@ -1,4 +1,7 @@ -use std::{collections::HashMap, sync::mpsc}; +use std::{ + collections::HashMap, + sync::{mpsc, Arc}, +}; use ed25519_consensus::VerificationKey as PublicKey; use futures::stream::FuturesUnordered; @@ -112,24 +115,27 @@ where public_key: PublicKey, ) -> Result<(Self, Vec>), ModelThreadError> where - F: ApiTrait, + F: ApiTrait + Send + Sync + 'static, M: ModelTrait + Send + 'static, { let model_ids = config.model_ids(); let api_key = config.api_key(); let storage_path = config.storage_path(); - let api = F::create(api_key, storage_path)?; + let api = Arc::new(F::create(api_key, storage_path)?); let mut handles = Vec::with_capacity(model_ids.len()); let mut model_senders = HashMap::with_capacity(model_ids.len()); - for model_id in model_ids { - let filenames = api.fetch(&model_id)?; + for (model_id, precision) in model_ids { + let api = api.clone(); let (model_sender, model_receiver) = mpsc::channel::>(); + let model_name = model_id.clone(); let join_handle = std::thread::spawn(move || { - let model = M::load(filenames)?; // TODO: for now this piece of code cannot be shared among threads safely + let filenames = api.fetch(&model_name)?; + + let model = M::load(filenames, precision)?; let model_thread = ModelThread { model, receiver: model_receiver, diff --git a/atoma-inference/src/models/candle/mamba.rs b/atoma-inference/src/models/candle/mamba.rs index a6a730d6..5d5a5e35 100644 --- a/atoma-inference/src/models/candle/mamba.rs +++ b/atoma-inference/src/models/candle/mamba.rs @@ -1,22 +1,123 @@ -use std::path::PathBuf; +use std::{path::PathBuf, time::Instant}; +use candle::{ + utils::{cuda_is_available, metal_is_available}, + DType, Device, +}; use candle_nn::VarBuilder; -use candle_transformers::models::mamba::Model as MambaModel; +use candle_transformers::{ + generation::LogitsProcessor, + models::mamba::{Config, Model}, +}; +use tokenizers::Tokenizer; +use tracing::info; -use crate::models::{ModelError, ModelId, ModelTrait}; +use crate::{ + models::{ModelError, ModelId, ModelTrait}, + types::PrecisionBits, +}; -impl ModelTrait for MambaModel { - type Input = String; +pub struct MambaModel { + model: Model, + config: Config, + device: Device, + dtype: DType, + tokenizer: Tokenizer, + which: Which, +} + +impl MambaModel { + pub fn new( + model: Model, + config: Config, + device: Device, + dtype: DType, + tokenizer: Tokenizer, + ) -> Self { + let which = Which::from_config(&config); + Self { + model, + config, + device, + dtype, + tokenizer, + which, + } + } +} + +pub struct MambaInput { + prompt: String, + temperature: f64, + random_seed: u64, + repeat_penalty: f32, + repeat_last_n: usize, + top_p: f64, +} + +impl MambaInput { + pub fn new( + prompt: String, + temperature: f64, + random_seed: u64, + repeat_penalty: f32, + repeat_last_n: usize, + top_p: f64, + ) -> Self { + Self { + prompt, + temperature, + random_seed, + repeat_penalty, + repeat_last_n, + top_p, + } + } +} + +impl ModelTrait for MambaModel { + type Input = MambaInput; type Output = String; - fn load(filenames: Vec) -> Result - where - Self: Sized { - todo!() + fn load(filenames: Vec, precision: PrecisionBits) -> Result + where + Self: Sized, + { + info!("Loading Mamba model ..."); + + let start = Instant::now(); + + let tokenizer_filename = filenames[0].clone(); + let config_filename = filenames[1].clone(); + let weights_filenames = filenames[2..].to_vec(); + + let tokenizer = + Tokenizer::from_file(tokenizer_filename).map_err(ModelError::TokenizerError)?; + + let config: Config = + serde_json::from_slice(&std::fs::read(config_filename).map_err(ModelError::IoError)?) + .map_err(ModelError::DeserializeError)?; + let device = if cuda_is_available() { + Device::new_cuda(0).map_err(ModelError::CandleError)? + } else if metal_is_available() { + Device::new_metal(0).map_err(ModelError::CandleError)? + } else { + Device::Cpu + }; + let dtype = precision.into_dtype(); + + let var_builder = unsafe { + VarBuilder::from_mmaped_safetensors(&weights_filenames, dtype, &device) + .map_err(ModelError::CandleError)? + }; + let model = Model::new(&config, var_builder).map_err(ModelError::CandleError)?; + info!("Loaded Mamba model in {:?}", start.elapsed()); + + Ok(Self::new(model, config, device, dtype, tokenizer)) } fn model_id(&self) -> ModelId { - todo!() + self.which.model_id().to_string() } fn run(&self, input: Self::Input) -> Result { @@ -24,3 +125,36 @@ impl ModelTrait for MambaModel { } } +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +enum Which { + Mamba130m, + Mamba370m, + Mamba790m, + Mamba1_4b, + Mamba2_8b, + Mamba2_8bSlimPj, +} + +impl Which { + fn model_id(&self) -> &'static str { + match self { + Self::Mamba130m => "state-spaces/mamba-130m", + Self::Mamba370m => "state-spaces/mamba-370m", + Self::Mamba790m => "state-spaces/mamba-790m", + Self::Mamba1_4b => "state-spaces/mamba-1.4b", + Self::Mamba2_8b => "state-spaces/mamba-2.8b", + Self::Mamba2_8bSlimPj => "state-spaces/mamba-2.8b-slimpj'", + } + } + + fn from_config(config: &Config) -> Self { + match config.d_model { + 768 => Self::Mamba130m, + 1024 => Self::Mamba370m, + 1536 => Self::Mamba790m, + 2048 => Self::Mamba1_4b, + 2560 => Self::Mamba2_8b, + _ => panic!("Invalid config d_model value"), + } + } +} diff --git a/atoma-inference/src/models/candle/mod.rs b/atoma-inference/src/models/candle/mod.rs index 6857668a..323f72f5 100644 --- a/atoma-inference/src/models/candle/mod.rs +++ b/atoma-inference/src/models/candle/mod.rs @@ -1 +1 @@ -pub mod mamba; \ No newline at end of file +pub mod mamba; diff --git a/atoma-inference/src/models/config.rs b/atoma-inference/src/models/config.rs index ff5bf9a6..eb75a984 100644 --- a/atoma-inference/src/models/config.rs +++ b/atoma-inference/src/models/config.rs @@ -3,12 +3,12 @@ use std::path::PathBuf; use config::Config; use serde::{Deserialize, Serialize}; -use crate::models::ModelId; +use crate::{models::ModelId, types::PrecisionBits}; #[derive(Debug, Deserialize, Serialize)] pub struct ModelConfig { api_key: String, - models: Vec, + models: Vec<(ModelId, PrecisionBits)>, storage_path: PathBuf, tracing: bool, } @@ -16,7 +16,7 @@ pub struct ModelConfig { impl ModelConfig { pub fn new( api_key: String, - models: Vec, + models: Vec<(ModelId, PrecisionBits)>, storage_path: PathBuf, tracing: bool, ) -> Self { @@ -32,7 +32,7 @@ impl ModelConfig { self.api_key.clone() } - pub fn model_ids(&self) -> Vec { + pub fn model_ids(&self) -> Vec<(ModelId, PrecisionBits)> { self.models.clone() } @@ -65,7 +65,7 @@ pub mod tests { fn test_config() { let config = ModelConfig::new( String::from("my_key"), - vec!["Llama2_7b".to_string()], + vec![("Llama2_7b".to_string(), PrecisionBits::F16)], "storage_path".parse().unwrap(), true, ); diff --git a/atoma-inference/src/models/mod.rs b/atoma-inference/src/models/mod.rs index 5168cd1b..3fb73aa6 100644 --- a/atoma-inference/src/models/mod.rs +++ b/atoma-inference/src/models/mod.rs @@ -1,10 +1,13 @@ use std::path::PathBuf; +use ::candle::Error as CandleError; use ed25519_consensus::VerificationKey as PublicKey; use thiserror::Error; -pub mod config; +use crate::types::PrecisionBits; + pub mod candle; +pub mod config; pub type ModelId = String; @@ -18,7 +21,7 @@ pub trait ModelTrait { type Input; type Output; - fn load(filenames: Vec) -> Result + fn load(filenames: Vec, precision: PrecisionBits) -> Result where Self: Sized; fn model_id(&self) -> ModelId; @@ -41,4 +44,13 @@ pub trait Response: Send + 'static { } #[derive(Debug, Error)] -pub enum ModelError {} +pub enum ModelError { + #[error("Tokenizer error: `{0}`")] + TokenizerError(Box), + #[error("IO error: `{0}`")] + IoError(std::io::Error), + #[error("Deserialize error: `{0}`")] + DeserializeError(serde_json::Error), + #[error("Candle error: `{0}`")] + CandleError(CandleError), +} diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index a01bd5f4..559f89a5 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -36,7 +36,7 @@ where ) -> Result where M: ModelTrait + Send + 'static, - F: ApiTrait, + F: ApiTrait + Send + Sync + 'static, { let private_key_bytes = std::fs::read(private_key_path).map_err(ModelServiceError::PrivateKeyError)?; @@ -140,7 +140,7 @@ mod tests { use std::io::Write; use toml::{toml, Value}; - use crate::models::ModelId; + use crate::{models::ModelId, types::PrecisionBits}; use super::*; @@ -194,7 +194,7 @@ mod tests { type Input = (); type Output = (); - fn load(_: Vec) -> Result { + fn load(_: Vec, _: PrecisionBits) -> Result { Ok(Self {}) } From 632f5cc8a0144fd636fa6a580705b63ea2a0b93e Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Mon, 1 Apr 2024 16:15:43 +0100 Subject: [PATCH 03/19] Fix model_thread.rs run method to take ownership of self The method in now takes ownership of by using instead of , ensuring exclusive access to the thread's state during execution. Refactor MambaModel to implement run method The struct in now implements the method of the trait. The method is refactored to handle input parameters and generate output accordingly. It utilizes the tokenizer to encode the prompt, generates tokens using the model, and constructs the output string based on the generated tokens. Introduce TokenOutputStream A new module is introduced in to define the struct. This struct handles tokenization and decoding of model outputs. Update mod.rs to include token_output_stream The file in the module is updated to include the module, enabling access to the struct. Update service.rs to use mutable reference in run method The method of the struct in is updated to take a mutable reference to , ensuring that it can modify the internal state if necessary. --- atoma-inference/src/model_thread.rs | 2 +- atoma-inference/src/models/candle/mamba.rs | 99 +++++++++++++++++-- atoma-inference/src/models/mod.rs | 24 ++++- .../src/models/token_output_stream.rs | 86 ++++++++++++++++ atoma-inference/src/service.rs | 2 +- 5 files changed, 203 insertions(+), 10 deletions(-) create mode 100644 atoma-inference/src/models/token_output_stream.rs diff --git a/atoma-inference/src/model_thread.rs b/atoma-inference/src/model_thread.rs index f3b58e2e..55a1ec94 100644 --- a/atoma-inference/src/model_thread.rs +++ b/atoma-inference/src/model_thread.rs @@ -72,7 +72,7 @@ where Req: Request, Resp: Response, { - pub fn run(self, public_key: PublicKey) -> Result<(), ModelThreadError> { + pub fn run(mut self, public_key: PublicKey) -> Result<(), ModelThreadError> { debug!("Start Model thread"); while let Ok(command) = self.receiver.recv() { diff --git a/atoma-inference/src/models/candle/mamba.rs b/atoma-inference/src/models/candle/mamba.rs index 5d5a5e35..11face7b 100644 --- a/atoma-inference/src/models/candle/mamba.rs +++ b/atoma-inference/src/models/candle/mamba.rs @@ -2,18 +2,20 @@ use std::{path::PathBuf, time::Instant}; use candle::{ utils::{cuda_is_available, metal_is_available}, - DType, Device, + DType, Device, Tensor, }; use candle_nn::VarBuilder; use candle_transformers::{ generation::LogitsProcessor, - models::mamba::{Config, Model}, + models::mamba::{Config, Model, State}, + utils::apply_repeat_penalty, }; use tokenizers::Tokenizer; use tracing::info; use crate::{ - models::{ModelError, ModelId, ModelTrait}, + bail, + models::{token_output_stream::TokenOutputStream, ModelError, ModelId, ModelTrait}, types::PrecisionBits, }; @@ -22,7 +24,7 @@ pub struct MambaModel { config: Config, device: Device, dtype: DType, - tokenizer: Tokenizer, + tokenizer: TokenOutputStream, which: Which, } @@ -40,7 +42,7 @@ impl MambaModel { config, device, dtype, - tokenizer, + tokenizer: TokenOutputStream::new(tokenizer), which, } } @@ -52,6 +54,7 @@ pub struct MambaInput { random_seed: u64, repeat_penalty: f32, repeat_last_n: usize, + max_tokens: usize, top_p: f64, } @@ -62,6 +65,7 @@ impl MambaInput { random_seed: u64, repeat_penalty: f32, repeat_last_n: usize, + max_tokens: usize, top_p: f64, ) -> Self { Self { @@ -70,6 +74,7 @@ impl MambaInput { random_seed, repeat_penalty, repeat_last_n, + max_tokens, top_p, } } @@ -120,8 +125,88 @@ impl ModelTrait for MambaModel { self.which.model_id().to_string() } - fn run(&self, input: Self::Input) -> Result { - todo!() + fn run(&mut self, input: Self::Input) -> Result { + let MambaInput { + prompt, + temperature, + random_seed, + repeat_penalty, + repeat_last_n, + max_tokens, + top_p, + } = input; + + self.tokenizer.clear(); + let mut tokens = self + .tokenizer + .tokenizer() + .encode(prompt, true) + .map_err(ModelError::TokenizerError)? + .get_ids() + .to_vec(); + let mut logits_processor = + LogitsProcessor::new(random_seed, Some(temperature), Some(top_p)); + + let mut generated_tokens = 0_usize; + let eos_token = match self.tokenizer.get_token("<|endoftext|>") { + Some(token) => token, + None => bail!("Invalid eos token"), + }; + + let mut state = State::new(1, &self.config, &self.device)?; // TODO: handle larger batch sizes + + let mut next_logits = None; + for &t in tokens.iter() { + let input = Tensor::new(&[t], &self.device)?; + let logits = self.model.forward(&input, &mut state)?; + next_logits = Some(logits); + if let Some(t) = self.tokenizer.next_token(t)? { + print!("{t}") + } + } + + let mut output = String::new(); + + let start_gen = Instant::now(); + for _ in 0..max_tokens { + let logits = match next_logits.as_ref() { + Some(logits) => logits, + None => bail!("cannot work on an empty prompt"), + }; + + let logits = logits.squeeze(0)?.to_dtype(self.dtype)?; + let logits = if repeat_penalty == 1.0 { + logits + } else { + let start_at = tokens.len().saturating_sub(repeat_last_n); + apply_repeat_penalty(&logits, repeat_penalty, &tokens[start_at..])? + }; + + let next_token = logits_processor.sample(&logits)?; + tokens.push(next_token); + generated_tokens += 1; + + if next_token == eos_token { + break; + } + + if let Some(t) = self.tokenizer.next_token(next_token)? { + output.push_str(t.as_str()); + } + + let input = Tensor::new(&[next_token], &self.device)?; + next_logits = Some(self.model.forward(&input, &mut state)?); + } + let dt = start_gen.elapsed(); + if let Some(rest) = self.tokenizer.decode_rest()? { + output.push_str(rest.as_str()); + } + + println!( + "\n{generated_tokens} tokens generated ({:.2} token/s)", + generated_tokens as f64 / dt.as_secs_f64(), + ); + Ok(output) } } diff --git a/atoma-inference/src/models/mod.rs b/atoma-inference/src/models/mod.rs index 3fb73aa6..bad0d10c 100644 --- a/atoma-inference/src/models/mod.rs +++ b/atoma-inference/src/models/mod.rs @@ -8,6 +8,7 @@ use crate::types::PrecisionBits; pub mod candle; pub mod config; +pub mod token_output_stream; pub type ModelId = String; @@ -25,7 +26,7 @@ pub trait ModelTrait { where Self: Sized; fn model_id(&self) -> ModelId; - fn run(&self, input: Self::Input) -> Result; + fn run(&mut self, input: Self::Input) -> Result; } pub trait Request: Send + 'static { @@ -53,4 +54,25 @@ pub enum ModelError { DeserializeError(serde_json::Error), #[error("Candle error: `{0}`")] CandleError(CandleError), + #[error("{0}")] + Msg(String), +} + +impl From for ModelError { + fn from(error: CandleError) -> Self { + Self::CandleError(error) + } +} + +#[macro_export] +macro_rules! bail { + ($msg:literal $(,)?) => { + return Err(ModelError::Msg(format!($msg).into())) + }; + ($err:expr $(,)?) => { + return Err(ModelError::Msg(format!($err).into()).bt()) + }; + ($fmt:expr, $($arg:tt)*) => { + return Err(ModelError::Msg(format!($fmt, $($arg)*).into()).bt()) + }; } diff --git a/atoma-inference/src/models/token_output_stream.rs b/atoma-inference/src/models/token_output_stream.rs new file mode 100644 index 00000000..33bfb27a --- /dev/null +++ b/atoma-inference/src/models/token_output_stream.rs @@ -0,0 +1,86 @@ +use crate::{bail, models::ModelError}; + +/// This is a wrapper around a tokenizer to ensure that tokens can be returned to the user in a +/// streaming way rather than having to wait for the full decoding. +pub struct TokenOutputStream { + tokenizer: tokenizers::Tokenizer, + tokens: Vec, + prev_index: usize, + current_index: usize, +} + +impl TokenOutputStream { + pub fn new(tokenizer: tokenizers::Tokenizer) -> Self { + Self { + tokenizer, + tokens: Vec::new(), + prev_index: 0, + current_index: 0, + } + } + + pub fn into_inner(self) -> tokenizers::Tokenizer { + self.tokenizer + } + + fn decode(&self, tokens: &[u32]) -> Result { + match self.tokenizer.decode(tokens, true) { + Ok(str) => Ok(str), + Err(err) => bail!("cannot decode: {err}"), + } + } + + // https://github.com/huggingface/text-generation-inference/blob/5ba53d44a18983a4de32d122f4cb46f4a17d9ef6/server/text_generation_server/models/model.py#L68 + pub fn next_token(&mut self, token: u32) -> Result, ModelError> { + let prev_text = if self.tokens.is_empty() { + String::new() + } else { + let tokens = &self.tokens[self.prev_index..self.current_index]; + self.decode(tokens)? + }; + self.tokens.push(token); + let text = self.decode(&self.tokens[self.prev_index..])?; + if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() { + let text = text.split_at(prev_text.len()); + self.prev_index = self.current_index; + self.current_index = self.tokens.len(); + Ok(Some(text.1.to_string())) + } else { + Ok(None) + } + } + + pub fn decode_rest(&self) -> Result, ModelError> { + let prev_text = if self.tokens.is_empty() { + String::new() + } else { + let tokens = &self.tokens[self.prev_index..self.current_index]; + self.decode(tokens)? + }; + let text = self.decode(&self.tokens[self.prev_index..])?; + if text.len() > prev_text.len() { + let text = text.split_at(prev_text.len()); + Ok(Some(text.1.to_string())) + } else { + Ok(None) + } + } + + pub fn decode_all(&self) -> Result { + self.decode(&self.tokens) + } + + pub fn get_token(&self, token_s: &str) -> Option { + self.tokenizer.get_vocab(true).get(token_s).copied() + } + + pub fn tokenizer(&self) -> &tokenizers::Tokenizer { + &self.tokenizer + } + + pub fn clear(&mut self) { + self.tokens.clear(); + self.prev_index = 0; + self.current_index = 0; + } +} diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index 559f89a5..b191f0b9 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -202,7 +202,7 @@ mod tests { String::from("") } - fn run(&self, _: Self::Input) -> Result { + fn run(&mut self, _: Self::Input) -> Result { Ok(()) } } From eb1d9fff0bb4ed1c62dc526db238c496d78f8edd Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Mon, 1 Apr 2024 16:16:27 +0100 Subject: [PATCH 04/19] remove Mamba2_b8SlimPj from `Which` --- atoma-inference/src/models/candle/mamba.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/atoma-inference/src/models/candle/mamba.rs b/atoma-inference/src/models/candle/mamba.rs index 11face7b..69dc90d0 100644 --- a/atoma-inference/src/models/candle/mamba.rs +++ b/atoma-inference/src/models/candle/mamba.rs @@ -217,7 +217,7 @@ enum Which { Mamba790m, Mamba1_4b, Mamba2_8b, - Mamba2_8bSlimPj, + // Mamba2_8bSlimPj, TODO: add this } impl Which { @@ -228,7 +228,7 @@ impl Which { Self::Mamba790m => "state-spaces/mamba-790m", Self::Mamba1_4b => "state-spaces/mamba-1.4b", Self::Mamba2_8b => "state-spaces/mamba-2.8b", - Self::Mamba2_8bSlimPj => "state-spaces/mamba-2.8b-slimpj'", + // Self::Mamba2_8bSlimPj => "state-spaces/mamba-2.8b-slimpj'", } } From 7023ea0c429237b3819d731f837e783408b0a3e9 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Mon, 1 Apr 2024 16:18:37 +0100 Subject: [PATCH 05/19] correct tests --- atoma-inference/src/models/config.rs | 2 +- atoma-inference/src/service.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/atoma-inference/src/models/config.rs b/atoma-inference/src/models/config.rs index eb75a984..98222c31 100644 --- a/atoma-inference/src/models/config.rs +++ b/atoma-inference/src/models/config.rs @@ -71,7 +71,7 @@ pub mod tests { ); let toml_str = toml::to_string(&config).unwrap(); - let should_be_toml_str = "api_key = \"my_key\"\nmodels = [\"Llama2_7b\"]\nstorage_path = \"storage_path\"\ntracing = true\n"; + let should_be_toml_str = "api_key = \"my_key\"\nmodels = [[\"Llama2_7b\", \"F16\"]]\nstorage_path = \"storage_path\"\ntracing = true\n"; assert_eq!(toml_str, should_be_toml_str); } } diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index b191f0b9..919a04f0 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -217,7 +217,7 @@ mod tests { let config_data = Value::Table(toml! { api_key = "your_api_key" - models = ["Mamba3b"] + models = [["Mamba3b", "F16"]] storage_path = "./storage_path/" tokenizer_file_path = "./tokenizer_file_path/" tracing = true From ff500a9443de4805e9793beb913d87c6b8663377 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Mon, 1 Apr 2024 17:45:37 +0100 Subject: [PATCH 06/19] refactor model fetching logic --- atoma-inference/src/apis/hugging_face.rs | 103 +++----------------- atoma-inference/src/apis/mod.rs | 2 +- atoma-inference/src/main.rs | 58 ++++++++++-- atoma-inference/src/model_thread.rs | 4 +- atoma-inference/src/models/candle/mamba.rs | 39 +------- atoma-inference/src/models/candle/mod.rs | 1 + atoma-inference/src/models/candle/types.rs | 104 +++++++++++++++++++++ atoma-inference/src/models/config.rs | 18 +++- atoma-inference/src/service.rs | 32 ++++++- 9 files changed, 213 insertions(+), 148 deletions(-) create mode 100644 atoma-inference/src/models/candle/types.rs diff --git a/atoma-inference/src/apis/hugging_face.rs b/atoma-inference/src/apis/hugging_face.rs index bc80099a..8c5ab2a6 100644 --- a/atoma-inference/src/apis/hugging_face.rs +++ b/atoma-inference/src/apis/hugging_face.rs @@ -1,95 +1,15 @@ use std::path::PathBuf; use async_trait::async_trait; -use hf_hub::api::sync::{Api, ApiBuilder}; +use hf_hub::{ + api::sync::{Api, ApiBuilder}, + Repo, RepoType, +}; use crate::models::ModelId; use super::{ApiError, ApiTrait}; -struct FilePaths { - file_paths: Vec, -} - -fn get_model_safe_tensors_from_hf(model_id: &ModelId) -> (String, FilePaths) { - match model_id.as_str() { - "Llama2_7b" => ( - String::from("meta-llama/Llama-2-7b-hf"), - FilePaths { - file_paths: vec![ - "model-00001-of-00002.safetensors".to_string(), - "model-00002-of-00002.safetensors".to_string(), - ], - }, - ), - "Mamba3b" => ( - String::from("state-spaces/mamba-2.8b-hf"), - FilePaths { - file_paths: vec![ - "model-00001-of-00003.safetensors".to_string(), - "model-00002-of-00003.safetensors".to_string(), - "model-00003-of-00003.safetensors".to_string(), - ], - }, - ), - "Mistral7b" => ( - String::from("mistralai/Mistral-7B-Instruct-v0.2"), - FilePaths { - file_paths: vec![ - "model-00001-of-00003.safetensors".to_string(), - "model-00002-of-00003.safetensors".to_string(), - "model-00003-of-00003.safetensors".to_string(), - ], - }, - ), - "Mixtral8x7b" => ( - String::from("mistralai/Mixtral-8x7B-Instruct-v0.1"), - FilePaths { - file_paths: vec![ - "model-00001-of-00019.safetensors".to_string(), - "model-00002-of-00019.safetensors".to_string(), - "model-00003-of-00019.safetensors".to_string(), - "model-00004-of-00019.safetensors".to_string(), - "model-00005-of-00019.safetensors".to_string(), - "model-00006-of-00019.safetensors".to_string(), - "model-00007-of-00019.safetensors".to_string(), - "model-00008-of-00019.safetensors".to_string(), - "model-00009-of-00019.safetensors".to_string(), - "model-000010-of-00019.safetensors".to_string(), - "model-000011-of-00019.safetensors".to_string(), - "model-000012-of-00019.safetensors".to_string(), - "model-000013-of-00019.safetensors".to_string(), - "model-000014-of-00019.safetensors".to_string(), - "model-000015-of-00019.safetensors".to_string(), - "model-000016-of-00019.safetensors".to_string(), - "model-000017-of-00019.safetensors".to_string(), - "model-000018-of-00019.safetensors".to_string(), - "model-000019-of-00019.safetensors".to_string(), - ], - }, - ), - "StableDiffusion2" => ( - String::from("stabilityai/stable-diffusion-2"), - FilePaths { - file_paths: vec!["768-v-ema.safetensors".to_string()], - }, - ), - "StableDiffusionXl" => ( - String::from("stabilityai/stable-diffusion-xl-base-1.0"), - FilePaths { - file_paths: vec![ - "sd_xl_base_1.0.safetensors".to_string(), - "sd_xl_base_1.0_0.9vae.safetensors".to_string(), - "sd_xl_offset_example-lora_1.0.safetensors".to_string(), - ], - }, - ), - _ => { - panic!("Invalid model id") - } - } -} - #[async_trait] impl ApiTrait for Api { fn create(api_key: String, cache_dir: PathBuf) -> Result @@ -103,15 +23,14 @@ impl ApiTrait for Api { .build()?) } - fn fetch(&self, model_id: &ModelId) -> Result, ApiError> { - let (model_path, files) = get_model_safe_tensors_from_hf(model_id); - let api_repo = self.model(model_path); - let mut path_bufs = Vec::with_capacity(files.file_paths.len()); + fn fetch(&self, model_id: ModelId, revision: String) -> Result, ApiError> { + let repo = self.repo(Repo::with_revision(model_id, RepoType::Model, revision)); + let mut filenames = Vec::with_capacity(3); - for file in files.file_paths { - path_bufs.push(api_repo.get(&file)?); - } + filenames.push(repo.get("tokenizer.json")?); + filenames.push(repo.get("config.json")?); + filenames.push(repo.get("model.safetensors")?); - Ok(path_bufs) + Ok(filenames) } } diff --git a/atoma-inference/src/apis/mod.rs b/atoma-inference/src/apis/mod.rs index e6d27941..c2b4ba16 100644 --- a/atoma-inference/src/apis/mod.rs +++ b/atoma-inference/src/apis/mod.rs @@ -22,7 +22,7 @@ impl From for ApiError { } pub trait ApiTrait: Send { - fn fetch(&self, model_id: &ModelId) -> Result, ApiError>; + fn fetch(&self, model_id: ModelId, revision: String) -> Result, ApiError>; fn create(api_key: String, cache_dir: PathBuf) -> Result where Self: Sized; diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index 1a614417..1e1ac920 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -1,16 +1,54 @@ -// use hf_hub::api::sync::Api; -// use inference::service::ModelService; +use hf_hub::api::sync::Api; +use inference::{ + models::candle::{ + mamba::MambaModel, + types::{TextRequest, TextResponse}, + }, + service::{ModelService, ModelServiceError}, +}; #[tokio::main] -async fn main() { +async fn main() -> Result<(), ModelServiceError> { tracing_subscriber::fmt::init(); - // let (_, receiver) = tokio::sync::mpsc::channel(32); + let (req_sender, req_receiver) = tokio::sync::mpsc::channel::(32); + let (resp_sender, mut resp_receiver) = tokio::sync::mpsc::channel::(32); - // let _ = ModelService::start::( - // "../inference.toml".parse().unwrap(), - // "../private_key".parse().unwrap(), - // receiver, - // ) - // .expect("Failed to start inference service"); + let mut service = ModelService::start::( + "../inference.toml".parse().unwrap(), + "../private_key".parse().unwrap(), + req_receiver, + resp_sender, + ) + .expect("Failed to start inference service"); + + tokio::spawn(async move { + service.run().await?; + Ok::<(), ModelServiceError>(()) + }); + + let sampled_nodes = vec![]; + + req_sender + .send(TextRequest { + request_id: 0, + prompt: "Natalie Portman".to_string(), + model: "state-spaces/mamba-2.8b".to_string(), + max_tokens: 512, + temperature: Some(0.6), + random_seed: 42, + repeat_last_n: 15, + repeat_penalty: 0.6, + sampled_nodes, + top_p: Some(1.0), + top_k: 10, + }) + .await + .expect("Failed to send request"); + + if let Some(response) = resp_receiver.recv().await { + println!("Got a response: {:?}", response); + } + + Ok(()) } diff --git a/atoma-inference/src/model_thread.rs b/atoma-inference/src/model_thread.rs index 55a1ec94..5cd6f511 100644 --- a/atoma-inference/src/model_thread.rs +++ b/atoma-inference/src/model_thread.rs @@ -126,14 +126,14 @@ where let mut handles = Vec::with_capacity(model_ids.len()); let mut model_senders = HashMap::with_capacity(model_ids.len()); - for (model_id, precision) in model_ids { + for (model_id, precision, revision) in model_ids { let api = api.clone(); let (model_sender, model_receiver) = mpsc::channel::>(); let model_name = model_id.clone(); let join_handle = std::thread::spawn(move || { - let filenames = api.fetch(&model_name)?; + let filenames = api.fetch(model_name, revision)?; let model = M::load(filenames, precision)?; let model_thread = ModelThread { diff --git a/atoma-inference/src/models/candle/mamba.rs b/atoma-inference/src/models/candle/mamba.rs index 69dc90d0..1f05bf7f 100644 --- a/atoma-inference/src/models/candle/mamba.rs +++ b/atoma-inference/src/models/candle/mamba.rs @@ -19,6 +19,8 @@ use crate::{ types::PrecisionBits, }; +use super::types::TextModelInput; + pub struct MambaModel { model: Model, config: Config, @@ -48,40 +50,8 @@ impl MambaModel { } } -pub struct MambaInput { - prompt: String, - temperature: f64, - random_seed: u64, - repeat_penalty: f32, - repeat_last_n: usize, - max_tokens: usize, - top_p: f64, -} - -impl MambaInput { - pub fn new( - prompt: String, - temperature: f64, - random_seed: u64, - repeat_penalty: f32, - repeat_last_n: usize, - max_tokens: usize, - top_p: f64, - ) -> Self { - Self { - prompt, - temperature, - random_seed, - repeat_penalty, - repeat_last_n, - max_tokens, - top_p, - } - } -} - impl ModelTrait for MambaModel { - type Input = MambaInput; + type Input = TextModelInput; type Output = String; fn load(filenames: Vec, precision: PrecisionBits) -> Result @@ -126,7 +96,7 @@ impl ModelTrait for MambaModel { } fn run(&mut self, input: Self::Input) -> Result { - let MambaInput { + let TextModelInput { prompt, temperature, random_seed, @@ -134,6 +104,7 @@ impl ModelTrait for MambaModel { repeat_last_n, max_tokens, top_p, + .. } = input; self.tokenizer.clear(); diff --git a/atoma-inference/src/models/candle/mod.rs b/atoma-inference/src/models/candle/mod.rs index 323f72f5..724d7103 100644 --- a/atoma-inference/src/models/candle/mod.rs +++ b/atoma-inference/src/models/candle/mod.rs @@ -1 +1,2 @@ pub mod mamba; +pub mod types; diff --git a/atoma-inference/src/models/candle/types.rs b/atoma-inference/src/models/candle/types.rs new file mode 100644 index 00000000..d49bfd61 --- /dev/null +++ b/atoma-inference/src/models/candle/types.rs @@ -0,0 +1,104 @@ +use ed25519_consensus::VerificationKey as PublicKey; +use serde::{Deserialize, Serialize}; + +use crate::models::{ModelId, Request, Response}; + +pub type NodeId = PublicKey; + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct TextRequest { + pub request_id: usize, + pub prompt: String, + pub model: ModelId, + pub max_tokens: usize, + pub random_seed: usize, + pub repeat_last_n: usize, + pub repeat_penalty: f32, + pub sampled_nodes: Vec, + pub temperature: Option, + pub top_k: usize, + pub top_p: Option, +} + +impl Request for TextRequest { + type ModelInput = TextModelInput; + + fn into_model_input(self) -> Self::ModelInput { + TextModelInput::new( + self.prompt, + self.temperature.unwrap_or_default() as f64, + self.random_seed as u64, + self.repeat_penalty, + self.repeat_last_n, + self.max_tokens, + self.top_k, + self.top_p.unwrap_or_default() as f64, + ) + } + + fn request_id(&self) -> usize { + self.request_id + } + + fn is_node_authorized(&self, public_key: &PublicKey) -> bool { + self.sampled_nodes.contains(&public_key) + } + + fn requested_model(&self) -> ModelId { + self.model.clone() + } +} + +pub struct TextModelInput { + pub(crate) prompt: String, + pub(crate) temperature: f64, + pub(crate) random_seed: u64, + pub(crate) repeat_penalty: f32, + pub(crate) repeat_last_n: usize, + pub(crate) max_tokens: usize, + pub(crate) _top_k: usize, + pub(crate) top_p: f64, +} + +impl TextModelInput { + pub fn new( + prompt: String, + temperature: f64, + random_seed: u64, + repeat_penalty: f32, + repeat_last_n: usize, + max_tokens: usize, + _top_k: usize, + top_p: f64, + ) -> Self { + Self { + prompt, + temperature, + random_seed, + repeat_penalty, + repeat_last_n, + max_tokens, + _top_k, + top_p, + } + } +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct TextResponse { + pub output: String, + pub is_success: bool, + pub status: String, +} + +impl Response for TextResponse { + type ModelOutput = String; + + fn from_model_output(model_output: Self::ModelOutput) -> Self { + Self { + output: model_output, + is_success: true, + status: "Successful".to_string(), + } + } +} diff --git a/atoma-inference/src/models/config.rs b/atoma-inference/src/models/config.rs index 98222c31..b5002c36 100644 --- a/atoma-inference/src/models/config.rs +++ b/atoma-inference/src/models/config.rs @@ -5,10 +5,13 @@ use serde::{Deserialize, Serialize}; use crate::{models::ModelId, types::PrecisionBits}; +type Revision = String; + #[derive(Debug, Deserialize, Serialize)] pub struct ModelConfig { api_key: String, - models: Vec<(ModelId, PrecisionBits)>, + flush_storage: bool, + models: Vec<(ModelId, PrecisionBits, Revision)>, storage_path: PathBuf, tracing: bool, } @@ -16,12 +19,14 @@ pub struct ModelConfig { impl ModelConfig { pub fn new( api_key: String, - models: Vec<(ModelId, PrecisionBits)>, + flush_storage: bool, + models: Vec<(ModelId, PrecisionBits, Revision)>, storage_path: PathBuf, tracing: bool, ) -> Self { Self { api_key, + flush_storage, models, storage_path, tracing, @@ -32,7 +37,11 @@ impl ModelConfig { self.api_key.clone() } - pub fn model_ids(&self) -> Vec<(ModelId, PrecisionBits)> { + pub fn flush_storage(&self) -> bool { + self.flush_storage + } + + pub fn model_ids(&self) -> Vec<(ModelId, PrecisionBits, Revision)> { self.models.clone() } @@ -65,7 +74,8 @@ pub mod tests { fn test_config() { let config = ModelConfig::new( String::from("my_key"), - vec![("Llama2_7b".to_string(), PrecisionBits::F16)], + true, + vec![("Llama2_7b".to_string(), PrecisionBits::F16, "".to_string())], "storage_path".parse().unwrap(), true, ); diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index 919a04f0..480f95a0 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -2,7 +2,7 @@ use candle::Error as CandleError; use ed25519_consensus::SigningKey as PrivateKey; use futures::StreamExt; use std::{io, path::PathBuf, time::Instant}; -use tokio::sync::mpsc::Receiver; +use tokio::sync::mpsc::{Receiver, Sender}; use tracing::{error, info}; use thiserror::Error; @@ -21,7 +21,10 @@ where model_thread_handle: Vec>, dispatcher: ModelThreadDispatcher, start_time: Instant, + flush_storage: bool, + storage_path: PathBuf, request_receiver: Receiver, + response_sender: Sender, } impl ModelService @@ -33,6 +36,7 @@ where config_file_path: PathBuf, private_key_path: PathBuf, request_receiver: Receiver, + response_sender: Sender, ) -> Result where M: ModelTrait + Send + 'static, @@ -48,6 +52,9 @@ where let public_key = private_key.verification_key(); let model_config = ModelConfig::from_file_path(config_file_path); + let flush_storage = model_config.flush_storage(); + let storage_path = model_config.storage_path(); + let (dispatcher, model_thread_handle) = ModelThreadDispatcher::start::(model_config, public_key) .map_err(ModelServiceError::ModelThreadError)?; @@ -57,7 +64,10 @@ where dispatcher, model_thread_handle, start_time, + flush_storage, + storage_path, request_receiver, + response_sender, }) } @@ -74,6 +84,7 @@ where match resp { Ok(response) => { info!("Received a new inference response: {:?}", response); + self.response_sender.send(response).await.map_err(|e| ModelServiceError::SendError(e.to_string()))?; } Err(e) => { error!("Found error in generating inference response: {e}"); @@ -97,6 +108,13 @@ where self.start_time.elapsed() ); + if self.flush_storage { + match std::fs::remove_dir(self.storage_path) { + Ok(()) => {} + Err(e) => error!("Failed to remove storage folder, on shutdown: {e}"), + }; + } + let _ = self .model_thread_handle .drain(..) @@ -108,7 +126,7 @@ where #[derive(Debug, Error)] pub enum ModelServiceError { #[error("Failed to run inference: `{0}`")] - FailedInference(Box), + FailedInference(Box), #[error("Failed to fecth model: `{0}`")] FailedModelFetch(String), #[error("Failed to generate private key: `{0}`")] @@ -119,6 +137,8 @@ pub enum ModelServiceError { ApiError(ApiError), #[error("Candle error: `{0}`")] CandleError(CandleError), + #[error("Sender error: `{0}`")] + SendError(String), } impl From for ModelServiceError { @@ -154,7 +174,7 @@ mod tests { Ok(Self {}) } - fn fetch(&self, _: &ModelId) -> Result, ApiError> { + fn fetch(&self, _: ModelId, _: String) -> Result, ApiError> { Ok(vec![]) } } @@ -229,12 +249,14 @@ mod tests { file.write_all(toml_string.as_bytes()) .expect("Failed to write to file"); - let (_, receiver) = tokio::sync::mpsc::channel::<()>(1); + let (_, req_receiver) = tokio::sync::mpsc::channel::<()>(1); + let (resp_sender, _) = tokio::sync::mpsc::channel::<()>(1); let _ = ModelService::<(), ()>::start::( PathBuf::try_from(CONFIG_FILE_PATH).unwrap(), PathBuf::try_from(PRIVATE_KEY_FILE_PATH).unwrap(), - receiver, + req_receiver, + resp_sender, ) .unwrap(); From b01482c76b91faef5c383058ebe313eb55dcb200 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Mon, 1 Apr 2024 17:51:22 +0100 Subject: [PATCH 07/19] cargo clippy --- atoma-inference/src/apis/hugging_face.rs | 10 +++++----- atoma-inference/src/models/candle/types.rs | 3 ++- atoma-inference/src/models/config.rs | 2 +- atoma-inference/src/service.rs | 3 ++- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/atoma-inference/src/apis/hugging_face.rs b/atoma-inference/src/apis/hugging_face.rs index 8c5ab2a6..3d53724e 100644 --- a/atoma-inference/src/apis/hugging_face.rs +++ b/atoma-inference/src/apis/hugging_face.rs @@ -25,11 +25,11 @@ impl ApiTrait for Api { fn fetch(&self, model_id: ModelId, revision: String) -> Result, ApiError> { let repo = self.repo(Repo::with_revision(model_id, RepoType::Model, revision)); - let mut filenames = Vec::with_capacity(3); - - filenames.push(repo.get("tokenizer.json")?); - filenames.push(repo.get("config.json")?); - filenames.push(repo.get("model.safetensors")?); + let filenames = vec![ + repo.get("tokenizer.json")?, + repo.get("config.json")?, + repo.get("model.safetensors")?, + ]; Ok(filenames) } diff --git a/atoma-inference/src/models/candle/types.rs b/atoma-inference/src/models/candle/types.rs index d49bfd61..983a0ba5 100644 --- a/atoma-inference/src/models/candle/types.rs +++ b/atoma-inference/src/models/candle/types.rs @@ -41,7 +41,7 @@ impl Request for TextRequest { } fn is_node_authorized(&self, public_key: &PublicKey) -> bool { - self.sampled_nodes.contains(&public_key) + self.sampled_nodes.contains(public_key) } fn requested_model(&self) -> ModelId { @@ -61,6 +61,7 @@ pub struct TextModelInput { } impl TextModelInput { + #[allow(clippy::too_many_arguments)] pub fn new( prompt: String, temperature: f64, diff --git a/atoma-inference/src/models/config.rs b/atoma-inference/src/models/config.rs index b5002c36..e015b861 100644 --- a/atoma-inference/src/models/config.rs +++ b/atoma-inference/src/models/config.rs @@ -81,7 +81,7 @@ pub mod tests { ); let toml_str = toml::to_string(&config).unwrap(); - let should_be_toml_str = "api_key = \"my_key\"\nmodels = [[\"Llama2_7b\", \"F16\"]]\nstorage_path = \"storage_path\"\ntracing = true\n"; + let should_be_toml_str = "api_key = \"my_key\"\nflush_storage = true\nmodels = [[\"Llama2_7b\", \"F16\", \"\"]]\nstorage_path = \"storage_path\"\ntracing = true\n"; assert_eq!(toml_str, should_be_toml_str); } } diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index 480f95a0..84915f11 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -237,9 +237,10 @@ mod tests { let config_data = Value::Table(toml! { api_key = "your_api_key" - models = [["Mamba3b", "F16"]] + models = [["Mamba3b", "F16", ""]] storage_path = "./storage_path/" tokenizer_file_path = "./tokenizer_file_path/" + flush_storage = true tracing = true }); let toml_string = From 76ff68e1523fc8b5aa094215d53d5986d953cd46 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Mon, 1 Apr 2024 21:34:34 +0100 Subject: [PATCH 08/19] add consistency with candle github repo --- Cargo.toml | 6 ++--- atoma-inference/Cargo.toml | 2 +- atoma-inference/src/apis/hugging_face.rs | 26 +++++++++++++++++----- atoma-inference/src/main.rs | 18 +++++++++------ atoma-inference/src/model_thread.rs | 4 +++- atoma-inference/src/models/candle/mamba.rs | 26 +++++++++++++--------- atoma-inference/src/service.rs | 10 +++++++-- 7 files changed, 62 insertions(+), 30 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ba5f2c71..f712f8d7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,9 +8,9 @@ version = "0.1.0" [workspace.dependencies] async-trait = "0.1.78" -candle = { git = "https://github.com/jorgeantonio21/candle/", package = "candle-core", branch = "main" } -candle-nn = { git = "https://github.com/jorgeantonio21/candle/", package = "candle-nn", branch = "main" } -candle-transformers = { git = "https://github.com/jorgeantonio21/candle/", package = "candle-transformers", branch = "main" } +candle = { git = "https://github.com/huggingface/candle", package = "candle-core", branch = "main" } +candle-nn = { git = "https://github.com/huggingface/candle", package = "candle-nn", branch = "main" } +candle-transformers = { git = "https://github.com/huggingface/candle", package = "candle-transformers", branch = "main" } config = "0.14.0" ed25519-consensus = "2.1.0" futures = "0.3.30" diff --git a/atoma-inference/Cargo.toml b/atoma-inference/Cargo.toml index f05bd27b..58bcd5aa 100644 --- a/atoma-inference/Cargo.toml +++ b/atoma-inference/Cargo.toml @@ -10,7 +10,7 @@ async-trait.workspace = true candle.workspace = true candle-nn.workspace = true candle-transformers.workspace = true -config.true = true +config.workspace = true ed25519-consensus.workspace = true futures.workspace = true hf-hub.workspace = true diff --git a/atoma-inference/src/apis/hugging_face.rs b/atoma-inference/src/apis/hugging_face.rs index 3d53724e..69ac9efc 100644 --- a/atoma-inference/src/apis/hugging_face.rs +++ b/atoma-inference/src/apis/hugging_face.rs @@ -5,6 +5,7 @@ use hf_hub::{ api::sync::{Api, ApiBuilder}, Repo, RepoType, }; +use tracing::error; use crate::models::ModelId; @@ -24,13 +25,28 @@ impl ApiTrait for Api { } fn fetch(&self, model_id: ModelId, revision: String) -> Result, ApiError> { + let mut tokenizer_file = None; + if model_id.contains("mamba") { + tokenizer_file = Some( + self.model("EleutherAI/gpt-neox-20b".to_string()) + .get("tokenizer.json") + .map_err(|e| { + error!("Failed to fetch tokenizer file: {e}"); + e + })?, + ) + } + let repo = self.repo(Repo::with_revision(model_id, RepoType::Model, revision)); - let filenames = vec![ - repo.get("tokenizer.json")?, + + Ok(vec![ repo.get("config.json")?, + if let Some(tkn) = tokenizer_file { + tkn + } else { + repo.get("tokenizer.json")? + }, repo.get("model.safetensors")?, - ]; - - Ok(filenames) + ]) } } diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index 1e1ac920..251951ce 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -1,3 +1,5 @@ +use std::time::Duration; + use hf_hub::api::sync::Api; use inference::{ models::candle::{ @@ -22,24 +24,26 @@ async fn main() -> Result<(), ModelServiceError> { ) .expect("Failed to start inference service"); + let pk = service.public_key(); + tokio::spawn(async move { service.run().await?; Ok::<(), ModelServiceError>(()) }); - let sampled_nodes = vec![]; + tokio::time::sleep(Duration::from_millis(5000)).await; req_sender .send(TextRequest { request_id: 0, - prompt: "Natalie Portman".to_string(), - model: "state-spaces/mamba-2.8b".to_string(), + prompt: "Who was the first american president ? ".to_string(), + model: "state-spaces/mamba-130m".to_string(), max_tokens: 512, - temperature: Some(0.6), + temperature: Some(0.0), random_seed: 42, - repeat_last_n: 15, - repeat_penalty: 0.6, - sampled_nodes, + repeat_last_n: 64, + repeat_penalty: 1.1, + sampled_nodes: vec![pk], top_p: Some(1.0), top_k: 10, }) diff --git a/atoma-inference/src/model_thread.rs b/atoma-inference/src/model_thread.rs index 5cd6f511..8fe01b5e 100644 --- a/atoma-inference/src/model_thread.rs +++ b/atoma-inference/src/model_thread.rs @@ -7,7 +7,7 @@ use ed25519_consensus::VerificationKey as PublicKey; use futures::stream::FuturesUnordered; use thiserror::Error; use tokio::sync::oneshot::{self, error::RecvError}; -use tracing::{debug, error, warn}; +use tracing::{debug, error, info, warn}; use crate::{ apis::{ApiError, ApiTrait}, @@ -127,12 +127,14 @@ where let mut model_senders = HashMap::with_capacity(model_ids.len()); for (model_id, precision, revision) in model_ids { + info!("Spawning new thread for model: {model_id}"); let api = api.clone(); let (model_sender, model_receiver) = mpsc::channel::>(); let model_name = model_id.clone(); let join_handle = std::thread::spawn(move || { + info!("Fetching files for model: {model_name}"); let filenames = api.fetch(model_name, revision)?; let model = M::load(filenames, precision)?; diff --git a/atoma-inference/src/models/candle/mamba.rs b/atoma-inference/src/models/candle/mamba.rs index 1f05bf7f..151a721f 100644 --- a/atoma-inference/src/models/candle/mamba.rs +++ b/atoma-inference/src/models/candle/mamba.rs @@ -62,8 +62,8 @@ impl ModelTrait for MambaModel { let start = Instant::now(); - let tokenizer_filename = filenames[0].clone(); - let config_filename = filenames[1].clone(); + let config_filename = filenames[0].clone(); + let tokenizer_filename = filenames[1].clone(); let weights_filenames = filenames[2..].to_vec(); let tokenizer = @@ -81,11 +81,10 @@ impl ModelTrait for MambaModel { }; let dtype = precision.into_dtype(); - let var_builder = unsafe { - VarBuilder::from_mmaped_safetensors(&weights_filenames, dtype, &device) - .map_err(ModelError::CandleError)? - }; - let model = Model::new(&config, var_builder).map_err(ModelError::CandleError)?; + info!("Loading model weights.."); + let var_builder = + unsafe { VarBuilder::from_mmaped_safetensors(&weights_filenames, dtype, &device)? }; + let model = Model::new(&config, var_builder.pp("backbone"))?; info!("Loaded Mamba model in {:?}", start.elapsed()); Ok(Self::new(model, config, device, dtype, tokenizer)) @@ -107,6 +106,8 @@ impl ModelTrait for MambaModel { .. } = input; + info!("Running inference on prompt: {:?}", prompt); + self.tokenizer.clear(); let mut tokens = self .tokenizer @@ -127,17 +128,19 @@ impl ModelTrait for MambaModel { let mut state = State::new(1, &self.config, &self.device)?; // TODO: handle larger batch sizes let mut next_logits = None; + let mut output = String::new(); + for &t in tokens.iter() { let input = Tensor::new(&[t], &self.device)?; let logits = self.model.forward(&input, &mut state)?; + next_logits = Some(logits); if let Some(t) = self.tokenizer.next_token(t)? { - print!("{t}") + info!("{:?}", t); + output.push_str(t.as_str()); } } - let mut output = String::new(); - let start_gen = Instant::now(); for _ in 0..max_tokens { let logits = match next_logits.as_ref() { @@ -162,6 +165,7 @@ impl ModelTrait for MambaModel { } if let Some(t) = self.tokenizer.next_token(next_token)? { + info!("{:?}", t); output.push_str(t.as_str()); } @@ -173,7 +177,7 @@ impl ModelTrait for MambaModel { output.push_str(rest.as_str()); } - println!( + info!( "\n{generated_tokens} tokens generated ({:.2} token/s)", generated_tokens as f64 / dt.as_secs_f64(), ); diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index 84915f11..f6e986d1 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -1,5 +1,5 @@ use candle::Error as CandleError; -use ed25519_consensus::SigningKey as PrivateKey; +use ed25519_consensus::{SigningKey as PrivateKey, VerificationKey as PublicKey}; use futures::StreamExt; use std::{io, path::PathBuf, time::Instant}; use tokio::sync::mpsc::{Receiver, Sender}; @@ -22,6 +22,7 @@ where dispatcher: ModelThreadDispatcher, start_time: Instant, flush_storage: bool, + public_key: PublicKey, storage_path: PathBuf, request_receiver: Receiver, response_sender: Sender, @@ -66,6 +67,7 @@ where start_time, flush_storage, storage_path, + public_key, request_receiver, response_sender, }) @@ -95,6 +97,10 @@ where } } } + + pub fn public_key(&self) -> PublicKey { + self.public_key + } } impl ModelService @@ -237,7 +243,7 @@ mod tests { let config_data = Value::Table(toml! { api_key = "your_api_key" - models = [["Mamba3b", "F16", ""]] + models = [["Mamba3b", "F16", "", ""]] storage_path = "./storage_path/" tokenizer_file_path = "./tokenizer_file_path/" flush_storage = true From c4c8c9e7c25f03beed72d414b19909aba05bcc0e Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Mon, 1 Apr 2024 21:41:53 +0100 Subject: [PATCH 09/19] minor refactor --- atoma-inference/src/main.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index 251951ce..1049a147 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -36,7 +36,7 @@ async fn main() -> Result<(), ModelServiceError> { req_sender .send(TextRequest { request_id: 0, - prompt: "Who was the first american president ? ".to_string(), + prompt: "Who was the first american president ?".to_string(), model: "state-spaces/mamba-130m".to_string(), max_tokens: 512, temperature: Some(0.0), From 9647425e92294bdf33fdabcaad10a83081bda08f Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Mon, 1 Apr 2024 22:02:41 +0100 Subject: [PATCH 10/19] rollback candle version 0.4.2 --- Cargo.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f712f8d7..44602d2d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,9 +8,9 @@ version = "0.1.0" [workspace.dependencies] async-trait = "0.1.78" -candle = { git = "https://github.com/huggingface/candle", package = "candle-core", branch = "main" } -candle-nn = { git = "https://github.com/huggingface/candle", package = "candle-nn", branch = "main" } -candle-transformers = { git = "https://github.com/huggingface/candle", package = "candle-transformers", branch = "main" } +candle = { git = "https://github.com/huggingface/candle", package = "candle-core", version = "0.4.2" } +candle-nn = { git = "https://github.com/huggingface/candle", package = "candle-nn", version = "0.4.2" } +candle-transformers = { git = "https://github.com/huggingface/candle", package = "candle-transformers", version = "0.4.2" } config = "0.14.0" ed25519-consensus = "2.1.0" futures = "0.3.30" From 0bc9acc97f488d2dc7c1694e78d682fe150df0b9 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Mon, 1 Apr 2024 23:54:36 +0100 Subject: [PATCH 11/19] remove unused trait --- atoma-inference/src/models/mod.rs | 6 ------ 1 file changed, 6 deletions(-) diff --git a/atoma-inference/src/models/mod.rs b/atoma-inference/src/models/mod.rs index bad0d10c..6366478a 100644 --- a/atoma-inference/src/models/mod.rs +++ b/atoma-inference/src/models/mod.rs @@ -12,12 +12,6 @@ pub mod token_output_stream; pub type ModelId = String; -pub trait ModelBuilder { - fn try_from_file(path: PathBuf) -> Result - where - Self: Sized; -} - pub trait ModelTrait { type Input; type Output; From 2f65f9849278be775091bf98599feddb7463f99d Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Tue, 2 Apr 2024 00:02:24 +0100 Subject: [PATCH 12/19] move file types.rs --- atoma-inference/src/lib.rs | 1 - atoma-inference/src/main.rs | 4 +- atoma-inference/src/models/candle/mamba.rs | 4 +- atoma-inference/src/models/candle/mod.rs | 1 - atoma-inference/src/models/config.rs | 2 +- atoma-inference/src/models/mod.rs | 3 +- .../src/models/{candle => }/types.rs | 27 ++++++++ atoma-inference/src/service.rs | 2 +- atoma-inference/src/types.rs | 65 ------------------- 9 files changed, 34 insertions(+), 75 deletions(-) rename atoma-inference/src/models/{candle => }/types.rs (81%) delete mode 100644 atoma-inference/src/types.rs diff --git a/atoma-inference/src/lib.rs b/atoma-inference/src/lib.rs index 539230f0..4ad5a4d4 100644 --- a/atoma-inference/src/lib.rs +++ b/atoma-inference/src/lib.rs @@ -1,7 +1,6 @@ pub mod model_thread; pub mod service; pub mod specs; -pub mod types; pub mod apis; pub mod models; diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index 1049a147..7ebe1835 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -2,8 +2,8 @@ use std::time::Duration; use hf_hub::api::sync::Api; use inference::{ - models::candle::{ - mamba::MambaModel, + models::{ + candle::mamba::MambaModel, types::{TextRequest, TextResponse}, }, service::{ModelService, ModelServiceError}, diff --git a/atoma-inference/src/models/candle/mamba.rs b/atoma-inference/src/models/candle/mamba.rs index 151a721f..1950d4ba 100644 --- a/atoma-inference/src/models/candle/mamba.rs +++ b/atoma-inference/src/models/candle/mamba.rs @@ -15,12 +15,10 @@ use tracing::info; use crate::{ bail, + models::types::{PrecisionBits, TextModelInput}, models::{token_output_stream::TokenOutputStream, ModelError, ModelId, ModelTrait}, - types::PrecisionBits, }; -use super::types::TextModelInput; - pub struct MambaModel { model: Model, config: Config, diff --git a/atoma-inference/src/models/candle/mod.rs b/atoma-inference/src/models/candle/mod.rs index 724d7103..323f72f5 100644 --- a/atoma-inference/src/models/candle/mod.rs +++ b/atoma-inference/src/models/candle/mod.rs @@ -1,2 +1 @@ pub mod mamba; -pub mod types; diff --git a/atoma-inference/src/models/config.rs b/atoma-inference/src/models/config.rs index e015b861..e17a3582 100644 --- a/atoma-inference/src/models/config.rs +++ b/atoma-inference/src/models/config.rs @@ -3,7 +3,7 @@ use std::path::PathBuf; use config::Config; use serde::{Deserialize, Serialize}; -use crate::{models::ModelId, types::PrecisionBits}; +use crate::{models::types::PrecisionBits, models::ModelId}; type Revision = String; diff --git a/atoma-inference/src/models/mod.rs b/atoma-inference/src/models/mod.rs index 6366478a..dc82f4c1 100644 --- a/atoma-inference/src/models/mod.rs +++ b/atoma-inference/src/models/mod.rs @@ -4,11 +4,12 @@ use ::candle::Error as CandleError; use ed25519_consensus::VerificationKey as PublicKey; use thiserror::Error; -use crate::types::PrecisionBits; +use crate::models::types::PrecisionBits; pub mod candle; pub mod config; pub mod token_output_stream; +pub mod types; pub type ModelId = String; diff --git a/atoma-inference/src/models/candle/types.rs b/atoma-inference/src/models/types.rs similarity index 81% rename from atoma-inference/src/models/candle/types.rs rename to atoma-inference/src/models/types.rs index 983a0ba5..54d85bcf 100644 --- a/atoma-inference/src/models/candle/types.rs +++ b/atoma-inference/src/models/types.rs @@ -1,3 +1,4 @@ +use candle::DType; use ed25519_consensus::VerificationKey as PublicKey; use serde::{Deserialize, Serialize}; @@ -103,3 +104,29 @@ impl Response for TextResponse { } } } + +#[derive(Copy, Clone, Debug, Deserialize, Serialize)] +pub enum PrecisionBits { + BF16, + F16, + F32, + F64, + I64, + U8, + U32, +} + +impl PrecisionBits { + #[allow(dead_code)] + pub(crate) fn into_dtype(self) -> DType { + match self { + Self::BF16 => DType::BF16, + Self::F16 => DType::F16, + Self::F32 => DType::F32, + Self::F64 => DType::F64, + Self::I64 => DType::I64, + Self::U8 => DType::U8, + Self::U32 => DType::U32, + } + } +} diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index f6e986d1..c43e1aa0 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -166,7 +166,7 @@ mod tests { use std::io::Write; use toml::{toml, Value}; - use crate::{models::ModelId, types::PrecisionBits}; + use crate::{models::types::PrecisionBits, models::ModelId}; use super::*; diff --git a/atoma-inference/src/types.rs b/atoma-inference/src/types.rs deleted file mode 100644 index 82643c19..00000000 --- a/atoma-inference/src/types.rs +++ /dev/null @@ -1,65 +0,0 @@ -use candle::DType; -use ed25519_consensus::VerificationKey; -use serde::{Deserialize, Serialize}; - -use crate::models::ModelId; - -pub type NodeId = VerificationKey; -pub type Temperature = f32; - -#[derive(Clone, Debug)] -pub struct InferenceRequest { - pub request_id: u128, - pub prompt: String, - pub model: ModelId, - pub max_tokens: usize, - pub random_seed: usize, - pub repeat_last_n: usize, - pub repeat_penalty: f32, - pub sampled_nodes: Vec, - pub temperature: Option, - pub top_k: usize, - pub top_p: Option, -} - -#[derive(Clone, Debug)] -#[allow(dead_code)] -pub struct InferenceResponse { - // TODO: possibly a Merkle root hash - // pub(crate) response_hash: [u8; 32], - // pub(crate) node_id: NodeId, - // pub(crate) signature: Vec, - pub(crate) response: String, -} - -#[derive(Clone, Debug)] -pub enum QuantizationMethod { - Ggml(PrecisionBits), - Gptq(PrecisionBits), -} - -#[derive(Copy, Clone, Debug, Deserialize, Serialize)] -pub enum PrecisionBits { - BF16, - F16, - F32, - F64, - I64, - U8, - U32, -} - -impl PrecisionBits { - #[allow(dead_code)] - pub(crate) fn into_dtype(self) -> DType { - match self { - Self::BF16 => DType::BF16, - Self::F16 => DType::F16, - Self::F32 => DType::F32, - Self::F64 => DType::F64, - Self::I64 => DType::I64, - Self::U8 => DType::U8, - Self::U32 => DType::U32, - } - } -} From efe64cb3042c28fe523fc62102c66f3149970ab3 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Tue, 2 Apr 2024 08:17:26 +0100 Subject: [PATCH 13/19] address PR comments --- atoma-inference/src/apis/hugging_face.rs | 19 ++++++------------- atoma-inference/src/model_thread.rs | 4 ++-- atoma-inference/src/models/candle/mamba.rs | 8 +++----- 3 files changed, 11 insertions(+), 20 deletions(-) diff --git a/atoma-inference/src/apis/hugging_face.rs b/atoma-inference/src/apis/hugging_face.rs index 69ac9efc..3a2df3c0 100644 --- a/atoma-inference/src/apis/hugging_face.rs +++ b/atoma-inference/src/apis/hugging_face.rs @@ -25,24 +25,17 @@ impl ApiTrait for Api { } fn fetch(&self, model_id: ModelId, revision: String) -> Result, ApiError> { - let mut tokenizer_file = None; - if model_id.contains("mamba") { - tokenizer_file = Some( + let repo = self.repo(Repo::with_revision(model_id.clone(), RepoType::Model, revision)); + + Ok(vec![ + repo.get("config.json")?, + if model_id.contains("mamba") { self.model("EleutherAI/gpt-neox-20b".to_string()) .get("tokenizer.json") .map_err(|e| { error!("Failed to fetch tokenizer file: {e}"); e - })?, - ) - } - - let repo = self.repo(Repo::with_revision(model_id, RepoType::Model, revision)); - - Ok(vec![ - repo.get("config.json")?, - if let Some(tkn) = tokenizer_file { - tkn + })? } else { repo.get("tokenizer.json")? }, diff --git a/atoma-inference/src/model_thread.rs b/atoma-inference/src/model_thread.rs index 8fe01b5e..287190ec 100644 --- a/atoma-inference/src/model_thread.rs +++ b/atoma-inference/src/model_thread.rs @@ -169,11 +169,11 @@ where fn send(&self, command: ModelThreadCommand) { let request = command.0.clone(); - let model_type = request.requested_model(); + let model_id = request.requested_model(); let sender = self .model_senders - .get(&model_type) + .get(&model_id) .expect("Failed to get model thread, this should not happen !"); if let Err(e) = sender.send(command) { diff --git a/atoma-inference/src/models/candle/mamba.rs b/atoma-inference/src/models/candle/mamba.rs index 1950d4ba..284e25e4 100644 --- a/atoma-inference/src/models/candle/mamba.rs +++ b/atoma-inference/src/models/candle/mamba.rs @@ -128,13 +128,12 @@ impl ModelTrait for MambaModel { let mut next_logits = None; let mut output = String::new(); - for &t in tokens.iter() { - let input = Tensor::new(&[t], &self.device)?; + for &token in tokens.iter() { + let input = Tensor::new(&[token], &self.device)?; let logits = self.model.forward(&input, &mut state)?; next_logits = Some(logits); - if let Some(t) = self.tokenizer.next_token(t)? { - info!("{:?}", t); + if let Some(t) = self.tokenizer.next_token(token)? { output.push_str(t.as_str()); } } @@ -163,7 +162,6 @@ impl ModelTrait for MambaModel { } if let Some(t) = self.tokenizer.next_token(next_token)? { - info!("{:?}", t); output.push_str(t.as_str()); } From 233488c819cd0b971fb731bc20770942eee68048 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Tue, 2 Apr 2024 08:28:24 +0100 Subject: [PATCH 14/19] cargo fmt --- atoma-inference/src/apis/hugging_face.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/atoma-inference/src/apis/hugging_face.rs b/atoma-inference/src/apis/hugging_face.rs index 3a2df3c0..41b002c0 100644 --- a/atoma-inference/src/apis/hugging_face.rs +++ b/atoma-inference/src/apis/hugging_face.rs @@ -25,7 +25,11 @@ impl ApiTrait for Api { } fn fetch(&self, model_id: ModelId, revision: String) -> Result, ApiError> { - let repo = self.repo(Repo::with_revision(model_id.clone(), RepoType::Model, revision)); + let repo = self.repo(Repo::with_revision( + model_id.clone(), + RepoType::Model, + revision, + )); Ok(vec![ repo.get("config.json")?, From 9696f22c0631b639f11b469ae836c38b3f353697 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Tue, 2 Apr 2024 08:43:25 +0100 Subject: [PATCH 15/19] make it compatible with cargo-lints clippy --- atoma-inference/src/service.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index c43e1aa0..a4b0dbb6 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -189,7 +189,6 @@ mod tests { type ModelInput = (); fn into_model_input(self) -> Self::ModelInput { - () } fn is_node_authorized(&self, _: &PublicKey) -> bool { @@ -209,7 +208,6 @@ mod tests { type ModelOutput = (); fn from_model_output(_: Self::ModelOutput) -> Self { - () } } @@ -238,7 +236,7 @@ mod tests { const CONFIG_FILE_PATH: &str = "./inference.toml"; const PRIVATE_KEY_FILE_PATH: &str = "./private_key"; - let private_key = PrivateKey::new(&mut OsRng); + let private_key = PrivateKey::new(OsRng); std::fs::write(PRIVATE_KEY_FILE_PATH, private_key.to_bytes()).unwrap(); let config_data = Value::Table(toml! { @@ -260,8 +258,8 @@ mod tests { let (resp_sender, _) = tokio::sync::mpsc::channel::<()>(1); let _ = ModelService::<(), ()>::start::( - PathBuf::try_from(CONFIG_FILE_PATH).unwrap(), - PathBuf::try_from(PRIVATE_KEY_FILE_PATH).unwrap(), + PathBuf::from(CONFIG_FILE_PATH), + PathBuf::from(PRIVATE_KEY_FILE_PATH), req_receiver, resp_sender, ) From 32975c343f5f6350f28957c3afa1adb28cda151c Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Tue, 2 Apr 2024 10:26:46 +0100 Subject: [PATCH 16/19] cargo fmt --- atoma-inference/src/main.rs | 2 +- atoma-inference/src/service.rs | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index 7ebe1835..36bfea0b 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -36,7 +36,7 @@ async fn main() -> Result<(), ModelServiceError> { req_sender .send(TextRequest { request_id: 0, - prompt: "Who was the first american president ?".to_string(), + prompt: "Leon, the professional is a movie".to_string(), model: "state-spaces/mamba-130m".to_string(), max_tokens: 512, temperature: Some(0.0), diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index a4b0dbb6..8dce6923 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -188,8 +188,7 @@ mod tests { impl Request for () { type ModelInput = (); - fn into_model_input(self) -> Self::ModelInput { - } + fn into_model_input(self) -> Self::ModelInput {} fn is_node_authorized(&self, _: &PublicKey) -> bool { true @@ -207,8 +206,7 @@ mod tests { impl Response for () { type ModelOutput = (); - fn from_model_output(_: Self::ModelOutput) -> Self { - } + fn from_model_output(_: Self::ModelOutput) -> Self {} } #[derive(Clone)] From f5b92af4378e08a1a100c1c0fc8c263c2b161dde Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Tue, 2 Apr 2024 10:49:13 +0100 Subject: [PATCH 17/19] address PR comments --- atoma-inference/src/main.rs | 14 ++++++++++++-- atoma-inference/src/service.rs | 21 ++++++--------------- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index 36bfea0b..bfb0181d 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -1,9 +1,11 @@ use std::time::Duration; +use ed25519_consensus::SigningKey as PrivateKey; use hf_hub::api::sync::Api; use inference::{ models::{ candle::mamba::MambaModel, + config::ModelConfig, types::{TextRequest, TextResponse}, }, service::{ModelService, ModelServiceError}, @@ -16,9 +18,17 @@ async fn main() -> Result<(), ModelServiceError> { let (req_sender, req_receiver) = tokio::sync::mpsc::channel::(32); let (resp_sender, mut resp_receiver) = tokio::sync::mpsc::channel::(32); + let model_config = ModelConfig::from_file_path("../inference.toml".parse().unwrap()); + let private_key_bytes = + std::fs::read("../private_key").map_err(ModelServiceError::PrivateKeyError)?; + let private_key_bytes: [u8; 32] = private_key_bytes + .try_into() + .expect("Incorrect private key bytes length"); + + let private_key = PrivateKey::from(private_key_bytes); let mut service = ModelService::start::( - "../inference.toml".parse().unwrap(), - "../private_key".parse().unwrap(), + model_config, + private_key, req_receiver, resp_sender, ) diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index 8dce6923..1d13024f 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -34,8 +34,8 @@ where Resp: std::fmt::Debug + Response, { pub fn start( - config_file_path: PathBuf, - private_key_path: PathBuf, + model_config: ModelConfig, + private_key: PrivateKey, request_receiver: Receiver, response_sender: Sender, ) -> Result @@ -43,15 +43,7 @@ where M: ModelTrait + Send + 'static, F: ApiTrait + Send + Sync + 'static, { - let private_key_bytes = - std::fs::read(private_key_path).map_err(ModelServiceError::PrivateKeyError)?; - let private_key_bytes: [u8; 32] = private_key_bytes - .try_into() - .expect("Incorrect private key bytes length"); - - let private_key = PrivateKey::from(private_key_bytes); let public_key = private_key.verification_key(); - let model_config = ModelConfig::from_file_path(config_file_path); let flush_storage = model_config.flush_storage(); let storage_path = model_config.storage_path(); @@ -232,10 +224,8 @@ mod tests { #[tokio::test] async fn test_inference_service_initialization() { const CONFIG_FILE_PATH: &str = "./inference.toml"; - const PRIVATE_KEY_FILE_PATH: &str = "./private_key"; let private_key = PrivateKey::new(OsRng); - std::fs::write(PRIVATE_KEY_FILE_PATH, private_key.to_bytes()).unwrap(); let config_data = Value::Table(toml! { api_key = "your_api_key" @@ -255,15 +245,16 @@ mod tests { let (_, req_receiver) = tokio::sync::mpsc::channel::<()>(1); let (resp_sender, _) = tokio::sync::mpsc::channel::<()>(1); + let config = ModelConfig::from_file_path(CONFIG_FILE_PATH.parse().unwrap()); + let _ = ModelService::<(), ()>::start::( - PathBuf::from(CONFIG_FILE_PATH), - PathBuf::from(PRIVATE_KEY_FILE_PATH), + config, + private_key, req_receiver, resp_sender, ) .unwrap(); std::fs::remove_file(CONFIG_FILE_PATH).unwrap(); - std::fs::remove_file(PRIVATE_KEY_FILE_PATH).unwrap(); } } From c8ab1adc4eff0969c640e21f8c1c581497388459 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Tue, 2 Apr 2024 10:57:12 +0100 Subject: [PATCH 18/19] address PR comments --- Cargo.toml | 1 + atoma-inference/Cargo.toml | 1 + atoma-inference/src/models/config.rs | 31 ++++++++++++++++++++++++++++ 3 files changed, 33 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index e37a9e35..c2ccab6a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ candle-flash-attn = { git = "https://github.com/huggingface/candle", package = " candle-nn = { git = "https://github.com/huggingface/candle", package = "candle-nn", version = "0.4.2" } candle-transformers = { git = "https://github.com/huggingface/candle", package = "candle-transformers", version = "0.4.2" } config = "0.14.0" +dotenv = "0.15.0" ed25519-consensus = "2.1.0" futures = "0.3.30" hf-hub = "0.3.2" diff --git a/atoma-inference/Cargo.toml b/atoma-inference/Cargo.toml index 87aeee0f..a9cbffc5 100644 --- a/atoma-inference/Cargo.toml +++ b/atoma-inference/Cargo.toml @@ -12,6 +12,7 @@ candle-flash-attn = { workspace = true, optional = true } candle-nn.workspace = true candle-transformers.workspace = true config.workspace = true +dotenv.workspace = true ed25519-consensus.workspace = true futures.workspace = true hf-hub.workspace = true diff --git a/atoma-inference/src/models/config.rs b/atoma-inference/src/models/config.rs index e17a3582..37ea846c 100644 --- a/atoma-inference/src/models/config.rs +++ b/atoma-inference/src/models/config.rs @@ -1,6 +1,7 @@ use std::path::PathBuf; use config::Config; +use dotenv::dotenv; use serde::{Deserialize, Serialize}; use crate::{models::types::PrecisionBits, models::ModelId}; @@ -64,6 +65,36 @@ impl ModelConfig { .try_deserialize::() .expect("Failed to generated config file") } + + pub fn from_env_file() -> Self { + dotenv().ok(); + + let api_key = std::env::var("API_KEY").expect("Failed to retrieve api key, from .env file"); + let flush_storage = std::env::var("FLUSH_STORAGE") + .expect("Failed to retrieve flush storage variable, from .env file") + .parse() + .unwrap(); + let models = serde_json::from_str( + &std::env::var("MODELS").expect("Failed to retrieve models metadata, from .env file"), + ) + .unwrap(); + let storage_path = std::env::var("STORAGE_PATH") + .expect("Failed to retrieve storage path, from .env file") + .parse() + .unwrap(); + let tracing = std::env::var("TRACING") + .expect("Failed to retrieve tracing variable, from .env file") + .parse() + .unwrap(); + + Self { + api_key, + flush_storage, + models, + storage_path, + tracing, + } + } } #[cfg(test)] From 9eeac33e48942e918159b961b3c7c88732714e33 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Wed, 3 Apr 2024 08:00:15 +0100 Subject: [PATCH 19/19] address PR comments --- atoma-inference/src/models/config.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/atoma-inference/src/models/config.rs b/atoma-inference/src/models/config.rs index 37ea846c..e5790163 100644 --- a/atoma-inference/src/models/config.rs +++ b/atoma-inference/src/models/config.rs @@ -71,7 +71,7 @@ impl ModelConfig { let api_key = std::env::var("API_KEY").expect("Failed to retrieve api key, from .env file"); let flush_storage = std::env::var("FLUSH_STORAGE") - .expect("Failed to retrieve flush storage variable, from .env file") + .unwrap_or_default() .parse() .unwrap(); let models = serde_json::from_str( @@ -83,7 +83,7 @@ impl ModelConfig { .parse() .unwrap(); let tracing = std::env::var("TRACING") - .expect("Failed to retrieve tracing variable, from .env file") + .unwrap_or_default() .parse() .unwrap();