Skip to content

Commit

Permalink
add tokens output stream generated tokens count
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgeantonio21 committed Apr 8, 2024
1 parent 18a1b10 commit ed44986
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 11 deletions.
8 changes: 4 additions & 4 deletions atoma-inference/src/models/candle/falcon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ impl ModelTrait for FalconModel {
let mut output = String::new();

let start_gen = Instant::now();
let mut tokens_generated = 0;
let mut generated_tokens = 0;
for index in 0..max_tokens {
let start_gen = Instant::now();
let context_size = if self.model.config().use_cache && index > 0 {
Expand All @@ -182,13 +182,13 @@ impl ModelTrait for FalconModel {
new_tokens.push(next_token);
debug!("> {:?}", start_gen);
output.push_str(&self.tokenizer.decode(&[next_token], true)?);
tokens_generated += 1;
generated_tokens += 1;
}
let dt = start_gen.elapsed();

info!(
"{tokens_generated} tokens generated ({} token/s)\n----\n{}\n----",
tokens_generated as f64 / dt.as_secs_f64(),
"{generated_tokens} tokens generated ({} token/s)\n----\n{}\n----",
generated_tokens as f64 / dt.as_secs_f64(),
self.tokenizer.decode(&new_tokens, true)?,
);

Expand Down
8 changes: 4 additions & 4 deletions atoma-inference/src/models/candle/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ impl ModelTrait for LlamaModel {
);
let mut index_pos = 0;
let mut res = String::new();
let mut tokens_generated = 0;
let mut generated_tokens = 0;

let start_gen = Instant::now();
for index in 0..input.max_tokens {
Expand Down Expand Up @@ -180,16 +180,16 @@ impl ModelTrait for LlamaModel {
res += &t;
}

tokens_generated += 1;
generated_tokens += 1;
}
if let Some(rest) = tokenizer.decode_rest()? {
res += &rest;
}

let dt = start_gen.elapsed();
info!(
"{tokens_generated} tokens generated ({} token/s)\n",
tokens_generated as f64 / dt.as_secs_f64(),
"{generated_tokens} tokens generated ({} token/s)\n",
generated_tokens as f64 / dt.as_secs_f64(),
);

Ok(res)
Expand Down
8 changes: 5 additions & 3 deletions atoma-inference/src/models/candle/mamba.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,11 @@ impl ModelTrait for MambaModel {
..
} = input;

// clean tokenizer state
self.tokenizer.clear();

info!("Running inference on prompt: {:?}", prompt);

self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
Expand All @@ -160,7 +162,6 @@ impl ModelTrait for MambaModel {
let mut logits_processor =
LogitsProcessor::new(random_seed, Some(temperature), Some(top_p));

let mut generated_tokens = 0_usize;
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
Some(token) => token,
None => bail!("Invalid eos token"),
Expand Down Expand Up @@ -198,7 +199,6 @@ impl ModelTrait for MambaModel {

let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;

if next_token == eos_token {
break;
Expand All @@ -216,10 +216,12 @@ impl ModelTrait for MambaModel {
output.push_str(rest.as_str());
}

let generated_tokens = self.tokenizer.get_num_generated_tokens();
info!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);

Ok(output)
}
}
4 changes: 4 additions & 0 deletions atoma-inference/src/models/token_output_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ impl TokenOutputStream {
&self.tokenizer
}

pub fn get_num_generated_tokens(&self) -> usize {
self.tokens.len()
}

pub fn clear(&mut self) {
self.tokens.clear();
self.prev_index = 0;
Expand Down

0 comments on commit ed44986

Please sign in to comment.