From ce7c9a1cc626763f7478266c9fb8f294a983a1fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jorge=20Ant=C3=B3nio?= Date: Mon, 8 Apr 2024 12:51:57 +0100 Subject: [PATCH] add logic to log tokens generated per seconds, for a few models (#31) * first commit * add tokens output stream generated tokens count * fmt * clippy --- atoma-inference/src/main.rs | 63 ++++++++++--------- atoma-inference/src/models/candle/falcon.rs | 6 +- atoma-inference/src/models/candle/llama.rs | 12 ++++ atoma-inference/src/models/candle/mamba.rs | 8 ++- .../src/models/token_output_stream.rs | 4 ++ 5 files changed, 58 insertions(+), 35 deletions(-) diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index 2c99ba93..41c37701 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -2,7 +2,7 @@ use std::time::Duration; use ed25519_consensus::SigningKey as PrivateKey; use inference::{ - models::{config::ModelsConfig, types::StableDiffusionRequest}, + models::{config::ModelsConfig, types::TextRequest}, service::{ModelService, ModelServiceError}, }; @@ -33,45 +33,48 @@ async fn main() -> Result<(), ModelServiceError> { tokio::time::sleep(Duration::from_millis(5_000)).await; - // req_sender - // .send(serde_json::to_value(TextRequest { - // request_id: 0, - // prompt: "Leon, the professional is a movie".to_string(), - // model: "llama_tiny_llama_1_1b_chat".to_string(), - // max_tokens: 512, - // temperature: Some(0.0), - // random_seed: 42, - // repeat_last_n: 64, - // repeat_penalty: 1.1, - // sampled_nodes: vec![pk], - // top_p: Some(1.0), - // _top_k: 10, - // }).unwrap()) - // .await - // .expect("Failed to send request"); - req_sender .send( - serde_json::to_value(StableDiffusionRequest { + serde_json::to_value(TextRequest { request_id: 0, - prompt: "A depiction of Natalie Portman".to_string(), - uncond_prompt: "".to_string(), - height: Some(256), - width: Some(256), - num_samples: 1, - n_steps: None, - model: "stable_diffusion_v1-5".to_string(), - guidance_scale: None, - img2img: None, - img2img_strength: 0.8, - random_seed: Some(42), + prompt: "Leon, the professional is a movie".to_string(), + model: "mamba_370m".to_string(), + max_tokens: 512, + temperature: Some(0.0), + random_seed: 42, + repeat_last_n: 64, + repeat_penalty: 1.1, sampled_nodes: vec![pk], + top_p: Some(1.0), + _top_k: 10, }) .unwrap(), ) .await .expect("Failed to send request"); + // req_sender + // .send( + // serde_json::to_value(StableDiffusionRequest { + // request_id: 0, + // prompt: "A depiction of Natalie Portman".to_string(), + // uncond_prompt: "".to_string(), + // height: Some(256), + // width: Some(256), + // num_samples: 1, + // n_steps: None, + // model: "stable_diffusion_v1-5".to_string(), + // guidance_scale: None, + // img2img: None, + // img2img_strength: 0.8, + // random_seed: Some(42), + // sampled_nodes: vec![pk], + // }) + // .unwrap(), + // ) + // .await + // .expect("Failed to send request"); + if let Some(response) = resp_receiver.recv().await { println!("Got a response: {:?}", response); } diff --git a/atoma-inference/src/models/candle/falcon.rs b/atoma-inference/src/models/candle/falcon.rs index 12c5164d..46f4ffb7 100644 --- a/atoma-inference/src/models/candle/falcon.rs +++ b/atoma-inference/src/models/candle/falcon.rs @@ -158,6 +158,7 @@ impl ModelTrait for FalconModel { let mut output = String::new(); let start_gen = Instant::now(); + let mut generated_tokens = 0; for index in 0..max_tokens { let start_gen = Instant::now(); let context_size = if self.model.config().use_cache && index > 0 { @@ -181,12 +182,13 @@ impl ModelTrait for FalconModel { new_tokens.push(next_token); debug!("> {:?}", start_gen); output.push_str(&self.tokenizer.decode(&[next_token], true)?); + generated_tokens += 1; } let dt = start_gen.elapsed(); info!( - "{max_tokens} tokens generated ({} token/s)\n----\n{}\n----", - max_tokens as f64 / dt.as_secs_f64(), + "{generated_tokens} tokens generated ({} token/s)\n----\n{}\n----", + generated_tokens as f64 / dt.as_secs_f64(), self.tokenizer.decode(&new_tokens, true)?, ); diff --git a/atoma-inference/src/models/candle/llama.rs b/atoma-inference/src/models/candle/llama.rs index 7e4a9026..eaed4304 100644 --- a/atoma-inference/src/models/candle/llama.rs +++ b/atoma-inference/src/models/candle/llama.rs @@ -144,6 +144,9 @@ impl ModelTrait for LlamaModel { ); let mut index_pos = 0; let mut res = String::new(); + let mut generated_tokens = 0; + + let start_gen = Instant::now(); for index in 0..input.max_tokens { let (context_size, context_index) = if self.cache.use_kv_cache && index > 0 { (1, index_pos) @@ -176,10 +179,19 @@ impl ModelTrait for LlamaModel { if let Some(t) = tokenizer.next_token(next_token)? { res += &t; } + + generated_tokens += 1; } if let Some(rest) = tokenizer.decode_rest()? { res += &rest; } + + let dt = start_gen.elapsed(); + info!( + "{generated_tokens} tokens generated ({} token/s)\n", + generated_tokens as f64 / dt.as_secs_f64(), + ); + Ok(res) } } diff --git a/atoma-inference/src/models/candle/mamba.rs b/atoma-inference/src/models/candle/mamba.rs index c77c68f0..0b1213ff 100644 --- a/atoma-inference/src/models/candle/mamba.rs +++ b/atoma-inference/src/models/candle/mamba.rs @@ -148,9 +148,11 @@ impl ModelTrait for MambaModel { .. } = input; + // clean tokenizer state + self.tokenizer.clear(); + info!("Running inference on prompt: {:?}", prompt); - self.tokenizer.clear(); let mut tokens = self .tokenizer .tokenizer() @@ -160,7 +162,6 @@ impl ModelTrait for MambaModel { let mut logits_processor = LogitsProcessor::new(random_seed, Some(temperature), Some(top_p)); - let mut generated_tokens = 0_usize; let eos_token = match self.tokenizer.get_token("<|endoftext|>") { Some(token) => token, None => bail!("Invalid eos token"), @@ -198,7 +199,6 @@ impl ModelTrait for MambaModel { let next_token = logits_processor.sample(&logits)?; tokens.push(next_token); - generated_tokens += 1; if next_token == eos_token { break; @@ -216,10 +216,12 @@ impl ModelTrait for MambaModel { output.push_str(rest.as_str()); } + let generated_tokens = self.tokenizer.get_num_generated_tokens(); info!( "\n{generated_tokens} tokens generated ({:.2} token/s)", generated_tokens as f64 / dt.as_secs_f64(), ); + Ok(output) } } diff --git a/atoma-inference/src/models/token_output_stream.rs b/atoma-inference/src/models/token_output_stream.rs index 33bfb27a..585a3f51 100644 --- a/atoma-inference/src/models/token_output_stream.rs +++ b/atoma-inference/src/models/token_output_stream.rs @@ -78,6 +78,10 @@ impl TokenOutputStream { &self.tokenizer } + pub fn get_num_generated_tokens(&self) -> usize { + self.tokens.len() + } + pub fn clear(&mut self) { self.tokens.clear(); self.prev_index = 0;