diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index 80d1db18..45d5974b 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -3,11 +3,7 @@ use std::time::Duration; use ed25519_consensus::SigningKey as PrivateKey; use hf_hub::api::sync::Api; use inference::{ - models::{ - candle::mamba::MambaModel, - config::ModelsConfig, - types::{TextRequest, TextResponse}, - }, + models::{candle::mamba::MambaModel, config::ModelsConfig, types::TextRequest}, service::{ModelService, ModelServiceError}, }; @@ -15,8 +11,8 @@ use inference::{ async fn main() -> Result<(), ModelServiceError> { tracing_subscriber::fmt::init(); - let (req_sender, req_receiver) = tokio::sync::mpsc::channel::(32); - let (resp_sender, mut resp_receiver) = tokio::sync::mpsc::channel::(32); + let (req_sender, req_receiver) = tokio::sync::mpsc::channel::(32); + let (resp_sender, mut resp_receiver) = tokio::sync::mpsc::channel::(32); let model_config = ModelsConfig::from_file_path("../inference.toml".parse().unwrap()); let private_key_bytes = @@ -44,19 +40,22 @@ async fn main() -> Result<(), ModelServiceError> { tokio::time::sleep(Duration::from_millis(5000)).await; req_sender - .send(TextRequest { - request_id: 0, - prompt: "Leon, the professional is a movie".to_string(), - model: "state-spaces/mamba-130m".to_string(), - max_tokens: 512, - temperature: Some(0.0), - random_seed: 42, - repeat_last_n: 64, - repeat_penalty: 1.1, - sampled_nodes: vec![pk], - top_p: Some(1.0), - top_k: 10, - }) + .send( + serde_json::to_value(TextRequest { + request_id: 0, + prompt: "Leon, the professional is a movie".to_string(), + model: "state-spaces/mamba-130m".to_string(), + max_tokens: 512, + temperature: Some(0.0), + random_seed: 42, + repeat_last_n: 64, + repeat_penalty: 1.1, + sampled_nodes: vec![pk], + top_p: Some(1.0), + top_k: 10, + }) + .unwrap(), + ) .await .expect("Failed to send request"); diff --git a/atoma-inference/src/model_thread.rs b/atoma-inference/src/model_thread.rs index 388554ae..68a3996e 100644 --- a/atoma-inference/src/model_thread.rs +++ b/atoma-inference/src/model_thread.rs @@ -11,13 +11,13 @@ use tracing::{debug, error, info, warn}; use crate::{ apis::{ApiError, ApiTrait}, - models::{config::ModelsConfig, ModelError, ModelId, ModelTrait, Request, Response}, + models::{config::ModelsConfig, ModelError, ModelId, ModelTrait}, }; -pub struct ModelThreadCommand(Req, oneshot::Sender) -where - Req: Request, - Resp: Response; +pub struct ModelThreadCommand { + request: serde_json::Value, + response_sender: oneshot::Sender, +} #[derive(Debug, Error)] pub enum ModelThreadError { @@ -27,6 +27,8 @@ pub enum ModelThreadError { ModelError(ModelError), #[error("Core thread shutdown: `{0}`")] Shutdown(RecvError), + #[error("Serde error: `{0}`")] + SerdeError(#[from] serde_json::Error), } impl From for ModelThreadError { @@ -41,82 +43,68 @@ impl From for ModelThreadError { } } -pub struct ModelThreadHandle -where - Req: Request, - Resp: Response, -{ - sender: mpsc::Sender>, +pub struct ModelThreadHandle { + sender: mpsc::Sender, join_handle: std::thread::JoinHandle>, } -impl ModelThreadHandle -where - Req: Request, - Resp: Response, -{ +impl ModelThreadHandle { pub fn stop(self) { drop(self.sender); self.join_handle.join().ok(); } } -pub struct ModelThread { +pub struct ModelThread { model: M, - receiver: mpsc::Receiver>, + receiver: mpsc::Receiver, } -impl ModelThread +impl ModelThread where - M: ModelTrait, - Req: Request, - Resp: Response, + M: ModelTrait, { - pub fn run(mut self, public_key: PublicKey) -> Result<(), ModelThreadError> { + pub fn run(mut self, _public_key: PublicKey) -> Result<(), ModelThreadError> { debug!("Start Model thread"); while let Ok(command) = self.receiver.recv() { - let ModelThreadCommand(request, sender) = command; - - if !request.is_node_authorized(&public_key) { - error!("Current node, with verification key = {:?} is not authorized to run request with id = {}", public_key, request.request_id()); - continue; - } - - let model_input = request.into_model_input(); + let ModelThreadCommand { + request, + response_sender, + } = command; + + // TODO: Implement node authorization + // if !request.is_node_authorized(&public_key) { + // error!("Current node, with verification key = {:?} is not authorized to run request with id = {}", public_key, request.request_id()); + // continue; + // } + + let model_input = serde_json::from_value(request).unwrap(); let model_output = self .model .run(model_input) .map_err(ModelThreadError::ModelError)?; - let response = Resp::from_model_output(model_output); - sender.send(response).ok(); + let response = serde_json::to_value(model_output)?; + response_sender.send(response).ok(); } Ok(()) } } -pub struct ModelThreadDispatcher -where - Req: Request, - Resp: Response, -{ - model_senders: HashMap>>, - pub(crate) responses: FuturesUnordered>, +pub struct ModelThreadDispatcher { + model_senders: HashMap>, + pub(crate) responses: FuturesUnordered>, } -impl ModelThreadDispatcher -where - Req: Clone + Request, - Resp: Response, -{ +impl ModelThreadDispatcher { pub(crate) fn start( config: ModelsConfig, public_key: PublicKey, - ) -> Result<(Self, Vec>), ModelThreadError> + ) -> Result<(Self, Vec), ModelThreadError> where F: ApiTrait + Send + Sync + 'static, - M: ModelTrait + Send + 'static, + M: ModelTrait, // + Send + 'static, { let api_key = config.api_key(); let storage_path = config.storage_path(); @@ -129,19 +117,16 @@ where info!("Spawning new thread for model: {}", model_config.model_id()); let api = api.clone(); - let (model_sender, model_receiver) = mpsc::channel::>(); + let (model_sender, model_receiver) = mpsc::channel::(); let model_name = model_config.model_id().clone(); model_senders.insert(model_name.clone(), model_sender.clone()); let join_handle = std::thread::spawn(move || { info!("Fetching files for model: {model_name}"); let filenames = api.fetch(model_name, model_config.revision())?; + let x = serde_json::from_value(model_config.params().clone()).unwrap(); - let model = M::load( - filenames, - model_config.precision(), - model_config.device_id(), - )?; + let model = M::load(filenames, x, model_config.device_id())?; let model_thread = ModelThread { model, receiver: model_receiver, @@ -170,10 +155,12 @@ where Ok((model_dispatcher, handles)) } - fn send(&self, command: ModelThreadCommand) { - let request = command.0.clone(); - let model_id = request.requested_model(); + fn send(&self, command: ModelThreadCommand) { + let request = command.request.clone(); + let model_id = request.get("model").unwrap().as_str().unwrap().to_string(); + println!("model_id {model_id}"); + println!("{:?}", self.model_senders); let sender = self .model_senders .get(&model_id) @@ -185,14 +172,13 @@ where } } -impl ModelThreadDispatcher -where - T: Clone + Request, - U: Response, -{ - pub(crate) fn run_inference(&self, request: T) { +impl ModelThreadDispatcher { + pub(crate) fn run_inference(&self, request: serde_json::Value) { let (sender, receiver) = oneshot::channel(); - self.send(ModelThreadCommand(request, sender)); + self.send(ModelThreadCommand { + request, + response_sender: sender, + }); self.responses.push(receiver); } } diff --git a/atoma-inference/src/models/candle/falcon.rs b/atoma-inference/src/models/candle/falcon.rs index 2fd96a73..5ec797e1 100644 --- a/atoma-inference/src/models/candle/falcon.rs +++ b/atoma-inference/src/models/candle/falcon.rs @@ -49,6 +49,7 @@ impl ModelTrait for FalconModel { type Fetch = (); type Input = TextModelInput; type Output = String; + type Load = PrecisionBits; fn fetch(_fetch: &Self::Fetch) -> Result<(), ModelError> { Ok(()) @@ -56,7 +57,7 @@ impl ModelTrait for FalconModel { fn load( filenames: Vec, - precision: PrecisionBits, + precision: Self::Load, device_id: usize, ) -> Result where diff --git a/atoma-inference/src/models/candle/llama.rs b/atoma-inference/src/models/candle/llama.rs index b3eeb83f..3b6b4079 100644 --- a/atoma-inference/src/models/candle/llama.rs +++ b/atoma-inference/src/models/candle/llama.rs @@ -6,7 +6,7 @@ extern crate intel_mkl_src; use std::path::PathBuf; -use candle::{DType, Device, Tensor}; +use candle::{Device, Tensor}; use candle_nn::VarBuilder; use candle_transformers::{ generation::LogitsProcessor, @@ -15,6 +15,7 @@ use candle_transformers::{ use hf_hub::{api::sync::Api, Repo, RepoType}; use candle_transformers::models::llama as model; +use serde::Deserialize; use tokenizers::Tokenizer; use crate::models::{ @@ -43,6 +44,7 @@ pub struct Llama { cache: Cache, } +#[derive(Deserialize)] pub struct Input { prompt: String, temperature: Option, @@ -71,8 +73,6 @@ pub struct Fetch { model_id: Option, revision: Option, which: Which, - use_flash_attn: bool, - dtype: Option, } impl Default for Fetch { @@ -81,8 +81,6 @@ impl Default for Fetch { model_id: None, revision: None, which: Which::TinyLlama1_1BChat, - use_flash_attn: false, - dtype: None, } } } @@ -90,17 +88,10 @@ impl Default for Fetch { impl ModelTrait for Llama { type Input = Input; type Fetch = Fetch; - type Output = Vec; + type Output = String; + type Load = PrecisionBits; fn fetch(fetch: &Self::Fetch) -> Result<(), ModelError> { - let device = device()?; - let dtype = match fetch.dtype.as_deref() { - Some("f16") => DType::F16, - Some("bf16") => DType::BF16, - Some("f32") => DType::F32, - Some(dtype) => Err(ModelError::Config(format!("Invalid dtype : {dtype}")))?, - None => DType::F16, - }; let api = Api::new()?; let model_id = fetch.model_id.clone().unwrap_or_else(|| match fetch.which { Which::V1 => "Narsil/amall-7b".to_string(), @@ -111,17 +102,15 @@ impl ModelTrait for Llama { let revision = fetch.revision.clone().unwrap_or("main".to_string()); let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision)); api.get("tokenizer.json")?; - let config_filename = api.get("config.json")?; - let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?; - let config = config.into_config(fetch.use_flash_attn); - let filenames = match fetch.which { + api.get("config.json")?; + match fetch.which { Which::V1 | Which::V2 | Which::Solar10_7B => { - hub_load_safetensors(&api, "model.safetensors.index.json")? + hub_load_safetensors(&api, "model.safetensors.index.json")?; + } + Which::TinyLlama1_1BChat => { + api.get("model.safetensors")?; } - Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?], }; - let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; - model::Llama::load(vb, &config)?; Ok(()) } @@ -129,8 +118,12 @@ impl ModelTrait for Llama { "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string() } - fn load(filenames: Vec, precision: PrecisionBits) -> Result { - let device = device()?; + fn load( + filenames: Vec, + precision: PrecisionBits, + device_id: usize, + ) -> Result { + let device = device(device_id)?; let dtype = precision.into_dtype(); let (llama, tokenizer_filename, cache) = { let tokenizer_filename = filenames[0].clone(); @@ -164,7 +157,6 @@ impl ModelTrait for Llama { let mut logits_processor = LogitsProcessor::new(input.seed, input.temperature, input.top_p); let mut index_pos = 0; let mut res = String::new(); - let mut result = Vec::new(); for index in 0..input.sample_len { let (context_size, context_index) = if self.cache.use_kv_cache && index > 0 { (1, index_pos) @@ -189,7 +181,6 @@ impl ModelTrait for Llama { }; index_pos += ctxt.len(); let next_token = logits_processor.sample(&logits)?; - result.push(logits); tokens.push(next_token); if Some(next_token) == eos_token_id { @@ -202,7 +193,6 @@ impl ModelTrait for Llama { if let Some(rest) = tokenizer.decode_rest()? { res += &rest; } - println!("Result {}", res); - Ok(result) + Ok(res) } } diff --git a/atoma-inference/src/models/candle/mamba.rs b/atoma-inference/src/models/candle/mamba.rs index 3552cb6b..d491ec58 100644 --- a/atoma-inference/src/models/candle/mamba.rs +++ b/atoma-inference/src/models/candle/mamba.rs @@ -7,6 +7,7 @@ use candle_transformers::{ models::mamba::{Config, Model, State}, utils::apply_repeat_penalty, }; +use serde::Deserialize; use tokenizers::Tokenizer; use tracing::info; @@ -49,14 +50,20 @@ impl MambaModel { } } +#[derive(Debug, Deserialize)] +pub struct Load { + pub precision: PrecisionBits, +} + impl ModelTrait for MambaModel { type Fetch = (); type Input = TextModelInput; type Output = String; + type Load = Load; fn load( filenames: Vec, - precision: PrecisionBits, + params: Self::Load, device_id: usize, ) -> Result where @@ -76,7 +83,7 @@ impl ModelTrait for MambaModel { serde_json::from_slice(&std::fs::read(config_filename).map_err(ModelError::IoError)?) .map_err(ModelError::DeserializeError)?; let device = device(device_id)?; - let dtype = precision.into_dtype(); + let dtype = params.precision.into_dtype(); info!("Loading model weights.."); let var_builder = diff --git a/atoma-inference/src/models/candle/mod.rs b/atoma-inference/src/models/candle/mod.rs index 373d20c9..07bddc9c 100644 --- a/atoma-inference/src/models/candle/mod.rs +++ b/atoma-inference/src/models/candle/mod.rs @@ -53,14 +53,19 @@ pub fn hub_load_safetensors( Ok(safetensors_files) } -pub fn save_image>(img: &Tensor, p: P) -> Result<(), ModelError> { - let p = p.as_ref(); +pub fn convert_to_image(img: &Tensor) -> Result<(Vec, usize, usize), ModelError> { let (channel, height, width) = img.dims3()?; if channel != 3 { bail!("save_image expects an input of shape (3, height, width)") } let img = img.permute((1, 2, 0))?.flatten_all()?; let pixels = img.to_vec1::()?; + Ok((pixels, width, height)) +} + +pub fn save_image>(img: &Tensor, p: P) -> Result<(), ModelError> { + let p = p.as_ref(); + let (pixels, width, height) = convert_to_image(img)?; let image: image::ImageBuffer, Vec> = match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) { Some(image) => image, diff --git a/atoma-inference/src/models/candle/stable_diffusion.rs b/atoma-inference/src/models/candle/stable_diffusion.rs index a04b8e50..f3be6ad6 100644 --- a/atoma-inference/src/models/candle/stable_diffusion.rs +++ b/atoma-inference/src/models/candle/stable_diffusion.rs @@ -7,12 +7,14 @@ extern crate intel_mkl_src; use candle_transformers::models::stable_diffusion::{self}; use candle::{DType, Device, IndexOp, Module, Tensor, D}; +use serde::Deserialize; use tokenizers::Tokenizer; use crate::models::{types::PrecisionBits, ModelError, ModelId, ModelTrait}; -use super::{device, save_tensor_to_file}; +use super::{convert_to_image, device, save_tensor_to_file}; +#[derive(Deserialize)] pub struct Input { prompt: String, uncond_prompt: String, @@ -110,14 +112,22 @@ pub struct Fetch { unet_weights: Option, } +#[derive(Debug, Deserialize)] +pub struct Load { + pub filenames: Vec, + pub precision: PrecisionBits, + pub device_id: usize, +} + impl ModelTrait for StableDiffusion { type Input = Input; type Fetch = Fetch; - type Output = Vec; + type Output = Vec<(Vec, usize, usize)>; + type Load = Load; fn load( _filenames: Vec, - _precision: PrecisionBits, + _precision: Self::Load, device_id: usize, ) -> Result where @@ -323,14 +333,14 @@ impl ModelTrait for StableDiffusion { save_tensor_to_file(&image, "tensor3")?; let image = (image.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?.i(0)?; save_tensor_to_file(&image, "tensor4")?; - res.push(image); + res.push(convert_to_image(&image)?); } Ok(res) } } #[allow(dead_code)] -#[derive(Clone, Copy)] +#[derive(Clone, Copy, Deserialize)] enum StableDiffusionVersion { V1_5, V2_1, diff --git a/atoma-inference/src/models/config.rs b/atoma-inference/src/models/config.rs index 8606ebcb..22c404c1 100644 --- a/atoma-inference/src/models/config.rs +++ b/atoma-inference/src/models/config.rs @@ -4,14 +4,14 @@ use config::Config; use dotenv::dotenv; use serde::{Deserialize, Serialize}; -use crate::{models::types::PrecisionBits, models::ModelId}; +use crate::models::ModelId; type Revision = String; #[derive(Clone, Debug, Deserialize, Serialize)] pub struct ModelConfig { model_id: ModelId, - precision: PrecisionBits, + params: serde_json::Value, revision: Revision, device_id: usize, } @@ -19,13 +19,13 @@ pub struct ModelConfig { impl ModelConfig { pub fn new( model_id: ModelId, - precision: PrecisionBits, + params: serde_json::Value, revision: Revision, device_id: usize, ) -> Self { Self { model_id, - precision, + params, revision, device_id, } @@ -35,8 +35,8 @@ impl ModelConfig { &self.model_id } - pub fn precision(&self) -> PrecisionBits { - self.precision + pub fn params(&self) -> &serde_json::Value { + &self.params } pub fn revision(&self) -> Revision { @@ -139,6 +139,8 @@ impl ModelsConfig { #[cfg(test)] pub mod tests { + use crate::models::types::PrecisionBits; + use super::*; #[test] @@ -148,7 +150,7 @@ pub mod tests { true, vec![ModelConfig::new( "Llama2_7b".to_string(), - PrecisionBits::F16, + serde_json::to_value(PrecisionBits::F16).unwrap(), "".to_string(), 0, )], diff --git a/atoma-inference/src/models/mod.rs b/atoma-inference/src/models/mod.rs index f12a7408..5ccccc91 100644 --- a/atoma-inference/src/models/mod.rs +++ b/atoma-inference/src/models/mod.rs @@ -2,10 +2,9 @@ use std::path::PathBuf; use ::candle::Error as CandleError; use ed25519_consensus::VerificationKey as PublicKey; +use serde::{de::DeserializeOwned, Serialize}; use thiserror::Error; -use crate::models::types::PrecisionBits; - pub mod candle; pub mod config; pub mod token_output_stream; @@ -15,15 +14,16 @@ pub type ModelId = String; pub trait ModelTrait { type Fetch; - type Input; - type Output; + type Input: DeserializeOwned; + type Output: Serialize; + type Load: DeserializeOwned; fn fetch(_fetch: &Self::Fetch) -> Result<(), ModelError> { Ok(()) } fn load( filenames: Vec, - precision: PrecisionBits, + params: Self::Load, device_id: usize, ) -> Result where diff --git a/atoma-inference/src/models/types.rs b/atoma-inference/src/models/types.rs index 54d85bcf..1a7f00ee 100644 --- a/atoma-inference/src/models/types.rs +++ b/atoma-inference/src/models/types.rs @@ -50,6 +50,7 @@ impl Request for TextRequest { } } +#[derive(Deserialize)] pub struct TextModelInput { pub(crate) prompt: String, pub(crate) temperature: f64, diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index d3e08830..4360538e 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -10,37 +10,29 @@ use thiserror::Error; use crate::{ apis::{ApiError, ApiTrait}, model_thread::{ModelThreadDispatcher, ModelThreadError, ModelThreadHandle}, - models::{config::ModelsConfig, ModelTrait, Request, Response}, + models::{config::ModelsConfig, ModelTrait}, }; -pub struct ModelService -where - Req: Request, - Resp: Response, -{ - model_thread_handle: Vec>, - dispatcher: ModelThreadDispatcher, +pub struct ModelService { + model_thread_handle: Vec, + dispatcher: ModelThreadDispatcher, start_time: Instant, flush_storage: bool, public_key: PublicKey, storage_path: PathBuf, - request_receiver: Receiver, - response_sender: Sender, + request_receiver: Receiver, + response_sender: Sender, } -impl ModelService -where - Req: Clone + Request, - Resp: std::fmt::Debug + Response, -{ +impl ModelService { pub fn start( model_config: ModelsConfig, private_key: PrivateKey, - request_receiver: Receiver, - response_sender: Sender, + request_receiver: Receiver, + response_sender: Sender, ) -> Result where - M: ModelTrait + Send + 'static, + M: ModelTrait + Send + 'static, F: ApiTrait + Send + Sync + 'static, { let public_key = private_key.verification_key(); @@ -65,7 +57,7 @@ where }) } - pub async fn run(&mut self) -> Result { + pub async fn run(&mut self) -> Result { loop { tokio::select! { message = self.request_receiver.recv() => { @@ -95,11 +87,7 @@ where } } -impl ModelService -where - Req: Request, - Resp: Response, -{ +impl ModelService { pub async fn stop(mut self) { info!( "Stopping Inference Service, running time: {:?}", @@ -158,7 +146,7 @@ mod tests { use std::io::Write; use toml::{toml, Value}; - use crate::{models::types::PrecisionBits, models::ModelId}; + use crate::models::{ModelId, Request, Response}; use super::*; @@ -208,6 +196,7 @@ mod tests { type Input = (); type Output = (); type Fetch = (); + type Load = (); fn fetch(_fetch: &Self::Fetch) -> Result<(), crate::models::ModelError> { Ok(()) @@ -215,7 +204,7 @@ mod tests { fn load( _: Vec, - _: PrecisionBits, + _: Self::Load, _device_id: usize, ) -> Result { Ok(Self {}) @@ -251,12 +240,12 @@ mod tests { file.write_all(toml_string.as_bytes()) .expect("Failed to write to file"); - let (_, req_receiver) = tokio::sync::mpsc::channel::<()>(1); - let (resp_sender, _) = tokio::sync::mpsc::channel::<()>(1); + let (_, req_receiver) = tokio::sync::mpsc::channel::(1); + let (resp_sender, _) = tokio::sync::mpsc::channel::(1); let config = ModelsConfig::from_file_path(CONFIG_FILE_PATH.parse().unwrap()); - let _ = ModelService::<(), ()>::start::( + let _ = ModelService::start::( config, private_key, req_receiver,