diff --git a/atoma-inference/src/models/candle/falcon.rs b/atoma-inference/src/models/candle/falcon.rs index bbde8577..46f4ffb7 100644 --- a/atoma-inference/src/models/candle/falcon.rs +++ b/atoma-inference/src/models/candle/falcon.rs @@ -158,7 +158,7 @@ impl ModelTrait for FalconModel { let mut output = String::new(); let start_gen = Instant::now(); - let mut tokens_generated = 0; + let mut generated_tokens = 0; for index in 0..max_tokens { let start_gen = Instant::now(); let context_size = if self.model.config().use_cache && index > 0 { @@ -182,13 +182,13 @@ impl ModelTrait for FalconModel { new_tokens.push(next_token); debug!("> {:?}", start_gen); output.push_str(&self.tokenizer.decode(&[next_token], true)?); - tokens_generated += 1; + generated_tokens += 1; } let dt = start_gen.elapsed(); info!( - "{tokens_generated} tokens generated ({} token/s)\n----\n{}\n----", - tokens_generated as f64 / dt.as_secs_f64(), + "{generated_tokens} tokens generated ({} token/s)\n----\n{}\n----", + generated_tokens as f64 / dt.as_secs_f64(), self.tokenizer.decode(&new_tokens, true)?, ); diff --git a/atoma-inference/src/models/candle/llama.rs b/atoma-inference/src/models/candle/llama.rs index 4d2c0a90..eaed4304 100644 --- a/atoma-inference/src/models/candle/llama.rs +++ b/atoma-inference/src/models/candle/llama.rs @@ -144,7 +144,7 @@ impl ModelTrait for LlamaModel { ); let mut index_pos = 0; let mut res = String::new(); - let mut tokens_generated = 0; + let mut generated_tokens = 0; let start_gen = Instant::now(); for index in 0..input.max_tokens { @@ -180,7 +180,7 @@ impl ModelTrait for LlamaModel { res += &t; } - tokens_generated += 1; + generated_tokens += 1; } if let Some(rest) = tokenizer.decode_rest()? { res += &rest; @@ -188,8 +188,8 @@ impl ModelTrait for LlamaModel { let dt = start_gen.elapsed(); info!( - "{tokens_generated} tokens generated ({} token/s)\n", - tokens_generated as f64 / dt.as_secs_f64(), + "{generated_tokens} tokens generated ({} token/s)\n", + generated_tokens as f64 / dt.as_secs_f64(), ); Ok(res) diff --git a/atoma-inference/src/models/candle/mamba.rs b/atoma-inference/src/models/candle/mamba.rs index c77c68f0..0b1213ff 100644 --- a/atoma-inference/src/models/candle/mamba.rs +++ b/atoma-inference/src/models/candle/mamba.rs @@ -148,9 +148,11 @@ impl ModelTrait for MambaModel { .. } = input; + // clean tokenizer state + self.tokenizer.clear(); + info!("Running inference on prompt: {:?}", prompt); - self.tokenizer.clear(); let mut tokens = self .tokenizer .tokenizer() @@ -160,7 +162,6 @@ impl ModelTrait for MambaModel { let mut logits_processor = LogitsProcessor::new(random_seed, Some(temperature), Some(top_p)); - let mut generated_tokens = 0_usize; let eos_token = match self.tokenizer.get_token("<|endoftext|>") { Some(token) => token, None => bail!("Invalid eos token"), @@ -198,7 +199,6 @@ impl ModelTrait for MambaModel { let next_token = logits_processor.sample(&logits)?; tokens.push(next_token); - generated_tokens += 1; if next_token == eos_token { break; @@ -216,10 +216,12 @@ impl ModelTrait for MambaModel { output.push_str(rest.as_str()); } + let generated_tokens = self.tokenizer.get_num_generated_tokens(); info!( "\n{generated_tokens} tokens generated ({:.2} token/s)", generated_tokens as f64 / dt.as_secs_f64(), ); + Ok(output) } } diff --git a/atoma-inference/src/models/token_output_stream.rs b/atoma-inference/src/models/token_output_stream.rs index 33bfb27a..585a3f51 100644 --- a/atoma-inference/src/models/token_output_stream.rs +++ b/atoma-inference/src/models/token_output_stream.rs @@ -78,6 +78,10 @@ impl TokenOutputStream { &self.tokenizer } + pub fn get_num_generated_tokens(&self) -> usize { + self.tokens.len() + } + pub fn clear(&mut self) { self.tokens.clear(); self.prev_index = 0;