Skip to content

Commit

Permalink
feat: add info to text output
Browse files Browse the repository at this point in the history
  • Loading branch information
Cifko committed Apr 10, 2024
1 parent 1fa6475 commit a9817bb
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 9 deletions.
10 changes: 7 additions & 3 deletions atoma-inference/src/models/candle/falcon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use tracing::{debug, error, info};
use crate::models::{
candle::hub_load_safetensors,
config::ModelConfig,
types::{LlmLoadData, ModelType, TextModelInput},
types::{LlmLoadData, ModelType, TextModelInput, TextModelOutput},
ModelError, ModelTrait,
};

Expand Down Expand Up @@ -48,7 +48,7 @@ impl FalconModel {

impl ModelTrait for FalconModel {
type Input = TextModelInput;
type Output = String;
type Output = TextModelOutput;
type LoadData = LlmLoadData;

fn fetch(
Expand Down Expand Up @@ -192,7 +192,11 @@ impl ModelTrait for FalconModel {
self.tokenizer.decode(&new_tokens, true)?,
);

Ok(output)
Ok(TextModelOutput {
output,
time: dt.as_secs_f64(),
tokens_count: generated_tokens,
})
}
}

Expand Down
10 changes: 7 additions & 3 deletions atoma-inference/src/models/candle/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use tracing::info;
use crate::models::{
config::ModelConfig,
token_output_stream::TokenOutputStream,
types::{LlmLoadData, ModelType, TextModelInput},
types::{LlmLoadData, ModelType, TextModelInput, TextModelOutput},
ModelError, ModelTrait,
};

Expand Down Expand Up @@ -44,7 +44,7 @@ pub struct LlamaModel {

impl ModelTrait for LlamaModel {
type Input = TextModelInput;
type Output = String;
type Output = TextModelOutput;
type LoadData = LlmLoadData;

fn fetch(
Expand Down Expand Up @@ -192,7 +192,11 @@ impl ModelTrait for LlamaModel {
generated_tokens as f64 / dt.as_secs_f64(),
);

Ok(res)
Ok(TextModelOutput {
output: res,
time: dt.as_secs_f64(),
tokens_count: generated_tokens,
})
}
}

Expand Down
10 changes: 7 additions & 3 deletions atoma-inference/src/models/candle/mamba.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use crate::{
candle::device,
config::ModelConfig,
token_output_stream::TokenOutputStream,
types::{LlmLoadData, ModelType, TextModelInput},
types::{LlmLoadData, ModelType, TextModelInput, TextModelOutput},
ModelError, ModelTrait,
},
};
Expand Down Expand Up @@ -53,7 +53,7 @@ impl MambaModel {

impl ModelTrait for MambaModel {
type Input = TextModelInput;
type Output = String;
type Output = TextModelOutput;
type LoadData = LlmLoadData;

fn fetch(
Expand Down Expand Up @@ -222,7 +222,11 @@ impl ModelTrait for MambaModel {
generated_tokens as f64 / dt.as_secs_f64(),
);

Ok(output)
Ok(TextModelOutput {
output,
time: dt.as_secs_f64(),
tokens_count: generated_tokens,
})
}
}

Expand Down
7 changes: 7 additions & 0 deletions atoma-inference/src/models/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,13 @@ impl TextModelInput {
}
}

#[derive(Serialize)]
pub struct TextModelOutput {
pub output: String,
pub time: f64,
pub tokens_count: usize,
}

#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct TextResponse {
pub output: String,
Expand Down

0 comments on commit a9817bb

Please sign in to comment.