Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgeantonio21 committed Apr 8, 2024
1 parent a58d50b commit 18a1b10
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 38 deletions.
72 changes: 36 additions & 36 deletions atoma-inference/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::time::Duration;

use ed25519_consensus::SigningKey as PrivateKey;
use inference::{
models::{config::ModelsConfig, types::StableDiffusionRequest},
models::{config::ModelsConfig, types::{StableDiffusionRequest, TextRequest}},
service::{ModelService, ModelServiceError},
};

Expand Down Expand Up @@ -33,45 +33,45 @@ async fn main() -> Result<(), ModelServiceError> {

tokio::time::sleep(Duration::from_millis(5_000)).await;

// req_sender
// .send(serde_json::to_value(TextRequest {
// request_id: 0,
// prompt: "Leon, the professional is a movie".to_string(),
// model: "llama_tiny_llama_1_1b_chat".to_string(),
// max_tokens: 512,
// temperature: Some(0.0),
// random_seed: 42,
// repeat_last_n: 64,
// repeat_penalty: 1.1,
// sampled_nodes: vec![pk],
// top_p: Some(1.0),
// _top_k: 10,
// }).unwrap())
// .await
// .expect("Failed to send request");

req_sender
.send(
serde_json::to_value(StableDiffusionRequest {
request_id: 0,
prompt: "A depiction of Natalie Portman".to_string(),
uncond_prompt: "".to_string(),
height: Some(256),
width: Some(256),
num_samples: 1,
n_steps: None,
model: "stable_diffusion_v1-5".to_string(),
guidance_scale: None,
img2img: None,
img2img_strength: 0.8,
random_seed: Some(42),
sampled_nodes: vec![pk],
})
.unwrap(),
)
.send(serde_json::to_value(TextRequest {
request_id: 0,
prompt: "Leon, the professional is a movie".to_string(),
model: "mamba_370m".to_string(),
max_tokens: 512,
temperature: Some(0.0),
random_seed: 42,
repeat_last_n: 64,
repeat_penalty: 1.1,
sampled_nodes: vec![pk],
top_p: Some(1.0),
_top_k: 10,
}).unwrap())
.await
.expect("Failed to send request");

// req_sender
// .send(
// serde_json::to_value(StableDiffusionRequest {
// request_id: 0,
// prompt: "A depiction of Natalie Portman".to_string(),
// uncond_prompt: "".to_string(),
// height: Some(256),
// width: Some(256),
// num_samples: 1,
// n_steps: None,
// model: "stable_diffusion_v1-5".to_string(),
// guidance_scale: None,
// img2img: None,
// img2img_strength: 0.8,
// random_seed: Some(42),
// sampled_nodes: vec![pk],
// })
// .unwrap(),
// )
// .await
// .expect("Failed to send request");

if let Some(response) = resp_receiver.recv().await {
println!("Got a response: {:?}", response);
}
Expand Down
6 changes: 4 additions & 2 deletions atoma-inference/src/models/candle/falcon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ impl ModelTrait for FalconModel {
let mut output = String::new();

let start_gen = Instant::now();
let mut tokens_generated = 0;
for index in 0..max_tokens {
let start_gen = Instant::now();
let context_size = if self.model.config().use_cache && index > 0 {
Expand All @@ -181,12 +182,13 @@ impl ModelTrait for FalconModel {
new_tokens.push(next_token);
debug!("> {:?}", start_gen);
output.push_str(&self.tokenizer.decode(&[next_token], true)?);
tokens_generated += 1;
}
let dt = start_gen.elapsed();

info!(
"{max_tokens} tokens generated ({} token/s)\n----\n{}\n----",
max_tokens as f64 / dt.as_secs_f64(),
"{tokens_generated} tokens generated ({} token/s)\n----\n{}\n----",
tokens_generated as f64 / dt.as_secs_f64(),
self.tokenizer.decode(&new_tokens, true)?,
);

Expand Down
12 changes: 12 additions & 0 deletions atoma-inference/src/models/candle/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ impl ModelTrait for LlamaModel {
);
let mut index_pos = 0;
let mut res = String::new();
let mut tokens_generated = 0;

let start_gen = Instant::now();
for index in 0..input.max_tokens {
let (context_size, context_index) = if self.cache.use_kv_cache && index > 0 {
(1, index_pos)
Expand Down Expand Up @@ -176,10 +179,19 @@ impl ModelTrait for LlamaModel {
if let Some(t) = tokenizer.next_token(next_token)? {
res += &t;
}

tokens_generated += 1;
}
if let Some(rest) = tokenizer.decode_rest()? {
res += &rest;
}

let dt = start_gen.elapsed();
info!(
"{tokens_generated} tokens generated ({} token/s)\n",
tokens_generated as f64 / dt.as_secs_f64(),
);

Ok(res)
}
}

0 comments on commit 18a1b10

Please sign in to comment.