Skip to content

Commit

Permalink
add logic to log tokens generated per seconds, for a few models (#31)
Browse files Browse the repository at this point in the history
* first commit

* add tokens output stream generated tokens count

* fmt

* clippy
  • Loading branch information
jorgeantonio21 authored Apr 8, 2024
1 parent a58d50b commit ce7c9a1
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 35 deletions.
63 changes: 33 additions & 30 deletions atoma-inference/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::time::Duration;

use ed25519_consensus::SigningKey as PrivateKey;
use inference::{
models::{config::ModelsConfig, types::StableDiffusionRequest},
models::{config::ModelsConfig, types::TextRequest},
service::{ModelService, ModelServiceError},
};

Expand Down Expand Up @@ -33,45 +33,48 @@ async fn main() -> Result<(), ModelServiceError> {

tokio::time::sleep(Duration::from_millis(5_000)).await;

// req_sender
// .send(serde_json::to_value(TextRequest {
// request_id: 0,
// prompt: "Leon, the professional is a movie".to_string(),
// model: "llama_tiny_llama_1_1b_chat".to_string(),
// max_tokens: 512,
// temperature: Some(0.0),
// random_seed: 42,
// repeat_last_n: 64,
// repeat_penalty: 1.1,
// sampled_nodes: vec![pk],
// top_p: Some(1.0),
// _top_k: 10,
// }).unwrap())
// .await
// .expect("Failed to send request");

req_sender
.send(
serde_json::to_value(StableDiffusionRequest {
serde_json::to_value(TextRequest {
request_id: 0,
prompt: "A depiction of Natalie Portman".to_string(),
uncond_prompt: "".to_string(),
height: Some(256),
width: Some(256),
num_samples: 1,
n_steps: None,
model: "stable_diffusion_v1-5".to_string(),
guidance_scale: None,
img2img: None,
img2img_strength: 0.8,
random_seed: Some(42),
prompt: "Leon, the professional is a movie".to_string(),
model: "mamba_370m".to_string(),
max_tokens: 512,
temperature: Some(0.0),
random_seed: 42,
repeat_last_n: 64,
repeat_penalty: 1.1,
sampled_nodes: vec![pk],
top_p: Some(1.0),
_top_k: 10,
})
.unwrap(),
)
.await
.expect("Failed to send request");

// req_sender
// .send(
// serde_json::to_value(StableDiffusionRequest {
// request_id: 0,
// prompt: "A depiction of Natalie Portman".to_string(),
// uncond_prompt: "".to_string(),
// height: Some(256),
// width: Some(256),
// num_samples: 1,
// n_steps: None,
// model: "stable_diffusion_v1-5".to_string(),
// guidance_scale: None,
// img2img: None,
// img2img_strength: 0.8,
// random_seed: Some(42),
// sampled_nodes: vec![pk],
// })
// .unwrap(),
// )
// .await
// .expect("Failed to send request");

if let Some(response) = resp_receiver.recv().await {
println!("Got a response: {:?}", response);
}
Expand Down
6 changes: 4 additions & 2 deletions atoma-inference/src/models/candle/falcon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ impl ModelTrait for FalconModel {
let mut output = String::new();

let start_gen = Instant::now();
let mut generated_tokens = 0;
for index in 0..max_tokens {
let start_gen = Instant::now();
let context_size = if self.model.config().use_cache && index > 0 {
Expand All @@ -181,12 +182,13 @@ impl ModelTrait for FalconModel {
new_tokens.push(next_token);
debug!("> {:?}", start_gen);
output.push_str(&self.tokenizer.decode(&[next_token], true)?);
generated_tokens += 1;
}
let dt = start_gen.elapsed();

info!(
"{max_tokens} tokens generated ({} token/s)\n----\n{}\n----",
max_tokens as f64 / dt.as_secs_f64(),
"{generated_tokens} tokens generated ({} token/s)\n----\n{}\n----",
generated_tokens as f64 / dt.as_secs_f64(),
self.tokenizer.decode(&new_tokens, true)?,
);

Expand Down
12 changes: 12 additions & 0 deletions atoma-inference/src/models/candle/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ impl ModelTrait for LlamaModel {
);
let mut index_pos = 0;
let mut res = String::new();
let mut generated_tokens = 0;

let start_gen = Instant::now();
for index in 0..input.max_tokens {
let (context_size, context_index) = if self.cache.use_kv_cache && index > 0 {
(1, index_pos)
Expand Down Expand Up @@ -176,10 +179,19 @@ impl ModelTrait for LlamaModel {
if let Some(t) = tokenizer.next_token(next_token)? {
res += &t;
}

generated_tokens += 1;
}
if let Some(rest) = tokenizer.decode_rest()? {
res += &rest;
}

let dt = start_gen.elapsed();
info!(
"{generated_tokens} tokens generated ({} token/s)\n",
generated_tokens as f64 / dt.as_secs_f64(),
);

Ok(res)
}
}
8 changes: 5 additions & 3 deletions atoma-inference/src/models/candle/mamba.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,11 @@ impl ModelTrait for MambaModel {
..
} = input;

// clean tokenizer state
self.tokenizer.clear();

info!("Running inference on prompt: {:?}", prompt);

self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
Expand All @@ -160,7 +162,6 @@ impl ModelTrait for MambaModel {
let mut logits_processor =
LogitsProcessor::new(random_seed, Some(temperature), Some(top_p));

let mut generated_tokens = 0_usize;
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
Some(token) => token,
None => bail!("Invalid eos token"),
Expand Down Expand Up @@ -198,7 +199,6 @@ impl ModelTrait for MambaModel {

let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;

if next_token == eos_token {
break;
Expand All @@ -216,10 +216,12 @@ impl ModelTrait for MambaModel {
output.push_str(rest.as_str());
}

let generated_tokens = self.tokenizer.get_num_generated_tokens();
info!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);

Ok(output)
}
}
4 changes: 4 additions & 0 deletions atoma-inference/src/models/token_output_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ impl TokenOutputStream {
&self.tokenizer
}

pub fn get_num_generated_tokens(&self) -> usize {
self.tokens.len()
}

pub fn clear(&mut self) {
self.tokens.clear();
self.prev_index = 0;
Expand Down

0 comments on commit ce7c9a1

Please sign in to comment.