diff --git a/Cargo.toml b/Cargo.toml index 1ccca77c..1758ec34 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ candle-flash-attn = { git = "https://github.com/huggingface/candle", package = " candle-nn = { git = "https://github.com/huggingface/candle", package = "candle-nn", version = "0.4.2" } candle-transformers = { git = "https://github.com/huggingface/candle", package = "candle-transformers", version = "0.4.2" } config = "0.14.0" +dotenv = "0.15.0" ed25519-consensus = "2.1.0" futures = "0.3.30" hf-hub = "0.3.2" diff --git a/atoma-inference/Cargo.toml b/atoma-inference/Cargo.toml index 74f6cdaa..c6037821 100644 --- a/atoma-inference/Cargo.toml +++ b/atoma-inference/Cargo.toml @@ -10,7 +10,8 @@ candle.workspace = true candle-flash-attn = { workspace = true, optional = true } candle-nn.workspace = true candle-transformers.workspace = true -config.true = true +config.workspace = true +dotenv.workspace = true ed25519-consensus.workspace = true futures.workspace = true hf-hub.workspace = true diff --git a/atoma-inference/src/apis/hugging_face.rs b/atoma-inference/src/apis/hugging_face.rs index bc80099a..41b002c0 100644 --- a/atoma-inference/src/apis/hugging_face.rs +++ b/atoma-inference/src/apis/hugging_face.rs @@ -1,95 +1,16 @@ use std::path::PathBuf; use async_trait::async_trait; -use hf_hub::api::sync::{Api, ApiBuilder}; +use hf_hub::{ + api::sync::{Api, ApiBuilder}, + Repo, RepoType, +}; +use tracing::error; use crate::models::ModelId; use super::{ApiError, ApiTrait}; -struct FilePaths { - file_paths: Vec, -} - -fn get_model_safe_tensors_from_hf(model_id: &ModelId) -> (String, FilePaths) { - match model_id.as_str() { - "Llama2_7b" => ( - String::from("meta-llama/Llama-2-7b-hf"), - FilePaths { - file_paths: vec![ - "model-00001-of-00002.safetensors".to_string(), - "model-00002-of-00002.safetensors".to_string(), - ], - }, - ), - "Mamba3b" => ( - String::from("state-spaces/mamba-2.8b-hf"), - FilePaths { - file_paths: vec![ - "model-00001-of-00003.safetensors".to_string(), - "model-00002-of-00003.safetensors".to_string(), - "model-00003-of-00003.safetensors".to_string(), - ], - }, - ), - "Mistral7b" => ( - String::from("mistralai/Mistral-7B-Instruct-v0.2"), - FilePaths { - file_paths: vec![ - "model-00001-of-00003.safetensors".to_string(), - "model-00002-of-00003.safetensors".to_string(), - "model-00003-of-00003.safetensors".to_string(), - ], - }, - ), - "Mixtral8x7b" => ( - String::from("mistralai/Mixtral-8x7B-Instruct-v0.1"), - FilePaths { - file_paths: vec![ - "model-00001-of-00019.safetensors".to_string(), - "model-00002-of-00019.safetensors".to_string(), - "model-00003-of-00019.safetensors".to_string(), - "model-00004-of-00019.safetensors".to_string(), - "model-00005-of-00019.safetensors".to_string(), - "model-00006-of-00019.safetensors".to_string(), - "model-00007-of-00019.safetensors".to_string(), - "model-00008-of-00019.safetensors".to_string(), - "model-00009-of-00019.safetensors".to_string(), - "model-000010-of-00019.safetensors".to_string(), - "model-000011-of-00019.safetensors".to_string(), - "model-000012-of-00019.safetensors".to_string(), - "model-000013-of-00019.safetensors".to_string(), - "model-000014-of-00019.safetensors".to_string(), - "model-000015-of-00019.safetensors".to_string(), - "model-000016-of-00019.safetensors".to_string(), - "model-000017-of-00019.safetensors".to_string(), - "model-000018-of-00019.safetensors".to_string(), - "model-000019-of-00019.safetensors".to_string(), - ], - }, - ), - "StableDiffusion2" => ( - String::from("stabilityai/stable-diffusion-2"), - FilePaths { - file_paths: vec!["768-v-ema.safetensors".to_string()], - }, - ), - "StableDiffusionXl" => ( - String::from("stabilityai/stable-diffusion-xl-base-1.0"), - FilePaths { - file_paths: vec![ - "sd_xl_base_1.0.safetensors".to_string(), - "sd_xl_base_1.0_0.9vae.safetensors".to_string(), - "sd_xl_offset_example-lora_1.0.safetensors".to_string(), - ], - }, - ), - _ => { - panic!("Invalid model id") - } - } -} - #[async_trait] impl ApiTrait for Api { fn create(api_key: String, cache_dir: PathBuf) -> Result @@ -103,15 +24,26 @@ impl ApiTrait for Api { .build()?) } - fn fetch(&self, model_id: &ModelId) -> Result, ApiError> { - let (model_path, files) = get_model_safe_tensors_from_hf(model_id); - let api_repo = self.model(model_path); - let mut path_bufs = Vec::with_capacity(files.file_paths.len()); - - for file in files.file_paths { - path_bufs.push(api_repo.get(&file)?); - } - - Ok(path_bufs) + fn fetch(&self, model_id: ModelId, revision: String) -> Result, ApiError> { + let repo = self.repo(Repo::with_revision( + model_id.clone(), + RepoType::Model, + revision, + )); + + Ok(vec![ + repo.get("config.json")?, + if model_id.contains("mamba") { + self.model("EleutherAI/gpt-neox-20b".to_string()) + .get("tokenizer.json") + .map_err(|e| { + error!("Failed to fetch tokenizer file: {e}"); + e + })? + } else { + repo.get("tokenizer.json")? + }, + repo.get("model.safetensors")?, + ]) } } diff --git a/atoma-inference/src/apis/mod.rs b/atoma-inference/src/apis/mod.rs index e6d27941..c2b4ba16 100644 --- a/atoma-inference/src/apis/mod.rs +++ b/atoma-inference/src/apis/mod.rs @@ -22,7 +22,7 @@ impl From for ApiError { } pub trait ApiTrait: Send { - fn fetch(&self, model_id: &ModelId) -> Result, ApiError>; + fn fetch(&self, model_id: ModelId, revision: String) -> Result, ApiError>; fn create(api_key: String, cache_dir: PathBuf) -> Result where Self: Sized; diff --git a/atoma-inference/src/candle/mod.rs b/atoma-inference/src/candle/mod.rs deleted file mode 100644 index 8fe9d9b6..00000000 --- a/atoma-inference/src/candle/mod.rs +++ /dev/null @@ -1,89 +0,0 @@ -pub mod stable_diffusion; -pub mod token_output_stream; - -use std::{fs::File, io::Write, path::PathBuf}; - -use candle::{ - utils::{cuda_is_available, metal_is_available}, - DType, Device, Tensor, -}; -use tracing::info; - -use crate::models::ModelError; - -pub trait CandleModel { - type Fetch; - type Input; - fn fetch(fetch: &Self::Fetch) -> Result<(), ModelError>; - fn inference(input: Self::Input) -> Result, ModelError>; -} - -pub fn device() -> Result { - if cuda_is_available() { - info!("Using CUDA"); - Device::new_cuda(0) - } else if metal_is_available() { - info!("Using Metal"); - Device::new_metal(0) - } else { - info!("Using Cpu"); - Ok(Device::Cpu) - } -} - -pub fn hub_load_safetensors( - repo: &hf_hub::api::sync::ApiRepo, - json_file: &str, -) -> candle::Result> { - let json_file = repo.get(json_file).map_err(candle::Error::wrap)?; - let json_file = std::fs::File::open(json_file)?; - let json: serde_json::Value = - serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?; - let weight_map = match json.get("weight_map") { - None => candle::bail!("no weight map in {json_file:?}"), - Some(serde_json::Value::Object(map)) => map, - Some(_) => candle::bail!("weight map in {json_file:?} is not a map"), - }; - let mut safetensors_files = std::collections::HashSet::new(); - for value in weight_map.values() { - if let Some(file) = value.as_str() { - safetensors_files.insert(file.to_string()); - } - } - let safetensors_files = safetensors_files - .iter() - .map(|v| repo.get(v).map_err(candle::Error::wrap)) - .collect::>>()?; - Ok(safetensors_files) -} - -pub fn save_image>(img: &Tensor, p: P) -> candle::Result<()> { - let p = p.as_ref(); - let (channel, height, width) = img.dims3()?; - if channel != 3 { - candle::bail!("save_image expects an input of shape (3, height, width)") - } - let img = img.permute((1, 2, 0))?.flatten_all()?; - let pixels = img.to_vec1::()?; - let image: image::ImageBuffer, Vec> = - match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) { - Some(image) => image, - None => candle::bail!("error saving image {p:?}"), - }; - image.save(p).map_err(candle::Error::wrap)?; - Ok(()) -} - -pub fn save_tensor_to_file(tensor: &Tensor, filename: &str) -> Result<(), candle::Error> { - let json_output = serde_json::to_string( - &tensor - .to_device(&Device::Cpu)? - .flatten_all()? - .to_dtype(DType::F64)? - .to_vec1::()?, - ) - .unwrap(); - let mut file = File::create(PathBuf::from(filename))?; - file.write_all(json_output.as_bytes())?; - Ok(()) -} diff --git a/atoma-inference/src/lib.rs b/atoma-inference/src/lib.rs index e96cd8bb..4ad5a4d4 100644 --- a/atoma-inference/src/lib.rs +++ b/atoma-inference/src/lib.rs @@ -1,7 +1,6 @@ -pub mod apis; -pub mod candle; pub mod model_thread; -pub mod models; pub mod service; pub mod specs; -pub mod types; + +pub mod apis; +pub mod models; diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index 1a614417..bfb0181d 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -1,16 +1,68 @@ -// use hf_hub::api::sync::Api; -// use inference::service::ModelService; +use std::time::Duration; + +use ed25519_consensus::SigningKey as PrivateKey; +use hf_hub::api::sync::Api; +use inference::{ + models::{ + candle::mamba::MambaModel, + config::ModelConfig, + types::{TextRequest, TextResponse}, + }, + service::{ModelService, ModelServiceError}, +}; #[tokio::main] -async fn main() { +async fn main() -> Result<(), ModelServiceError> { tracing_subscriber::fmt::init(); - // let (_, receiver) = tokio::sync::mpsc::channel(32); + let (req_sender, req_receiver) = tokio::sync::mpsc::channel::(32); + let (resp_sender, mut resp_receiver) = tokio::sync::mpsc::channel::(32); + + let model_config = ModelConfig::from_file_path("../inference.toml".parse().unwrap()); + let private_key_bytes = + std::fs::read("../private_key").map_err(ModelServiceError::PrivateKeyError)?; + let private_key_bytes: [u8; 32] = private_key_bytes + .try_into() + .expect("Incorrect private key bytes length"); + + let private_key = PrivateKey::from(private_key_bytes); + let mut service = ModelService::start::( + model_config, + private_key, + req_receiver, + resp_sender, + ) + .expect("Failed to start inference service"); + + let pk = service.public_key(); + + tokio::spawn(async move { + service.run().await?; + Ok::<(), ModelServiceError>(()) + }); + + tokio::time::sleep(Duration::from_millis(5000)).await; + + req_sender + .send(TextRequest { + request_id: 0, + prompt: "Leon, the professional is a movie".to_string(), + model: "state-spaces/mamba-130m".to_string(), + max_tokens: 512, + temperature: Some(0.0), + random_seed: 42, + repeat_last_n: 64, + repeat_penalty: 1.1, + sampled_nodes: vec![pk], + top_p: Some(1.0), + top_k: 10, + }) + .await + .expect("Failed to send request"); + + if let Some(response) = resp_receiver.recv().await { + println!("Got a response: {:?}", response); + } - // let _ = ModelService::start::( - // "../inference.toml".parse().unwrap(), - // "../private_key".parse().unwrap(), - // receiver, - // ) - // .expect("Failed to start inference service"); + Ok(()) } diff --git a/atoma-inference/src/model_thread.rs b/atoma-inference/src/model_thread.rs index 5e5db91c..287190ec 100644 --- a/atoma-inference/src/model_thread.rs +++ b/atoma-inference/src/model_thread.rs @@ -1,10 +1,13 @@ -use std::{collections::HashMap, sync::mpsc}; +use std::{ + collections::HashMap, + sync::{mpsc, Arc}, +}; use ed25519_consensus::VerificationKey as PublicKey; use futures::stream::FuturesUnordered; use thiserror::Error; use tokio::sync::oneshot::{self, error::RecvError}; -use tracing::{debug, error, warn}; +use tracing::{debug, error, info, warn}; use crate::{ apis::{ApiError, ApiTrait}, @@ -69,7 +72,7 @@ where Req: Request, Resp: Response, { - pub fn run(self, public_key: PublicKey) -> Result<(), ModelThreadError> { + pub fn run(mut self, public_key: PublicKey) -> Result<(), ModelThreadError> { debug!("Start Model thread"); while let Ok(command) = self.receiver.recv() { @@ -112,24 +115,29 @@ where public_key: PublicKey, ) -> Result<(Self, Vec>), ModelThreadError> where - F: ApiTrait, + F: ApiTrait + Send + Sync + 'static, M: ModelTrait + Send + 'static, { let model_ids = config.model_ids(); let api_key = config.api_key(); let storage_path = config.storage_path(); - let api = F::create(api_key, storage_path)?; + let api = Arc::new(F::create(api_key, storage_path)?); let mut handles = Vec::with_capacity(model_ids.len()); let mut model_senders = HashMap::with_capacity(model_ids.len()); - for model_id in model_ids { - let filenames = api.fetch(&model_id)?; + for (model_id, precision, revision) in model_ids { + info!("Spawning new thread for model: {model_id}"); + let api = api.clone(); let (model_sender, model_receiver) = mpsc::channel::>(); + let model_name = model_id.clone(); let join_handle = std::thread::spawn(move || { - let model = M::load(filenames)?; // TODO: for now this piece of code cannot be shared among threads safely + info!("Fetching files for model: {model_name}"); + let filenames = api.fetch(model_name, revision)?; + + let model = M::load(filenames, precision)?; let model_thread = ModelThread { model, receiver: model_receiver, @@ -161,11 +169,11 @@ where fn send(&self, command: ModelThreadCommand) { let request = command.0.clone(); - let model_type = request.requested_model(); + let model_id = request.requested_model(); let sender = self .model_senders - .get(&model_type) + .get(&model_id) .expect("Failed to get model thread, this should not happen !"); if let Err(e) = sender.send(command) { diff --git a/atoma-inference/src/models/candle/mamba.rs b/atoma-inference/src/models/candle/mamba.rs new file mode 100644 index 00000000..fd7b5d25 --- /dev/null +++ b/atoma-inference/src/models/candle/mamba.rs @@ -0,0 +1,215 @@ +use std::{path::PathBuf, time::Instant}; + +use candle::{ + utils::{cuda_is_available, metal_is_available}, + DType, Device, Tensor, +}; +use candle_nn::VarBuilder; +use candle_transformers::{ + generation::LogitsProcessor, + models::mamba::{Config, Model, State}, + utils::apply_repeat_penalty, +}; +use tokenizers::Tokenizer; +use tracing::info; + +use crate::{ + bail, + models::types::{PrecisionBits, TextModelInput}, + models::{token_output_stream::TokenOutputStream, ModelError, ModelId, ModelTrait}, +}; + +pub struct MambaModel { + model: Model, + config: Config, + device: Device, + dtype: DType, + tokenizer: TokenOutputStream, + which: Which, +} + +impl MambaModel { + pub fn new( + model: Model, + config: Config, + device: Device, + dtype: DType, + tokenizer: Tokenizer, + ) -> Self { + let which = Which::from_config(&config); + Self { + model, + config, + device, + dtype, + tokenizer: TokenOutputStream::new(tokenizer), + which, + } + } +} + +impl ModelTrait for MambaModel { + type Input = TextModelInput; + type Output = String; + + fn load(filenames: Vec, precision: PrecisionBits) -> Result + where + Self: Sized, + { + info!("Loading Mamba model ..."); + + let start = Instant::now(); + + let config_filename = filenames[0].clone(); + let tokenizer_filename = filenames[1].clone(); + let weights_filenames = filenames[2..].to_vec(); + + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(ModelError::BoxedError)?; + + let config: Config = + serde_json::from_slice(&std::fs::read(config_filename).map_err(ModelError::IoError)?) + .map_err(ModelError::DeserializeError)?; + let device = if cuda_is_available() { + Device::new_cuda(0).map_err(ModelError::CandleError)? + } else if metal_is_available() { + Device::new_metal(0).map_err(ModelError::CandleError)? + } else { + Device::Cpu + }; + let dtype = precision.into_dtype(); + + info!("Loading model weights.."); + let var_builder = + unsafe { VarBuilder::from_mmaped_safetensors(&weights_filenames, dtype, &device)? }; + let model = Model::new(&config, var_builder.pp("backbone"))?; + info!("Loaded Mamba model in {:?}", start.elapsed()); + + Ok(Self::new(model, config, device, dtype, tokenizer)) + } + + fn model_id(&self) -> ModelId { + self.which.model_id().to_string() + } + + fn run(&mut self, input: Self::Input) -> Result { + let TextModelInput { + prompt, + temperature, + random_seed, + repeat_penalty, + repeat_last_n, + max_tokens, + top_p, + .. + } = input; + + info!("Running inference on prompt: {:?}", prompt); + + self.tokenizer.clear(); + let mut tokens = self + .tokenizer + .tokenizer() + .encode(prompt, true) + .map_err(ModelError::BoxedError)? + .get_ids() + .to_vec(); + let mut logits_processor = + LogitsProcessor::new(random_seed, Some(temperature), Some(top_p)); + + let mut generated_tokens = 0_usize; + let eos_token = match self.tokenizer.get_token("<|endoftext|>") { + Some(token) => token, + None => bail!("Invalid eos token"), + }; + + let mut state = State::new(1, &self.config, &self.device)?; // TODO: handle larger batch sizes + + let mut next_logits = None; + let mut output = String::new(); + + for &token in tokens.iter() { + let input = Tensor::new(&[token], &self.device)?; + let logits = self.model.forward(&input, &mut state)?; + + next_logits = Some(logits); + if let Some(t) = self.tokenizer.next_token(token)? { + output.push_str(t.as_str()); + } + } + + let start_gen = Instant::now(); + for _ in 0..max_tokens { + let logits = match next_logits.as_ref() { + Some(logits) => logits, + None => bail!("cannot work on an empty prompt"), + }; + + let logits = logits.squeeze(0)?.to_dtype(self.dtype)?; + let logits = if repeat_penalty == 1.0 { + logits + } else { + let start_at = tokens.len().saturating_sub(repeat_last_n); + apply_repeat_penalty(&logits, repeat_penalty, &tokens[start_at..])? + }; + + let next_token = logits_processor.sample(&logits)?; + tokens.push(next_token); + generated_tokens += 1; + + if next_token == eos_token { + break; + } + + if let Some(t) = self.tokenizer.next_token(next_token)? { + output.push_str(t.as_str()); + } + + let input = Tensor::new(&[next_token], &self.device)?; + next_logits = Some(self.model.forward(&input, &mut state)?); + } + let dt = start_gen.elapsed(); + if let Some(rest) = self.tokenizer.decode_rest()? { + output.push_str(rest.as_str()); + } + + info!( + "\n{generated_tokens} tokens generated ({:.2} token/s)", + generated_tokens as f64 / dt.as_secs_f64(), + ); + Ok(output) + } +} + +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +enum Which { + Mamba130m, + Mamba370m, + Mamba790m, + Mamba1_4b, + Mamba2_8b, + // Mamba2_8bSlimPj, TODO: add this +} + +impl Which { + fn model_id(&self) -> &'static str { + match self { + Self::Mamba130m => "state-spaces/mamba-130m", + Self::Mamba370m => "state-spaces/mamba-370m", + Self::Mamba790m => "state-spaces/mamba-790m", + Self::Mamba1_4b => "state-spaces/mamba-1.4b", + Self::Mamba2_8b => "state-spaces/mamba-2.8b", + // Self::Mamba2_8bSlimPj => "state-spaces/mamba-2.8b-slimpj'", + } + } + + fn from_config(config: &Config) -> Self { + match config.d_model { + 768 => Self::Mamba130m, + 1024 => Self::Mamba370m, + 1536 => Self::Mamba790m, + 2048 => Self::Mamba1_4b, + 2560 => Self::Mamba2_8b, + _ => panic!("Invalid config d_model value"), + } + } +} diff --git a/atoma-inference/src/models/candle/mod.rs b/atoma-inference/src/models/candle/mod.rs new file mode 100644 index 00000000..b053452e --- /dev/null +++ b/atoma-inference/src/models/candle/mod.rs @@ -0,0 +1,2 @@ +pub mod mamba; +pub mod stable_diffusion; diff --git a/atoma-inference/src/candle/stable_diffusion.rs b/atoma-inference/src/models/candle/stable_diffusion.rs similarity index 95% rename from atoma-inference/src/candle/stable_diffusion.rs rename to atoma-inference/src/models/candle/stable_diffusion.rs index 46dadbc1..0b93fa62 100644 --- a/atoma-inference/src/candle/stable_diffusion.rs +++ b/atoma-inference/src/models/candle/stable_diffusion.rs @@ -4,14 +4,17 @@ extern crate accelerate_src; #[cfg(feature = "mkl")] extern crate intel_mkl_src; +use std::path::PathBuf; + use candle_transformers::models::stable_diffusion::{self}; use candle::{DType, Device, IndexOp, Module, Tensor, D}; use tokenizers::Tokenizer; -use crate::{candle::device, models::ModelError}; - -use super::{save_tensor_to_file, CandleModel}; +use crate::models::{ + types::{PrecisionBits, TextModelInput}, + ModelError, ModelTrait, +}; pub struct Input { prompt: String, @@ -108,11 +111,11 @@ pub struct Fetch { unet_weights: Option, } -impl CandleModel for StableDiffusion { - type Input = Input; - type Fetch = Fetch; +impl ModelTrait for StableDiffusion { + type Input = TextModelInput; + type Output = Vec; - fn fetch(fetch: &Self::Fetch) -> Result<(), ModelError> { + fn load(filenames: Vec, precision: PrecisionBits) -> Result { let which = match fetch.sd_version { StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => vec![true, false], _ => vec![true], @@ -132,7 +135,7 @@ impl CandleModel for StableDiffusion { Ok(()) } - fn inference(input: Self::Input) -> Result, ModelError> { + fn run(&mut self, input: Self::Input) -> Result { if !(0. ..=1.).contains(&input.img2img_strength) { Err(ModelError::Config(format!( "img2img_strength must be between 0 and 1, got {}", @@ -536,3 +539,17 @@ impl StableDiffusion { Ok(img) } } + +pub fn save_tensor_to_file(tensor: &Tensor, filename: &str) -> Result<(), candle::Error> { + let json_output = serde_json::to_string( + &tensor + .to_device(&Device::Cpu)? + .flatten_all()? + .to_dtype(DType::F64)? + .to_vec1::()?, + ) + .unwrap(); + let mut file = std::fs::File::create(PathBuf::from(filename))?; + file.write_all(json_output.as_bytes())?; + Ok(()) +} diff --git a/atoma-inference/src/models/config.rs b/atoma-inference/src/models/config.rs index ff5bf9a6..e5790163 100644 --- a/atoma-inference/src/models/config.rs +++ b/atoma-inference/src/models/config.rs @@ -1,14 +1,18 @@ use std::path::PathBuf; use config::Config; +use dotenv::dotenv; use serde::{Deserialize, Serialize}; -use crate::models::ModelId; +use crate::{models::types::PrecisionBits, models::ModelId}; + +type Revision = String; #[derive(Debug, Deserialize, Serialize)] pub struct ModelConfig { api_key: String, - models: Vec, + flush_storage: bool, + models: Vec<(ModelId, PrecisionBits, Revision)>, storage_path: PathBuf, tracing: bool, } @@ -16,12 +20,14 @@ pub struct ModelConfig { impl ModelConfig { pub fn new( api_key: String, - models: Vec, + flush_storage: bool, + models: Vec<(ModelId, PrecisionBits, Revision)>, storage_path: PathBuf, tracing: bool, ) -> Self { Self { api_key, + flush_storage, models, storage_path, tracing, @@ -32,7 +38,11 @@ impl ModelConfig { self.api_key.clone() } - pub fn model_ids(&self) -> Vec { + pub fn flush_storage(&self) -> bool { + self.flush_storage + } + + pub fn model_ids(&self) -> Vec<(ModelId, PrecisionBits, Revision)> { self.models.clone() } @@ -55,6 +65,36 @@ impl ModelConfig { .try_deserialize::() .expect("Failed to generated config file") } + + pub fn from_env_file() -> Self { + dotenv().ok(); + + let api_key = std::env::var("API_KEY").expect("Failed to retrieve api key, from .env file"); + let flush_storage = std::env::var("FLUSH_STORAGE") + .unwrap_or_default() + .parse() + .unwrap(); + let models = serde_json::from_str( + &std::env::var("MODELS").expect("Failed to retrieve models metadata, from .env file"), + ) + .unwrap(); + let storage_path = std::env::var("STORAGE_PATH") + .expect("Failed to retrieve storage path, from .env file") + .parse() + .unwrap(); + let tracing = std::env::var("TRACING") + .unwrap_or_default() + .parse() + .unwrap(); + + Self { + api_key, + flush_storage, + models, + storage_path, + tracing, + } + } } #[cfg(test)] @@ -65,13 +105,14 @@ pub mod tests { fn test_config() { let config = ModelConfig::new( String::from("my_key"), - vec!["Llama2_7b".to_string()], + true, + vec![("Llama2_7b".to_string(), PrecisionBits::F16, "".to_string())], "storage_path".parse().unwrap(), true, ); let toml_str = toml::to_string(&config).unwrap(); - let should_be_toml_str = "api_key = \"my_key\"\nmodels = [\"Llama2_7b\"]\nstorage_path = \"storage_path\"\ntracing = true\n"; + let should_be_toml_str = "api_key = \"my_key\"\nflush_storage = true\nmodels = [[\"Llama2_7b\", \"F16\", \"\"]]\nstorage_path = \"storage_path\"\ntracing = true\n"; assert_eq!(toml_str, should_be_toml_str); } } diff --git a/atoma-inference/src/models/mod.rs b/atoma-inference/src/models/mod.rs index 72aed1f3..3fe8c7b3 100644 --- a/atoma-inference/src/models/mod.rs +++ b/atoma-inference/src/models/mod.rs @@ -1,27 +1,27 @@ use std::path::PathBuf; +use ::candle::Error as CandleError; use ed25519_consensus::VerificationKey as PublicKey; use thiserror::Error; +use crate::models::types::PrecisionBits; + +pub mod candle; pub mod config; +pub mod token_output_stream; +pub mod types; pub type ModelId = String; -pub trait ModelBuilder { - fn try_from_file(path: PathBuf) -> Result - where - Self: Sized; -} - pub trait ModelTrait { type Input; type Output; - fn load(filenames: Vec) -> Result + fn load(filenames: Vec, precision: PrecisionBits) -> Result where Self: Sized; fn model_id(&self) -> ModelId; - fn run(&self, input: Self::Input) -> Result; + fn run(&mut self, input: Self::Input) -> Result; } pub trait Request: Send + 'static { @@ -42,7 +42,7 @@ pub trait Response: Send + 'static { #[derive(Debug, Error)] pub enum ModelError { #[error("Candle error: `{0}`")] - CandleError(#[from] candle::Error), + CandleError(#[from] CandleError), #[error("Config error: `{0}`")] Config(String), #[error("Image error: `{0}`")] @@ -55,4 +55,25 @@ pub enum ModelError { Msg(String), #[error("ApiError error: `{0}`")] ApiError(#[from] hf_hub::api::sync::ApiError), + #[error("Deserialize error: `{0}`")] + DeserializeError(serde_json::Error), +} + +// impl From for ModelError { +// fn from(error: CandleError) -> Self { +// Self::CandleError(error) +// } +// } + +#[macro_export] +macro_rules! bail { + ($msg:literal $(,)?) => { + return Err(ModelError::Msg(format!($msg).into())) + }; + ($err:expr $(,)?) => { + return Err(ModelError::Msg(format!($err).into()).bt()) + }; + ($fmt:expr, $($arg:tt)*) => { + return Err(ModelError::Msg(format!($fmt, $($arg)*).into()).bt()) + }; } diff --git a/atoma-inference/src/candle/token_output_stream.rs b/atoma-inference/src/models/token_output_stream.rs similarity index 88% rename from atoma-inference/src/candle/token_output_stream.rs rename to atoma-inference/src/models/token_output_stream.rs index 1f507c5e..33bfb27a 100644 --- a/atoma-inference/src/candle/token_output_stream.rs +++ b/atoma-inference/src/models/token_output_stream.rs @@ -1,4 +1,4 @@ -use candle::Result; +use crate::{bail, models::ModelError}; /// This is a wrapper around a tokenizer to ensure that tokens can be returned to the user in a /// streaming way rather than having to wait for the full decoding. @@ -23,15 +23,15 @@ impl TokenOutputStream { self.tokenizer } - fn decode(&self, tokens: &[u32]) -> Result { + fn decode(&self, tokens: &[u32]) -> Result { match self.tokenizer.decode(tokens, true) { Ok(str) => Ok(str), - Err(err) => candle::bail!("cannot decode: {err}"), + Err(err) => bail!("cannot decode: {err}"), } } // https://github.com/huggingface/text-generation-inference/blob/5ba53d44a18983a4de32d122f4cb46f4a17d9ef6/server/text_generation_server/models/model.py#L68 - pub fn next_token(&mut self, token: u32) -> Result> { + pub fn next_token(&mut self, token: u32) -> Result, ModelError> { let prev_text = if self.tokens.is_empty() { String::new() } else { @@ -50,7 +50,7 @@ impl TokenOutputStream { } } - pub fn decode_rest(&self) -> Result> { + pub fn decode_rest(&self) -> Result, ModelError> { let prev_text = if self.tokens.is_empty() { String::new() } else { @@ -66,7 +66,7 @@ impl TokenOutputStream { } } - pub fn decode_all(&self) -> Result { + pub fn decode_all(&self) -> Result { self.decode(&self.tokens) } diff --git a/atoma-inference/src/models/types.rs b/atoma-inference/src/models/types.rs new file mode 100644 index 00000000..f3d5fa05 --- /dev/null +++ b/atoma-inference/src/models/types.rs @@ -0,0 +1,217 @@ +use candle::DType; +use ed25519_consensus::VerificationKey as PublicKey; +use serde::{Deserialize, Serialize}; + +use crate::models::{ModelId, Request, Response}; + +pub type NodeId = PublicKey; + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct TextRequest { + pub request_id: usize, + pub prompt: String, + pub model: ModelId, + pub max_tokens: usize, + pub random_seed: usize, + pub repeat_last_n: usize, + pub repeat_penalty: f32, + pub sampled_nodes: Vec, + pub temperature: Option, + pub top_k: usize, + pub top_p: Option, +} + +impl Request for TextRequest { + type ModelInput = TextModelInput; + + fn into_model_input(self) -> Self::ModelInput { + TextModelInput::new( + self.prompt, + self.temperature.unwrap_or_default() as f64, + self.random_seed as u64, + self.repeat_penalty, + self.repeat_last_n, + self.max_tokens, + self.top_k, + self.top_p.unwrap_or_default() as f64, + ) + } + + fn request_id(&self) -> usize { + self.request_id + } + + fn is_node_authorized(&self, public_key: &PublicKey) -> bool { + self.sampled_nodes.contains(public_key) + } + + fn requested_model(&self) -> ModelId { + self.model.clone() + } +} + +pub struct TextModelInput { + pub(crate) prompt: String, + pub(crate) temperature: f64, + pub(crate) random_seed: u64, + pub(crate) repeat_penalty: f32, + pub(crate) repeat_last_n: usize, + pub(crate) max_tokens: usize, + pub(crate) _top_k: usize, + pub(crate) top_p: f64, +} + +impl TextModelInput { + #[allow(clippy::too_many_arguments)] + pub fn new( + prompt: String, + temperature: f64, + random_seed: u64, + repeat_penalty: f32, + repeat_last_n: usize, + max_tokens: usize, + _top_k: usize, + top_p: f64, + ) -> Self { + Self { + prompt, + temperature, + random_seed, + repeat_penalty, + repeat_last_n, + max_tokens, + _top_k, + top_p, + } + } +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct TextResponse { + pub output: String, + pub is_success: bool, + pub status: String, +} + +impl Response for TextResponse { + type ModelOutput = String; + + fn from_model_output(model_output: Self::ModelOutput) -> Self { + Self { + output: model_output, + is_success: true, + status: "Successful".to_string(), + } + } +} + +pub struct ImageRequest { + pub request_id: usize, + pub prompt: String, + pub uncond_prompt: String, + pub height: Option, + pub width: Option, + pub model: ModelId, + pub random_seed: usize, + pub sampled_nodes: Vec, + pub temperature: Option, + pub num_samples: i64, + /// The strength, indicates how much to transform the initial image. The + /// value must be between 0 and 1, a value of 1 discards the initial image + /// information. + pub img2img_strength: f64, + + + /// The size of the sliced attention or 0 for automatic slicing (disabled by default) + sliced_attention_size: Option, + + /// The number of steps to run the diffusion for. + n_steps: Option, + + /// The number of samples to generate. + num_samples: i64, + + sd_version: StableDiffusionVersion, + + use_flash_attn: bool, + + use_f16: bool, + + guidance_scale: Option, + + img2img: Option, + +} + +impl Request for ImageRequest { + type ModelInput = TextModelInput; + + fn into_model_input(self) -> Self::ModelInput { + Self::ModelInput::new( + self.prompt, + self.temperature.unwrap_or_default() as f64, + self.random_seed as u64, + self.repeat_penalty, + self.repeat_last_n, + self.max_tokens, + self.top_k, + self.top_p.unwrap_or_default() as f64, + ) + } + + fn is_node_authorized(&self, public_key: &PublicKey) -> bool { + self.sampled_nodes.contains(public_key) + } + + fn request_id(&self) -> usize { + self.request_id + } + + fn requested_model(&self) -> ModelId { + self.model.clone() + } +} + +pub struct ImageResponse { + pub output: Vec, + pub is_success: bool, + pub status: String, +} + +impl Response for ImageResponse { + type ModelOutput = Vec; + + fn from_model_output(model_output: Self::ModelOutput) -> Self { + Self { + output: model_output, + is_success: true, + status: "Successful".to_string(), + } + } +} + +#[derive(Copy, Clone, Debug, Deserialize, Serialize)] +pub enum PrecisionBits { + BF16, + F16, + F32, + F64, + I64, + U8, + U32, +} + +impl PrecisionBits { + #[allow(dead_code)] + pub(crate) fn into_dtype(self) -> DType { + match self { + Self::BF16 => DType::BF16, + Self::F16 => DType::F16, + Self::F32 => DType::F32, + Self::F64 => DType::F64, + Self::I64 => DType::I64, + Self::U8 => DType::U8, + Self::U32 => DType::U32, + } + } +} diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index a01bd5f4..1d13024f 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -1,8 +1,8 @@ use candle::Error as CandleError; -use ed25519_consensus::SigningKey as PrivateKey; +use ed25519_consensus::{SigningKey as PrivateKey, VerificationKey as PublicKey}; use futures::StreamExt; use std::{io, path::PathBuf, time::Instant}; -use tokio::sync::mpsc::Receiver; +use tokio::sync::mpsc::{Receiver, Sender}; use tracing::{error, info}; use thiserror::Error; @@ -21,7 +21,11 @@ where model_thread_handle: Vec>, dispatcher: ModelThreadDispatcher, start_time: Instant, + flush_storage: bool, + public_key: PublicKey, + storage_path: PathBuf, request_receiver: Receiver, + response_sender: Sender, } impl ModelService @@ -30,23 +34,19 @@ where Resp: std::fmt::Debug + Response, { pub fn start( - config_file_path: PathBuf, - private_key_path: PathBuf, + model_config: ModelConfig, + private_key: PrivateKey, request_receiver: Receiver, + response_sender: Sender, ) -> Result where M: ModelTrait + Send + 'static, - F: ApiTrait, + F: ApiTrait + Send + Sync + 'static, { - let private_key_bytes = - std::fs::read(private_key_path).map_err(ModelServiceError::PrivateKeyError)?; - let private_key_bytes: [u8; 32] = private_key_bytes - .try_into() - .expect("Incorrect private key bytes length"); - - let private_key = PrivateKey::from(private_key_bytes); let public_key = private_key.verification_key(); - let model_config = ModelConfig::from_file_path(config_file_path); + + let flush_storage = model_config.flush_storage(); + let storage_path = model_config.storage_path(); let (dispatcher, model_thread_handle) = ModelThreadDispatcher::start::(model_config, public_key) @@ -57,7 +57,11 @@ where dispatcher, model_thread_handle, start_time, + flush_storage, + storage_path, + public_key, request_receiver, + response_sender, }) } @@ -74,6 +78,7 @@ where match resp { Ok(response) => { info!("Received a new inference response: {:?}", response); + self.response_sender.send(response).await.map_err(|e| ModelServiceError::SendError(e.to_string()))?; } Err(e) => { error!("Found error in generating inference response: {e}"); @@ -84,6 +89,10 @@ where } } } + + pub fn public_key(&self) -> PublicKey { + self.public_key + } } impl ModelService @@ -97,6 +106,13 @@ where self.start_time.elapsed() ); + if self.flush_storage { + match std::fs::remove_dir(self.storage_path) { + Ok(()) => {} + Err(e) => error!("Failed to remove storage folder, on shutdown: {e}"), + }; + } + let _ = self .model_thread_handle .drain(..) @@ -108,7 +124,7 @@ where #[derive(Debug, Error)] pub enum ModelServiceError { #[error("Failed to run inference: `{0}`")] - FailedInference(Box), + FailedInference(Box), #[error("Failed to fecth model: `{0}`")] FailedModelFetch(String), #[error("Failed to generate private key: `{0}`")] @@ -119,6 +135,8 @@ pub enum ModelServiceError { ApiError(ApiError), #[error("Candle error: `{0}`")] CandleError(CandleError), + #[error("Sender error: `{0}`")] + SendError(String), } impl From for ModelServiceError { @@ -140,7 +158,7 @@ mod tests { use std::io::Write; use toml::{toml, Value}; - use crate::models::ModelId; + use crate::{models::types::PrecisionBits, models::ModelId}; use super::*; @@ -154,7 +172,7 @@ mod tests { Ok(Self {}) } - fn fetch(&self, _: &ModelId) -> Result, ApiError> { + fn fetch(&self, _: ModelId, _: String) -> Result, ApiError> { Ok(vec![]) } } @@ -162,9 +180,7 @@ mod tests { impl Request for () { type ModelInput = (); - fn into_model_input(self) -> Self::ModelInput { - () - } + fn into_model_input(self) -> Self::ModelInput {} fn is_node_authorized(&self, _: &PublicKey) -> bool { true @@ -182,9 +198,7 @@ mod tests { impl Response for () { type ModelOutput = (); - fn from_model_output(_: Self::ModelOutput) -> Self { - () - } + fn from_model_output(_: Self::ModelOutput) -> Self {} } #[derive(Clone)] @@ -194,7 +208,7 @@ mod tests { type Input = (); type Output = (); - fn load(_: Vec) -> Result { + fn load(_: Vec, _: PrecisionBits) -> Result { Ok(Self {}) } @@ -202,7 +216,7 @@ mod tests { String::from("") } - fn run(&self, _: Self::Input) -> Result { + fn run(&mut self, _: Self::Input) -> Result { Ok(()) } } @@ -210,16 +224,15 @@ mod tests { #[tokio::test] async fn test_inference_service_initialization() { const CONFIG_FILE_PATH: &str = "./inference.toml"; - const PRIVATE_KEY_FILE_PATH: &str = "./private_key"; - let private_key = PrivateKey::new(&mut OsRng); - std::fs::write(PRIVATE_KEY_FILE_PATH, private_key.to_bytes()).unwrap(); + let private_key = PrivateKey::new(OsRng); let config_data = Value::Table(toml! { api_key = "your_api_key" - models = ["Mamba3b"] + models = [["Mamba3b", "F16", "", ""]] storage_path = "./storage_path/" tokenizer_file_path = "./tokenizer_file_path/" + flush_storage = true tracing = true }); let toml_string = @@ -229,16 +242,19 @@ mod tests { file.write_all(toml_string.as_bytes()) .expect("Failed to write to file"); - let (_, receiver) = tokio::sync::mpsc::channel::<()>(1); + let (_, req_receiver) = tokio::sync::mpsc::channel::<()>(1); + let (resp_sender, _) = tokio::sync::mpsc::channel::<()>(1); + + let config = ModelConfig::from_file_path(CONFIG_FILE_PATH.parse().unwrap()); let _ = ModelService::<(), ()>::start::( - PathBuf::try_from(CONFIG_FILE_PATH).unwrap(), - PathBuf::try_from(PRIVATE_KEY_FILE_PATH).unwrap(), - receiver, + config, + private_key, + req_receiver, + resp_sender, ) .unwrap(); std::fs::remove_file(CONFIG_FILE_PATH).unwrap(); - std::fs::remove_file(PRIVATE_KEY_FILE_PATH).unwrap(); } } diff --git a/atoma-inference/src/types.rs b/atoma-inference/src/types.rs deleted file mode 100644 index 82643c19..00000000 --- a/atoma-inference/src/types.rs +++ /dev/null @@ -1,65 +0,0 @@ -use candle::DType; -use ed25519_consensus::VerificationKey; -use serde::{Deserialize, Serialize}; - -use crate::models::ModelId; - -pub type NodeId = VerificationKey; -pub type Temperature = f32; - -#[derive(Clone, Debug)] -pub struct InferenceRequest { - pub request_id: u128, - pub prompt: String, - pub model: ModelId, - pub max_tokens: usize, - pub random_seed: usize, - pub repeat_last_n: usize, - pub repeat_penalty: f32, - pub sampled_nodes: Vec, - pub temperature: Option, - pub top_k: usize, - pub top_p: Option, -} - -#[derive(Clone, Debug)] -#[allow(dead_code)] -pub struct InferenceResponse { - // TODO: possibly a Merkle root hash - // pub(crate) response_hash: [u8; 32], - // pub(crate) node_id: NodeId, - // pub(crate) signature: Vec, - pub(crate) response: String, -} - -#[derive(Clone, Debug)] -pub enum QuantizationMethod { - Ggml(PrecisionBits), - Gptq(PrecisionBits), -} - -#[derive(Copy, Clone, Debug, Deserialize, Serialize)] -pub enum PrecisionBits { - BF16, - F16, - F32, - F64, - I64, - U8, - U32, -} - -impl PrecisionBits { - #[allow(dead_code)] - pub(crate) fn into_dtype(self) -> DType { - match self { - Self::BF16 => DType::BF16, - Self::F16 => DType::F16, - Self::F32 => DType::F32, - Self::F64 => DType::F64, - Self::I64 => DType::I64, - Self::U8 => DType::U8, - Self::U32 => DType::U32, - } - } -}