diff --git a/atoma-inference/src/models/candle/falcon.rs b/atoma-inference/src/models/candle/falcon.rs index a7deb45c..b124c520 100644 --- a/atoma-inference/src/models/candle/falcon.rs +++ b/atoma-inference/src/models/candle/falcon.rs @@ -14,7 +14,7 @@ use tracing::{debug, error, info}; use crate::models::{ candle::hub_load_safetensors, config::ModelConfig, - types::{LlmLoadData, ModelType, TextModelInput}, + types::{LlmLoadData, ModelType, TextModelInput, TextModelOutput}, ModelError, ModelTrait, }; @@ -48,7 +48,7 @@ impl FalconModel { impl ModelTrait for FalconModel { type Input = TextModelInput; - type Output = String; + type Output = TextModelOutput; type LoadData = LlmLoadData; fn fetch( @@ -192,7 +192,11 @@ impl ModelTrait for FalconModel { self.tokenizer.decode(&new_tokens, true)?, ); - Ok(output) + Ok(TextModelOutput { + output, + time: dt.as_secs_f64(), + tokens_count: generated_tokens, + }) } } diff --git a/atoma-inference/src/models/candle/llama.rs b/atoma-inference/src/models/candle/llama.rs index 0a6f803f..ad7407de 100644 --- a/atoma-inference/src/models/candle/llama.rs +++ b/atoma-inference/src/models/candle/llama.rs @@ -15,7 +15,7 @@ use tracing::info; use crate::models::{ config::ModelConfig, token_output_stream::TokenOutputStream, - types::{LlmLoadData, ModelType, TextModelInput}, + types::{LlmLoadData, ModelType, TextModelInput, TextModelOutput}, ModelError, ModelTrait, }; @@ -44,7 +44,7 @@ pub struct LlamaModel { impl ModelTrait for LlamaModel { type Input = TextModelInput; - type Output = String; + type Output = TextModelOutput; type LoadData = LlmLoadData; fn fetch( @@ -192,7 +192,11 @@ impl ModelTrait for LlamaModel { generated_tokens as f64 / dt.as_secs_f64(), ); - Ok(res) + Ok(TextModelOutput { + output: res, + time: dt.as_secs_f64(), + tokens_count: generated_tokens, + }) } } diff --git a/atoma-inference/src/models/candle/mamba.rs b/atoma-inference/src/models/candle/mamba.rs index cb2836d7..3c7918bd 100644 --- a/atoma-inference/src/models/candle/mamba.rs +++ b/atoma-inference/src/models/candle/mamba.rs @@ -17,7 +17,7 @@ use crate::{ candle::device, config::ModelConfig, token_output_stream::TokenOutputStream, - types::{LlmLoadData, ModelType, TextModelInput}, + types::{LlmLoadData, ModelType, TextModelInput, TextModelOutput}, ModelError, ModelTrait, }, }; @@ -53,7 +53,7 @@ impl MambaModel { impl ModelTrait for MambaModel { type Input = TextModelInput; - type Output = String; + type Output = TextModelOutput; type LoadData = LlmLoadData; fn fetch( @@ -222,7 +222,11 @@ impl ModelTrait for MambaModel { generated_tokens as f64 / dt.as_secs_f64(), ); - Ok(output) + Ok(TextModelOutput { + output, + time: dt.as_secs_f64(), + tokens_count: generated_tokens, + }) } } diff --git a/atoma-inference/src/models/types.rs b/atoma-inference/src/models/types.rs index 7191aa27..e60a65e2 100644 --- a/atoma-inference/src/models/types.rs +++ b/atoma-inference/src/models/types.rs @@ -225,6 +225,13 @@ impl TextModelInput { } } +#[derive(Serialize)] +pub struct TextModelOutput { + pub output: String, + pub time: f64, + pub tokens_count: usize, +} + #[derive(Clone, Debug, Deserialize, Serialize)] pub struct TextResponse { pub output: String,