From 343dfe548da810b512ca340315fa7726b6347a57 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Fri, 5 Apr 2024 10:19:36 +0100 Subject: [PATCH 01/21] first commit --- atoma-inference/src/models/candle/falcon.rs | 60 +++++--- atoma-inference/src/models/candle/llama.rs | 158 ++++++++------------ atoma-inference/src/models/candle/mamba.rs | 58 ++++--- atoma-inference/src/models/candle/mod.rs | 2 +- atoma-inference/src/models/config.rs | 10 +- atoma-inference/src/models/mod.rs | 14 +- atoma-inference/src/models/types.rs | 21 ++- atoma-inference/src/service.rs | 8 +- 8 files changed, 180 insertions(+), 151 deletions(-) diff --git a/atoma-inference/src/models/candle/falcon.rs b/atoma-inference/src/models/candle/falcon.rs index 5ec797e1..f1a14847 100644 --- a/atoma-inference/src/models/candle/falcon.rs +++ b/atoma-inference/src/models/candle/falcon.rs @@ -10,13 +10,11 @@ use candle_transformers::{ models::falcon::{Config, Falcon}, utils::apply_repeat_penalty, }; +use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; use tokenizers::Tokenizer; use tracing::{debug, info}; -use crate::{ - models::types::{PrecisionBits, TextModelInput}, - models::{ModelError, ModelId, ModelTrait}, -}; +use crate::models::{types::{LlmFetchData, LlmLoadData, TextModelInput}, ModelError, ModelId, ModelTrait}; pub struct FalconModel { model: Falcon, @@ -46,19 +44,40 @@ impl FalconModel { } impl ModelTrait for FalconModel { - type Fetch = (); + type FetchData = LlmFetchData; type Input = TextModelInput; type Output = String; - type Load = PrecisionBits; - - fn fetch(_fetch: &Self::Fetch) -> Result<(), ModelError> { - Ok(()) + type LoadData = LlmLoadData; + + fn fetch(fetch_data: &Self::FetchData) -> Result, ModelError> { + let api_key = fetch_data.api_key; + let cache_dir = fetch_data.cache_dir; + + let api = ApiBuilder::new() + .with_progress(true) + .with_token(Some(api_key)) + .with_cache_dir(cache_dir) + .build()?; + + let repo = api.repo(Repo::with_revision( + fetch_data.model_id.clone(), + RepoType::Model, + fetch_data.revision, + )); + + let config_file_path = repo.get("config.json")?; + let tokenizer_file_path = repo.get("tokenizer.json")?; + let model_weights_file_path = repo.get("model.safetensors")?; + + Ok(vec![ + config_file_path, + tokenizer_file_path, + model_weights_file_path, + ]) } fn load( - filenames: Vec, - precision: Self::Load, - device_id: usize, + load_data: Self::LoadData ) -> Result where Self: Sized, @@ -67,9 +86,9 @@ impl ModelTrait for FalconModel { let start = Instant::now(); - let config_filename = filenames[0].clone(); - let tokenizer_filename = filenames[1].clone(); - let weights_filenames = filenames[2..].to_vec(); + let config_filename = load_data.file_paths[0].clone(); + let tokenizer_filename = load_data.file_paths[1].clone(); + let weights_filenames = load_data.file_paths[2..].to_vec(); let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(ModelError::BoxedError)?; @@ -79,24 +98,23 @@ impl ModelTrait for FalconModel { config.validate()?; let device = if cuda_is_available() { - Device::new_cuda(device_id).map_err(ModelError::CandleError)? + Device::new_cuda(load_data.device_id).map_err(ModelError::CandleError)? } else if metal_is_available() { - Device::new_metal(device_id).map_err(ModelError::CandleError)? + Device::new_metal(load_data.device_id).map_err(ModelError::CandleError)? } else { Device::Cpu }; - let dtype = precision.into_dtype(); - if dtype != DType::BF16 || dtype != DType::F32 { + if load_data.dtype != DType::BF16 || load_data.dtype != DType::F32 { panic!("Invalid dtype, it must be either BF16 or F32 precision"); } let vb = - unsafe { VarBuilder::from_mmaped_safetensors(&weights_filenames, dtype, &device)? }; + unsafe { VarBuilder::from_mmaped_safetensors(&weights_filenames, load_data.dtype, &device)? }; let model = Falcon::load(vb, config.clone())?; info!("loaded the model in {:?}", start.elapsed()); - Ok(Self::new(model, config, device, dtype, tokenizer)) + Ok(Self::new(model, config, device, load_data.dtype, tokenizer)) } fn model_id(&self) -> ModelId { diff --git a/atoma-inference/src/models/candle/llama.rs b/atoma-inference/src/models/candle/llama.rs index 3b6b4079..9f0acdea 100644 --- a/atoma-inference/src/models/candle/llama.rs +++ b/atoma-inference/src/models/candle/llama.rs @@ -4,22 +4,21 @@ extern crate accelerate_src; #[cfg(feature = "mkl")] extern crate intel_mkl_src; -use std::path::PathBuf; - -use candle::{Device, Tensor}; +use candle::{DType, Device, Tensor}; use candle_nn::VarBuilder; use candle_transformers::{ generation::LogitsProcessor, models::llama::{Cache, LlamaConfig}, }; -use hf_hub::{api::sync::Api, Repo, RepoType}; +use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; use candle_transformers::models::llama as model; -use serde::Deserialize; use tokenizers::Tokenizer; use crate::models::{ - token_output_stream::TokenOutputStream, types::PrecisionBits, ModelError, ModelTrait, + token_output_stream::TokenOutputStream, + types::{LlmFetchData, LlmLoadData, TextModelInput}, + ModelError, ModelId, ModelTrait, }; use super::{device, hub_load_safetensors}; @@ -37,111 +36,78 @@ enum Which { pub struct Config {} -pub struct Llama { +pub struct LlamaModel { + cache: Cache, device: Device, + dtype: DType, + model: model::Llama, + model_id: ModelId, tokenizer: Tokenizer, - llama: model::Llama, - cache: Cache, -} - -#[derive(Deserialize)] -pub struct Input { - prompt: String, - temperature: Option, - top_p: Option, - seed: u64, - sample_len: usize, - repeat_penalty: f32, - repeat_last_n: usize, -} - -impl Input { - pub fn default_prompt(prompt: String) -> Self { - Self { - prompt, - temperature: None, - top_p: None, - seed: 0, - sample_len: 10000, - repeat_penalty: 1., - repeat_last_n: 64, - } - } } -pub struct Fetch { - model_id: Option, - revision: Option, - which: Which, -} +impl ModelTrait for LlamaModel { + type Input = TextModelInput; + type FetchData = LlmFetchData; + type Output = String; + type LoadData = LlmLoadData; + + fn fetch(fetch_data: &Self::FetchData) -> Result<(), ModelError> { + let api = ApiBuilder::new() + .with_progress(true) + .with_token(Some(fetch_data.api_key)) + .with_cache_dir(fetch_data.cache_dir) + .build()?; + + let api = api.repo(Repo::with_revision( + fetch_data.model_id, + RepoType::Model, + fetch_data.revision, + )); + let config_file_path = api.get("tokenizer.json")?; + let tokenizer_file_path = api.get("config.json")?; + + let model_weights_file_paths = + if &fetch_data.model_id == "TinyLlama/TinyLlama-1.1B-Chat-v1.0" { + vec![api.get("model.safetensors")?] + } else { + hub_load_safetensors(&api, "model.safetensors.index.json")? + }; -impl Default for Fetch { - fn default() -> Self { - Self { - model_id: None, - revision: None, - which: Which::TinyLlama1_1BChat, - } - } -} + let mut output = Vec::with_capacity(2 + model_weights_file_paths.len()); + output.extend(vec![config_file_path, tokenizer_file_path]); + output.extend(model_weights_file_paths); -impl ModelTrait for Llama { - type Input = Input; - type Fetch = Fetch; - type Output = String; - type Load = PrecisionBits; - - fn fetch(fetch: &Self::Fetch) -> Result<(), ModelError> { - let api = Api::new()?; - let model_id = fetch.model_id.clone().unwrap_or_else(|| match fetch.which { - Which::V1 => "Narsil/amall-7b".to_string(), - Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(), - Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(), - Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(), - }); - let revision = fetch.revision.clone().unwrap_or("main".to_string()); - let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision)); - api.get("tokenizer.json")?; - api.get("config.json")?; - match fetch.which { - Which::V1 | Which::V2 | Which::Solar10_7B => { - hub_load_safetensors(&api, "model.safetensors.index.json")?; - } - Which::TinyLlama1_1BChat => { - api.get("model.safetensors")?; - } - }; - Ok(()) + Ok(output) } - fn model_id(&self) -> crate::models::ModelId { - "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string() + fn model_id(&self) -> ModelId { + self.model_id.clone() } - fn load( - filenames: Vec, - precision: PrecisionBits, - device_id: usize, - ) -> Result { - let device = device(device_id)?; - let dtype = precision.into_dtype(); - let (llama, tokenizer_filename, cache) = { - let tokenizer_filename = filenames[0].clone(); - let config_filename = filenames[1].clone(); + fn load(load_data: Self::LoadData) -> Result { + let device = device(load_data.device_id)?; + let dtype = load_data.dtype; + let (model, tokenizer_filename, cache) = { + let config_filename = load_data.file_paths[0].clone(); let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?; - let config = config.into_config(false); // TODO: use from config - let vb = - unsafe { VarBuilder::from_mmaped_safetensors(&filenames[2..], dtype, &device)? }; + let tokenizer_filename = load_data.file_paths[1].clone(); + let config = config.into_config(load_data.use_flash_attention); + + let vb = unsafe { + VarBuilder::from_mmaped_safetensors(&load_data.file_paths[2..], dtype, &device)? + }; let cache = model::Cache::new(true, dtype, &config, &device)?; // TODO: use from config (model::Llama::load(vb, &config)?, tokenizer_filename, cache) }; let tokenizer = Tokenizer::from_file(tokenizer_filename)?; - Ok(Llama { + Ok(Self { + cache, device, + dtype, + model, + model_id: load_data.model_id, tokenizer, - llama, - cache, }) } @@ -154,7 +120,11 @@ impl ModelTrait for Llama { .to_vec(); let mut tokenizer = TokenOutputStream::new(self.tokenizer.clone()); - let mut logits_processor = LogitsProcessor::new(input.seed, input.temperature, input.top_p); + let mut logits_processor = LogitsProcessor::new( + input.random_seed, + Some(input.temperature), + Some(input.top_p), + ); let mut index_pos = 0; let mut res = String::new(); for index in 0..input.sample_len { diff --git a/atoma-inference/src/models/candle/mamba.rs b/atoma-inference/src/models/candle/mamba.rs index d491ec58..bf0dd23e 100644 --- a/atoma-inference/src/models/candle/mamba.rs +++ b/atoma-inference/src/models/candle/mamba.rs @@ -7,7 +7,7 @@ use candle_transformers::{ models::mamba::{Config, Model, State}, utils::apply_repeat_penalty, }; -use serde::Deserialize; +use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; use tokenizers::Tokenizer; use tracing::info; @@ -16,7 +16,7 @@ use crate::{ models::{ candle::device, token_output_stream::TokenOutputStream, - types::{PrecisionBits, TextModelInput}, + types::{LlmFetchData, LlmLoadData, TextModelInput}, ModelError, ModelId, ModelTrait, }, }; @@ -50,22 +50,42 @@ impl MambaModel { } } -#[derive(Debug, Deserialize)] -pub struct Load { - pub precision: PrecisionBits, -} - impl ModelTrait for MambaModel { - type Fetch = (); + type FetchData = LlmFetchData; type Input = TextModelInput; type Output = String; - type Load = Load; + type LoadData = LlmLoadData; + + fn fetch(fetch_data: &Self::FetchData) -> Result, ModelError> { + let api_key = fetch_data.api_key; + let cache_dir = fetch_data.cache_dir; + + let api = ApiBuilder::new() + .with_progress(true) + .with_token(Some(api_key)) + .with_cache_dir(cache_dir) + .build()?; + + let repo = api.repo(Repo::with_revision( + fetch_data.model_id.clone(), + RepoType::Model, + fetch_data.revision, + )); + + let config_file_path = repo.get("config.json")?; + let tokenizer_file_path = api + .model("EleutherAI/gpt-neox-20b".to_string()) + .get("tokenizer.json")?; + let model_weights_file_path = repo.get("model.safetensors")?; + + Ok(vec![ + config_file_path, + tokenizer_file_path, + model_weights_file_path, + ]) + } - fn load( - filenames: Vec, - params: Self::Load, - device_id: usize, - ) -> Result + fn load(load_data: Self::LoadData) -> Result where Self: Sized, { @@ -73,17 +93,17 @@ impl ModelTrait for MambaModel { let start = Instant::now(); - let config_filename = filenames[0].clone(); - let tokenizer_filename = filenames[1].clone(); - let weights_filenames = filenames[2..].to_vec(); + let config_filename = load_data.file_paths[0].clone(); + let tokenizer_filename = load_data.file_paths[1].clone(); + let weights_filenames = load_data.file_paths[2..].to_vec(); let tokenizer = Tokenizer::from_file(tokenizer_filename)?; let config: Config = serde_json::from_slice(&std::fs::read(config_filename).map_err(ModelError::IoError)?) .map_err(ModelError::DeserializeError)?; - let device = device(device_id)?; - let dtype = params.precision.into_dtype(); + let device = device(load_data.device_id)?; + let dtype = load_data.dtype; info!("Loading model weights.."); let var_builder = diff --git a/atoma-inference/src/models/candle/mod.rs b/atoma-inference/src/models/candle/mod.rs index 07bddc9c..6f5d3b11 100644 --- a/atoma-inference/src/models/candle/mod.rs +++ b/atoma-inference/src/models/candle/mod.rs @@ -31,7 +31,7 @@ pub fn device(device_id: usize) -> Result { pub fn hub_load_safetensors( repo: &hf_hub::api::sync::ApiRepo, json_file: &str, -) -> Result, ModelError> { +) -> Result, ModelError> { let json_file = repo.get(json_file)?; let json_file = std::fs::File::open(json_file)?; let json: serde_json::Value = serde_json::from_reader(&json_file)?; diff --git a/atoma-inference/src/models/config.rs b/atoma-inference/src/models/config.rs index 22c404c1..f2ecabc6 100644 --- a/atoma-inference/src/models/config.rs +++ b/atoma-inference/src/models/config.rs @@ -10,10 +10,11 @@ type Revision = String; #[derive(Clone, Debug, Deserialize, Serialize)] pub struct ModelConfig { + device_id: usize, model_id: ModelId, params: serde_json::Value, revision: Revision, - device_id: usize, + use_flash_attention: bool, } impl ModelConfig { @@ -22,12 +23,14 @@ impl ModelConfig { params: serde_json::Value, revision: Revision, device_id: usize, + use_flash_attention: bool, ) -> Self { Self { model_id, params, revision, device_id, + use_flash_attention } } @@ -46,6 +49,10 @@ impl ModelConfig { pub fn device_id(&self) -> usize { self.device_id } + + pub fn use_flash_attention(&self) -> bool { + self.use_flash_attention + } } #[derive(Debug, Deserialize, Serialize)] @@ -153,6 +160,7 @@ pub mod tests { serde_json::to_value(PrecisionBits::F16).unwrap(), "".to_string(), 0, + true )], "storage_path".parse().unwrap(), true, diff --git a/atoma-inference/src/models/mod.rs b/atoma-inference/src/models/mod.rs index 5ccccc91..4e7bf7fe 100644 --- a/atoma-inference/src/models/mod.rs +++ b/atoma-inference/src/models/mod.rs @@ -13,19 +13,13 @@ pub mod types; pub type ModelId = String; pub trait ModelTrait { - type Fetch; + type FetchData; type Input: DeserializeOwned; type Output: Serialize; - type Load: DeserializeOwned; + type LoadData: DeserializeOwned; - fn fetch(_fetch: &Self::Fetch) -> Result<(), ModelError> { - Ok(()) - } - fn load( - filenames: Vec, - params: Self::Load, - device_id: usize, - ) -> Result + fn fetch(fetch_data: &Self::FetchData) -> Result, ModelError>; + fn load(load_data: Self::LoadData) -> Result where Self: Sized; fn model_id(&self) -> ModelId; diff --git a/atoma-inference/src/models/types.rs b/atoma-inference/src/models/types.rs index 1a7f00ee..43a07a7e 100644 --- a/atoma-inference/src/models/types.rs +++ b/atoma-inference/src/models/types.rs @@ -1,4 +1,6 @@ -use candle::DType; +use std::path::PathBuf; + +use candle::{DType, Device}; use ed25519_consensus::VerificationKey as PublicKey; use serde::{Deserialize, Serialize}; @@ -6,6 +8,23 @@ use crate::models::{ModelId, Request, Response}; pub type NodeId = PublicKey; +#[derive(Debug, Deserialize)] +pub struct LlmFetchData { + pub api_key: String, + pub cache_dir: PathBuf, + pub model_id: ModelId, + pub revision: String, +} + +#[derive(Debug, Deserialize)] +pub struct LlmLoadData { + pub device_id: Device, + pub dtype: DType, + pub file_paths: Vec, + pub model_id: ModelId, + pub use_flash_attention: bool, +} + #[derive(Clone, Debug, Deserialize, Serialize)] pub struct TextRequest { pub request_id: usize, diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index 4360538e..a7f97714 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -195,16 +195,16 @@ mod tests { impl ModelTrait for TestModelInstance { type Input = (); type Output = (); - type Fetch = (); - type Load = (); + type FetchData = (); + type LoadData = (); - fn fetch(_fetch: &Self::Fetch) -> Result<(), crate::models::ModelError> { + fn fetch(_fetch: &Self::FetchData) -> Result<(), crate::models::ModelError> { Ok(()) } fn load( _: Vec, - _: Self::Load, + _: Self::LoadData, _device_id: usize, ) -> Result { Ok(Self {}) From d8c0da2412acd2417702e7498183e936d672d01d Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Fri, 5 Apr 2024 18:13:39 +0100 Subject: [PATCH 02/21] refactor models, re-working on stable diffusion --- atoma-inference/src/model_thread.rs | 15 +- atoma-inference/src/models/candle/falcon.rs | 87 ++++++----- atoma-inference/src/models/candle/llama.rs | 62 +++++--- atoma-inference/src/models/candle/mamba.rs | 58 ++++--- .../src/models/candle/stable_diffusion.rs | 114 ++++++-------- atoma-inference/src/models/config.rs | 75 +++++---- atoma-inference/src/models/mod.rs | 20 +-- atoma-inference/src/models/types.rs | 146 ++++++++++++++++-- atoma-inference/src/service.rs | 15 +- 9 files changed, 361 insertions(+), 231 deletions(-) diff --git a/atoma-inference/src/model_thread.rs b/atoma-inference/src/model_thread.rs index 68a3996e..2b5c318e 100644 --- a/atoma-inference/src/model_thread.rs +++ b/atoma-inference/src/model_thread.rs @@ -1,7 +1,4 @@ -use std::{ - collections::HashMap, - sync::{mpsc, Arc}, -}; +use std::{collections::HashMap, sync::mpsc}; use ed25519_consensus::VerificationKey as PublicKey; use futures::stream::FuturesUnordered; @@ -106,16 +103,11 @@ impl ModelThreadDispatcher { F: ApiTrait + Send + Sync + 'static, M: ModelTrait, // + Send + 'static, { - let api_key = config.api_key(); - let storage_path = config.storage_path(); - let api = Arc::new(F::create(api_key, storage_path)?); - let mut handles = Vec::new(); let mut model_senders = HashMap::new(); for model_config in config.models() { info!("Spawning new thread for model: {}", model_config.model_id()); - let api = api.clone(); let (model_sender, model_receiver) = mpsc::channel::(); let model_name = model_config.model_id().clone(); @@ -123,10 +115,9 @@ impl ModelThreadDispatcher { let join_handle = std::thread::spawn(move || { info!("Fetching files for model: {model_name}"); - let filenames = api.fetch(model_name, model_config.revision())?; - let x = serde_json::from_value(model_config.params().clone()).unwrap(); + let load_data = M::fetch(model_config)?; - let model = M::load(filenames, x, model_config.device_id())?; + let model = M::load(load_data)?; let model_thread = ModelThread { model, receiver: model_receiver, diff --git a/atoma-inference/src/models/candle/falcon.rs b/atoma-inference/src/models/candle/falcon.rs index f1a14847..f09086a1 100644 --- a/atoma-inference/src/models/candle/falcon.rs +++ b/atoma-inference/src/models/candle/falcon.rs @@ -1,9 +1,6 @@ -use std::{path::PathBuf, time::Instant}; +use std::{str::FromStr, time::Instant}; -use candle::{ - utils::{cuda_is_available, metal_is_available}, - DType, Device, Tensor, -}; +use candle::{DType, Device, Tensor}; use candle_nn::VarBuilder; use candle_transformers::{ generation::LogitsProcessor, @@ -14,14 +11,20 @@ use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; use tokenizers::Tokenizer; use tracing::{debug, info}; -use crate::models::{types::{LlmFetchData, LlmLoadData, TextModelInput}, ModelError, ModelId, ModelTrait}; +use crate::models::{ + config::ModelConfig, + types::{LlmLoadData, ModelType, TextModelInput}, + ModelError, ModelTrait, +}; + +use super::device; pub struct FalconModel { model: Falcon, device: Device, dtype: DType, + model_type: ModelType, tokenizer: Tokenizer, - which: Which, } impl FalconModel { @@ -30,55 +33,61 @@ impl FalconModel { config: Config, device: Device, dtype: DType, + model_type: ModelType, tokenizer: Tokenizer, ) -> Self { - let which = Which::from_config(&config); Self { model, device, dtype, tokenizer, - which, + model_type, } } } impl ModelTrait for FalconModel { - type FetchData = LlmFetchData; type Input = TextModelInput; type Output = String; type LoadData = LlmLoadData; - fn fetch(fetch_data: &Self::FetchData) -> Result, ModelError> { - let api_key = fetch_data.api_key; - let cache_dir = fetch_data.cache_dir; + fn fetch(config: ModelConfig) -> Result { + let device = device(config.device_id())?; + let dtype = DType::from_str(&config.dtype())?; + + let api_key = config.api_key(); + let cache_dir = config.cache_dir(); let api = ApiBuilder::new() .with_progress(true) .with_token(Some(api_key)) - .with_cache_dir(cache_dir) + .with_cache_dir(cache_dir.into()) .build()?; let repo = api.repo(Repo::with_revision( - fetch_data.model_id.clone(), + config.model_id(), RepoType::Model, - fetch_data.revision, + config.revision(), )); let config_file_path = repo.get("config.json")?; let tokenizer_file_path = repo.get("tokenizer.json")?; let model_weights_file_path = repo.get("model.safetensors")?; - Ok(vec![ - config_file_path, - tokenizer_file_path, - model_weights_file_path, - ]) + Ok(Self::LoadData { + device, + dtype, + file_paths: vec![ + config_file_path, + tokenizer_file_path, + model_weights_file_path, + ], + model_type: ModelType::from_str(&config.model_id())?, + use_flash_attention: config.use_flash_attention(), + }) } - fn load( - load_data: Self::LoadData - ) -> Result + fn load(load_data: Self::LoadData) -> Result where Self: Sized, { @@ -97,28 +106,32 @@ impl ModelTrait for FalconModel { .map_err(ModelError::DeserializeError)?; config.validate()?; - let device = if cuda_is_available() { - Device::new_cuda(load_data.device_id).map_err(ModelError::CandleError)? - } else if metal_is_available() { - Device::new_metal(load_data.device_id).map_err(ModelError::CandleError)? - } else { - Device::Cpu - }; - if load_data.dtype != DType::BF16 || load_data.dtype != DType::F32 { panic!("Invalid dtype, it must be either BF16 or F32 precision"); } - let vb = - unsafe { VarBuilder::from_mmaped_safetensors(&weights_filenames, load_data.dtype, &device)? }; + let vb = unsafe { + VarBuilder::from_mmaped_safetensors( + &weights_filenames, + load_data.dtype, + &load_data.device, + )? + }; let model = Falcon::load(vb, config.clone())?; info!("loaded the model in {:?}", start.elapsed()); - Ok(Self::new(model, config, device, load_data.dtype, tokenizer)) + Ok(Self::new( + model, + config, + load_data.device, + load_data.dtype, + load_data.model_type, + tokenizer, + )) } - fn model_id(&self) -> ModelId { - self.which.model_id().to_string() + fn model_type(&self) -> ModelType { + self.model_type.clone() } fn run(&mut self, input: Self::Input) -> Result { diff --git a/atoma-inference/src/models/candle/llama.rs b/atoma-inference/src/models/candle/llama.rs index 9f0acdea..dc5d17be 100644 --- a/atoma-inference/src/models/candle/llama.rs +++ b/atoma-inference/src/models/candle/llama.rs @@ -4,6 +4,8 @@ extern crate accelerate_src; #[cfg(feature = "mkl")] extern crate intel_mkl_src; +use std::str::FromStr; + use candle::{DType, Device, Tensor}; use candle_nn::VarBuilder; use candle_transformers::{ @@ -16,9 +18,10 @@ use candle_transformers::models::llama as model; use tokenizers::Tokenizer; use crate::models::{ + config::ModelConfig, token_output_stream::TokenOutputStream, - types::{LlmFetchData, LlmLoadData, TextModelInput}, - ModelError, ModelId, ModelTrait, + types::{LlmLoadData, ModelType, TextModelInput}, + ModelError, ModelTrait, }; use super::{device, hub_load_safetensors}; @@ -41,52 +44,61 @@ pub struct LlamaModel { device: Device, dtype: DType, model: model::Llama, - model_id: ModelId, + model_type: ModelType, tokenizer: Tokenizer, } impl ModelTrait for LlamaModel { type Input = TextModelInput; - type FetchData = LlmFetchData; type Output = String; type LoadData = LlmLoadData; - fn fetch(fetch_data: &Self::FetchData) -> Result<(), ModelError> { + fn fetch(config: ModelConfig) -> Result { + let device = device(config.device_id())?; + let dtype = DType::from_str(&config.dtype())?; + let api = ApiBuilder::new() .with_progress(true) - .with_token(Some(fetch_data.api_key)) - .with_cache_dir(fetch_data.cache_dir) + .with_token(Some(config.api_key())) + .with_cache_dir(config.cache_dir()) .build()?; let api = api.repo(Repo::with_revision( - fetch_data.model_id, + config.model_id(), RepoType::Model, - fetch_data.revision, + config.revision(), )); let config_file_path = api.get("tokenizer.json")?; let tokenizer_file_path = api.get("config.json")?; - let model_weights_file_paths = - if &fetch_data.model_id == "TinyLlama/TinyLlama-1.1B-Chat-v1.0" { - vec![api.get("model.safetensors")?] - } else { - hub_load_safetensors(&api, "model.safetensors.index.json")? - }; + let model_weights_file_paths = if &config.model_id() == "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + { + vec![api.get("model.safetensors")?] + } else { + hub_load_safetensors(&api, "model.safetensors.index.json")? + }; - let mut output = Vec::with_capacity(2 + model_weights_file_paths.len()); - output.extend(vec![config_file_path, tokenizer_file_path]); - output.extend(model_weights_file_paths); + let mut file_paths = Vec::with_capacity(2 + model_weights_file_paths.len()); + file_paths.extend(vec![config_file_path, tokenizer_file_path]); + file_paths.extend(model_weights_file_paths); - Ok(output) + Ok(Self::LoadData { + device, + dtype, + file_paths, + model_type: ModelType::from_str(&config.model_id())?, + use_flash_attention: config.use_flash_attention(), + }) } - fn model_id(&self) -> ModelId { - self.model_id.clone() + fn model_type(&self) -> ModelType { + self.model_type.clone() } fn load(load_data: Self::LoadData) -> Result { let device = device(load_data.device_id)?; - let dtype = load_data.dtype; + let dtype = + DType::from_str(&load_data.dtype).map_err(|e| ModelError::Msg(e.to_string()))?; let (model, tokenizer_filename, cache) = { let config_filename = load_data.file_paths[0].clone(); let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?; @@ -106,7 +118,7 @@ impl ModelTrait for LlamaModel { device, dtype, model, - model_id: load_data.model_id, + model_type: load_data.model_type, tokenizer, }) } @@ -127,7 +139,7 @@ impl ModelTrait for LlamaModel { ); let mut index_pos = 0; let mut res = String::new(); - for index in 0..input.sample_len { + for index in 0..input.max_tokens { let (context_size, context_index) = if self.cache.use_kv_cache && index > 0 { (1, index_pos) } else { @@ -136,7 +148,7 @@ impl ModelTrait for LlamaModel { let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let input_tensor = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; let logits = self - .llama + .model .forward(&input_tensor, context_index, &mut self.cache)?; let logits = logits.squeeze(0)?; let logits = if input.repeat_penalty == 1. { diff --git a/atoma-inference/src/models/candle/mamba.rs b/atoma-inference/src/models/candle/mamba.rs index bf0dd23e..f499e1e4 100644 --- a/atoma-inference/src/models/candle/mamba.rs +++ b/atoma-inference/src/models/candle/mamba.rs @@ -1,4 +1,4 @@ -use std::{path::PathBuf, time::Instant}; +use std::{str::FromStr, time::Instant}; use candle::{DType, Device, Tensor}; use candle_nn::VarBuilder; @@ -15,9 +15,10 @@ use crate::{ bail, models::{ candle::device, + config::ModelConfig, token_output_stream::TokenOutputStream, - types::{LlmFetchData, LlmLoadData, TextModelInput}, - ModelError, ModelId, ModelTrait, + types::{LlmLoadData, ModelType, TextModelInput}, + ModelError, ModelTrait, }, }; @@ -51,25 +52,27 @@ impl MambaModel { } impl ModelTrait for MambaModel { - type FetchData = LlmFetchData; type Input = TextModelInput; type Output = String; type LoadData = LlmLoadData; - fn fetch(fetch_data: &Self::FetchData) -> Result, ModelError> { - let api_key = fetch_data.api_key; - let cache_dir = fetch_data.cache_dir; + fn fetch(config: ModelConfig) -> Result { + let api_key = config.api_key(); + let cache_dir = config.cache_dir(); + + let device = device(config.device_id())?; + let dtype = DType::from_str(&config.dtype())?; let api = ApiBuilder::new() .with_progress(true) .with_token(Some(api_key)) - .with_cache_dir(cache_dir) + .with_cache_dir(config.cache_dir().into()) .build()?; let repo = api.repo(Repo::with_revision( - fetch_data.model_id.clone(), + config.model_id(), RepoType::Model, - fetch_data.revision, + config.revision(), )); let config_file_path = repo.get("config.json")?; @@ -78,11 +81,17 @@ impl ModelTrait for MambaModel { .get("tokenizer.json")?; let model_weights_file_path = repo.get("model.safetensors")?; - Ok(vec![ - config_file_path, - tokenizer_file_path, - model_weights_file_path, - ]) + Ok(Self::LoadData { + device, + dtype, + file_paths: vec![ + config_file_path, + tokenizer_file_path, + model_weights_file_path, + ], + model_type: ModelType::from_str(&config.model_id())?, + use_flash_attention: config.use_flash_attention(), + }) } fn load(load_data: Self::LoadData) -> Result @@ -102,20 +111,25 @@ impl ModelTrait for MambaModel { let config: Config = serde_json::from_slice(&std::fs::read(config_filename).map_err(ModelError::IoError)?) .map_err(ModelError::DeserializeError)?; - let device = device(load_data.device_id)?; - let dtype = load_data.dtype; info!("Loading model weights.."); - let var_builder = - unsafe { VarBuilder::from_mmaped_safetensors(&weights_filenames, dtype, &device)? }; + let var_builder = unsafe { + VarBuilder::from_mmaped_safetensors(&weights_filenames, load_data.dtype, &device)? + }; let model = Model::new(&config, var_builder.pp("backbone"))?; info!("Loaded Mamba model in {:?}", start.elapsed()); - Ok(Self::new(model, config, device, dtype, tokenizer)) + Ok(Self::new( + model, + config, + load_data.device, + load_data.dtype, + tokenizer, + )) } - fn model_id(&self) -> ModelId { - self.which.model_id().to_string() + fn model_type(&self) -> ModelType { + self.model_type.clone() } fn run(&mut self, input: Self::Input) -> Result { diff --git a/atoma-inference/src/models/candle/stable_diffusion.rs b/atoma-inference/src/models/candle/stable_diffusion.rs index f3be6ad6..3b9e1d14 100644 --- a/atoma-inference/src/models/candle/stable_diffusion.rs +++ b/atoma-inference/src/models/candle/stable_diffusion.rs @@ -4,13 +4,16 @@ extern crate accelerate_src; #[cfg(feature = "mkl")] extern crate intel_mkl_src; +use std::path::PathBuf; + use candle_transformers::models::stable_diffusion::{self}; use candle::{DType, Device, IndexOp, Module, Tensor, D}; +use hf_hub::api::sync::ApiBuilder; use serde::Deserialize; use tokenizers::Tokenizer; -use crate::models::{types::PrecisionBits, ModelError, ModelId, ModelTrait}; +use crate::models::{config::ModelConfig, types::{ModelType, PrecisionBits}, ModelError, ModelId, ModelTrait}; use super::{convert_to_image, device, save_tensor_to_file}; @@ -99,46 +102,30 @@ impl From<&Input> for Fetch { } } } -pub struct StableDiffusion { - device_id: usize, -} - -pub struct Fetch { - tokenizer: Option, - sd_version: StableDiffusionVersion, - use_f16: bool, - clip_weights: Option, - vae_weights: Option, - unet_weights: Option, -} -#[derive(Debug, Deserialize)] -pub struct Load { - pub filenames: Vec, - pub precision: PrecisionBits, - pub device_id: usize, +pub struct StableDiffusion { + device: Device, + dtype: DType, } impl ModelTrait for StableDiffusion { type Input = Input; - type Fetch = Fetch; type Output = Vec<(Vec, usize, usize)>; - type Load = Load; + type LoadData = Self; fn load( - _filenames: Vec, - _precision: Self::Load, - device_id: usize, + load_data: Self::LoadData ) -> Result where Self: Sized, { - Ok(Self { device_id }) + Ok(load_data) } - fn fetch(fetch: &Self::Fetch) -> Result<(), ModelError> { - let which = match fetch.sd_version { - StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => vec![true, false], + fn fetch(config: ModelConfig) -> Result { + let model_type = ModelType::from_str(&config.model_id())?; + let which = match model_type { + ModelType::StableDiffusionXl | ModelType::StableDiffusionTurbo => vec![true, false], _ => vec![true], }; for first in which { @@ -148,16 +135,20 @@ impl ModelTrait for StableDiffusion { (ModelFile::Clip2, ModelFile::Tokenizer2) }; - clip_weights_file.get(fetch.clip_weights.clone(), fetch.sd_version, false)?; - ModelFile::Vae.get(fetch.vae_weights.clone(), fetch.sd_version, fetch.use_f16)?; - tokenizer_file.get(fetch.tokenizer.clone(), fetch.sd_version, fetch.use_f16)?; - ModelFile::Unet.get(fetch.unet_weights.clone(), fetch.sd_version, fetch.use_f16)?; + let api_key = config.api_key(); + let cache_dir = config.cache_dir().into(); + let use_f16 = config.dtype() == "f16"; + + clip_weights_file.get(api_key, cache_dir,model_type, false)?; + ModelFile::Vae.get(api_key, cache_dir, model_type, use_f16)?; + tokenizer_file.get(api_key, cache_dir, model_type, use_f16)?; + ModelFile::Unet.get(api_key, cache_dir, model_type, use_f16)?; } - Ok(()) + Ok(Load) } - fn model_id(&self) -> ModelId { - "candle/stable_diffusion".to_string() + fn model_type(&self) -> ModelType { + self.model_type } fn run(&mut self, input: Self::Input) -> Result { @@ -358,61 +349,56 @@ enum ModelFile { Vae, } -impl StableDiffusionVersion { - fn repo(&self) -> &'static str { - match self { - Self::Xl => "stabilityai/stable-diffusion-xl-base-1.0", - Self::V2_1 => "stabilityai/stable-diffusion-2-1", - Self::V1_5 => "runwayml/stable-diffusion-v1-5", - Self::Turbo => "stabilityai/sdxl-turbo", - } - } - +impl ModelType { fn unet_file(&self, use_f16: bool) -> &'static str { match self { - Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { + Self::StableDiffusionV1_5 | Self::StableDiffusionV2_1 | Self::StableDiffusionXl | Self::StableDiffusionTurbo => { if use_f16 { "unet/diffusion_pytorch_model.fp16.safetensors" } else { "unet/diffusion_pytorch_model.safetensors" } } + _ => panic!("Invalid stable diffusion model type") } } fn vae_file(&self, use_f16: bool) -> &'static str { match self { - Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { + Self::StableDiffusionV1_5 | Self::StableDiffusionV2_1 | Self::StableDiffusionXl | Self::StableDiffusionTurbo => { if use_f16 { "vae/diffusion_pytorch_model.fp16.safetensors" } else { "vae/diffusion_pytorch_model.safetensors" } } + _ => panic!("Invalid stable diffusion model type") } } fn clip_file(&self, use_f16: bool) -> &'static str { match self { - Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { + Self::StableDiffusionV1_5 | Self::StableDiffusionV2_1 | Self::StableDiffusionXl | Self::StableDiffusionTurbo => { if use_f16 { "text_encoder/model.fp16.safetensors" } else { "text_encoder/model.safetensors" } } + _ => panic!("Invalid stable diffusion model type") } } fn clip2_file(&self, use_f16: bool) -> &'static str { match self { - Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { + Self::StableDiffusionV1_5 | Self::StableDiffusionV2_1 | Self::StableDiffusionXl | Self::StableDiffusionTurbo => { if use_f16 { "text_encoder_2/model.fp16.safetensors" } else { "text_encoder_2/model.safetensors" } } + _ => panic!("Invalid stable diffusion model type") } } } @@ -420,17 +406,15 @@ impl StableDiffusionVersion { impl ModelFile { fn get( &self, - filename: Option, - version: StableDiffusionVersion, + api_key: String, + cache_dir: PathBuf, + + model_type: ModelType, use_f16: bool, ) -> Result { - use hf_hub::api::sync::Api; - match filename { - Some(filename) => Ok(std::path::PathBuf::from(filename)), - None => { let (repo, path) = match self { Self::Tokenizer => { - let tokenizer_repo = match version { + let tokenizer_repo = match model_type { StableDiffusionVersion::V1_5 | StableDiffusionVersion::V2_1 => { "openai/clip-vit-base-patch32" } @@ -445,15 +429,15 @@ impl ModelFile { Self::Tokenizer2 => { ("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", "tokenizer.json") } - Self::Clip => (version.repo(), version.clip_file(use_f16)), - Self::Clip2 => (version.repo(), version.clip2_file(use_f16)), - Self::Unet => (version.repo(), version.unet_file(use_f16)), + Self::Clip => (model_type.repo(), model_type.clip_file(use_f16)), + Self::Clip2 => (model_type.repo(), model_type.clip2_file(use_f16)), + Self::Unet => (model_type.repo(), model_type.unet_file(use_f16)), Self::Vae => { // Override for SDXL when using f16 weights. // See https://github.com/huggingface/candle/issues/1060 if matches!( - version, - StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo, + model_type, + ModelType::StableDiffusionXl | ModelType::StableDiffusionTurbo, ) && use_f16 { ( @@ -461,14 +445,16 @@ impl ModelFile { "diffusion_pytorch_model.safetensors", ) } else { - (version.repo(), version.vae_file(use_f16)) + (model_type.repo(), model_type.vae_file(use_f16)) } } }; - let filename = Api::new()?.model(repo.to_string()).get(path)?; + let filename = ApiBuilder::new() + .with_progress(true) + .with_token(Some(api_key)) + .with_cache_dir(cache_dir) + .build()?; Ok(filename) - } - } } } diff --git a/atoma-inference/src/models/config.rs b/atoma-inference/src/models/config.rs index f2ecabc6..282efe9e 100644 --- a/atoma-inference/src/models/config.rs +++ b/atoma-inference/src/models/config.rs @@ -10,36 +10,53 @@ type Revision = String; #[derive(Clone, Debug, Deserialize, Serialize)] pub struct ModelConfig { + api_key: String, + cache_dir: String, device_id: usize, + dtype: String, model_id: ModelId, - params: serde_json::Value, revision: Revision, use_flash_attention: bool, + sliced_attention_size: Option, } impl ModelConfig { pub fn new( + api_key: String, + cache_dir: String, model_id: ModelId, - params: serde_json::Value, + dtype: String, revision: Revision, device_id: usize, use_flash_attention: bool, + sliced_attention_size: Option, ) -> Self { Self { + api_key, + cache_dir, + dtype, model_id, - params, revision, device_id, - use_flash_attention + use_flash_attention, + sliced_attention_size } } - pub fn model_id(&self) -> &ModelId { - &self.model_id + pub fn api_key(&self) -> String { + self.api_key.clone() + } + + pub fn cache_dir(&self) -> String { + self.cache_dir.clone() } - pub fn params(&self) -> &serde_json::Value { - &self.params + pub fn dtype(&self) -> String { + self.dtype.clone() + } + + pub fn model_id(&self) -> ModelId { + self.model_id.clone() } pub fn revision(&self) -> Revision { @@ -53,38 +70,28 @@ impl ModelConfig { pub fn use_flash_attention(&self) -> bool { self.use_flash_attention } + + pub fn sliced_attention_size(&self) -> Option { + self.sliced_attention_size + } } #[derive(Debug, Deserialize, Serialize)] pub struct ModelsConfig { - api_key: String, flush_storage: bool, models: Vec, - storage_path: PathBuf, tracing: bool, } impl ModelsConfig { - pub fn new( - api_key: String, - flush_storage: bool, - models: Vec, - storage_path: PathBuf, - tracing: bool, - ) -> Self { + pub fn new(flush_storage: bool, models: Vec, tracing: bool) -> Self { Self { - api_key, flush_storage, models, - storage_path, tracing, } } - pub fn api_key(&self) -> String { - self.api_key.clone() - } - pub fn flush_storage(&self) -> bool { self.flush_storage } @@ -93,10 +100,6 @@ impl ModelsConfig { self.models.clone() } - pub fn storage_path(&self) -> PathBuf { - self.storage_path.clone() - } - pub fn tracing(&self) -> bool { self.tracing } @@ -116,7 +119,6 @@ impl ModelsConfig { pub fn from_env_file() -> Self { dotenv().ok(); - let api_key = std::env::var("API_KEY").expect("Failed to retrieve api key, from .env file"); let flush_storage = std::env::var("FLUSH_STORAGE") .unwrap_or_default() .parse() @@ -125,20 +127,14 @@ impl ModelsConfig { &std::env::var("MODELS").expect("Failed to retrieve models metadata, from .env file"), ) .unwrap(); - let storage_path = std::env::var("STORAGE_PATH") - .expect("Failed to retrieve storage path, from .env file") - .parse() - .unwrap(); let tracing = std::env::var("TRACING") .unwrap_or_default() .parse() .unwrap(); Self { - api_key, flush_storage, models, - storage_path, tracing, } } @@ -146,23 +142,22 @@ impl ModelsConfig { #[cfg(test)] pub mod tests { - use crate::models::types::PrecisionBits; - use super::*; #[test] fn test_config() { let config = ModelsConfig::new( - String::from("my_key"), true, vec![ModelConfig::new( + "my_key".to_string(), + "/".to_string(), + "F16".to_string(), "Llama2_7b".to_string(), - serde_json::to_value(PrecisionBits::F16).unwrap(), "".to_string(), 0, - true + true, + Some(0) )], - "storage_path".parse().unwrap(), true, ); diff --git a/atoma-inference/src/models/mod.rs b/atoma-inference/src/models/mod.rs index 4e7bf7fe..7ce61b37 100644 --- a/atoma-inference/src/models/mod.rs +++ b/atoma-inference/src/models/mod.rs @@ -1,10 +1,9 @@ -use std::path::PathBuf; - use ::candle::Error as CandleError; use ed25519_consensus::VerificationKey as PublicKey; -use serde::{de::DeserializeOwned, Serialize}; use thiserror::Error; +use self::{config::ModelConfig, types::ModelType}; + pub mod candle; pub mod config; pub mod token_output_stream; @@ -13,16 +12,15 @@ pub mod types; pub type ModelId = String; pub trait ModelTrait { - type FetchData; - type Input: DeserializeOwned; - type Output: Serialize; - type LoadData: DeserializeOwned; + type Input; + type Output; + type LoadData; - fn fetch(fetch_data: &Self::FetchData) -> Result, ModelError>; + fn fetch(config: ModelConfig) -> Result; fn load(load_data: Self::LoadData) -> Result where Self: Sized; - fn model_id(&self) -> ModelId; + fn model_type(&self) -> ModelType; fn run(&mut self, input: Self::Input) -> Result; } @@ -59,6 +57,10 @@ pub enum ModelError { BoxedError(#[from] Box), #[error("ApiError error: `{0}`")] ApiError(#[from] hf_hub::api::sync::ApiError), + #[error("DTypeParseError: `{0}`")] + DTypeParseError(#[from] DTypeParseError), + #[error("Invalid model type: `{0}`")] + InvalidModelType(String), } #[macro_export] diff --git a/atoma-inference/src/models/types.rs b/atoma-inference/src/models/types.rs index 43a07a7e..363326b5 100644 --- a/atoma-inference/src/models/types.rs +++ b/atoma-inference/src/models/types.rs @@ -1,4 +1,4 @@ -use std::path::PathBuf; +use std::{path::PathBuf, str::FromStr}; use candle::{DType, Device}; use ed25519_consensus::VerificationKey as PublicKey; @@ -6,25 +6,147 @@ use serde::{Deserialize, Serialize}; use crate::models::{ModelId, Request, Response}; -pub type NodeId = PublicKey; +use super::ModelError; -#[derive(Debug, Deserialize)] -pub struct LlmFetchData { - pub api_key: String, - pub cache_dir: PathBuf, - pub model_id: ModelId, - pub revision: String, -} +pub type NodeId = PublicKey; -#[derive(Debug, Deserialize)] +#[derive(Debug)] pub struct LlmLoadData { - pub device_id: Device, + pub device: Device, pub dtype: DType, pub file_paths: Vec, - pub model_id: ModelId, + pub model_type: ModelType, pub use_flash_attention: bool, } +#[derive(Clone, Debug, Deserialize)] +pub enum ModelType { + Falcon7b, + Falcon40b, + Falcon180b, + LlamaV1, + LlamaV2, + LlamaSolar10_7B, + LlamaTinyLlama1_1BChat, + Mamba130m, + Mamba370m, + Mamba790m, + Mamba1_4b, + Mamba2_8b, + Mistral7b, + Mixtral8x7b, + StableDiffusionV1_5, + StableDiffusionV2_1, + StableDiffusionXl, + StableDiffusionTurbo, +} + +impl FromStr for ModelType { + type Err = ModelError; + + fn from_str(s: &str) -> Result { + match s { + "falcon_7b" => Ok(Self::Falcon7b), + "falcon_40b" => Ok(Self::Falcon40b), + "falcon_180b" => Ok(Self::Falcon180b0), + "llama_v1" => Ok(Self::LlamaV1), + "llama_v2" => Ok(Self::LlamaV2), + "llama_solar_10_7b" => Ok(Self::LlamaSolar10_7B), + "llama_tiny_llama_1_1b_chat" => Ok(Self::LlamaTinyLlama1_1BChat), + "mamba_130m" => Ok(Self::Mamba130m), + "mamba_370m" => Ok(Self::Mamba370m), + "mamba_790m" => Ok(Self::Mamba790m), + "mamba_1-4b" => Ok(Self::Mamba1_4b), + "mamba_2-8b" => Ok(Self::Mamba2_8b), + "mistral_7b" => Ok(Self::Mistral7b), + "mixtral_8x7b" => Ok(Self::Mixtral8x7b), + "stable_diffusion_v1-5" => Ok(Self::StableDiffusionV1_5), + "stable_diffusion_v2-1" => Ok(Self::StableDiffusionV2_1), + "stable_diffusion_xl" => Ok(Self::StableDiffusionXl), + "stable_diffusion_turbo" => Ok(Self::StableDiffusionTurbo), + _ => { + return Err(ModelError::InvalidModelType(format!( + "Invalid string model type descryption" + ))) + } + } + } +} + +impl ModelType { + pub fn repo(&self) -> &'static str { + match self { + Self::Falcon7b => "tiiuae/falcon-7b", + Self::Falcon40b => "tiiuae/falcon-40b", + Self::Falcon180b => "tiiuae/falcon-180b", + Self::LlamaV1 => "Narsil/amall-7b", + Self::LlamaV2 => "meta-llama/Llama-2-7b-hf", + Self::LlamaSolar10_7B => "upstage/SOLAR-10.7B-v1.0", + Self::LlamaTinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + Self::Mamba130m => "state-spaces/mamba-130m", + Self::Mamba370m => "state-spaces/mamba-370m", + Self::Mamba790m => "state-spaces/mamba-790m", + Self::Mamba1_4b => "state-spaces/mamba-1.4b", + Self::Mamba2_8b => "state-spaces/mamba-2.8b", + Self::Mistral7b => "TODO", + Self::Mixtral8x7b => "TODO", + Self::StableDiffusionV1_5 => "runwayml/stable-diffusion-v1-5", + Self::StableDiffusionV2_1 => "stabilityai/stable-diffusion-2-1", + Self::StableDiffusionV1_5 => "runwayml/stable-diffusion-v1-5", + Self::StableDiffusionXl => "stabilityai/stable-diffusion-xl-base-1.0", + Self::StableDiffusionTurbo => "stabilityai/sdxl-turbo", + } + } + + pub fn default_revision(&self) -> &'static str { + match self { + Self::Falcon7b => "refs/pr/43", + Self::Falcon40b => "refs/pr/43", + Self::Falcon180b => "refs/pr/43", + Self::LlamaV1 => "main", + Self::LlamaV2 => "main", + Self::LlamaSolar10_7B => "main", + Self::LlamaTinyLlama1_1BChat => "main", + Self::Mamba130m => "refs/pr/1", + Self::Mamba370m => "refs/pr/1", + Self::Mamba790m => "refs/pr/1", + Self::Mamba1_4b => "refs/pr/1", + Self::Mamba2_8b => "refs/pr/4", + Self::Mistral7b => "TODO", + Self::Mixtral8x7b => "TODO", + Self::StableDiffusionV1_5 => "", + Self::StableDiffusionV2_1 => "", + Self::StableDiffusionTurbo => "", + Self::StableDiffusionXl => "", + } + } +} + +impl ToString for ModelType { + fn to_string(&self) -> String { + match self { + Self::Falcon7b => "falcon_7b".to_string(), + Self::Falcon40b => "falcon_40b".to_string(), + Self::Falcon180b => "falcon_180b".to_string(), + Self::LlamaV1 => "llama_v1".to_string(), + Self::LlamaV2 => "llama_v2".to_string(), + Self::LlamaSolar10_7B => "llama_solar_10_7b".to_string(), + Self::LlamaTinyLlama1_1BChat => "llama_tiny_llama_1_1b_chat".to_string(), + Self::Mamba130m => "mamba_130m".to_string(), + Self::Mamba370m => "mamba_370m".to_string(), + Self::Mamba790m => "mamba_790m".to_string(), + Self::Mamba1_4b => "mamba_1-4b".to_string(), + Self::Mamba2_8b => "mamba_2-8b".to_string(), + Self::Mistral7b => "mistral_7b".to_string(), + Self::Mixtral8x7b => "mixtral_8x7b".to_string(), + Self::StableDiffusionV1_5 => "stable_diffusion_v1-5".to_string(), + Self::StableDiffusionV2_1 => "stable_diffusion_v2-1".to_string(), + Self::StableDiffusionXl => "stable_diffusion_xl".to_string(), + Self::StableDiffusionTurbo => "stable_diffusion_turbo".to_string(), + } + } +} + #[derive(Clone, Debug, Deserialize, Serialize)] pub struct TextRequest { pub request_id: usize, diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index a7f97714..60deea54 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -146,7 +146,7 @@ mod tests { use std::io::Write; use toml::{toml, Value}; - use crate::models::{ModelId, Request, Response}; + use crate::models::{config::ModelConfig, ModelId, Request, Response}; use super::*; @@ -195,23 +195,18 @@ mod tests { impl ModelTrait for TestModelInstance { type Input = (); type Output = (); - type FetchData = (); type LoadData = (); - fn fetch(_fetch: &Self::FetchData) -> Result<(), crate::models::ModelError> { + fn fetch(config: ModelConfig) -> Result<(), crate::models::ModelError> { Ok(()) } - fn load( - _: Vec, - _: Self::LoadData, - _device_id: usize, - ) -> Result { + fn load(_: Self::LoadData) -> Result { Ok(Self {}) } - fn model_id(&self) -> crate::models::ModelId { - String::from("") + fn model_type(&self) -> crate::models::types::ModelType { + crate::models::types::ModelType::LlamaV1 } fn run(&mut self, _: Self::Input) -> Result { From 48d561e640f0c987e9a201b49fc65ccd9e1696a7 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Fri, 5 Apr 2024 19:39:29 +0100 Subject: [PATCH 03/21] refactor stable diffusion model implementation --- atoma-inference/src/models/candle/mamba.rs | 7 +- .../src/models/candle/stable_diffusion.rs | 420 +++++++++++------- atoma-inference/src/models/config.rs | 6 +- atoma-inference/src/models/mod.rs | 2 +- 4 files changed, 257 insertions(+), 178 deletions(-) diff --git a/atoma-inference/src/models/candle/mamba.rs b/atoma-inference/src/models/candle/mamba.rs index f499e1e4..e661b106 100644 --- a/atoma-inference/src/models/candle/mamba.rs +++ b/atoma-inference/src/models/candle/mamba.rs @@ -27,8 +27,8 @@ pub struct MambaModel { config: Config, device: Device, dtype: DType, + model_type: ModelType, tokenizer: TokenOutputStream, - which: Which, } impl MambaModel { @@ -37,16 +37,16 @@ impl MambaModel { config: Config, device: Device, dtype: DType, + model_type: ModelType, tokenizer: Tokenizer, ) -> Self { - let which = Which::from_config(&config); Self { model, config, device, dtype, + model_type, tokenizer: TokenOutputStream::new(tokenizer), - which, } } } @@ -124,6 +124,7 @@ impl ModelTrait for MambaModel { config, load_data.device, load_data.dtype, + load_data.model_type, tokenizer, )) } diff --git a/atoma-inference/src/models/candle/stable_diffusion.rs b/atoma-inference/src/models/candle/stable_diffusion.rs index 3b9e1d14..43275020 100644 --- a/atoma-inference/src/models/candle/stable_diffusion.rs +++ b/atoma-inference/src/models/candle/stable_diffusion.rs @@ -4,16 +4,24 @@ extern crate accelerate_src; #[cfg(feature = "mkl")] extern crate intel_mkl_src; -use std::path::PathBuf; +use std::{path::PathBuf, str::FromStr}; -use candle_transformers::models::stable_diffusion::{self}; +use candle_transformers::models::{ + clip::text_model::ClipTextTransformer, + stable_diffusion::{ + self, unet_2d::UNet2DConditionModel, vae::AutoEncoderKL, StableDiffusionConfig, + }, +}; use candle::{DType, Device, IndexOp, Module, Tensor, D}; use hf_hub::api::sync::ApiBuilder; use serde::Deserialize; use tokenizers::Tokenizer; -use crate::models::{config::ModelConfig, types::{ModelType, PrecisionBits}, ModelError, ModelId, ModelTrait}; +use crate::{ + bail, + models::{config::ModelConfig, types::ModelType, ModelError, ModelTrait}, +}; use super::{convert_to_image, device, save_tensor_to_file}; @@ -90,44 +98,56 @@ impl Input { } } -impl From<&Input> for Fetch { - fn from(input: &Input) -> Self { - Self { - tokenizer: input.tokenizer.clone(), - sd_version: input.sd_version, - use_f16: input.use_f16, - clip_weights: input.clip_weights.clone(), - vae_weights: input.vae_weights.clone(), - unet_weights: input.unet_weights.clone(), - } - } +pub struct StableDiffusionLoadData { + device: Device, + dtype: DType, + model_type: ModelType, + sliced_attention_size: Option, + clip_weights_file_paths: Vec, + tokenizer_file_paths: Vec, + vae_weights_file_path: PathBuf, + unet_weights_file_path: PathBuf, + use_flash_attention: bool, } pub struct StableDiffusion { + config: StableDiffusionConfig, device: Device, dtype: DType, + model_type: ModelType, + text_model: ClipTextTransformer, + text_model_2: Option, + tokenizer: Tokenizer, + tokenizer_2: Option, + unet: UNet2DConditionModel, + vae: AutoEncoderKL, } impl ModelTrait for StableDiffusion { type Input = Input; type Output = Vec<(Vec, usize, usize)>; - type LoadData = Self; - - fn load( - load_data: Self::LoadData - ) -> Result - where - Self: Sized, - { - Ok(load_data) - } + type LoadData = StableDiffusionLoadData; fn fetch(config: ModelConfig) -> Result { + let device = device(config.device_id())?; + let dtype = DType::from_str(&config.dtype())?; let model_type = ModelType::from_str(&config.model_id())?; let which = match model_type { ModelType::StableDiffusionXl | ModelType::StableDiffusionTurbo => vec![true, false], _ => vec![true], }; + + let api_key = config.api_key(); + let cache_dir = config.cache_dir().into(); + let use_f16 = config.dtype() == "f16"; + + let vae_weights_file_path = ModelFile::Vae.get(api_key, cache_dir, model_type, use_f16)?; + let unet_weights_file_path = + ModelFile::Unet.get(api_key, cache_dir, model_type, use_f16)?; + + let mut clip_weights_file_paths = vec![]; + let mut tokenizer_file_paths = vec![]; + for first in which { let (clip_weights_file, tokenizer_file) = if first { (ModelFile::Clip, ModelFile::Tokenizer) @@ -135,16 +155,104 @@ impl ModelTrait for StableDiffusion { (ModelFile::Clip2, ModelFile::Tokenizer2) }; - let api_key = config.api_key(); - let cache_dir = config.cache_dir().into(); - let use_f16 = config.dtype() == "f16"; + let clip_weights_file_path = + clip_weights_file.get(api_key, cache_dir, model_type, false)?; + let tokenizer_file_path = + tokenizer_file.get(api_key, cache_dir, model_type, use_f16)?; - clip_weights_file.get(api_key, cache_dir,model_type, false)?; - ModelFile::Vae.get(api_key, cache_dir, model_type, use_f16)?; - tokenizer_file.get(api_key, cache_dir, model_type, use_f16)?; - ModelFile::Unet.get(api_key, cache_dir, model_type, use_f16)?; + clip_weights_file_paths.push(clip_weights_file_path); + tokenizer_file_paths.push(tokenizer_file_path); } - Ok(Load) + + Ok(Self::LoadData { + device, + dtype, + model_type, + sliced_attention_size: config.sliced_attention_size(), + clip_weights_file_paths, + tokenizer_file_paths, + vae_weights_file_path, + unet_weights_file_path, + use_flash_attention: config.use_flash_attention(), + }) + } + + fn load(load_data: Self::LoadData) -> Result + where + Self: Sized, + { + let sliced_attention_size = load_data.sliced_attention_size; + let config = match load_data.model_type { + ModelType::StableDiffusionV1_5 => { + stable_diffusion::StableDiffusionConfig::v1_5(sliced_attention_size, None, None) + } + ModelType::StableDiffusionV2_1 => { + stable_diffusion::StableDiffusionConfig::v2_1(sliced_attention_size, None, None) + } + ModelType::StableDiffusionXl => { + stable_diffusion::StableDiffusionConfig::sdxl(sliced_attention_size, None, None) + } + ModelType::StableDiffusionTurbo => stable_diffusion::StableDiffusionConfig::sdxl_turbo( + sliced_attention_size, + None, + None, + ), + _ => bail!("Invalid stable diffusion model type"), + }; + + let (tokenizer, tokenizer_2) = match load_data.model_type { + ModelType::StableDiffusionXl | ModelType::StableDiffusionTurbo => ( + Tokenizer::from_file(load_data.tokenizer_file_paths[0]), + Some(Tokenizer::from_file(load_data.tokenizer_file_paths[1])), + ), + _ => ( + Tokenizer::from_file(load_data.tokenizer_file_paths[0]), + None, + ), // INTEGRITY: we have checked previously if the model type is valid for the family of stable diffusion models + }; + + let text_model = stable_diffusion::build_clip_transformer( + &config.clip, + load_data.clip_weights_file_paths, + &load_data.device, + load_data.dtype, + )?; + let text_model_2 = if let Some(clip_config_2) = config.clip2 { + Some(stable_diffusion::build_clip_transformer( + &clip_config_2, + load_data.clip_weights_file_paths, + &load_data.device, + load_data.dtype, + )?) + } else { + None + }; + + let vae = config.build_vae( + load_data.vae_weights_file_path, + &load_data.device, + load_data.dtype, + )?; + let unet = config.build_unet( + load_data.unet_weights_file_path, + &load_data.device, + 4, // see https://github.com/huggingface/candle/blob/main/candle-examples/examples/stable-diffusion/main.rs#L492 + load_data.use_flash_attention, + load_data.dtype, + ); + + Ok(Self { + config, + device: load_data.device, + dtype: load_data.dtype, + model_type: load_data.model_type, + tokenizer, + tokenizer_2, + text_model, + text_model_2, + vae, + unet, + }) } fn model_type(&self) -> ModelType { @@ -159,76 +267,59 @@ impl ModelTrait for StableDiffusion { )))? } + // self.config.height = input.height; + // self.config.width = input.width; + let guidance_scale = match input.guidance_scale { Some(guidance_scale) => guidance_scale, - None => match input.sd_version { - StableDiffusionVersion::V1_5 - | StableDiffusionVersion::V2_1 - | StableDiffusionVersion::Xl => 7.5, - StableDiffusionVersion::Turbo => 0., + None => match self.model_type { + ModelType::StableDiffusionV1_5 + | ModelType::StableDiffusionV2_1 + | ModelType::StableDiffusionXl => 7.5, + ModelType::StableDiffusionTurbo => 0., + _ => bail!("Invalid stable diffusion model type"), }, }; let n_steps = match input.n_steps { Some(n_steps) => n_steps, - None => match input.sd_version { - StableDiffusionVersion::V1_5 - | StableDiffusionVersion::V2_1 - | StableDiffusionVersion::Xl => 30, - StableDiffusionVersion::Turbo => 1, + None => match self.model_type { + ModelType::StableDiffusionV1_5 + | ModelType::StableDiffusionV2_1 + | ModelType::StableDiffusionXl => 30, + ModelType::StableDiffusionTurbo => 1, + _ => bail!("Invalid stable diffusion model type"), }, }; - let dtype = if input.use_f16 { - DType::F16 - } else { - DType::F32 - }; - let sd_config = match input.sd_version { - StableDiffusionVersion::V1_5 => stable_diffusion::StableDiffusionConfig::v1_5( - input.sliced_attention_size, - input.height, - input.width, - ), - StableDiffusionVersion::V2_1 => stable_diffusion::StableDiffusionConfig::v2_1( - input.sliced_attention_size, - input.height, - input.width, - ), - StableDiffusionVersion::Xl => stable_diffusion::StableDiffusionConfig::sdxl( - input.sliced_attention_size, - input.height, - input.width, - ), - StableDiffusionVersion::Turbo => stable_diffusion::StableDiffusionConfig::sdxl_turbo( - input.sliced_attention_size, - input.height, - input.width, - ), - }; - let scheduler = sd_config.build_scheduler(n_steps)?; - let device = device(self.device_id)?; + let scheduler = self.config.build_scheduler(n_steps)?; if let Some(seed) = input.seed { device.set_seed(seed)?; } let use_guide_scale = guidance_scale > 1.0; - let which = match input.sd_version { + let which = match self.model_type { StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => vec![true, false], - _ => vec![true], + _ => vec![true], // INTEGRITY: we have checked previously if the model type is valid for the family of stable diffusion models }; let text_embeddings = which .iter() .map(|first| { + let (tokenizer, text_model) = if first { + (&self.tokenizer, &self.text_model) + } else { + (&self.tokenizer_2.unwrap(), &self.text_model_2.unwrap()) + }; + Self::text_embeddings( &input.prompt, &input.uncond_prompt, - input.tokenizer.clone(), - input.clip_weights.clone(), - input.sd_version, - &sd_config, + tokenizer, + text_model, + self.model_type, + &self.config, input.use_f16, - &device, - dtype, + &self.device, + self.dtype, use_guide_scale, *first, ) @@ -237,18 +328,13 @@ impl ModelTrait for StableDiffusion { let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?; - let vae_weights = ModelFile::Vae.get(input.vae_weights, input.sd_version, input.use_f16)?; - let vae = sd_config.build_vae(vae_weights, &device, dtype)?; let init_latent_dist = match &input.img2img { None => None, Some(image) => { let image = Self::image_preprocess(image)?.to_device(&device)?; - Some(vae.encode(&image)?) + Some(self.vae.encode(&image)?) } }; - let unet_weights = - ModelFile::Unet.get(input.unet_weights, input.sd_version, input.use_f16)?; - let unet = sd_config.build_unet(unet_weights, &device, 4, input.use_flash_attn, dtype)?; let t_start = if input.img2img.is_some() { n_steps - (n_steps as f64 * input.img2img_strength) as usize @@ -281,14 +367,14 @@ impl ModelTrait for StableDiffusion { let latents = Tensor::randn( 0f32, 1f32, - (bsize, 4, sd_config.height / 8, sd_config.width / 8), - &device, + (bsize, 4, input.height / 8, input.width / 8), + &self.device, )?; // scale the initial noise by the standard deviation required by the scheduler (latents * scheduler.init_noise_sigma())? } }; - let mut latents = latents.to_dtype(dtype)?; + let mut latents = latents.to_dtype(self.dtype)?; for (timestep_index, ×tep) in timesteps.iter().enumerate() { if timestep_index < t_start { @@ -303,7 +389,8 @@ impl ModelTrait for StableDiffusion { let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)?; let noise_pred = - unet.forward(&latent_model_input, timestep as f64, &text_embeddings)?; + self.unet + .forward(&latent_model_input, timestep as f64, &text_embeddings)?; let noise_pred = if use_guide_scale { let noise_pred = noise_pred.chunk(2, 0)?; @@ -318,7 +405,7 @@ impl ModelTrait for StableDiffusion { latents = scheduler.step(&noise_pred, timestep, &latents)?; } save_tensor_to_file(&latents, "tensor1")?; - let image = vae.decode(&(&latents / vae_scale)?)?; + let image = self.vae.decode(&(&latents / vae_scale)?)?; save_tensor_to_file(&image, "tensor2")?; let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?; save_tensor_to_file(&image, "tensor3")?; @@ -339,70 +426,82 @@ enum StableDiffusionVersion { Turbo, } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum ModelFile { - Tokenizer, - Tokenizer2, - Clip, - Clip2, - Unet, - Vae, -} - impl ModelType { fn unet_file(&self, use_f16: bool) -> &'static str { match self { - Self::StableDiffusionV1_5 | Self::StableDiffusionV2_1 | Self::StableDiffusionXl | Self::StableDiffusionTurbo => { + Self::StableDiffusionV1_5 + | Self::StableDiffusionV2_1 + | Self::StableDiffusionXl + | Self::StableDiffusionTurbo => { if use_f16 { "unet/diffusion_pytorch_model.fp16.safetensors" } else { "unet/diffusion_pytorch_model.safetensors" } } - _ => panic!("Invalid stable diffusion model type") + _ => panic!("Invalid stable diffusion model type"), } } fn vae_file(&self, use_f16: bool) -> &'static str { match self { - Self::StableDiffusionV1_5 | Self::StableDiffusionV2_1 | Self::StableDiffusionXl | Self::StableDiffusionTurbo => { + Self::StableDiffusionV1_5 + | Self::StableDiffusionV2_1 + | Self::StableDiffusionXl + | Self::StableDiffusionTurbo => { if use_f16 { "vae/diffusion_pytorch_model.fp16.safetensors" } else { "vae/diffusion_pytorch_model.safetensors" } } - _ => panic!("Invalid stable diffusion model type") + _ => panic!("Invalid stable diffusion model type"), } } fn clip_file(&self, use_f16: bool) -> &'static str { match self { - Self::StableDiffusionV1_5 | Self::StableDiffusionV2_1 | Self::StableDiffusionXl | Self::StableDiffusionTurbo => { + Self::StableDiffusionV1_5 + | Self::StableDiffusionV2_1 + | Self::StableDiffusionXl + | Self::StableDiffusionTurbo => { if use_f16 { "text_encoder/model.fp16.safetensors" } else { "text_encoder/model.safetensors" } } - _ => panic!("Invalid stable diffusion model type") + _ => panic!("Invalid stable diffusion model type"), } } fn clip2_file(&self, use_f16: bool) -> &'static str { match self { - Self::StableDiffusionV1_5 | Self::StableDiffusionV2_1 | Self::StableDiffusionXl | Self::StableDiffusionTurbo => { + Self::StableDiffusionV1_5 + | Self::StableDiffusionV2_1 + | Self::StableDiffusionXl + | Self::StableDiffusionTurbo => { if use_f16 { "text_encoder_2/model.fp16.safetensors" } else { "text_encoder_2/model.safetensors" } } - _ => panic!("Invalid stable diffusion model type") + _ => panic!("Invalid stable diffusion model type"), } } } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ModelFile { + Tokenizer, + Tokenizer2, + Clip, + Clip2, + Unet, + Vae, +} + impl ModelFile { fn get( &self, @@ -412,49 +511,47 @@ impl ModelFile { model_type: ModelType, use_f16: bool, ) -> Result { - let (repo, path) = match self { - Self::Tokenizer => { - let tokenizer_repo = match model_type { - StableDiffusionVersion::V1_5 | StableDiffusionVersion::V2_1 => { - "openai/clip-vit-base-patch32" - } - StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => { - // This seems similar to the patch32 version except some very small - // difference in the split regex. - "openai/clip-vit-large-patch14" - } - }; - (tokenizer_repo, "tokenizer.json") - } - Self::Tokenizer2 => { - ("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", "tokenizer.json") + let (repo, path) = match self { + Self::Tokenizer => { + let tokenizer_repo = match model_type { + StableDiffusionVersion::V1_5 | StableDiffusionVersion::V2_1 => { + "openai/clip-vit-base-patch32" } - Self::Clip => (model_type.repo(), model_type.clip_file(use_f16)), - Self::Clip2 => (model_type.repo(), model_type.clip2_file(use_f16)), - Self::Unet => (model_type.repo(), model_type.unet_file(use_f16)), - Self::Vae => { - // Override for SDXL when using f16 weights. - // See https://github.com/huggingface/candle/issues/1060 - if matches!( - model_type, - ModelType::StableDiffusionXl | ModelType::StableDiffusionTurbo, - ) && use_f16 - { - ( - "madebyollin/sdxl-vae-fp16-fix", - "diffusion_pytorch_model.safetensors", - ) - } else { - (model_type.repo(), model_type.vae_file(use_f16)) - } + StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => { + // This seems similar to the patch32 version except some very small + // difference in the split regex. + "openai/clip-vit-large-patch14" } }; - let filename = ApiBuilder::new() - .with_progress(true) - .with_token(Some(api_key)) - .with_cache_dir(cache_dir) - .build()?; - Ok(filename) + (tokenizer_repo, "tokenizer.json") + } + Self::Tokenizer2 => ("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", "tokenizer.json"), + Self::Clip => (model_type.repo(), model_type.clip_file(use_f16)), + Self::Clip2 => (model_type.repo(), model_type.clip2_file(use_f16)), + Self::Unet => (model_type.repo(), model_type.unet_file(use_f16)), + Self::Vae => { + // Override for SDXL when using f16 weights. + // See https://github.com/huggingface/candle/issues/1060 + if matches!( + model_type, + ModelType::StableDiffusionXl | ModelType::StableDiffusionTurbo, + ) && use_f16 + { + ( + "madebyollin/sdxl-vae-fp16-fix", + "diffusion_pytorch_model.safetensors", + ) + } else { + (model_type.repo(), model_type.vae_file(use_f16)) + } + } + }; + let filename = ApiBuilder::new() + .with_progress(true) + .with_token(Some(api_key)) + .with_cache_dir(cache_dir) + .build()?; + Ok(filename) } } @@ -463,23 +560,16 @@ impl StableDiffusion { fn text_embeddings( prompt: &str, uncond_prompt: &str, - tokenizer: Option, - clip_weights: Option, - sd_version: StableDiffusionVersion, - sd_config: &stable_diffusion::StableDiffusionConfig, + tokenizer: &Tokenizer, + text_model: &ClipTextTransformer, + model_type: ModelType, + sd_config: &StableDiffusionConfig, use_f16: bool, device: &Device, dtype: DType, use_guide_scale: bool, first: bool, ) -> Result { - let (clip_weights_file, tokenizer_file) = if first { - (ModelFile::Clip, ModelFile::Tokenizer) - } else { - (ModelFile::Clip2, ModelFile::Tokenizer2) - }; - let tokenizer = tokenizer_file.get(tokenizer, sd_version, use_f16)?; - let tokenizer = Tokenizer::from_file(tokenizer)?; let pad_id = match &sd_config.clip.pad_with { Some(padding) => { *tokenizer @@ -500,18 +590,6 @@ impl StableDiffusion { } let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?; - let clip_weights = clip_weights_file.get(clip_weights, sd_version, false)?; - let clip_config = if first { - &sd_config.clip - } else { - sd_config.clip2.as_ref().unwrap() - }; - let text_model = stable_diffusion::build_clip_transformer( - clip_config, - clip_weights, - device, - DType::F64, - )?; let text_embeddings = text_model.forward(&tokens)?; let text_embeddings = if use_guide_scale { diff --git a/atoma-inference/src/models/config.rs b/atoma-inference/src/models/config.rs index 282efe9e..4ec3a618 100644 --- a/atoma-inference/src/models/config.rs +++ b/atoma-inference/src/models/config.rs @@ -39,7 +39,7 @@ impl ModelConfig { revision, device_id, use_flash_attention, - sliced_attention_size + sliced_attention_size, } } @@ -71,7 +71,7 @@ impl ModelConfig { self.use_flash_attention } - pub fn sliced_attention_size(&self) -> Option { + pub fn sliced_attention_size(&self) -> Option { self.sliced_attention_size } } @@ -156,7 +156,7 @@ pub mod tests { "".to_string(), 0, true, - Some(0) + Some(0), )], true, ); diff --git a/atoma-inference/src/models/mod.rs b/atoma-inference/src/models/mod.rs index 7ce61b37..f6f24008 100644 --- a/atoma-inference/src/models/mod.rs +++ b/atoma-inference/src/models/mod.rs @@ -1,4 +1,4 @@ -use ::candle::Error as CandleError; +use ::candle::{DTypeParseError, Error as CandleError}; use ed25519_consensus::VerificationKey as PublicKey; use thiserror::Error; From 3e09d62ec0ccc56677885627c300cc07c2fddfc5 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Fri, 5 Apr 2024 20:19:42 +0100 Subject: [PATCH 04/21] further refactoring --- atoma-inference/src/models/candle/falcon.rs | 7 ++- atoma-inference/src/models/candle/llama.rs | 11 ++-- atoma-inference/src/models/candle/mamba.rs | 13 +++-- .../src/models/candle/stable_diffusion.rs | 51 ++++++++++--------- atoma-inference/src/models/config.rs | 27 ++++++---- atoma-inference/src/models/mod.rs | 4 +- atoma-inference/src/models/types.rs | 2 +- atoma-inference/src/service.rs | 12 ++--- 8 files changed, 72 insertions(+), 55 deletions(-) diff --git a/atoma-inference/src/models/candle/falcon.rs b/atoma-inference/src/models/candle/falcon.rs index f09086a1..4c425605 100644 --- a/atoma-inference/src/models/candle/falcon.rs +++ b/atoma-inference/src/models/candle/falcon.rs @@ -1,4 +1,4 @@ -use std::{str::FromStr, time::Instant}; +use std::{path::PathBuf, str::FromStr, time::Instant}; use candle::{DType, Device, Tensor}; use candle_nn::VarBuilder; @@ -51,17 +51,16 @@ impl ModelTrait for FalconModel { type Output = String; type LoadData = LlmLoadData; - fn fetch(config: ModelConfig) -> Result { + fn fetch(cache_dir: PathBuf, config: ModelConfig) -> Result { let device = device(config.device_id())?; let dtype = DType::from_str(&config.dtype())?; let api_key = config.api_key(); - let cache_dir = config.cache_dir(); let api = ApiBuilder::new() .with_progress(true) .with_token(Some(api_key)) - .with_cache_dir(cache_dir.into()) + .with_cache_dir(cache_dir) .build()?; let repo = api.repo(Repo::with_revision( diff --git a/atoma-inference/src/models/candle/llama.rs b/atoma-inference/src/models/candle/llama.rs index dc5d17be..24d0dc91 100644 --- a/atoma-inference/src/models/candle/llama.rs +++ b/atoma-inference/src/models/candle/llama.rs @@ -4,7 +4,7 @@ extern crate accelerate_src; #[cfg(feature = "mkl")] extern crate intel_mkl_src; -use std::str::FromStr; +use std::{path::PathBuf, str::FromStr}; use candle::{DType, Device, Tensor}; use candle_nn::VarBuilder; @@ -53,14 +53,14 @@ impl ModelTrait for LlamaModel { type Output = String; type LoadData = LlmLoadData; - fn fetch(config: ModelConfig) -> Result { + fn fetch(cache_dir: PathBuf, config: ModelConfig) -> Result { let device = device(config.device_id())?; let dtype = DType::from_str(&config.dtype())?; let api = ApiBuilder::new() .with_progress(true) .with_token(Some(config.api_key())) - .with_cache_dir(config.cache_dir()) + .with_cache_dir(cache_dir) .build()?; let api = api.repo(Repo::with_revision( @@ -96,9 +96,8 @@ impl ModelTrait for LlamaModel { } fn load(load_data: Self::LoadData) -> Result { - let device = device(load_data.device_id)?; - let dtype = - DType::from_str(&load_data.dtype).map_err(|e| ModelError::Msg(e.to_string()))?; + let device = load_data.device; + let dtype = load_data.dtype; let (model, tokenizer_filename, cache) = { let config_filename = load_data.file_paths[0].clone(); let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?; diff --git a/atoma-inference/src/models/candle/mamba.rs b/atoma-inference/src/models/candle/mamba.rs index e661b106..38d3e8d0 100644 --- a/atoma-inference/src/models/candle/mamba.rs +++ b/atoma-inference/src/models/candle/mamba.rs @@ -1,4 +1,4 @@ -use std::{str::FromStr, time::Instant}; +use std::{path::PathBuf, str::FromStr, time::Instant}; use candle::{DType, Device, Tensor}; use candle_nn::VarBuilder; @@ -56,9 +56,8 @@ impl ModelTrait for MambaModel { type Output = String; type LoadData = LlmLoadData; - fn fetch(config: ModelConfig) -> Result { + fn fetch(cache_dir: PathBuf, config: ModelConfig) -> Result { let api_key = config.api_key(); - let cache_dir = config.cache_dir(); let device = device(config.device_id())?; let dtype = DType::from_str(&config.dtype())?; @@ -66,7 +65,7 @@ impl ModelTrait for MambaModel { let api = ApiBuilder::new() .with_progress(true) .with_token(Some(api_key)) - .with_cache_dir(config.cache_dir().into()) + .with_cache_dir(cache_dir) .build()?; let repo = api.repo(Repo::with_revision( @@ -114,7 +113,11 @@ impl ModelTrait for MambaModel { info!("Loading model weights.."); let var_builder = unsafe { - VarBuilder::from_mmaped_safetensors(&weights_filenames, load_data.dtype, &device)? + VarBuilder::from_mmaped_safetensors( + &weights_filenames, + load_data.dtype, + &load_data.device, + )? }; let model = Model::new(&config, var_builder.pp("backbone"))?; info!("Loaded Mamba model in {:?}", start.elapsed()); diff --git a/atoma-inference/src/models/candle/stable_diffusion.rs b/atoma-inference/src/models/candle/stable_diffusion.rs index 43275020..10103b2e 100644 --- a/atoma-inference/src/models/candle/stable_diffusion.rs +++ b/atoma-inference/src/models/candle/stable_diffusion.rs @@ -6,11 +6,9 @@ extern crate intel_mkl_src; use std::{path::PathBuf, str::FromStr}; -use candle_transformers::models::{ - clip::text_model::ClipTextTransformer, - stable_diffusion::{ - self, unet_2d::UNet2DConditionModel, vae::AutoEncoderKL, StableDiffusionConfig, - }, +use candle_transformers::models::stable_diffusion::{ + self, clip::ClipTextTransformer, unet_2d::UNet2DConditionModel, vae::AutoEncoderKL, + StableDiffusionConfig, }; use candle::{DType, Device, IndexOp, Module, Tensor, D}; @@ -128,7 +126,7 @@ impl ModelTrait for StableDiffusion { type Output = Vec<(Vec, usize, usize)>; type LoadData = StableDiffusionLoadData; - fn fetch(config: ModelConfig) -> Result { + fn fetch(cache_dir: PathBuf, config: ModelConfig) -> Result { let device = device(config.device_id())?; let dtype = DType::from_str(&config.dtype())?; let model_type = ModelType::from_str(&config.model_id())?; @@ -138,7 +136,6 @@ impl ModelTrait for StableDiffusion { }; let api_key = config.api_key(); - let cache_dir = config.cache_dir().into(); let use_f16 = config.dtype() == "f16"; let vae_weights_file_path = ModelFile::Vae.get(api_key, cache_dir, model_type, use_f16)?; @@ -202,25 +199,25 @@ impl ModelTrait for StableDiffusion { let (tokenizer, tokenizer_2) = match load_data.model_type { ModelType::StableDiffusionXl | ModelType::StableDiffusionTurbo => ( - Tokenizer::from_file(load_data.tokenizer_file_paths[0]), - Some(Tokenizer::from_file(load_data.tokenizer_file_paths[1])), + Tokenizer::from_file(load_data.tokenizer_file_paths[0])?, + Some(Tokenizer::from_file(load_data.tokenizer_file_paths[1])?), ), _ => ( - Tokenizer::from_file(load_data.tokenizer_file_paths[0]), + Tokenizer::from_file(load_data.tokenizer_file_paths[0])?, None, ), // INTEGRITY: we have checked previously if the model type is valid for the family of stable diffusion models }; let text_model = stable_diffusion::build_clip_transformer( &config.clip, - load_data.clip_weights_file_paths, + load_data.clip_weights_file_paths[0], &load_data.device, load_data.dtype, )?; let text_model_2 = if let Some(clip_config_2) = config.clip2 { Some(stable_diffusion::build_clip_transformer( &clip_config_2, - load_data.clip_weights_file_paths, + load_data.clip_weights_file_paths[1], &load_data.device, load_data.dtype, )?) @@ -239,7 +236,7 @@ impl ModelTrait for StableDiffusion { 4, // see https://github.com/huggingface/candle/blob/main/candle-examples/examples/stable-diffusion/main.rs#L492 load_data.use_flash_attention, load_data.dtype, - ); + )?; Ok(Self { config, @@ -293,18 +290,18 @@ impl ModelTrait for StableDiffusion { let scheduler = self.config.build_scheduler(n_steps)?; if let Some(seed) = input.seed { - device.set_seed(seed)?; + self.device.set_seed(seed)?; } let use_guide_scale = guidance_scale > 1.0; let which = match self.model_type { - StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => vec![true, false], + ModelType::StableDiffusionXl | ModelType::StableDiffusionTurbo => vec![true, false], _ => vec![true], // INTEGRITY: we have checked previously if the model type is valid for the family of stable diffusion models }; let text_embeddings = which .iter() .map(|first| { - let (tokenizer, text_model) = if first { + let (tokenizer, text_model) = if *first { (&self.tokenizer, &self.text_model) } else { (&self.tokenizer_2.unwrap(), &self.text_model_2.unwrap()) @@ -331,7 +328,7 @@ impl ModelTrait for StableDiffusion { let init_latent_dist = match &input.img2img { None => None, Some(image) => { - let image = Self::image_preprocess(image)?.to_device(&device)?; + let image = Self::image_preprocess(image)?.to_device(&self.device)?; Some(self.vae.encode(&image)?) } }; @@ -355,7 +352,8 @@ impl ModelTrait for StableDiffusion { let timesteps = scheduler.timesteps(); let latents = match &init_latent_dist { Some(init_latent_dist) => { - let latents = (init_latent_dist.sample()? * vae_scale)?.to_device(&device)?; + let latents = + (init_latent_dist.sample()? * vae_scale)?.to_device(&self.device)?; if t_start < timesteps.len() { let noise = latents.randn_like(0f64, 1f64)?; scheduler.add_noise(&latents, noise, timesteps[t_start])? @@ -367,7 +365,12 @@ impl ModelTrait for StableDiffusion { let latents = Tensor::randn( 0f32, 1f32, - (bsize, 4, input.height / 8, input.width / 8), + ( + bsize, + 4, + input.height.unwrap_or(512) / 8, + input.width.unwrap_or(512) / 8, + ), &self.device, )?; // scale the initial noise by the standard deviation required by the scheduler @@ -510,18 +513,19 @@ impl ModelFile { model_type: ModelType, use_f16: bool, - ) -> Result { + ) -> Result { let (repo, path) = match self { Self::Tokenizer => { let tokenizer_repo = match model_type { - StableDiffusionVersion::V1_5 | StableDiffusionVersion::V2_1 => { + ModelType::StableDiffusionV1_5 | ModelType::StableDiffusionV2_1 => { "openai/clip-vit-base-patch32" } - StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => { + ModelType::StableDiffusionXl | ModelType::StableDiffusionTurbo => { // This seems similar to the patch32 version except some very small // difference in the split regex. "openai/clip-vit-large-patch14" } + _ => bail!("Invalid stable diffusion model type"), }; (tokenizer_repo, "tokenizer.json") } @@ -546,11 +550,12 @@ impl ModelFile { } } }; - let filename = ApiBuilder::new() + let api = ApiBuilder::new() .with_progress(true) .with_token(Some(api_key)) .with_cache_dir(cache_dir) .build()?; + let filename = api.model(repo.to_string()).get(path)?; Ok(filename) } } diff --git a/atoma-inference/src/models/config.rs b/atoma-inference/src/models/config.rs index 4ec3a618..f527a32f 100644 --- a/atoma-inference/src/models/config.rs +++ b/atoma-inference/src/models/config.rs @@ -11,7 +11,6 @@ type Revision = String; #[derive(Clone, Debug, Deserialize, Serialize)] pub struct ModelConfig { api_key: String, - cache_dir: String, device_id: usize, dtype: String, model_id: ModelId, @@ -23,7 +22,6 @@ pub struct ModelConfig { impl ModelConfig { pub fn new( api_key: String, - cache_dir: String, model_id: ModelId, dtype: String, revision: Revision, @@ -33,7 +31,6 @@ impl ModelConfig { ) -> Self { Self { api_key, - cache_dir, dtype, model_id, revision, @@ -47,10 +44,6 @@ impl ModelConfig { self.api_key.clone() } - pub fn cache_dir(&self) -> String { - self.cache_dir.clone() - } - pub fn dtype(&self) -> String { self.dtype.clone() } @@ -78,20 +71,31 @@ impl ModelConfig { #[derive(Debug, Deserialize, Serialize)] pub struct ModelsConfig { + cache_dir: PathBuf, flush_storage: bool, models: Vec, tracing: bool, } impl ModelsConfig { - pub fn new(flush_storage: bool, models: Vec, tracing: bool) -> Self { + pub fn new( + cache_dir: PathBuf, + flush_storage: bool, + models: Vec, + tracing: bool, + ) -> Self { Self { + cache_dir, flush_storage, models, tracing, } } + pub fn cache_dir(&self) -> PathBuf { + self.cache_dir.clone() + } + pub fn flush_storage(&self) -> bool { self.flush_storage } @@ -119,6 +123,10 @@ impl ModelsConfig { pub fn from_env_file() -> Self { dotenv().ok(); + let cache_dir = std::env::var("CACHE_DIR") + .unwrap_or_default() + .parse() + .unwrap(); let flush_storage = std::env::var("FLUSH_STORAGE") .unwrap_or_default() .parse() @@ -133,6 +141,7 @@ impl ModelsConfig { .unwrap(); Self { + cache_dir, flush_storage, models, tracing, @@ -147,10 +156,10 @@ pub mod tests { #[test] fn test_config() { let config = ModelsConfig::new( + "/".to_string().into(), true, vec![ModelConfig::new( "my_key".to_string(), - "/".to_string(), "F16".to_string(), "Llama2_7b".to_string(), "".to_string(), diff --git a/atoma-inference/src/models/mod.rs b/atoma-inference/src/models/mod.rs index f6f24008..521e4e3e 100644 --- a/atoma-inference/src/models/mod.rs +++ b/atoma-inference/src/models/mod.rs @@ -1,3 +1,5 @@ +use std::path::PathBuf; + use ::candle::{DTypeParseError, Error as CandleError}; use ed25519_consensus::VerificationKey as PublicKey; use thiserror::Error; @@ -16,7 +18,7 @@ pub trait ModelTrait { type Output; type LoadData; - fn fetch(config: ModelConfig) -> Result; + fn fetch(cache_dir: PathBuf, config: ModelConfig) -> Result; fn load(load_data: Self::LoadData) -> Result where Self: Sized; diff --git a/atoma-inference/src/models/types.rs b/atoma-inference/src/models/types.rs index 363326b5..3c11ead0 100644 --- a/atoma-inference/src/models/types.rs +++ b/atoma-inference/src/models/types.rs @@ -48,7 +48,7 @@ impl FromStr for ModelType { match s { "falcon_7b" => Ok(Self::Falcon7b), "falcon_40b" => Ok(Self::Falcon40b), - "falcon_180b" => Ok(Self::Falcon180b0), + "falcon_180b" => Ok(Self::Falcon180b), "llama_v1" => Ok(Self::LlamaV1), "llama_v2" => Ok(Self::LlamaV2), "llama_solar_10_7b" => Ok(Self::LlamaSolar10_7B), diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index 60deea54..32918f26 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -19,7 +19,7 @@ pub struct ModelService { start_time: Instant, flush_storage: bool, public_key: PublicKey, - storage_path: PathBuf, + cache_dir: PathBuf, request_receiver: Receiver, response_sender: Sender, } @@ -38,7 +38,7 @@ impl ModelService { let public_key = private_key.verification_key(); let flush_storage = model_config.flush_storage(); - let storage_path = model_config.storage_path(); + let cache_dir = model_config.cache_dir(); let (dispatcher, model_thread_handle) = ModelThreadDispatcher::start::(model_config, public_key) @@ -50,7 +50,7 @@ impl ModelService { model_thread_handle, start_time, flush_storage, - storage_path, + cache_dir, public_key, request_receiver, response_sender, @@ -95,7 +95,7 @@ impl ModelService { ); if self.flush_storage { - match std::fs::remove_dir(self.storage_path) { + match std::fs::remove_dir(self.cache_dir) { Ok(()) => {} Err(e) => error!("Failed to remove storage folder, on shutdown: {e}"), }; @@ -197,7 +197,7 @@ mod tests { type Output = (); type LoadData = (); - fn fetch(config: ModelConfig) -> Result<(), crate::models::ModelError> { + fn fetch(_: PathBuf, _: ModelConfig) -> Result<(), crate::models::ModelError> { Ok(()) } @@ -223,7 +223,7 @@ mod tests { let config_data = Value::Table(toml! { api_key = "your_api_key" models = [["Mamba3b", "F16", "", ""]] - storage_path = "./storage_path/" + cache_dir = "./cache_dir/" tokenizer_file_path = "./tokenizer_file_path/" flush_storage = true tracing = true From 4c9d0b8a8d11fe96a72191f98d318692d379e40d Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Fri, 5 Apr 2024 20:52:16 +0100 Subject: [PATCH 05/21] resolve compiler errors --- atoma-inference/src/model_thread.rs | 5 +- atoma-inference/src/models/candle/falcon.rs | 27 ------ atoma-inference/src/models/candle/llama.rs | 2 - atoma-inference/src/models/candle/mamba.rs | 34 ------- .../src/models/candle/stable_diffusion.rs | 91 +++++-------------- atoma-inference/src/models/mod.rs | 5 +- atoma-inference/src/models/types.rs | 1 - 7 files changed, 30 insertions(+), 135 deletions(-) diff --git a/atoma-inference/src/model_thread.rs b/atoma-inference/src/model_thread.rs index 2b5c318e..9f44df25 100644 --- a/atoma-inference/src/model_thread.rs +++ b/atoma-inference/src/model_thread.rs @@ -106,16 +106,19 @@ impl ModelThreadDispatcher { let mut handles = Vec::new(); let mut model_senders = HashMap::new(); + let cache_dir = config.cache_dir(); + for model_config in config.models() { info!("Spawning new thread for model: {}", model_config.model_id()); + let model_cache_dir = cache_dir.clone(); let (model_sender, model_receiver) = mpsc::channel::(); let model_name = model_config.model_id().clone(); model_senders.insert(model_name.clone(), model_sender.clone()); let join_handle = std::thread::spawn(move || { info!("Fetching files for model: {model_name}"); - let load_data = M::fetch(model_config)?; + let load_data = M::fetch(model_cache_dir, model_config)?; let model = M::load(load_data)?; let model_thread = ModelThread { diff --git a/atoma-inference/src/models/candle/falcon.rs b/atoma-inference/src/models/candle/falcon.rs index 4c425605..cb816566 100644 --- a/atoma-inference/src/models/candle/falcon.rs +++ b/atoma-inference/src/models/candle/falcon.rs @@ -30,7 +30,6 @@ pub struct FalconModel { impl FalconModel { pub fn new( model: Falcon, - config: Config, device: Device, dtype: DType, model_type: ModelType, @@ -121,7 +120,6 @@ impl ModelTrait for FalconModel { Ok(Self::new( model, - config, load_data.device, load_data.dtype, load_data.model_type, @@ -201,28 +199,3 @@ impl ModelTrait for FalconModel { Ok(output) } } - -enum Which { - Falcon7b, - Falcon40b, - Falcon180b, -} - -impl Which { - fn model_id(&self) -> &'static str { - match self { - Self::Falcon7b => "tiiuae/falcon-7b", - Self::Falcon40b => "tiiuae/falcon-40b", - Self::Falcon180b => "tiiuae/falcon-180b", - } - } - - fn from_config(config: &Config) -> Self { - match config.hidden_size { - 4544 => Self::Falcon7b, - 8192 => Self::Falcon40b, - 14848 => Self::Falcon180b, - _ => panic!("Invalid config hidden size value"), - } - } -} diff --git a/atoma-inference/src/models/candle/llama.rs b/atoma-inference/src/models/candle/llama.rs index 24d0dc91..3ff633bf 100644 --- a/atoma-inference/src/models/candle/llama.rs +++ b/atoma-inference/src/models/candle/llama.rs @@ -42,7 +42,6 @@ pub struct Config {} pub struct LlamaModel { cache: Cache, device: Device, - dtype: DType, model: model::Llama, model_type: ModelType, tokenizer: Tokenizer, @@ -115,7 +114,6 @@ impl ModelTrait for LlamaModel { Ok(Self { cache, device, - dtype, model, model_type: load_data.model_type, tokenizer, diff --git a/atoma-inference/src/models/candle/mamba.rs b/atoma-inference/src/models/candle/mamba.rs index 38d3e8d0..3daceeb4 100644 --- a/atoma-inference/src/models/candle/mamba.rs +++ b/atoma-inference/src/models/candle/mamba.rs @@ -223,37 +223,3 @@ impl ModelTrait for MambaModel { Ok(output) } } - -#[derive(Clone, Copy, PartialEq, Eq, Debug)] -enum Which { - Mamba130m, - Mamba370m, - Mamba790m, - Mamba1_4b, - Mamba2_8b, - // Mamba2_8bSlimPj, TODO: add this -} - -impl Which { - fn model_id(&self) -> &'static str { - match self { - Self::Mamba130m => "state-spaces/mamba-130m", - Self::Mamba370m => "state-spaces/mamba-370m", - Self::Mamba790m => "state-spaces/mamba-790m", - Self::Mamba1_4b => "state-spaces/mamba-1.4b", - Self::Mamba2_8b => "state-spaces/mamba-2.8b", - // Self::Mamba2_8bSlimPj => "state-spaces/mamba-2.8b-slimpj'", - } - } - - fn from_config(config: &Config) -> Self { - match config.d_model { - 768 => Self::Mamba130m, - 1024 => Self::Mamba370m, - 1536 => Self::Mamba790m, - 2048 => Self::Mamba1_4b, - 2560 => Self::Mamba2_8b, - _ => panic!("Invalid config d_model value"), - } - } -} diff --git a/atoma-inference/src/models/candle/stable_diffusion.rs b/atoma-inference/src/models/candle/stable_diffusion.rs index 10103b2e..08d7e076 100644 --- a/atoma-inference/src/models/candle/stable_diffusion.rs +++ b/atoma-inference/src/models/candle/stable_diffusion.rs @@ -31,21 +31,6 @@ pub struct Input { height: Option, width: Option, - /// The UNet weight file, in .safetensors format. - unet_weights: Option, - - /// The CLIP weight file, in .safetensors format. - clip_weights: Option, - - /// The VAE weight file, in .safetensors format. - vae_weights: Option, - - /// The file specifying the tokenizer to used for tokenization. - tokenizer: Option, - - /// The size of the sliced attention or 0 for automatic slicing (disabled by default) - sliced_attention_size: Option, - /// The number of steps to run the diffusion for. n_steps: Option, @@ -54,10 +39,6 @@ pub struct Input { sd_version: StableDiffusionVersion, - use_flash_attn: bool, - - use_f16: bool, - guidance_scale: Option, img2img: Option, @@ -71,31 +52,6 @@ pub struct Input { seed: Option, } -impl Input { - pub fn default_prompt(prompt: String) -> Self { - Self { - prompt, - uncond_prompt: "".to_string(), - height: Some(256), - width: Some(256), - unet_weights: None, - clip_weights: None, - vae_weights: None, - tokenizer: None, - sliced_attention_size: None, - n_steps: Some(20), - num_samples: 1, - sd_version: StableDiffusionVersion::V1_5, - use_flash_attn: false, - use_f16: true, - guidance_scale: None, - img2img: None, - img2img_strength: 0.8, - seed: Some(0), - } - } -} - pub struct StableDiffusionLoadData { device: Device, dtype: DType, @@ -138,9 +94,9 @@ impl ModelTrait for StableDiffusion { let api_key = config.api_key(); let use_f16 = config.dtype() == "f16"; - let vae_weights_file_path = ModelFile::Vae.get(api_key, cache_dir, model_type, use_f16)?; + let vae_weights_file_path = ModelFile::Vae.get(api_key.clone(), cache_dir.clone(), model_type.clone(), use_f16)?; let unet_weights_file_path = - ModelFile::Unet.get(api_key, cache_dir, model_type, use_f16)?; + ModelFile::Unet.get(api_key.clone(), cache_dir.clone(), model_type.clone(), use_f16)?; let mut clip_weights_file_paths = vec![]; let mut tokenizer_file_paths = vec![]; @@ -153,9 +109,9 @@ impl ModelTrait for StableDiffusion { }; let clip_weights_file_path = - clip_weights_file.get(api_key, cache_dir, model_type, false)?; + clip_weights_file.get(api_key.clone(), cache_dir.clone(), model_type.clone(), false)?; let tokenizer_file_path = - tokenizer_file.get(api_key, cache_dir, model_type, use_f16)?; + tokenizer_file.get(api_key.clone(), cache_dir.clone(), model_type.clone(), use_f16)?; clip_weights_file_paths.push(clip_weights_file_path); tokenizer_file_paths.push(tokenizer_file_path); @@ -199,25 +155,25 @@ impl ModelTrait for StableDiffusion { let (tokenizer, tokenizer_2) = match load_data.model_type { ModelType::StableDiffusionXl | ModelType::StableDiffusionTurbo => ( - Tokenizer::from_file(load_data.tokenizer_file_paths[0])?, - Some(Tokenizer::from_file(load_data.tokenizer_file_paths[1])?), + Tokenizer::from_file(load_data.tokenizer_file_paths[0].clone())?, + Some(Tokenizer::from_file(load_data.tokenizer_file_paths[1].clone())?), ), _ => ( - Tokenizer::from_file(load_data.tokenizer_file_paths[0])?, + Tokenizer::from_file(load_data.tokenizer_file_paths[0].clone())?, None, ), // INTEGRITY: we have checked previously if the model type is valid for the family of stable diffusion models }; let text_model = stable_diffusion::build_clip_transformer( &config.clip, - load_data.clip_weights_file_paths[0], + load_data.clip_weights_file_paths[0].clone(), &load_data.device, load_data.dtype, )?; - let text_model_2 = if let Some(clip_config_2) = config.clip2 { + let text_model_2 = if let Some(clip_config_2) = &config.clip2 { Some(stable_diffusion::build_clip_transformer( - &clip_config_2, - load_data.clip_weights_file_paths[1], + clip_config_2, + load_data.clip_weights_file_paths[1].clone(), &load_data.device, load_data.dtype, )?) @@ -253,7 +209,7 @@ impl ModelTrait for StableDiffusion { } fn model_type(&self) -> ModelType { - self.model_type + self.model_type.clone() } fn run(&mut self, input: Self::Input) -> Result { @@ -301,20 +257,14 @@ impl ModelTrait for StableDiffusion { let text_embeddings = which .iter() .map(|first| { - let (tokenizer, text_model) = if *first { - (&self.tokenizer, &self.text_model) - } else { - (&self.tokenizer_2.unwrap(), &self.text_model_2.unwrap()) - }; - Self::text_embeddings( &input.prompt, &input.uncond_prompt, - tokenizer, - text_model, - self.model_type, + &self.tokenizer, + self.tokenizer_2.as_ref(), + &self.text_model, + self.text_model_2.as_ref(), &self.config, - input.use_f16, &self.device, self.dtype, use_guide_scale, @@ -566,15 +516,20 @@ impl StableDiffusion { prompt: &str, uncond_prompt: &str, tokenizer: &Tokenizer, + tokenizer_2: Option<&Tokenizer>, text_model: &ClipTextTransformer, - model_type: ModelType, + text_model_2: Option<&ClipTextTransformer>, sd_config: &StableDiffusionConfig, - use_f16: bool, device: &Device, dtype: DType, use_guide_scale: bool, first: bool, ) -> Result { + let (tokenizer, text_model) = if first { + (tokenizer, text_model) + } else { + (tokenizer_2.unwrap(), text_model_2.unwrap()) + }; let pad_id = match &sd_config.clip.pad_with { Some(padding) => { *tokenizer diff --git a/atoma-inference/src/models/mod.rs b/atoma-inference/src/models/mod.rs index 521e4e3e..4dcfbbf5 100644 --- a/atoma-inference/src/models/mod.rs +++ b/atoma-inference/src/models/mod.rs @@ -2,6 +2,7 @@ use std::path::PathBuf; use ::candle::{DTypeParseError, Error as CandleError}; use ed25519_consensus::VerificationKey as PublicKey; +use serde::{de::DeserializeOwned, Serialize}; use thiserror::Error; use self::{config::ModelConfig, types::ModelType}; @@ -14,8 +15,8 @@ pub mod types; pub type ModelId = String; pub trait ModelTrait { - type Input; - type Output; + type Input: DeserializeOwned; + type Output: Serialize; type LoadData; fn fetch(cache_dir: PathBuf, config: ModelConfig) -> Result; diff --git a/atoma-inference/src/models/types.rs b/atoma-inference/src/models/types.rs index 3c11ead0..a422298a 100644 --- a/atoma-inference/src/models/types.rs +++ b/atoma-inference/src/models/types.rs @@ -92,7 +92,6 @@ impl ModelType { Self::Mixtral8x7b => "TODO", Self::StableDiffusionV1_5 => "runwayml/stable-diffusion-v1-5", Self::StableDiffusionV2_1 => "stabilityai/stable-diffusion-2-1", - Self::StableDiffusionV1_5 => "runwayml/stable-diffusion-v1-5", Self::StableDiffusionXl => "stabilityai/stable-diffusion-xl-base-1.0", Self::StableDiffusionTurbo => "stabilityai/sdxl-turbo", } From f2278ff513fa8da737f23a51a60f101636b32fce Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Fri, 5 Apr 2024 20:54:38 +0100 Subject: [PATCH 06/21] cargo clippy --- atoma-inference/src/main.rs | 11 ++---- atoma-inference/src/model_thread.rs | 5 +-- .../src/models/candle/stable_diffusion.rs | 39 ++++++++++++++----- atoma-inference/src/models/types.rs | 8 ++-- atoma-inference/src/service.rs | 26 +++---------- 5 files changed, 42 insertions(+), 47 deletions(-) diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index 45d5974b..2fdd4fa8 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -1,7 +1,6 @@ use std::time::Duration; use ed25519_consensus::SigningKey as PrivateKey; -use hf_hub::api::sync::Api; use inference::{ models::{candle::mamba::MambaModel, config::ModelsConfig, types::TextRequest}, service::{ModelService, ModelServiceError}, @@ -22,13 +21,9 @@ async fn main() -> Result<(), ModelServiceError> { .expect("Incorrect private key bytes length"); let private_key = PrivateKey::from(private_key_bytes); - let mut service = ModelService::start::( - model_config, - private_key, - req_receiver, - resp_sender, - ) - .expect("Failed to start inference service"); + let mut service = + ModelService::start::(model_config, private_key, req_receiver, resp_sender) + .expect("Failed to start inference service"); let pk = service.public_key(); diff --git a/atoma-inference/src/model_thread.rs b/atoma-inference/src/model_thread.rs index 9f44df25..9e608b37 100644 --- a/atoma-inference/src/model_thread.rs +++ b/atoma-inference/src/model_thread.rs @@ -7,7 +7,7 @@ use tokio::sync::oneshot::{self, error::RecvError}; use tracing::{debug, error, info, warn}; use crate::{ - apis::{ApiError, ApiTrait}, + apis::ApiError, models::{config::ModelsConfig, ModelError, ModelId, ModelTrait}, }; @@ -95,12 +95,11 @@ pub struct ModelThreadDispatcher { } impl ModelThreadDispatcher { - pub(crate) fn start( + pub(crate) fn start( config: ModelsConfig, public_key: PublicKey, ) -> Result<(Self, Vec), ModelThreadError> where - F: ApiTrait + Send + Sync + 'static, M: ModelTrait, // + Send + 'static, { let mut handles = Vec::new(); diff --git a/atoma-inference/src/models/candle/stable_diffusion.rs b/atoma-inference/src/models/candle/stable_diffusion.rs index 08d7e076..fde4954a 100644 --- a/atoma-inference/src/models/candle/stable_diffusion.rs +++ b/atoma-inference/src/models/candle/stable_diffusion.rs @@ -94,9 +94,18 @@ impl ModelTrait for StableDiffusion { let api_key = config.api_key(); let use_f16 = config.dtype() == "f16"; - let vae_weights_file_path = ModelFile::Vae.get(api_key.clone(), cache_dir.clone(), model_type.clone(), use_f16)?; - let unet_weights_file_path = - ModelFile::Unet.get(api_key.clone(), cache_dir.clone(), model_type.clone(), use_f16)?; + let vae_weights_file_path = ModelFile::Vae.get( + api_key.clone(), + cache_dir.clone(), + model_type.clone(), + use_f16, + )?; + let unet_weights_file_path = ModelFile::Unet.get( + api_key.clone(), + cache_dir.clone(), + model_type.clone(), + use_f16, + )?; let mut clip_weights_file_paths = vec![]; let mut tokenizer_file_paths = vec![]; @@ -108,10 +117,18 @@ impl ModelTrait for StableDiffusion { (ModelFile::Clip2, ModelFile::Tokenizer2) }; - let clip_weights_file_path = - clip_weights_file.get(api_key.clone(), cache_dir.clone(), model_type.clone(), false)?; - let tokenizer_file_path = - tokenizer_file.get(api_key.clone(), cache_dir.clone(), model_type.clone(), use_f16)?; + let clip_weights_file_path = clip_weights_file.get( + api_key.clone(), + cache_dir.clone(), + model_type.clone(), + false, + )?; + let tokenizer_file_path = tokenizer_file.get( + api_key.clone(), + cache_dir.clone(), + model_type.clone(), + use_f16, + )?; clip_weights_file_paths.push(clip_weights_file_path); tokenizer_file_paths.push(tokenizer_file_path); @@ -156,7 +173,9 @@ impl ModelTrait for StableDiffusion { let (tokenizer, tokenizer_2) = match load_data.model_type { ModelType::StableDiffusionXl | ModelType::StableDiffusionTurbo => ( Tokenizer::from_file(load_data.tokenizer_file_paths[0].clone())?, - Some(Tokenizer::from_file(load_data.tokenizer_file_paths[1].clone())?), + Some(Tokenizer::from_file( + load_data.tokenizer_file_paths[1].clone(), + )?), ), _ => ( Tokenizer::from_file(load_data.tokenizer_file_paths[0].clone())?, @@ -525,9 +544,9 @@ impl StableDiffusion { use_guide_scale: bool, first: bool, ) -> Result { - let (tokenizer, text_model) = if first { + let (tokenizer, text_model) = if first { (tokenizer, text_model) - } else { + } else { (tokenizer_2.unwrap(), text_model_2.unwrap()) }; let pad_id = match &sd_config.clip.pad_with { diff --git a/atoma-inference/src/models/types.rs b/atoma-inference/src/models/types.rs index a422298a..256ba56c 100644 --- a/atoma-inference/src/models/types.rs +++ b/atoma-inference/src/models/types.rs @@ -64,11 +64,9 @@ impl FromStr for ModelType { "stable_diffusion_v2-1" => Ok(Self::StableDiffusionV2_1), "stable_diffusion_xl" => Ok(Self::StableDiffusionXl), "stable_diffusion_turbo" => Ok(Self::StableDiffusionTurbo), - _ => { - return Err(ModelError::InvalidModelType(format!( - "Invalid string model type descryption" - ))) - } + _ => Err(ModelError::InvalidModelType( + "Invalid string model type descryption".to_string() + )), } } } diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index 32918f26..9feef69f 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -8,7 +8,7 @@ use tracing::{error, info}; use thiserror::Error; use crate::{ - apis::{ApiError, ApiTrait}, + apis::ApiError, model_thread::{ModelThreadDispatcher, ModelThreadError, ModelThreadHandle}, models::{config::ModelsConfig, ModelTrait}, }; @@ -25,7 +25,7 @@ pub struct ModelService { } impl ModelService { - pub fn start( + pub fn start( model_config: ModelsConfig, private_key: PrivateKey, request_receiver: Receiver, @@ -33,7 +33,6 @@ impl ModelService { ) -> Result where M: ModelTrait + Send + 'static, - F: ApiTrait + Send + Sync + 'static, { let public_key = private_key.verification_key(); @@ -41,7 +40,7 @@ impl ModelService { let cache_dir = model_config.cache_dir(); let (dispatcher, model_thread_handle) = - ModelThreadDispatcher::start::(model_config, public_key) + ModelThreadDispatcher::start::(model_config, public_key) .map_err(ModelServiceError::ModelThreadError)?; let start_time = Instant::now(); @@ -146,25 +145,10 @@ mod tests { use std::io::Write; use toml::{toml, Value}; - use crate::models::{config::ModelConfig, ModelId, Request, Response}; + use crate::models::{config::ModelConfig, Request, Response}; use super::*; - struct MockApi {} - - impl ApiTrait for MockApi { - fn create(_: String, _: PathBuf) -> Result - where - Self: Sized, - { - Ok(Self {}) - } - - fn fetch(&self, _: ModelId, _: String) -> Result, ApiError> { - Ok(vec![]) - } - } - impl Request for () { type ModelInput = (); @@ -240,7 +224,7 @@ mod tests { let config = ModelsConfig::from_file_path(CONFIG_FILE_PATH.parse().unwrap()); - let _ = ModelService::start::( + let _ = ModelService::start::( config, private_key, req_receiver, From 82bfd6dd701051f1ceeb97fb1ca3fe9c95c7e753 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Fri, 5 Apr 2024 21:19:17 +0100 Subject: [PATCH 07/21] remove extern crates bug dependencies --- atoma-inference/src/models/candle/llama.rs | 6 ------ atoma-inference/src/models/candle/stable_diffusion.rs | 6 ------ atoma-inference/src/models/config.rs | 2 +- atoma-inference/src/service.rs | 2 +- 4 files changed, 2 insertions(+), 14 deletions(-) diff --git a/atoma-inference/src/models/candle/llama.rs b/atoma-inference/src/models/candle/llama.rs index 3ff633bf..1fe28c4d 100644 --- a/atoma-inference/src/models/candle/llama.rs +++ b/atoma-inference/src/models/candle/llama.rs @@ -1,9 +1,3 @@ -#[cfg(feature = "accelerate")] -extern crate accelerate_src; - -#[cfg(feature = "mkl")] -extern crate intel_mkl_src; - use std::{path::PathBuf, str::FromStr}; use candle::{DType, Device, Tensor}; diff --git a/atoma-inference/src/models/candle/stable_diffusion.rs b/atoma-inference/src/models/candle/stable_diffusion.rs index fde4954a..12480fa1 100644 --- a/atoma-inference/src/models/candle/stable_diffusion.rs +++ b/atoma-inference/src/models/candle/stable_diffusion.rs @@ -1,9 +1,3 @@ -#[cfg(feature = "accelerate")] -extern crate accelerate_src; - -#[cfg(feature = "mkl")] -extern crate intel_mkl_src; - use std::{path::PathBuf, str::FromStr}; use candle_transformers::models::stable_diffusion::{ diff --git a/atoma-inference/src/models/config.rs b/atoma-inference/src/models/config.rs index f527a32f..0ee2dfcd 100644 --- a/atoma-inference/src/models/config.rs +++ b/atoma-inference/src/models/config.rs @@ -171,7 +171,7 @@ pub mod tests { ); let toml_str = toml::to_string(&config).unwrap(); - let should_be_toml_str = "api_key = \"my_key\"\nflush_storage = true\nmodels = [[\"Llama2_7b\", \"F16\", \"\"]]\nstorage_path = \"storage_path\"\ntracing = true\n"; + let should_be_toml_str = "cache_dir = \"/\"\nflush_storage = true\ntracing = true\n\n[[models]]\napi_key = \"my_key\"\ndevice_id = 0\ndtype = \"Llama2_7b\"\nmodel_id = \"F16\"\nrevision = \"\"\nuse_flash_attention = true\nsliced_attention_size = 0\n"; assert_eq!(toml_str, should_be_toml_str); } } diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index 9feef69f..e876fc46 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -206,7 +206,7 @@ mod tests { let config_data = Value::Table(toml! { api_key = "your_api_key" - models = [["Mamba3b", "F16", "", ""]] + models = [["Mamba370m", 0, "f16", "", "", true, 0]] cache_dir = "./cache_dir/" tokenizer_file_path = "./tokenizer_file_path/" flush_storage = true From 15b43f9cad7f74d43e31ef4a6d7477ed49a5d5e1 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Fri, 5 Apr 2024 22:59:44 +0100 Subject: [PATCH 08/21] resolve issues with running models --- atoma-inference/src/main.rs | 39 ++++---- atoma-inference/src/model_thread.rs | 89 +++++++++++++------ atoma-inference/src/models/candle/falcon.rs | 18 ++-- atoma-inference/src/models/candle/llama.rs | 26 +++--- atoma-inference/src/models/candle/mamba.rs | 18 ++-- .../src/models/candle/stable_diffusion.rs | 7 +- atoma-inference/src/models/config.rs | 25 ++++-- atoma-inference/src/models/mod.rs | 11 ++- atoma-inference/src/models/types.rs | 2 +- atoma-inference/src/service.rs | 48 ++++++---- 10 files changed, 178 insertions(+), 105 deletions(-) diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index 2fdd4fa8..de8f7e3f 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -2,7 +2,11 @@ use std::time::Duration; use ed25519_consensus::SigningKey as PrivateKey; use inference::{ - models::{candle::mamba::MambaModel, config::ModelsConfig, types::TextRequest}, + models::{ + candle::mamba::MambaModel, + config::ModelsConfig, + types::{TextRequest, TextResponse}, + }, service::{ModelService, ModelServiceError}, }; @@ -10,8 +14,8 @@ use inference::{ async fn main() -> Result<(), ModelServiceError> { tracing_subscriber::fmt::init(); - let (req_sender, req_receiver) = tokio::sync::mpsc::channel::(32); - let (resp_sender, mut resp_receiver) = tokio::sync::mpsc::channel::(32); + let (req_sender, req_receiver) = tokio::sync::mpsc::channel::(32); + let (resp_sender, mut resp_receiver) = tokio::sync::mpsc::channel::(32); let model_config = ModelsConfig::from_file_path("../inference.toml".parse().unwrap()); let private_key_bytes = @@ -35,22 +39,19 @@ async fn main() -> Result<(), ModelServiceError> { tokio::time::sleep(Duration::from_millis(5000)).await; req_sender - .send( - serde_json::to_value(TextRequest { - request_id: 0, - prompt: "Leon, the professional is a movie".to_string(), - model: "state-spaces/mamba-130m".to_string(), - max_tokens: 512, - temperature: Some(0.0), - random_seed: 42, - repeat_last_n: 64, - repeat_penalty: 1.1, - sampled_nodes: vec![pk], - top_p: Some(1.0), - top_k: 10, - }) - .unwrap(), - ) + .send(TextRequest { + request_id: 0, + prompt: "Leon, the professional is a movie".to_string(), + model: "mamba_370m".to_string(), + max_tokens: 512, + temperature: Some(0.0), + random_seed: 42, + repeat_last_n: 64, + repeat_penalty: 1.1, + sampled_nodes: vec![pk], + top_p: Some(1.0), + top_k: 10, + }) .await .expect("Failed to send request"); diff --git a/atoma-inference/src/model_thread.rs b/atoma-inference/src/model_thread.rs index 9e608b37..0d2235fa 100644 --- a/atoma-inference/src/model_thread.rs +++ b/atoma-inference/src/model_thread.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, sync::mpsc}; +use std::{collections::HashMap, fmt::Debug, sync::mpsc}; use ed25519_consensus::VerificationKey as PublicKey; use futures::stream::FuturesUnordered; @@ -8,12 +8,16 @@ use tracing::{debug, error, info, warn}; use crate::{ apis::ApiError, - models::{config::ModelsConfig, ModelError, ModelId, ModelTrait}, + models::{config::ModelsConfig, ModelError, ModelId, ModelTrait, Request, Response}, }; -pub struct ModelThreadCommand { - request: serde_json::Value, - response_sender: oneshot::Sender, +pub struct ModelThreadCommand +where + Req: Request, + Resp: Response, +{ + request: Req, + response_sender: oneshot::Sender, } #[derive(Debug, Error)] @@ -40,26 +44,41 @@ impl From for ModelThreadError { } } -pub struct ModelThreadHandle { - sender: mpsc::Sender, +pub struct ModelThreadHandle +where + Req: Request, + Resp: Response, +{ + sender: mpsc::Sender>, join_handle: std::thread::JoinHandle>, } -impl ModelThreadHandle { +impl ModelThreadHandle +where + Req: Request, + Resp: Response, +{ pub fn stop(self) { drop(self.sender); self.join_handle.join().ok(); } } -pub struct ModelThread { +pub struct ModelThread +where + M: ModelTrait, + Req: Request, + Resp: Response, +{ model: M, - receiver: mpsc::Receiver, + receiver: mpsc::Receiver>, } -impl ModelThread +impl ModelThread where - M: ModelTrait, + M: ModelTrait, + Req: Request, + Resp: Response, { pub fn run(mut self, _public_key: PublicKey) -> Result<(), ModelThreadError> { debug!("Start Model thread"); @@ -76,12 +95,12 @@ where // continue; // } - let model_input = serde_json::from_value(request).unwrap(); + let model_input = request.into_model_input(); let model_output = self .model .run(model_input) .map_err(ModelThreadError::ModelError)?; - let response = serde_json::to_value(model_output)?; + let response = Response::from_model_output(model_output); response_sender.send(response).ok(); } @@ -89,35 +108,45 @@ where } } -pub struct ModelThreadDispatcher { - model_senders: HashMap>, - pub(crate) responses: FuturesUnordered>, +pub struct ModelThreadDispatcher +where + Req: Request, + Resp: Response, +{ + model_senders: HashMap>>, + pub(crate) responses: FuturesUnordered>, } -impl ModelThreadDispatcher { +impl ModelThreadDispatcher +where + Req: Clone + Request, + Resp: Response, +{ pub(crate) fn start( config: ModelsConfig, public_key: PublicKey, - ) -> Result<(Self, Vec), ModelThreadError> + ) -> Result<(Self, Vec>), ModelThreadError> where - M: ModelTrait, // + Send + 'static, + M: ModelTrait + Send + 'static, { let mut handles = Vec::new(); let mut model_senders = HashMap::new(); + let api_key = config.api_key(); let cache_dir = config.cache_dir(); for model_config in config.models() { info!("Spawning new thread for model: {}", model_config.model_id()); + let model_api_key = api_key.clone(); let model_cache_dir = cache_dir.clone(); - let (model_sender, model_receiver) = mpsc::channel::(); + let (model_sender, model_receiver) = mpsc::channel::>(); let model_name = model_config.model_id().clone(); model_senders.insert(model_name.clone(), model_sender.clone()); let join_handle = std::thread::spawn(move || { info!("Fetching files for model: {model_name}"); - let load_data = M::fetch(model_cache_dir, model_config)?; + let load_data = M::fetch(model_api_key, model_cache_dir, model_config)?; let model = M::load(load_data)?; let model_thread = ModelThread { @@ -148,12 +177,12 @@ impl ModelThreadDispatcher { Ok((model_dispatcher, handles)) } - fn send(&self, command: ModelThreadCommand) { + fn send(&self, command: ModelThreadCommand) { let request = command.request.clone(); - let model_id = request.get("model").unwrap().as_str().unwrap().to_string(); - println!("model_id {model_id}"); + let model_id = request.requested_model(); + + info!("model_id {model_id}"); - println!("{:?}", self.model_senders); let sender = self .model_senders .get(&model_id) @@ -165,8 +194,12 @@ impl ModelThreadDispatcher { } } -impl ModelThreadDispatcher { - pub(crate) fn run_inference(&self, request: serde_json::Value) { +impl ModelThreadDispatcher +where + Req: Clone + Debug + Request, + Resp: Debug + Response, +{ + pub(crate) fn run_inference(&self, request: Req) { let (sender, receiver) = oneshot::channel(); self.send(ModelThreadCommand { request, diff --git a/atoma-inference/src/models/candle/falcon.rs b/atoma-inference/src/models/candle/falcon.rs index cb816566..f57db91e 100644 --- a/atoma-inference/src/models/candle/falcon.rs +++ b/atoma-inference/src/models/candle/falcon.rs @@ -50,23 +50,25 @@ impl ModelTrait for FalconModel { type Output = String; type LoadData = LlmLoadData; - fn fetch(cache_dir: PathBuf, config: ModelConfig) -> Result { + fn fetch( + api_key: String, + cache_dir: PathBuf, + config: ModelConfig, + ) -> Result { let device = device(config.device_id())?; let dtype = DType::from_str(&config.dtype())?; - let api_key = config.api_key(); - let api = ApiBuilder::new() .with_progress(true) .with_token(Some(api_key)) .with_cache_dir(cache_dir) .build()?; - let repo = api.repo(Repo::with_revision( - config.model_id(), - RepoType::Model, - config.revision(), - )); + let model_type = ModelType::from_str(&config.model_id())?; + let repo_id = model_type.repo().to_string(); + let revision = model_type.default_revision().to_string(); + + let repo = api.repo(Repo::with_revision(repo_id, RepoType::Model, revision)); let config_file_path = repo.get("config.json")?; let tokenizer_file_path = repo.get("tokenizer.json")?; diff --git a/atoma-inference/src/models/candle/llama.rs b/atoma-inference/src/models/candle/llama.rs index 1fe28c4d..bc2d45ad 100644 --- a/atoma-inference/src/models/candle/llama.rs +++ b/atoma-inference/src/models/candle/llama.rs @@ -46,29 +46,33 @@ impl ModelTrait for LlamaModel { type Output = String; type LoadData = LlmLoadData; - fn fetch(cache_dir: PathBuf, config: ModelConfig) -> Result { + fn fetch( + api_key: String, + cache_dir: PathBuf, + config: ModelConfig, + ) -> Result { let device = device(config.device_id())?; let dtype = DType::from_str(&config.dtype())?; let api = ApiBuilder::new() .with_progress(true) - .with_token(Some(config.api_key())) + .with_token(Some(api_key)) .with_cache_dir(cache_dir) .build()?; - let api = api.repo(Repo::with_revision( - config.model_id(), - RepoType::Model, - config.revision(), - )); - let config_file_path = api.get("tokenizer.json")?; - let tokenizer_file_path = api.get("config.json")?; + let model_type = ModelType::from_str(&config.model_id())?; + let repo_id = model_type.repo().to_string(); + let revision = model_type.default_revision().to_string(); + + let repo = api.repo(Repo::with_revision(repo_id, RepoType::Model, revision)); + let config_file_path = repo.get("tokenizer.json")?; + let tokenizer_file_path = repo.get("config.json")?; let model_weights_file_paths = if &config.model_id() == "TinyLlama/TinyLlama-1.1B-Chat-v1.0" { - vec![api.get("model.safetensors")?] + vec![repo.get("model.safetensors")?] } else { - hub_load_safetensors(&api, "model.safetensors.index.json")? + hub_load_safetensors(&repo, "model.safetensors.index.json")? }; let mut file_paths = Vec::with_capacity(2 + model_weights_file_paths.len()); diff --git a/atoma-inference/src/models/candle/mamba.rs b/atoma-inference/src/models/candle/mamba.rs index 3daceeb4..c0bc890c 100644 --- a/atoma-inference/src/models/candle/mamba.rs +++ b/atoma-inference/src/models/candle/mamba.rs @@ -56,9 +56,11 @@ impl ModelTrait for MambaModel { type Output = String; type LoadData = LlmLoadData; - fn fetch(cache_dir: PathBuf, config: ModelConfig) -> Result { - let api_key = config.api_key(); - + fn fetch( + api_key: String, + cache_dir: PathBuf, + config: ModelConfig, + ) -> Result { let device = device(config.device_id())?; let dtype = DType::from_str(&config.dtype())?; @@ -68,11 +70,11 @@ impl ModelTrait for MambaModel { .with_cache_dir(cache_dir) .build()?; - let repo = api.repo(Repo::with_revision( - config.model_id(), - RepoType::Model, - config.revision(), - )); + let model_type = ModelType::from_str(&config.model_id())?; + let repo_id = model_type.repo().to_string(); + let revision = model_type.default_revision().to_string(); + + let repo = api.repo(Repo::with_revision(repo_id, RepoType::Model, revision)); let config_file_path = repo.get("config.json")?; let tokenizer_file_path = api diff --git a/atoma-inference/src/models/candle/stable_diffusion.rs b/atoma-inference/src/models/candle/stable_diffusion.rs index 12480fa1..825d374a 100644 --- a/atoma-inference/src/models/candle/stable_diffusion.rs +++ b/atoma-inference/src/models/candle/stable_diffusion.rs @@ -76,7 +76,11 @@ impl ModelTrait for StableDiffusion { type Output = Vec<(Vec, usize, usize)>; type LoadData = StableDiffusionLoadData; - fn fetch(cache_dir: PathBuf, config: ModelConfig) -> Result { + fn fetch( + api_key: String, + cache_dir: PathBuf, + config: ModelConfig, + ) -> Result { let device = device(config.device_id())?; let dtype = DType::from_str(&config.dtype())?; let model_type = ModelType::from_str(&config.model_id())?; @@ -85,7 +89,6 @@ impl ModelTrait for StableDiffusion { _ => vec![true], }; - let api_key = config.api_key(); let use_f16 = config.dtype() == "f16"; let vae_weights_file_path = ModelFile::Vae.get( diff --git a/atoma-inference/src/models/config.rs b/atoma-inference/src/models/config.rs index 0ee2dfcd..52fa2e77 100644 --- a/atoma-inference/src/models/config.rs +++ b/atoma-inference/src/models/config.rs @@ -3,6 +3,7 @@ use std::path::PathBuf; use config::Config; use dotenv::dotenv; use serde::{Deserialize, Serialize}; +use tracing::error; use crate::models::ModelId; @@ -10,7 +11,6 @@ type Revision = String; #[derive(Clone, Debug, Deserialize, Serialize)] pub struct ModelConfig { - api_key: String, device_id: usize, dtype: String, model_id: ModelId, @@ -21,7 +21,6 @@ pub struct ModelConfig { impl ModelConfig { pub fn new( - api_key: String, model_id: ModelId, dtype: String, revision: Revision, @@ -30,7 +29,6 @@ impl ModelConfig { sliced_attention_size: Option, ) -> Self { Self { - api_key, dtype, model_id, revision, @@ -40,10 +38,6 @@ impl ModelConfig { } } - pub fn api_key(&self) -> String { - self.api_key.clone() - } - pub fn dtype(&self) -> String { self.dtype.clone() } @@ -71,6 +65,7 @@ impl ModelConfig { #[derive(Debug, Deserialize, Serialize)] pub struct ModelsConfig { + api_key: String, cache_dir: PathBuf, flush_storage: bool, models: Vec, @@ -79,12 +74,14 @@ pub struct ModelsConfig { impl ModelsConfig { pub fn new( + api_key: String, cache_dir: PathBuf, flush_storage: bool, models: Vec, tracing: bool, ) -> Self { Self { + api_key, cache_dir, flush_storage, models, @@ -92,6 +89,10 @@ impl ModelsConfig { } } + pub fn api_key(&self) -> String { + self.api_key.clone() + } + pub fn cache_dir(&self) -> PathBuf { self.cache_dir.clone() } @@ -114,6 +115,9 @@ impl ModelsConfig { )); let config = builder .build() + .map_err(|e| { + error!("{:?}", e); + }) .expect("Failed to generate inference configuration file"); config .try_deserialize::() @@ -123,6 +127,10 @@ impl ModelsConfig { pub fn from_env_file() -> Self { dotenv().ok(); + let api_key = std::env::var("API_KEY") + .unwrap_or_default() + .parse() + .unwrap(); let cache_dir = std::env::var("CACHE_DIR") .unwrap_or_default() .parse() @@ -141,6 +149,7 @@ impl ModelsConfig { .unwrap(); Self { + api_key, cache_dir, flush_storage, models, @@ -156,10 +165,10 @@ pub mod tests { #[test] fn test_config() { let config = ModelsConfig::new( + "my_key".to_string(), "/".to_string().into(), true, vec![ModelConfig::new( - "my_key".to_string(), "F16".to_string(), "Llama2_7b".to_string(), "".to_string(), diff --git a/atoma-inference/src/models/mod.rs b/atoma-inference/src/models/mod.rs index 4dcfbbf5..d17f6085 100644 --- a/atoma-inference/src/models/mod.rs +++ b/atoma-inference/src/models/mod.rs @@ -2,7 +2,6 @@ use std::path::PathBuf; use ::candle::{DTypeParseError, Error as CandleError}; use ed25519_consensus::VerificationKey as PublicKey; -use serde::{de::DeserializeOwned, Serialize}; use thiserror::Error; use self::{config::ModelConfig, types::ModelType}; @@ -15,11 +14,15 @@ pub mod types; pub type ModelId = String; pub trait ModelTrait { - type Input: DeserializeOwned; - type Output: Serialize; + type Input; + type Output; type LoadData; - fn fetch(cache_dir: PathBuf, config: ModelConfig) -> Result; + fn fetch( + api_key: String, + cache_dir: PathBuf, + config: ModelConfig, + ) -> Result; fn load(load_data: Self::LoadData) -> Result where Self: Sized; diff --git a/atoma-inference/src/models/types.rs b/atoma-inference/src/models/types.rs index 256ba56c..0905dc33 100644 --- a/atoma-inference/src/models/types.rs +++ b/atoma-inference/src/models/types.rs @@ -65,7 +65,7 @@ impl FromStr for ModelType { "stable_diffusion_xl" => Ok(Self::StableDiffusionXl), "stable_diffusion_turbo" => Ok(Self::StableDiffusionTurbo), _ => Err(ModelError::InvalidModelType( - "Invalid string model type descryption".to_string() + "Invalid string model type descryption".to_string(), )), } } diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index e876fc46..599a4bd4 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -1,6 +1,7 @@ use candle::Error as CandleError; use ed25519_consensus::{SigningKey as PrivateKey, VerificationKey as PublicKey}; use futures::StreamExt; +use std::fmt::Debug; use std::{io, path::PathBuf, time::Instant}; use tokio::sync::mpsc::{Receiver, Sender}; use tracing::{error, info}; @@ -10,29 +11,37 @@ use thiserror::Error; use crate::{ apis::ApiError, model_thread::{ModelThreadDispatcher, ModelThreadError, ModelThreadHandle}, - models::{config::ModelsConfig, ModelTrait}, + models::{config::ModelsConfig, ModelTrait, Request, Response}, }; -pub struct ModelService { - model_thread_handle: Vec, - dispatcher: ModelThreadDispatcher, +pub struct ModelService +where + Req: Request, + Resp: Response, +{ + model_thread_handle: Vec>, + dispatcher: ModelThreadDispatcher, start_time: Instant, flush_storage: bool, public_key: PublicKey, cache_dir: PathBuf, - request_receiver: Receiver, - response_sender: Sender, + request_receiver: Receiver, + response_sender: Sender, } -impl ModelService { +impl ModelService +where + Req: Clone + Debug + Request, + Resp: Debug + Response, +{ pub fn start( model_config: ModelsConfig, private_key: PrivateKey, - request_receiver: Receiver, - response_sender: Sender, + request_receiver: Receiver, + response_sender: Sender, ) -> Result where - M: ModelTrait + Send + 'static, + M: ModelTrait + Send + 'static, { let public_key = private_key.verification_key(); @@ -56,7 +65,7 @@ impl ModelService { }) } - pub async fn run(&mut self) -> Result { + pub async fn run(&mut self) -> Result { loop { tokio::select! { message = self.request_receiver.recv() => { @@ -86,7 +95,11 @@ impl ModelService { } } -impl ModelService { +impl ModelService +where + Req: Request, + Resp: Response, +{ pub async fn stop(mut self) { info!( "Stopping Inference Service, running time: {:?}", @@ -145,7 +158,10 @@ mod tests { use std::io::Write; use toml::{toml, Value}; - use crate::models::{config::ModelConfig, Request, Response}; + use crate::models::{ + config::ModelConfig, + Request, Response, + }; use super::*; @@ -181,7 +197,7 @@ mod tests { type Output = (); type LoadData = (); - fn fetch(_: PathBuf, _: ModelConfig) -> Result<(), crate::models::ModelError> { + fn fetch(_: String, _: PathBuf, _: ModelConfig) -> Result<(), crate::models::ModelError> { Ok(()) } @@ -219,8 +235,8 @@ mod tests { file.write_all(toml_string.as_bytes()) .expect("Failed to write to file"); - let (_, req_receiver) = tokio::sync::mpsc::channel::(1); - let (resp_sender, _) = tokio::sync::mpsc::channel::(1); + let (_, req_receiver) = tokio::sync::mpsc::channel::<()>(1); + let (resp_sender, _) = tokio::sync::mpsc::channel::<()>(1); let config = ModelsConfig::from_file_path(CONFIG_FILE_PATH.parse().unwrap()); From 69b268629118d615ea09dbd8a92dd305efaf79a3 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Fri, 5 Apr 2024 23:29:52 +0100 Subject: [PATCH 09/21] resolve tests --- atoma-inference/src/models/config.rs | 2 +- atoma-inference/src/service.rs | 7 ++----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/atoma-inference/src/models/config.rs b/atoma-inference/src/models/config.rs index 52fa2e77..98460377 100644 --- a/atoma-inference/src/models/config.rs +++ b/atoma-inference/src/models/config.rs @@ -180,7 +180,7 @@ pub mod tests { ); let toml_str = toml::to_string(&config).unwrap(); - let should_be_toml_str = "cache_dir = \"/\"\nflush_storage = true\ntracing = true\n\n[[models]]\napi_key = \"my_key\"\ndevice_id = 0\ndtype = \"Llama2_7b\"\nmodel_id = \"F16\"\nrevision = \"\"\nuse_flash_attention = true\nsliced_attention_size = 0\n"; + let should_be_toml_str = "api_key = \"my_key\"\ncache_dir = \"/\"\nflush_storage = true\ntracing = true\n\n[[models]]\ndevice_id = 0\ndtype = \"Llama2_7b\"\nmodel_id = \"F16\"\nrevision = \"\"\nuse_flash_attention = true\nsliced_attention_size = 0\n"; assert_eq!(toml_str, should_be_toml_str); } } diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index 599a4bd4..53603e62 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -158,10 +158,7 @@ mod tests { use std::io::Write; use toml::{toml, Value}; - use crate::models::{ - config::ModelConfig, - Request, Response, - }; + use crate::models::{config::ModelConfig, Request, Response}; use super::*; @@ -222,7 +219,7 @@ mod tests { let config_data = Value::Table(toml! { api_key = "your_api_key" - models = [["Mamba370m", 0, "f16", "", "", true, 0]] + models = [[0, "f32", "mamba_370m", "", false, 0]] cache_dir = "./cache_dir/" tokenizer_file_path = "./tokenizer_file_path/" flush_storage = true From 27919b2bab3a5a36c3f5edf52b829cad1e58ce6a Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Fri, 5 Apr 2024 23:47:51 +0100 Subject: [PATCH 10/21] resolve tests --- atoma-inference/src/main.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index de8f7e3f..ec11424b 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -3,7 +3,7 @@ use std::time::Duration; use ed25519_consensus::SigningKey as PrivateKey; use inference::{ models::{ - candle::mamba::MambaModel, + candle::{llama::LlamaModel, mamba::MambaModel}, config::ModelsConfig, types::{TextRequest, TextResponse}, }, @@ -26,7 +26,7 @@ async fn main() -> Result<(), ModelServiceError> { let private_key = PrivateKey::from(private_key_bytes); let mut service = - ModelService::start::(model_config, private_key, req_receiver, resp_sender) + ModelService::start::(model_config, private_key, req_receiver, resp_sender) .expect("Failed to start inference service"); let pk = service.public_key(); From b25d7b356430003ff6377290c88171056b40fc6e Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Sat, 6 Apr 2024 00:06:18 +0100 Subject: [PATCH 11/21] clean up code and resolve minor bug with llama --- atoma-inference/src/main.rs | 4 ++-- atoma-inference/src/model_thread.rs | 3 +-- atoma-inference/src/models/candle/falcon.rs | 14 +++++--------- atoma-inference/src/models/candle/llama.rs | 8 ++++---- atoma-inference/src/models/candle/mamba.rs | 3 +-- atoma-inference/src/models/config.rs | 4 ---- 6 files changed, 13 insertions(+), 23 deletions(-) diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index ec11424b..758273d3 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -3,7 +3,7 @@ use std::time::Duration; use ed25519_consensus::SigningKey as PrivateKey; use inference::{ models::{ - candle::{llama::LlamaModel, mamba::MambaModel}, + candle::llama::LlamaModel, config::ModelsConfig, types::{TextRequest, TextResponse}, }, @@ -42,7 +42,7 @@ async fn main() -> Result<(), ModelServiceError> { .send(TextRequest { request_id: 0, prompt: "Leon, the professional is a movie".to_string(), - model: "mamba_370m".to_string(), + model: "llama_tiny_llama_1_1b_chat".to_string(), max_tokens: 512, temperature: Some(0.0), random_seed: 42, diff --git a/atoma-inference/src/model_thread.rs b/atoma-inference/src/model_thread.rs index 0d2235fa..35439a4e 100644 --- a/atoma-inference/src/model_thread.rs +++ b/atoma-inference/src/model_thread.rs @@ -98,8 +98,7 @@ where let model_input = request.into_model_input(); let model_output = self .model - .run(model_input) - .map_err(ModelThreadError::ModelError)?; + .run(model_input)?; let response = Response::from_model_output(model_output); response_sender.send(response).ok(); } diff --git a/atoma-inference/src/models/candle/falcon.rs b/atoma-inference/src/models/candle/falcon.rs index f57db91e..f953c017 100644 --- a/atoma-inference/src/models/candle/falcon.rs +++ b/atoma-inference/src/models/candle/falcon.rs @@ -99,11 +99,10 @@ impl ModelTrait for FalconModel { let tokenizer_filename = load_data.file_paths[1].clone(); let weights_filenames = load_data.file_paths[2..].to_vec(); - let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(ModelError::BoxedError)?; + let tokenizer = Tokenizer::from_file(tokenizer_filename)?; let config: Config = - serde_json::from_slice(&std::fs::read(config_filename).map_err(ModelError::IoError)?) - .map_err(ModelError::DeserializeError)?; + serde_json::from_slice(&std::fs::read(config_filename)?)?; config.validate()?; if load_data.dtype != DType::BF16 || load_data.dtype != DType::F32 { @@ -150,8 +149,7 @@ impl ModelTrait for FalconModel { info!("Running inference on prompt: {:?}", prompt); let mut tokens = self .tokenizer - .encode(prompt, true) - .map_err(ModelError::BoxedError)? + .encode(prompt, true)? .get_ids() .to_vec(); @@ -184,8 +182,7 @@ impl ModelTrait for FalconModel { output.push_str( &self .tokenizer - .decode(&[next_token], true) - .map_err(ModelError::BoxedError)?, + .decode(&[next_token], true)?, ); } let dt = start_gen.elapsed(); @@ -194,8 +191,7 @@ impl ModelTrait for FalconModel { "{max_tokens} tokens generated ({} token/s)\n----\n{}\n----", max_tokens as f64 / dt.as_secs_f64(), self.tokenizer - .decode(&new_tokens, true) - .map_err(ModelError::BoxedError)?, + .decode(&new_tokens, true)?, ); Ok(output) diff --git a/atoma-inference/src/models/candle/llama.rs b/atoma-inference/src/models/candle/llama.rs index bc2d45ad..4233b514 100644 --- a/atoma-inference/src/models/candle/llama.rs +++ b/atoma-inference/src/models/candle/llama.rs @@ -64,11 +64,11 @@ impl ModelTrait for LlamaModel { let repo_id = model_type.repo().to_string(); let revision = model_type.default_revision().to_string(); - let repo = api.repo(Repo::with_revision(repo_id, RepoType::Model, revision)); - let config_file_path = repo.get("tokenizer.json")?; - let tokenizer_file_path = repo.get("config.json")?; + let repo = api.repo(Repo::with_revision(repo_id.clone(), RepoType::Model, revision)); + let config_file_path = repo.get("config.json")?; + let tokenizer_file_path = repo.get("tokenizer.json")?; - let model_weights_file_paths = if &config.model_id() == "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + let model_weights_file_paths = if &repo_id == "TinyLlama/TinyLlama-1.1B-Chat-v1.0" { vec![repo.get("model.safetensors")?] } else { diff --git a/atoma-inference/src/models/candle/mamba.rs b/atoma-inference/src/models/candle/mamba.rs index c0bc890c..0c278019 100644 --- a/atoma-inference/src/models/candle/mamba.rs +++ b/atoma-inference/src/models/candle/mamba.rs @@ -110,8 +110,7 @@ impl ModelTrait for MambaModel { let tokenizer = Tokenizer::from_file(tokenizer_filename)?; let config: Config = - serde_json::from_slice(&std::fs::read(config_filename).map_err(ModelError::IoError)?) - .map_err(ModelError::DeserializeError)?; + serde_json::from_slice(&std::fs::read(config_filename)?)?; info!("Loading model weights.."); let var_builder = unsafe { diff --git a/atoma-inference/src/models/config.rs b/atoma-inference/src/models/config.rs index 98460377..a50991cb 100644 --- a/atoma-inference/src/models/config.rs +++ b/atoma-inference/src/models/config.rs @@ -3,7 +3,6 @@ use std::path::PathBuf; use config::Config; use dotenv::dotenv; use serde::{Deserialize, Serialize}; -use tracing::error; use crate::models::ModelId; @@ -115,9 +114,6 @@ impl ModelsConfig { )); let config = builder .build() - .map_err(|e| { - error!("{:?}", e); - }) .expect("Failed to generate inference configuration file"); config .try_deserialize::() From 9315a834dad14ac9b4cd82c2df407b6cebd47088 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Sat, 6 Apr 2024 00:07:00 +0100 Subject: [PATCH 12/21] clean up code and resolve minor bug with llama --- atoma-inference/src/main.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index 758273d3..f7dcbca7 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -3,7 +3,7 @@ use std::time::Duration; use ed25519_consensus::SigningKey as PrivateKey; use inference::{ models::{ - candle::llama::LlamaModel, + candle::{falcon::FalconModel, llama::LlamaModel}, config::ModelsConfig, types::{TextRequest, TextResponse}, }, @@ -26,7 +26,7 @@ async fn main() -> Result<(), ModelServiceError> { let private_key = PrivateKey::from(private_key_bytes); let mut service = - ModelService::start::(model_config, private_key, req_receiver, resp_sender) + ModelService::start::(model_config, private_key, req_receiver, resp_sender) .expect("Failed to start inference service"); let pk = service.public_key(); From b7115aeefc7a74bdb0995eb0e607074b22e5f559 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Sat, 6 Apr 2024 09:07:06 +0100 Subject: [PATCH 13/21] minor changes --- atoma-inference/src/main.rs | 4 +-- atoma-inference/src/model_thread.rs | 15 ++++----- atoma-inference/src/models/candle/falcon.rs | 37 +++++++-------------- atoma-inference/src/models/candle/llama.rs | 9 +++-- atoma-inference/src/models/candle/mamba.rs | 3 +- 5 files changed, 27 insertions(+), 41 deletions(-) diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index f7dcbca7..db1eb100 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -3,7 +3,7 @@ use std::time::Duration; use ed25519_consensus::SigningKey as PrivateKey; use inference::{ models::{ - candle::{falcon::FalconModel, llama::LlamaModel}, + candle::falcon::FalconModel, config::ModelsConfig, types::{TextRequest, TextResponse}, }, @@ -36,7 +36,7 @@ async fn main() -> Result<(), ModelServiceError> { Ok::<(), ModelServiceError>(()) }); - tokio::time::sleep(Duration::from_millis(5000)).await; + tokio::time::sleep(Duration::from_millis(50000000)).await; req_sender .send(TextRequest { diff --git a/atoma-inference/src/model_thread.rs b/atoma-inference/src/model_thread.rs index 35439a4e..288ed54c 100644 --- a/atoma-inference/src/model_thread.rs +++ b/atoma-inference/src/model_thread.rs @@ -80,7 +80,7 @@ where Req: Request, Resp: Response, { - pub fn run(mut self, _public_key: PublicKey) -> Result<(), ModelThreadError> { + pub fn run(mut self, public_key: PublicKey) -> Result<(), ModelThreadError> { debug!("Start Model thread"); while let Ok(command) = self.receiver.recv() { @@ -89,16 +89,13 @@ where response_sender, } = command; - // TODO: Implement node authorization - // if !request.is_node_authorized(&public_key) { - // error!("Current node, with verification key = {:?} is not authorized to run request with id = {}", public_key, request.request_id()); - // continue; - // } + if !request.is_node_authorized(&public_key) { + error!("Current node, with verification key = {:?} is not authorized to run request with id = {}", public_key, request.request_id()); + continue; + } let model_input = request.into_model_input(); - let model_output = self - .model - .run(model_input)?; + let model_output = self.model.run(model_input)?; let response = Response::from_model_output(model_output); response_sender.send(response).ok(); } diff --git a/atoma-inference/src/models/candle/falcon.rs b/atoma-inference/src/models/candle/falcon.rs index f953c017..420ca7d3 100644 --- a/atoma-inference/src/models/candle/falcon.rs +++ b/atoma-inference/src/models/candle/falcon.rs @@ -12,9 +12,7 @@ use tokenizers::Tokenizer; use tracing::{debug, info}; use crate::models::{ - config::ModelConfig, - types::{LlmLoadData, ModelType, TextModelInput}, - ModelError, ModelTrait, + candle::hub_load_safetensors, config::ModelConfig, types::{LlmLoadData, ModelType, TextModelInput}, ModelError, ModelTrait }; use super::device; @@ -68,20 +66,19 @@ impl ModelTrait for FalconModel { let repo_id = model_type.repo().to_string(); let revision = model_type.default_revision().to_string(); + info!("{repo_id} <> {revision}"); + let repo = api.repo(Repo::with_revision(repo_id, RepoType::Model, revision)); - let config_file_path = repo.get("config.json")?; - let tokenizer_file_path = repo.get("tokenizer.json")?; - let model_weights_file_path = repo.get("model.safetensors")?; + let mut file_paths = vec![]; + file_paths.push(repo.get("config.json")?); + file_paths.push(repo.get("tokenizer.json")?); + file_paths.extend(hub_load_safetensors(&repo, "model.safetensors.index.json")?); Ok(Self::LoadData { device, dtype, - file_paths: vec![ - config_file_path, - tokenizer_file_path, - model_weights_file_path, - ], + file_paths, model_type: ModelType::from_str(&config.model_id())?, use_flash_attention: config.use_flash_attention(), }) @@ -101,8 +98,7 @@ impl ModelTrait for FalconModel { let tokenizer = Tokenizer::from_file(tokenizer_filename)?; - let config: Config = - serde_json::from_slice(&std::fs::read(config_filename)?)?; + let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?; config.validate()?; if load_data.dtype != DType::BF16 || load_data.dtype != DType::F32 { @@ -147,11 +143,7 @@ impl ModelTrait for FalconModel { let mut logits_processor = LogitsProcessor::new(random_seed, Some(temperature), Some(top_p)); info!("Running inference on prompt: {:?}", prompt); - let mut tokens = self - .tokenizer - .encode(prompt, true)? - .get_ids() - .to_vec(); + let mut tokens = self.tokenizer.encode(prompt, true)?.get_ids().to_vec(); let mut new_tokens = vec![]; let mut output = String::new(); @@ -179,19 +171,14 @@ impl ModelTrait for FalconModel { tokens.push(next_token); new_tokens.push(next_token); debug!("> {:?}", start_gen); - output.push_str( - &self - .tokenizer - .decode(&[next_token], true)?, - ); + output.push_str(&self.tokenizer.decode(&[next_token], true)?); } let dt = start_gen.elapsed(); info!( "{max_tokens} tokens generated ({} token/s)\n----\n{}\n----", max_tokens as f64 / dt.as_secs_f64(), - self.tokenizer - .decode(&new_tokens, true)?, + self.tokenizer.decode(&new_tokens, true)?, ); Ok(output) diff --git a/atoma-inference/src/models/candle/llama.rs b/atoma-inference/src/models/candle/llama.rs index 4233b514..396b0e39 100644 --- a/atoma-inference/src/models/candle/llama.rs +++ b/atoma-inference/src/models/candle/llama.rs @@ -64,12 +64,15 @@ impl ModelTrait for LlamaModel { let repo_id = model_type.repo().to_string(); let revision = model_type.default_revision().to_string(); - let repo = api.repo(Repo::with_revision(repo_id.clone(), RepoType::Model, revision)); + let repo = api.repo(Repo::with_revision( + repo_id.clone(), + RepoType::Model, + revision, + )); let config_file_path = repo.get("config.json")?; let tokenizer_file_path = repo.get("tokenizer.json")?; - let model_weights_file_paths = if &repo_id == "TinyLlama/TinyLlama-1.1B-Chat-v1.0" - { + let model_weights_file_paths = if &repo_id == "TinyLlama/TinyLlama-1.1B-Chat-v1.0" { vec![repo.get("model.safetensors")?] } else { hub_load_safetensors(&repo, "model.safetensors.index.json")? diff --git a/atoma-inference/src/models/candle/mamba.rs b/atoma-inference/src/models/candle/mamba.rs index 0c278019..c77c68f0 100644 --- a/atoma-inference/src/models/candle/mamba.rs +++ b/atoma-inference/src/models/candle/mamba.rs @@ -109,8 +109,7 @@ impl ModelTrait for MambaModel { let tokenizer = Tokenizer::from_file(tokenizer_filename)?; - let config: Config = - serde_json::from_slice(&std::fs::read(config_filename)?)?; + let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?; info!("Loading model weights.."); let var_builder = unsafe { From ceb52cfcb35759ae471f8b23fbf91e666b0c5e21 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Sat, 6 Apr 2024 18:15:32 +0100 Subject: [PATCH 14/21] correct minor bugs --- atoma-inference/src/main.rs | 63 ++++++++++---- atoma-inference/src/model_thread.rs | 67 ++++++++++----- atoma-inference/src/models/candle/falcon.rs | 23 +++-- .../src/models/candle/stable_diffusion.rs | 54 ++++++------ atoma-inference/src/models/types.rs | 85 ++++++++++++++++++- 5 files changed, 214 insertions(+), 78 deletions(-) diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index db1eb100..a0559864 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -3,9 +3,11 @@ use std::time::Duration; use ed25519_consensus::SigningKey as PrivateKey; use inference::{ models::{ - candle::falcon::FalconModel, + candle::stable_diffusion::StableDiffusion, config::ModelsConfig, - types::{TextRequest, TextResponse}, + types::{ + ModelType, StableDiffusionRequest, StableDiffusionResponse, + }, }, service::{ModelService, ModelServiceError}, }; @@ -14,8 +16,9 @@ use inference::{ async fn main() -> Result<(), ModelServiceError> { tracing_subscriber::fmt::init(); - let (req_sender, req_receiver) = tokio::sync::mpsc::channel::(32); - let (resp_sender, mut resp_receiver) = tokio::sync::mpsc::channel::(32); + let (req_sender, req_receiver) = tokio::sync::mpsc::channel::(32); + let (resp_sender, mut resp_receiver) = + tokio::sync::mpsc::channel::(32); let model_config = ModelsConfig::from_file_path("../inference.toml".parse().unwrap()); let private_key_bytes = @@ -25,9 +28,14 @@ async fn main() -> Result<(), ModelServiceError> { .expect("Incorrect private key bytes length"); let private_key = PrivateKey::from(private_key_bytes); - let mut service = - ModelService::start::(model_config, private_key, req_receiver, resp_sender) - .expect("Failed to start inference service"); + let mut service: ModelService = + ModelService::start::( + model_config, + private_key, + req_receiver, + resp_sender, + ) + .expect("Failed to start inference service"); let pk = service.public_key(); @@ -36,21 +44,40 @@ async fn main() -> Result<(), ModelServiceError> { Ok::<(), ModelServiceError>(()) }); - tokio::time::sleep(Duration::from_millis(50000000)).await; + tokio::time::sleep(Duration::from_millis(5_000)).await; + + // req_sender + // .send(TextRequest { + // request_id: 0, + // prompt: "Leon, the professional is a movie".to_string(), + // model: "falcon_7b".to_string(), + // max_tokens: 512, + // temperature: Some(0.0), + // random_seed: 42, + // repeat_last_n: 64, + // repeat_penalty: 1.1, + // sampled_nodes: vec![pk], + // top_p: Some(1.0), + // top_k: 10, + // }) + // .await + // .expect("Failed to send request"); req_sender - .send(TextRequest { + .send(StableDiffusionRequest { request_id: 0, - prompt: "Leon, the professional is a movie".to_string(), - model: "llama_tiny_llama_1_1b_chat".to_string(), - max_tokens: 512, - temperature: Some(0.0), - random_seed: 42, - repeat_last_n: 64, - repeat_penalty: 1.1, + prompt: "A portrait of young Natalie Portman".to_string(), + uncond_prompt: "".to_string(), + height: None, + width: None, + num_samples: 1, + n_steps: None, + model_type: ModelType::StableDiffusionV1_5, + guidance_scale: None, + img2img: None, + img2img_strength: 0.5, + random_seed: Some(42), sampled_nodes: vec![pk], - top_p: Some(1.0), - top_k: 10, }) .await .expect("Failed to send request"); diff --git a/atoma-inference/src/model_thread.rs b/atoma-inference/src/model_thread.rs index 288ed54c..79c7a655 100644 --- a/atoma-inference/src/model_thread.rs +++ b/atoma-inference/src/model_thread.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, fmt::Debug, sync::mpsc}; +use std::{collections::HashMap, fmt::Debug, path::PathBuf, sync::mpsc, thread::JoinHandle}; use ed25519_consensus::VerificationKey as PublicKey; use futures::stream::FuturesUnordered; @@ -8,7 +8,10 @@ use tracing::{debug, error, info, warn}; use crate::{ apis::ApiError, - models::{config::ModelsConfig, ModelError, ModelId, ModelTrait, Request, Response}, + models::{ + config::{ModelConfig, ModelsConfig}, + ModelError, ModelId, ModelTrait, Request, Response, + }, }; pub struct ModelThreadCommand @@ -134,31 +137,19 @@ where for model_config in config.models() { info!("Spawning new thread for model: {}", model_config.model_id()); - let model_api_key = api_key.clone(); - let model_cache_dir = cache_dir.clone(); let (model_sender, model_receiver) = mpsc::channel::>(); let model_name = model_config.model_id().clone(); model_senders.insert(model_name.clone(), model_sender.clone()); - let join_handle = std::thread::spawn(move || { - info!("Fetching files for model: {model_name}"); - let load_data = M::fetch(model_api_key, model_cache_dir, model_config)?; - - let model = M::load(load_data)?; - let model_thread = ModelThread { - model, - receiver: model_receiver, - }; - - if let Err(e) = model_thread.run(public_key) { - error!("Model thread error: {e}"); - if !matches!(e, ModelThreadError::Shutdown(_)) { - panic!("Fatal error occurred: {e}"); - } - } + let join_handle = Self::start_model_thread::( + model_name, + api_key.clone(), + cache_dir.clone(), + model_config, + public_key, + model_receiver, + ); - Ok(()) - }); handles.push(ModelThreadHandle { join_handle, sender: model_sender.clone(), @@ -173,6 +164,38 @@ where Ok((model_dispatcher, handles)) } + fn start_model_thread( + model_name: String, + api_key: String, + cache_dir: PathBuf, + model_config: ModelConfig, + public_key: PublicKey, + model_receiver: mpsc::Receiver>, + ) -> JoinHandle> + where + M: ModelTrait + Send + 'static, + { + std::thread::spawn(move || { + info!("Fetching files for model: {model_name}"); + let load_data = M::fetch(api_key, cache_dir, model_config)?; + + let model = M::load(load_data)?; + let model_thread = ModelThread { + model, + receiver: model_receiver, + }; + + if let Err(e) = model_thread.run(public_key) { + error!("Model thread error: {e}"); + if !matches!(e, ModelThreadError::Shutdown(_)) { + panic!("Fatal error occurred: {e}"); + } + } + + Ok(()) + }) + } + fn send(&self, command: ModelThreadCommand) { let request = command.request.clone(); let model_id = request.requested_model(); diff --git a/atoma-inference/src/models/candle/falcon.rs b/atoma-inference/src/models/candle/falcon.rs index 420ca7d3..b13acc40 100644 --- a/atoma-inference/src/models/candle/falcon.rs +++ b/atoma-inference/src/models/candle/falcon.rs @@ -9,10 +9,13 @@ use candle_transformers::{ }; use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; use tokenizers::Tokenizer; -use tracing::{debug, info}; +use tracing::{debug, error, info}; use crate::models::{ - candle::hub_load_safetensors, config::ModelConfig, types::{LlmLoadData, ModelType, TextModelInput}, ModelError, ModelTrait + candle::hub_load_safetensors, + config::ModelConfig, + types::{LlmLoadData, ModelType, TextModelInput}, + ModelError, ModelTrait, }; use super::device; @@ -68,12 +71,19 @@ impl ModelTrait for FalconModel { info!("{repo_id} <> {revision}"); - let repo = api.repo(Repo::with_revision(repo_id, RepoType::Model, revision)); - let mut file_paths = vec![]; + let repo = api.repo(Repo::new(repo_id.clone(), RepoType::Model)); file_paths.push(repo.get("config.json")?); + + let repo = api.repo(Repo::with_revision(repo_id, RepoType::Model, revision)); file_paths.push(repo.get("tokenizer.json")?); - file_paths.extend(hub_load_safetensors(&repo, "model.safetensors.index.json")?); + + file_paths.extend( + hub_load_safetensors(&repo, "model.safetensors.index.json").map_err(|e| { + error!("{e}"); + e + })?, + ); Ok(Self::LoadData { device, @@ -97,11 +107,10 @@ impl ModelTrait for FalconModel { let weights_filenames = load_data.file_paths[2..].to_vec(); let tokenizer = Tokenizer::from_file(tokenizer_filename)?; - let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?; config.validate()?; - if load_data.dtype != DType::BF16 || load_data.dtype != DType::F32 { + if load_data.dtype != DType::BF16 && load_data.dtype != DType::F32 { panic!("Invalid dtype, it must be either BF16 or F32 precision"); } diff --git a/atoma-inference/src/models/candle/stable_diffusion.rs b/atoma-inference/src/models/candle/stable_diffusion.rs index 825d374a..a6c1bffe 100644 --- a/atoma-inference/src/models/candle/stable_diffusion.rs +++ b/atoma-inference/src/models/candle/stable_diffusion.rs @@ -9,6 +9,7 @@ use candle::{DType, Device, IndexOp, Module, Tensor, D}; use hf_hub::api::sync::ApiBuilder; use serde::Deserialize; use tokenizers::Tokenizer; +use tracing::info; use crate::{ bail, @@ -18,32 +19,32 @@ use crate::{ use super::{convert_to_image, device, save_tensor_to_file}; #[derive(Deserialize)] -pub struct Input { - prompt: String, - uncond_prompt: String, +pub struct StableDiffusionInput { + pub prompt: String, + pub uncond_prompt: String, - height: Option, - width: Option, + pub height: Option, + pub width: Option, /// The number of steps to run the diffusion for. - n_steps: Option, + pub n_steps: Option, /// The number of samples to generate. - num_samples: i64, + pub num_samples: i64, - sd_version: StableDiffusionVersion, + pub model_type: ModelType, - guidance_scale: Option, + pub guidance_scale: Option, - img2img: Option, + pub img2img: Option, /// The strength, indicates how much to transform the initial image. The /// value must be between 0 and 1, a value of 1 discards the initial image /// information. - img2img_strength: f64, + pub img2img_strength: f64, /// The seed to use when generating random samples. - seed: Option, + pub random_seed: Option, } pub struct StableDiffusionLoadData { @@ -72,7 +73,7 @@ pub struct StableDiffusion { } impl ModelTrait for StableDiffusion { - type Input = Input; + type Input = StableDiffusionInput; type Output = Vec<(Vec, usize, usize)>; type LoadData = StableDiffusionLoadData; @@ -236,8 +237,8 @@ impl ModelTrait for StableDiffusion { )))? } - // self.config.height = input.height; - // self.config.width = input.width; + // self.config.height = input.height.unwrap_or(512); + // self.config.width = input.width.unwrap_or(512); let guidance_scale = match input.guidance_scale { Some(guidance_scale) => guidance_scale, @@ -261,7 +262,7 @@ impl ModelTrait for StableDiffusion { }; let scheduler = self.config.build_scheduler(n_steps)?; - if let Some(seed) = input.seed { + if let Some(seed) = input.random_seed { self.device.set_seed(seed)?; } let use_guide_scale = guidance_scale > 1.0; @@ -306,11 +307,12 @@ impl ModelTrait for StableDiffusion { }; let bsize = 1; - let vae_scale = match input.sd_version { - StableDiffusionVersion::V1_5 - | StableDiffusionVersion::V2_1 - | StableDiffusionVersion::Xl => 0.18215, - StableDiffusionVersion::Turbo => 0.13025, + let vae_scale = match input.model_type { + ModelType::StableDiffusionV1_5 + | ModelType::StableDiffusionV2_1 + | ModelType::StableDiffusionXl => 0.18215, + ModelType::StableDiffusionTurbo => 0.13025, + _ => bail!("Invalid stable diffusion model type"), }; let mut res = Vec::new(); @@ -355,6 +357,7 @@ impl ModelTrait for StableDiffusion { latents.clone() }; + info!("FLAG: {:?}", latent_model_input.shape()); let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)?; let noise_pred = @@ -386,15 +389,6 @@ impl ModelTrait for StableDiffusion { } } -#[allow(dead_code)] -#[derive(Clone, Copy, Deserialize)] -enum StableDiffusionVersion { - V1_5, - V2_1, - Xl, - Turbo, -} - impl ModelType { fn unet_file(&self, use_f16: bool) -> &'static str { match self { diff --git a/atoma-inference/src/models/types.rs b/atoma-inference/src/models/types.rs index 0905dc33..7af052be 100644 --- a/atoma-inference/src/models/types.rs +++ b/atoma-inference/src/models/types.rs @@ -6,7 +6,7 @@ use serde::{Deserialize, Serialize}; use crate::models::{ModelId, Request, Response}; -use super::ModelError; +use super::{candle::stable_diffusion::StableDiffusionInput, ModelError}; pub type NodeId = PublicKey; @@ -244,6 +244,89 @@ impl Response for TextResponse { } } +#[derive(Clone, Debug, Deserialize)] +pub struct StableDiffusionRequest { + pub request_id: usize, + pub prompt: String, + pub uncond_prompt: String, + + pub height: Option, + pub width: Option, + + /// The number of steps to run the diffusion for. + pub n_steps: Option, + + /// The number of samples to generate. + pub num_samples: i64, + + pub model_type: ModelType, + + pub guidance_scale: Option, + + pub img2img: Option, + + /// The strength, indicates how much to transform the initial image. The + /// value must be between 0 and 1, a value of 1 discards the initial image + /// information. + pub img2img_strength: f64, + + /// The seed to use when generating random samples. + pub random_seed: Option, + + pub sampled_nodes: Vec, +} + +impl Request for StableDiffusionRequest { + type ModelInput = StableDiffusionInput; + + fn into_model_input(self) -> Self::ModelInput { + Self::ModelInput { + prompt: self.prompt, + uncond_prompt: self.uncond_prompt, + height: self.height, + width: self.width, + n_steps: self.n_steps, + num_samples: self.num_samples, + model_type: self.model_type, + guidance_scale: self.guidance_scale, + img2img: self.img2img, + img2img_strength: self.img2img_strength, + random_seed: self.random_seed, + } + } + + fn is_node_authorized(&self, public_key: &PublicKey) -> bool { + self.sampled_nodes.contains(public_key) + } + + fn request_id(&self) -> usize { + self.request_id + } + + fn requested_model(&self) -> ModelId { + self.model_type.to_string() + } +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct StableDiffusionResponse { + pub output: Vec<(Vec, usize, usize)>, + pub is_success: bool, + pub status: String, +} + +impl Response for StableDiffusionResponse { + type ModelOutput = Vec<(Vec, usize, usize)>; + + fn from_model_output(model_output: Self::ModelOutput) -> Self { + Self { + output: model_output, + is_success: true, + status: "Successful".to_string(), + } + } +} + #[derive(Copy, Clone, Debug, Deserialize, Serialize)] pub enum PrecisionBits { BF16, From b5fb727a95ef6acf802fe9996693d48d8aaae4fb Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Sat, 6 Apr 2024 18:18:17 +0100 Subject: [PATCH 15/21] correct minor bugs --- atoma-inference/src/models/candle/stable_diffusion.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/atoma-inference/src/models/candle/stable_diffusion.rs b/atoma-inference/src/models/candle/stable_diffusion.rs index a6c1bffe..c16db69e 100644 --- a/atoma-inference/src/models/candle/stable_diffusion.rs +++ b/atoma-inference/src/models/candle/stable_diffusion.rs @@ -357,7 +357,6 @@ impl ModelTrait for StableDiffusion { latents.clone() }; - info!("FLAG: {:?}", latent_model_input.shape()); let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)?; let noise_pred = From 0fa06d33e7f7ee9ab845334a72aed68011298c21 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Sat, 6 Apr 2024 19:57:19 +0100 Subject: [PATCH 16/21] small changes --- atoma-inference/src/main.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index a0559864..872907d6 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -72,10 +72,10 @@ async fn main() -> Result<(), ModelServiceError> { width: None, num_samples: 1, n_steps: None, - model_type: ModelType::StableDiffusionV1_5, + model_type: ModelType::StableDiffusionV2_1, guidance_scale: None, img2img: None, - img2img_strength: 0.5, + img2img_strength: 0.8, random_seed: Some(42), sampled_nodes: vec![pk], }) From e0b070764af41eeb5afeb18c5a7e9849cee7f0b8 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Sat, 6 Apr 2024 21:30:27 +0100 Subject: [PATCH 17/21] remove sliced_attention_window from config --- atoma-inference/src/main.rs | 8 ++--- .../src/models/candle/stable_diffusion.rs | 35 ++++++++++++++----- atoma-inference/src/models/config.rs | 10 +----- 3 files changed, 31 insertions(+), 22 deletions(-) diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index 872907d6..c1ef682d 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -5,9 +5,7 @@ use inference::{ models::{ candle::stable_diffusion::StableDiffusion, config::ModelsConfig, - types::{ - ModelType, StableDiffusionRequest, StableDiffusionResponse, - }, + types::{ModelType, StableDiffusionRequest, StableDiffusionResponse}, }, service::{ModelService, ModelServiceError}, }; @@ -66,13 +64,13 @@ async fn main() -> Result<(), ModelServiceError> { req_sender .send(StableDiffusionRequest { request_id: 0, - prompt: "A portrait of young Natalie Portman".to_string(), + prompt: "A depiction of Natalie Portman".to_string(), uncond_prompt: "".to_string(), height: None, width: None, num_samples: 1, n_steps: None, - model_type: ModelType::StableDiffusionV2_1, + model_type: ModelType::StableDiffusionV1_5, guidance_scale: None, img2img: None, img2img_strength: 0.8, diff --git a/atoma-inference/src/models/candle/stable_diffusion.rs b/atoma-inference/src/models/candle/stable_diffusion.rs index c16db69e..306e9375 100644 --- a/atoma-inference/src/models/candle/stable_diffusion.rs +++ b/atoma-inference/src/models/candle/stable_diffusion.rs @@ -9,11 +9,11 @@ use candle::{DType, Device, IndexOp, Module, Tensor, D}; use hf_hub::api::sync::ApiBuilder; use serde::Deserialize; use tokenizers::Tokenizer; -use tracing::info; +use tracing::{debug, info}; use crate::{ bail, - models::{config::ModelConfig, types::ModelType, ModelError, ModelTrait}, + models::{candle::save_image, config::ModelConfig, types::ModelType, ModelError, ModelTrait}, }; use super::{convert_to_image, device, save_tensor_to_file}; @@ -136,7 +136,7 @@ impl ModelTrait for StableDiffusion { device, dtype, model_type, - sliced_attention_size: config.sliced_attention_size(), + sliced_attention_size: None, clip_weights_file_paths, tokenizer_file_paths, vae_weights_file_path, @@ -181,6 +181,7 @@ impl ModelTrait for StableDiffusion { ), // INTEGRITY: we have checked previously if the model type is valid for the family of stable diffusion models }; + info!("Loading text model..."); let text_model = stable_diffusion::build_clip_transformer( &config.clip, load_data.clip_weights_file_paths[0].clone(), @@ -188,6 +189,7 @@ impl ModelTrait for StableDiffusion { load_data.dtype, )?; let text_model_2 = if let Some(clip_config_2) = &config.clip2 { + info!("Loading second text model..."); Some(stable_diffusion::build_clip_transformer( clip_config_2, load_data.clip_weights_file_paths[1].clone(), @@ -198,11 +200,13 @@ impl ModelTrait for StableDiffusion { None }; + info!("Loading variational auto encoder model..."); let vae = config.build_vae( load_data.vae_weights_file_path, &load_data.device, load_data.dtype, )?; + info!("Loading unet model..."); let unet = config.build_unet( load_data.unet_weights_file_path, &load_data.device, @@ -237,8 +241,8 @@ impl ModelTrait for StableDiffusion { )))? } - // self.config.height = input.height.unwrap_or(512); - // self.config.width = input.width.unwrap_or(512); + let height = input.height.unwrap_or(512); + let width = input.width.unwrap_or(512); let guidance_scale = match input.guidance_scale { Some(guidance_scale) => guidance_scale, @@ -271,6 +275,8 @@ impl ModelTrait for StableDiffusion { ModelType::StableDiffusionXl | ModelType::StableDiffusionTurbo => vec![true, false], _ => vec![true], // INTEGRITY: we have checked previously if the model type is valid for the family of stable diffusion models }; + + debug!("Computing text embeddings..."); let text_embeddings = which .iter() .map(|first| { @@ -316,7 +322,7 @@ impl ModelTrait for StableDiffusion { }; let mut res = Vec::new(); - for _ in 0..input.num_samples { + for idx in 0..input.num_samples { let timesteps = scheduler.timesteps(); let latents = match &init_latent_dist { Some(init_latent_dist) => { @@ -336,8 +342,8 @@ impl ModelTrait for StableDiffusion { ( bsize, 4, - input.height.unwrap_or(512) / 8, - input.width.unwrap_or(512) / 8, + height / 8, + width / 8, ), &self.device, )?; @@ -351,6 +357,7 @@ impl ModelTrait for StableDiffusion { if timestep_index < t_start { continue; } + let start_time = std::time::Instant::now(); let latent_model_input = if use_guide_scale { Tensor::cat(&[&latents, &latents], 0)? } else { @@ -359,6 +366,7 @@ impl ModelTrait for StableDiffusion { let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)?; + debug!("Computing noise prediction..."); let noise_pred = self.unet .forward(&latent_model_input, timestep as f64, &text_embeddings)?; @@ -374,13 +382,24 @@ impl ModelTrait for StableDiffusion { }; latents = scheduler.step(&noise_pred, timestep, &latents)?; + let dt = start_time.elapsed().as_secs_f32(); + debug!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt); } + + debug!( + "Generating the final image for sample {}/{}.", + idx + 1, + input.num_samples + ); save_tensor_to_file(&latents, "tensor1")?; let image = self.vae.decode(&(&latents / vae_scale)?)?; save_tensor_to_file(&image, "tensor2")?; let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?; save_tensor_to_file(&image, "tensor3")?; let image = (image.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?.i(0)?; + if idx == input.num_samples - 1 { + save_image(&image, "./image.png").unwrap(); + } save_tensor_to_file(&image, "tensor4")?; res.push(convert_to_image(&image)?); } diff --git a/atoma-inference/src/models/config.rs b/atoma-inference/src/models/config.rs index a50991cb..cc6a4231 100644 --- a/atoma-inference/src/models/config.rs +++ b/atoma-inference/src/models/config.rs @@ -15,7 +15,6 @@ pub struct ModelConfig { model_id: ModelId, revision: Revision, use_flash_attention: bool, - sliced_attention_size: Option, } impl ModelConfig { @@ -25,7 +24,6 @@ impl ModelConfig { revision: Revision, device_id: usize, use_flash_attention: bool, - sliced_attention_size: Option, ) -> Self { Self { dtype, @@ -33,7 +31,6 @@ impl ModelConfig { revision, device_id, use_flash_attention, - sliced_attention_size, } } @@ -56,10 +53,6 @@ impl ModelConfig { pub fn use_flash_attention(&self) -> bool { self.use_flash_attention } - - pub fn sliced_attention_size(&self) -> Option { - self.sliced_attention_size - } } #[derive(Debug, Deserialize, Serialize)] @@ -170,13 +163,12 @@ pub mod tests { "".to_string(), 0, true, - Some(0), )], true, ); let toml_str = toml::to_string(&config).unwrap(); - let should_be_toml_str = "api_key = \"my_key\"\ncache_dir = \"/\"\nflush_storage = true\ntracing = true\n\n[[models]]\ndevice_id = 0\ndtype = \"Llama2_7b\"\nmodel_id = \"F16\"\nrevision = \"\"\nuse_flash_attention = true\nsliced_attention_size = 0\n"; + let should_be_toml_str = "api_key = \"my_key\"\ncache_dir = \"/\"\nflush_storage = true\ntracing = true\n\n[[models]]\ndevice_id = 0\ndtype = \"Llama2_7b\"\nmodel_id = \"F16\"\nrevision = \"\"\nuse_flash_attention = true\n"; assert_eq!(toml_str, should_be_toml_str); } } From 4158a236671f29f16cb8948622e3f622ef78edbf Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Sat, 6 Apr 2024 23:16:52 +0100 Subject: [PATCH 18/21] revert back to serde_json::value --- atoma-inference/src/main.rs | 52 ++-- atoma-inference/src/model_thread.rs | 234 ++++++++++-------- .../src/models/candle/stable_diffusion.rs | 22 +- atoma-inference/src/models/mod.rs | 5 +- atoma-inference/src/models/types.rs | 10 +- atoma-inference/src/service.rs | 57 ++--- 6 files changed, 193 insertions(+), 187 deletions(-) diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index c1ef682d..d2e118f1 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -2,11 +2,7 @@ use std::time::Duration; use ed25519_consensus::SigningKey as PrivateKey; use inference::{ - models::{ - candle::stable_diffusion::StableDiffusion, - config::ModelsConfig, - types::{ModelType, StableDiffusionRequest, StableDiffusionResponse}, - }, + models::{config::ModelsConfig, types::StableDiffusionRequest}, service::{ModelService, ModelServiceError}, }; @@ -14,9 +10,8 @@ use inference::{ async fn main() -> Result<(), ModelServiceError> { tracing_subscriber::fmt::init(); - let (req_sender, req_receiver) = tokio::sync::mpsc::channel::(32); - let (resp_sender, mut resp_receiver) = - tokio::sync::mpsc::channel::(32); + let (req_sender, req_receiver) = tokio::sync::mpsc::channel(32); + let (resp_sender, mut resp_receiver) = tokio::sync::mpsc::channel(32); let model_config = ModelsConfig::from_file_path("../inference.toml".parse().unwrap()); let private_key_bytes = @@ -26,13 +21,7 @@ async fn main() -> Result<(), ModelServiceError> { .expect("Incorrect private key bytes length"); let private_key = PrivateKey::from(private_key_bytes); - let mut service: ModelService = - ModelService::start::( - model_config, - private_key, - req_receiver, - resp_sender, - ) + let mut service = ModelService::start(model_config, private_key, req_receiver, resp_sender) .expect("Failed to start inference service"); let pk = service.public_key(); @@ -62,21 +51,24 @@ async fn main() -> Result<(), ModelServiceError> { // .expect("Failed to send request"); req_sender - .send(StableDiffusionRequest { - request_id: 0, - prompt: "A depiction of Natalie Portman".to_string(), - uncond_prompt: "".to_string(), - height: None, - width: None, - num_samples: 1, - n_steps: None, - model_type: ModelType::StableDiffusionV1_5, - guidance_scale: None, - img2img: None, - img2img_strength: 0.8, - random_seed: Some(42), - sampled_nodes: vec![pk], - }) + .send( + serde_json::to_value(StableDiffusionRequest { + request_id: 0, + prompt: "A depiction of Natalie Portman".to_string(), + uncond_prompt: "".to_string(), + height: None, + width: None, + num_samples: 1, + n_steps: None, + model: "stable_diffusion_v1-5".to_string(), + guidance_scale: None, + img2img: None, + img2img_strength: 0.8, + random_seed: Some(42), + sampled_nodes: vec![pk], + }) + .unwrap(), + ) .await .expect("Failed to send request"); diff --git a/atoma-inference/src/model_thread.rs b/atoma-inference/src/model_thread.rs index 79c7a655..57a07d69 100644 --- a/atoma-inference/src/model_thread.rs +++ b/atoma-inference/src/model_thread.rs @@ -1,4 +1,6 @@ -use std::{collections::HashMap, fmt::Debug, path::PathBuf, sync::mpsc, thread::JoinHandle}; +use std::{ + collections::HashMap, fmt::Debug, path::PathBuf, str::FromStr, sync::mpsc, thread::JoinHandle, +}; use ed25519_consensus::VerificationKey as PublicKey; use futures::stream::FuturesUnordered; @@ -9,18 +11,19 @@ use tracing::{debug, error, info, warn}; use crate::{ apis::ApiError, models::{ + candle::{ + falcon::FalconModel, llama::LlamaModel, mamba::MambaModel, + stable_diffusion::StableDiffusion, + }, config::{ModelConfig, ModelsConfig}, - ModelError, ModelId, ModelTrait, Request, Response, + types::ModelType, + ModelError, ModelId, ModelTrait, }, }; -pub struct ModelThreadCommand -where - Req: Request, - Resp: Response, -{ - request: Req, - response_sender: oneshot::Sender, +pub struct ModelThreadCommand { + request: serde_json::Value, + response_sender: oneshot::Sender, } #[derive(Debug, Error)] @@ -47,43 +50,28 @@ impl From for ModelThreadError { } } -pub struct ModelThreadHandle -where - Req: Request, - Resp: Response, -{ - sender: mpsc::Sender>, +pub struct ModelThreadHandle { + sender: mpsc::Sender, join_handle: std::thread::JoinHandle>, } -impl ModelThreadHandle -where - Req: Request, - Resp: Response, -{ +impl ModelThreadHandle { pub fn stop(self) { drop(self.sender); self.join_handle.join().ok(); } } -pub struct ModelThread -where - M: ModelTrait, - Req: Request, - Resp: Response, -{ +pub struct ModelThread { model: M, - receiver: mpsc::Receiver>, + receiver: mpsc::Receiver, } -impl ModelThread +impl ModelThread where - M: ModelTrait, - Req: Request, - Resp: Response, + M: ModelTrait, { - pub fn run(mut self, public_key: PublicKey) -> Result<(), ModelThreadError> { + pub fn run(mut self, _public_key: PublicKey) -> Result<(), ModelThreadError> { debug!("Start Model thread"); while let Ok(command) = self.receiver.recv() { @@ -92,14 +80,14 @@ where response_sender, } = command; - if !request.is_node_authorized(&public_key) { - error!("Current node, with verification key = {:?} is not authorized to run request with id = {}", public_key, request.request_id()); - continue; - } + // if !request.is_node_authorized(&public_key) { + // error!("Current node, with verification key = {:?} is not authorized to run request with id = {}", public_key, request.request_id()); + // continue; + // } - let model_input = request.into_model_input(); + let model_input = serde_json::from_value(request)?; let model_output = self.model.run(model_input)?; - let response = Response::from_model_output(model_output); + let response = serde_json::to_value(model_output)?; response_sender.send(response).ok(); } @@ -107,27 +95,16 @@ where } } -pub struct ModelThreadDispatcher -where - Req: Request, - Resp: Response, -{ - model_senders: HashMap>>, - pub(crate) responses: FuturesUnordered>, +pub struct ModelThreadDispatcher { + model_senders: HashMap>, + pub(crate) responses: FuturesUnordered>, } -impl ModelThreadDispatcher -where - Req: Clone + Request, - Resp: Response, -{ - pub(crate) fn start( +impl ModelThreadDispatcher { + pub(crate) fn start( config: ModelsConfig, public_key: PublicKey, - ) -> Result<(Self, Vec>), ModelThreadError> - where - M: ModelTrait + Send + 'static, - { + ) -> Result<(Self, Vec), ModelThreadError> { let mut handles = Vec::new(); let mut model_senders = HashMap::new(); @@ -137,11 +114,14 @@ where for model_config in config.models() { info!("Spawning new thread for model: {}", model_config.model_id()); - let (model_sender, model_receiver) = mpsc::channel::>(); let model_name = model_config.model_id().clone(); + let model_type = ModelType::from_str(&model_name)?; + + let (model_sender, model_receiver) = mpsc::channel::(); model_senders.insert(model_name.clone(), model_sender.clone()); - let join_handle = Self::start_model_thread::( + let join_handle = dispatch_model_thread( + model_type, model_name, api_key.clone(), cache_dir.clone(), @@ -164,41 +144,14 @@ where Ok((model_dispatcher, handles)) } - fn start_model_thread( - model_name: String, - api_key: String, - cache_dir: PathBuf, - model_config: ModelConfig, - public_key: PublicKey, - model_receiver: mpsc::Receiver>, - ) -> JoinHandle> - where - M: ModelTrait + Send + 'static, - { - std::thread::spawn(move || { - info!("Fetching files for model: {model_name}"); - let load_data = M::fetch(api_key, cache_dir, model_config)?; - - let model = M::load(load_data)?; - let model_thread = ModelThread { - model, - receiver: model_receiver, - }; - - if let Err(e) = model_thread.run(public_key) { - error!("Model thread error: {e}"); - if !matches!(e, ModelThreadError::Shutdown(_)) { - panic!("Fatal error occurred: {e}"); - } - } - - Ok(()) - }) - } - - fn send(&self, command: ModelThreadCommand) { + fn send(&self, command: ModelThreadCommand) { let request = command.request.clone(); - let model_id = request.requested_model(); + let model_id = if let Some(model_id) = request.get("model") { + model_id.as_str().unwrap().to_string() + } else { + error!("Request malformed: Missing model_id from request"); + return; + }; info!("model_id {model_id}"); @@ -213,12 +166,8 @@ where } } -impl ModelThreadDispatcher -where - Req: Clone + Debug + Request, - Resp: Debug + Response, -{ - pub(crate) fn run_inference(&self, request: Req) { +impl ModelThreadDispatcher { + pub(crate) fn run_inference(&self, request: serde_json::Value) { let (sender, receiver) = oneshot::channel(); self.send(ModelThreadCommand { request, @@ -227,3 +176,94 @@ where self.responses.push(receiver); } } + +fn dispatch_model_thread( + model_type: ModelType, + model_name: String, + api_key: String, + cache_dir: PathBuf, + model_config: ModelConfig, + public_key: PublicKey, + model_receiver: mpsc::Receiver, +) -> JoinHandle> { + match model_type { + ModelType::Falcon7b | ModelType::Falcon40b | ModelType::Falcon180b => { + spawn_model_thread::( + model_name, + api_key.clone(), + cache_dir.clone(), + model_config, + public_key, + model_receiver, + ) + } + ModelType::LlamaV1 + | ModelType::LlamaV2 + | ModelType::LlamaTinyLlama1_1BChat + | ModelType::LlamaSolar10_7B => spawn_model_thread::( + model_name, + api_key, + cache_dir, + model_config, + public_key, + model_receiver, + ), + ModelType::Mamba130m + | ModelType::Mamba370m + | ModelType::Mamba790m + | ModelType::Mamba1_4b + | ModelType::Mamba2_8b => spawn_model_thread::( + model_name, + api_key, + cache_dir, + model_config, + public_key, + model_receiver, + ), + ModelType::Mistral7b => todo!(), + ModelType::Mixtral8x7b => todo!(), + ModelType::StableDiffusionV1_5 + | ModelType::StableDiffusionV2_1 + | ModelType::StableDiffusionTurbo + | ModelType::StableDiffusionXl => spawn_model_thread::( + model_name, + api_key, + cache_dir, + model_config, + public_key, + model_receiver, + ), + } +} + +fn spawn_model_thread( + model_name: String, + api_key: String, + cache_dir: PathBuf, + model_config: ModelConfig, + public_key: PublicKey, + model_receiver: mpsc::Receiver, +) -> JoinHandle> +where + M: ModelTrait + Send + 'static, +{ + std::thread::spawn(move || { + info!("Fetching files for model: {model_name}"); + let load_data = M::fetch(api_key, cache_dir, model_config)?; + + let model = M::load(load_data)?; + let model_thread = ModelThread { + model, + receiver: model_receiver, + }; + + if let Err(e) = model_thread.run(public_key) { + error!("Model thread error: {e}"); + if !matches!(e, ModelThreadError::Shutdown(_)) { + panic!("Fatal error occurred: {e}"); + } + } + + Ok(()) + }) +} diff --git a/atoma-inference/src/models/candle/stable_diffusion.rs b/atoma-inference/src/models/candle/stable_diffusion.rs index 306e9375..ea42f8be 100644 --- a/atoma-inference/src/models/candle/stable_diffusion.rs +++ b/atoma-inference/src/models/candle/stable_diffusion.rs @@ -13,7 +13,9 @@ use tracing::{debug, info}; use crate::{ bail, - models::{candle::save_image, config::ModelConfig, types::ModelType, ModelError, ModelTrait}, + models::{ + candle::save_image, config::ModelConfig, types::ModelType, ModelError, ModelId, ModelTrait, + }, }; use super::{convert_to_image, device, save_tensor_to_file}; @@ -32,7 +34,7 @@ pub struct StableDiffusionInput { /// The number of samples to generate. pub num_samples: i64, - pub model_type: ModelType, + pub model: ModelId, pub guidance_scale: Option, @@ -313,7 +315,8 @@ impl ModelTrait for StableDiffusion { }; let bsize = 1; - let vae_scale = match input.model_type { + let model_type = ModelType::from_str(&input.model)?; + let vae_scale = match model_type { ModelType::StableDiffusionV1_5 | ModelType::StableDiffusionV2_1 | ModelType::StableDiffusionXl => 0.18215, @@ -336,17 +339,8 @@ impl ModelTrait for StableDiffusion { } } None => { - let latents = Tensor::randn( - 0f32, - 1f32, - ( - bsize, - 4, - height / 8, - width / 8, - ), - &self.device, - )?; + let latents = + Tensor::randn(0f32, 1f32, (bsize, 4, height / 8, width / 8), &self.device)?; // scale the initial noise by the standard deviation required by the scheduler (latents * scheduler.init_noise_sigma())? } diff --git a/atoma-inference/src/models/mod.rs b/atoma-inference/src/models/mod.rs index d17f6085..bd8fdd1d 100644 --- a/atoma-inference/src/models/mod.rs +++ b/atoma-inference/src/models/mod.rs @@ -2,6 +2,7 @@ use std::path::PathBuf; use ::candle::{DTypeParseError, Error as CandleError}; use ed25519_consensus::VerificationKey as PublicKey; +use serde::{de::DeserializeOwned, Serialize}; use thiserror::Error; use self::{config::ModelConfig, types::ModelType}; @@ -14,8 +15,8 @@ pub mod types; pub type ModelId = String; pub trait ModelTrait { - type Input; - type Output; + type Input: DeserializeOwned; + type Output: Serialize; type LoadData; fn fetch( diff --git a/atoma-inference/src/models/types.rs b/atoma-inference/src/models/types.rs index 7af052be..75c9f586 100644 --- a/atoma-inference/src/models/types.rs +++ b/atoma-inference/src/models/types.rs @@ -19,7 +19,7 @@ pub struct LlmLoadData { pub use_flash_attention: bool, } -#[derive(Clone, Debug, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub enum ModelType { Falcon7b, Falcon40b, @@ -244,7 +244,7 @@ impl Response for TextResponse { } } -#[derive(Clone, Debug, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct StableDiffusionRequest { pub request_id: usize, pub prompt: String, @@ -259,7 +259,7 @@ pub struct StableDiffusionRequest { /// The number of samples to generate. pub num_samples: i64, - pub model_type: ModelType, + pub model: ModelId, pub guidance_scale: Option, @@ -287,7 +287,7 @@ impl Request for StableDiffusionRequest { width: self.width, n_steps: self.n_steps, num_samples: self.num_samples, - model_type: self.model_type, + model: self.model, guidance_scale: self.guidance_scale, img2img: self.img2img, img2img_strength: self.img2img_strength, @@ -304,7 +304,7 @@ impl Request for StableDiffusionRequest { } fn requested_model(&self) -> ModelId { - self.model_type.to_string() + self.model.clone() } } diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index 53603e62..7cf4cc8a 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -11,45 +11,34 @@ use thiserror::Error; use crate::{ apis::ApiError, model_thread::{ModelThreadDispatcher, ModelThreadError, ModelThreadHandle}, - models::{config::ModelsConfig, ModelTrait, Request, Response}, + models::config::ModelsConfig, }; -pub struct ModelService -where - Req: Request, - Resp: Response, -{ - model_thread_handle: Vec>, - dispatcher: ModelThreadDispatcher, +pub struct ModelService { + model_thread_handle: Vec, + dispatcher: ModelThreadDispatcher, start_time: Instant, flush_storage: bool, public_key: PublicKey, cache_dir: PathBuf, - request_receiver: Receiver, - response_sender: Sender, + request_receiver: Receiver, + response_sender: Sender, } -impl ModelService -where - Req: Clone + Debug + Request, - Resp: Debug + Response, -{ - pub fn start( +impl ModelService { + pub fn start( model_config: ModelsConfig, private_key: PrivateKey, - request_receiver: Receiver, - response_sender: Sender, - ) -> Result - where - M: ModelTrait + Send + 'static, - { + request_receiver: Receiver, + response_sender: Sender, + ) -> Result { let public_key = private_key.verification_key(); let flush_storage = model_config.flush_storage(); let cache_dir = model_config.cache_dir(); let (dispatcher, model_thread_handle) = - ModelThreadDispatcher::start::(model_config, public_key) + ModelThreadDispatcher::start(model_config, public_key) .map_err(ModelServiceError::ModelThreadError)?; let start_time = Instant::now(); @@ -65,7 +54,7 @@ where }) } - pub async fn run(&mut self) -> Result { + pub async fn run(&mut self) -> Result<(), ModelServiceError> { loop { tokio::select! { message = self.request_receiver.recv() => { @@ -95,11 +84,7 @@ where } } -impl ModelService -where - Req: Request, - Resp: Response, -{ +impl ModelService { pub async fn stop(mut self) { info!( "Stopping Inference Service, running time: {:?}", @@ -158,7 +143,7 @@ mod tests { use std::io::Write; use toml::{toml, Value}; - use crate::models::{config::ModelConfig, Request, Response}; + use crate::models::{config::ModelConfig, ModelTrait, Request, Response}; use super::*; @@ -232,18 +217,12 @@ mod tests { file.write_all(toml_string.as_bytes()) .expect("Failed to write to file"); - let (_, req_receiver) = tokio::sync::mpsc::channel::<()>(1); - let (resp_sender, _) = tokio::sync::mpsc::channel::<()>(1); + let (_, req_receiver) = tokio::sync::mpsc::channel(1); + let (resp_sender, _) = tokio::sync::mpsc::channel(1); let config = ModelsConfig::from_file_path(CONFIG_FILE_PATH.parse().unwrap()); - let _ = ModelService::start::( - config, - private_key, - req_receiver, - resp_sender, - ) - .unwrap(); + let _ = ModelService::start(config, private_key, req_receiver, resp_sender).unwrap(); std::fs::remove_file(CONFIG_FILE_PATH).unwrap(); } From 9ca01c2a548c8122312e37f0f56aa3db932cefa5 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Sat, 6 Apr 2024 23:36:18 +0100 Subject: [PATCH 19/21] add logs to falcon and llama load methods --- atoma-inference/Cargo.toml | 3 +-- atoma-inference/src/main.rs | 14 +++++++------- atoma-inference/src/models/candle/falcon.rs | 2 +- atoma-inference/src/models/candle/llama.rs | 9 ++++++++- atoma-inference/src/models/types.rs | 4 ++-- 5 files changed, 19 insertions(+), 13 deletions(-) diff --git a/atoma-inference/Cargo.toml b/atoma-inference/Cargo.toml index 3da618eb..8089cfcd 100644 --- a/atoma-inference/Cargo.toml +++ b/atoma-inference/Cargo.toml @@ -14,12 +14,11 @@ dotenv.workspace = true ed25519-consensus.workspace = true futures.workspace = true hf-hub.workspace = true -reqwest = { workspace = true, features = ["json"] } serde = { workspace = true, features = ["derive"] } serde_json.workspace = true image = { workspace = true } thiserror.workspace = true -tokenizers = { workspace = true, features = ["onig"] } +tokenizers = { workspace = true } tokio = { workspace = true, features = ["full", "tracing"] } tracing.workspace = true tracing-subscriber.workspace = true diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index d2e118f1..9356d85b 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -2,7 +2,7 @@ use std::time::Duration; use ed25519_consensus::SigningKey as PrivateKey; use inference::{ - models::{config::ModelsConfig, types::StableDiffusionRequest}, + models::{config::ModelsConfig, types::{StableDiffusionRequest}}, service::{ModelService, ModelServiceError}, }; @@ -34,10 +34,10 @@ async fn main() -> Result<(), ModelServiceError> { tokio::time::sleep(Duration::from_millis(5_000)).await; // req_sender - // .send(TextRequest { + // .send(serde_json::to_value(TextRequest { // request_id: 0, // prompt: "Leon, the professional is a movie".to_string(), - // model: "falcon_7b".to_string(), + // model: "llama_tiny_llama_1_1b_chat".to_string(), // max_tokens: 512, // temperature: Some(0.0), // random_seed: 42, @@ -45,8 +45,8 @@ async fn main() -> Result<(), ModelServiceError> { // repeat_penalty: 1.1, // sampled_nodes: vec![pk], // top_p: Some(1.0), - // top_k: 10, - // }) + // _top_k: 10, + // }).unwrap()) // .await // .expect("Failed to send request"); @@ -56,8 +56,8 @@ async fn main() -> Result<(), ModelServiceError> { request_id: 0, prompt: "A depiction of Natalie Portman".to_string(), uncond_prompt: "".to_string(), - height: None, - width: None, + height: Some(256), + width: Some(256), num_samples: 1, n_steps: None, model: "stable_diffusion_v1-5".to_string(), diff --git a/atoma-inference/src/models/candle/falcon.rs b/atoma-inference/src/models/candle/falcon.rs index b13acc40..12c5164d 100644 --- a/atoma-inference/src/models/candle/falcon.rs +++ b/atoma-inference/src/models/candle/falcon.rs @@ -122,7 +122,7 @@ impl ModelTrait for FalconModel { )? }; let model = Falcon::load(vb, config.clone())?; - info!("loaded the model in {:?}", start.elapsed()); + info!("Loaded Falcon model in {:?}", start.elapsed()); Ok(Self::new( model, diff --git a/atoma-inference/src/models/candle/llama.rs b/atoma-inference/src/models/candle/llama.rs index 396b0e39..d47155dc 100644 --- a/atoma-inference/src/models/candle/llama.rs +++ b/atoma-inference/src/models/candle/llama.rs @@ -1,4 +1,4 @@ -use std::{path::PathBuf, str::FromStr}; +use std::{path::PathBuf, str::FromStr, time::Instant}; use candle::{DType, Device, Tensor}; use candle_nn::VarBuilder; @@ -10,6 +10,7 @@ use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; use candle_transformers::models::llama as model; use tokenizers::Tokenizer; +use tracing::info; use crate::models::{ config::ModelConfig, @@ -96,6 +97,10 @@ impl ModelTrait for LlamaModel { } fn load(load_data: Self::LoadData) -> Result { + info!("Loading Llama model ..."); + + let start = Instant::now(); + let device = load_data.device; let dtype = load_data.dtype; let (model, tokenizer_filename, cache) = { @@ -112,6 +117,8 @@ impl ModelTrait for LlamaModel { (model::Llama::load(vb, &config)?, tokenizer_filename, cache) }; let tokenizer = Tokenizer::from_file(tokenizer_filename)?; + info!("Loaded Llama model in {:?}", start.elapsed()); + Ok(Self { cache, device, diff --git a/atoma-inference/src/models/types.rs b/atoma-inference/src/models/types.rs index 75c9f586..dfb2001f 100644 --- a/atoma-inference/src/models/types.rs +++ b/atoma-inference/src/models/types.rs @@ -155,7 +155,7 @@ pub struct TextRequest { pub repeat_penalty: f32, pub sampled_nodes: Vec, pub temperature: Option, - pub top_k: usize, + pub _top_k: usize, pub top_p: Option, } @@ -170,7 +170,7 @@ impl Request for TextRequest { self.repeat_penalty, self.repeat_last_n, self.max_tokens, - self.top_k, + self._top_k, self.top_p.unwrap_or_default() as f64, ) } From a65ac25d67ec51d9e2914e5a3547fca092e8a021 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Sat, 6 Apr 2024 23:36:49 +0100 Subject: [PATCH 20/21] cargo fmt --- atoma-inference/src/main.rs | 2 +- atoma-inference/src/models/candle/llama.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index 9356d85b..2c99ba93 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -2,7 +2,7 @@ use std::time::Duration; use ed25519_consensus::SigningKey as PrivateKey; use inference::{ - models::{config::ModelsConfig, types::{StableDiffusionRequest}}, + models::{config::ModelsConfig, types::StableDiffusionRequest}, service::{ModelService, ModelServiceError}, }; diff --git a/atoma-inference/src/models/candle/llama.rs b/atoma-inference/src/models/candle/llama.rs index d47155dc..7e4a9026 100644 --- a/atoma-inference/src/models/candle/llama.rs +++ b/atoma-inference/src/models/candle/llama.rs @@ -118,7 +118,7 @@ impl ModelTrait for LlamaModel { }; let tokenizer = Tokenizer::from_file(tokenizer_filename)?; info!("Loaded Llama model in {:?}", start.elapsed()); - + Ok(Self { cache, device, From 45a43ec2b4682850bcf71568ead366507c5097a3 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Sun, 7 Apr 2024 11:41:13 +0100 Subject: [PATCH 21/21] minor changes --- atoma-inference/src/model_thread.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/atoma-inference/src/model_thread.rs b/atoma-inference/src/model_thread.rs index 57a07d69..b4c66d94 100644 --- a/atoma-inference/src/model_thread.rs +++ b/atoma-inference/src/model_thread.rs @@ -121,10 +121,10 @@ impl ModelThreadDispatcher { model_senders.insert(model_name.clone(), model_sender.clone()); let join_handle = dispatch_model_thread( - model_type, - model_name, api_key.clone(), cache_dir.clone(), + model_name, + model_type, model_config, public_key, model_receiver, @@ -178,10 +178,10 @@ impl ModelThreadDispatcher { } fn dispatch_model_thread( - model_type: ModelType, - model_name: String, api_key: String, cache_dir: PathBuf, + model_name: String, + model_type: ModelType, model_config: ModelConfig, public_key: PublicKey, model_receiver: mpsc::Receiver,