Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add logic to log tokens generated per seconds, for a few models #31

Merged
merged 4 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 36 additions & 36 deletions atoma-inference/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::time::Duration;

use ed25519_consensus::SigningKey as PrivateKey;
use inference::{
models::{config::ModelsConfig, types::StableDiffusionRequest},
models::{config::ModelsConfig, types::{StableDiffusionRequest, TextRequest}},
service::{ModelService, ModelServiceError},
};

Expand Down Expand Up @@ -33,45 +33,45 @@ async fn main() -> Result<(), ModelServiceError> {

tokio::time::sleep(Duration::from_millis(5_000)).await;

// req_sender
// .send(serde_json::to_value(TextRequest {
// request_id: 0,
// prompt: "Leon, the professional is a movie".to_string(),
// model: "llama_tiny_llama_1_1b_chat".to_string(),
// max_tokens: 512,
// temperature: Some(0.0),
// random_seed: 42,
// repeat_last_n: 64,
// repeat_penalty: 1.1,
// sampled_nodes: vec![pk],
// top_p: Some(1.0),
// _top_k: 10,
// }).unwrap())
// .await
// .expect("Failed to send request");

req_sender
.send(
serde_json::to_value(StableDiffusionRequest {
request_id: 0,
prompt: "A depiction of Natalie Portman".to_string(),
uncond_prompt: "".to_string(),
height: Some(256),
width: Some(256),
num_samples: 1,
n_steps: None,
model: "stable_diffusion_v1-5".to_string(),
guidance_scale: None,
img2img: None,
img2img_strength: 0.8,
random_seed: Some(42),
sampled_nodes: vec![pk],
})
.unwrap(),
)
.send(serde_json::to_value(TextRequest {
request_id: 0,
prompt: "Leon, the professional is a movie".to_string(),
model: "mamba_370m".to_string(),
max_tokens: 512,
temperature: Some(0.0),
random_seed: 42,
repeat_last_n: 64,
repeat_penalty: 1.1,
sampled_nodes: vec![pk],
top_p: Some(1.0),
_top_k: 10,
}).unwrap())
.await
.expect("Failed to send request");

// req_sender
// .send(
// serde_json::to_value(StableDiffusionRequest {
// request_id: 0,
// prompt: "A depiction of Natalie Portman".to_string(),
// uncond_prompt: "".to_string(),
// height: Some(256),
// width: Some(256),
// num_samples: 1,
// n_steps: None,
// model: "stable_diffusion_v1-5".to_string(),
// guidance_scale: None,
// img2img: None,
// img2img_strength: 0.8,
// random_seed: Some(42),
// sampled_nodes: vec![pk],
// })
// .unwrap(),
// )
// .await
// .expect("Failed to send request");

if let Some(response) = resp_receiver.recv().await {
println!("Got a response: {:?}", response);
}
Expand Down
6 changes: 4 additions & 2 deletions atoma-inference/src/models/candle/falcon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ impl ModelTrait for FalconModel {
let mut output = String::new();

let start_gen = Instant::now();
let mut generated_tokens = 0;
for index in 0..max_tokens {
let start_gen = Instant::now();
let context_size = if self.model.config().use_cache && index > 0 {
Expand All @@ -181,12 +182,13 @@ impl ModelTrait for FalconModel {
new_tokens.push(next_token);
debug!("> {:?}", start_gen);
output.push_str(&self.tokenizer.decode(&[next_token], true)?);
generated_tokens += 1;
}
let dt = start_gen.elapsed();

info!(
"{max_tokens} tokens generated ({} token/s)\n----\n{}\n----",
max_tokens as f64 / dt.as_secs_f64(),
"{generated_tokens} tokens generated ({} token/s)\n----\n{}\n----",
generated_tokens as f64 / dt.as_secs_f64(),
self.tokenizer.decode(&new_tokens, true)?,
);

Expand Down
12 changes: 12 additions & 0 deletions atoma-inference/src/models/candle/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ impl ModelTrait for LlamaModel {
);
let mut index_pos = 0;
let mut res = String::new();
let mut generated_tokens = 0;

let start_gen = Instant::now();
for index in 0..input.max_tokens {
let (context_size, context_index) = if self.cache.use_kv_cache && index > 0 {
(1, index_pos)
Expand Down Expand Up @@ -176,10 +179,19 @@ impl ModelTrait for LlamaModel {
if let Some(t) = tokenizer.next_token(next_token)? {
res += &t;
}

generated_tokens += 1;
}
if let Some(rest) = tokenizer.decode_rest()? {
res += &rest;
}

let dt = start_gen.elapsed();
info!(
"{generated_tokens} tokens generated ({} token/s)\n",
generated_tokens as f64 / dt.as_secs_f64(),
);

Ok(res)
}
}
8 changes: 5 additions & 3 deletions atoma-inference/src/models/candle/mamba.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,11 @@ impl ModelTrait for MambaModel {
..
} = input;

// clean tokenizer state
self.tokenizer.clear();

info!("Running inference on prompt: {:?}", prompt);

self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
Expand All @@ -160,7 +162,6 @@ impl ModelTrait for MambaModel {
let mut logits_processor =
LogitsProcessor::new(random_seed, Some(temperature), Some(top_p));

let mut generated_tokens = 0_usize;
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
Some(token) => token,
None => bail!("Invalid eos token"),
Expand Down Expand Up @@ -198,7 +199,6 @@ impl ModelTrait for MambaModel {

let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;

if next_token == eos_token {
break;
Expand All @@ -216,10 +216,12 @@ impl ModelTrait for MambaModel {
output.push_str(rest.as_str());
}

let generated_tokens = self.tokenizer.get_num_generated_tokens();
info!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);

Ok(output)
}
}
4 changes: 4 additions & 0 deletions atoma-inference/src/models/token_output_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ impl TokenOutputStream {
&self.tokenizer
}

pub fn get_num_generated_tokens(&self) -> usize {
self.tokens.len()
}

pub fn clear(&mut self) {
self.tokens.clear();
self.prev_index = 0;
Expand Down
Loading