Skip to content

Commit

Permalink
Merge branch 'main' into 20240407-testing
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgeantonio21 committed Apr 10, 2024
2 parents ccf1ad0 + 8410601 commit 8b0ee5a
Show file tree
Hide file tree
Showing 14 changed files with 146 additions and 112 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ name: CI
merge_group:

env:
toolchain: nightly-2024-03-06
toolchain: nightly-2024-04-05
CARGO_HTTP_MULTIPLEXING: false
CARGO_TERM_COLOR: always
CARGO_UNSTABLE_SPARSE_REGISTRY: true
Expand Down
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ edition = "2021"

[workspace.dependencies]
async-trait = "0.1.78"
axum = "0.7.5"
candle = { git = "https://github.com/huggingface/candle", package = "candle-core", version = "0.5.0" }
candle-flash-attn = { git = "https://github.com/huggingface/candle", package = "candle-flash-attn", version = "0.5.0" }
candle-nn = { git = "https://github.com/huggingface/candle", package = "candle-nn", version = "0.5.0" }
Expand Down
1 change: 1 addition & 0 deletions atoma-inference/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ version.workspace = true
edition = "2021"

[dependencies]
axum.workspace = true
async-trait.workspace = true
candle.workspace = true
candle-flash-attn = { workspace = true, optional = true }
Expand Down
75 changes: 75 additions & 0 deletions atoma-inference/src/jrpc_server/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
use std::{net::Shutdown, sync::Arc};

use axum::{extract::State, http::StatusCode, routing::post, Extension, Json, Router};
use serde_json::{json, Value};
use tokio::sync::{mpsc, oneshot};

pub type RequestSender = mpsc::Sender<(Value, oneshot::Sender<Value>)>;

pub async fn run(sender: RequestSender) {
let (shutdown_signal_sender, mut shutdown_signal_receiver) = mpsc::channel::<()>(1);
let app = Router::new()
.route("/", post(jrpc_call))
.route("/healthz", post(healthz))
.layer(Extension(Arc::new(sender)))
.layer(Extension(Arc::new(shutdown_signal_sender)));

let listener = tokio::net::TcpListener::bind("0.0.0.0:21212")
.await
.unwrap();
axum::serve(listener, app)
.with_graceful_shutdown(async move { shutdown_signal_receiver.recv().await.unwrap() })
.await
.unwrap();
}

async fn healthz() -> Json<Value> {
Json(json!({
"status": "ok"
}))
}

async fn jrpc_call(
Extension(sender): Extension<Arc<RequestSender>>,
Extension(shutdown): Extension<Arc<mpsc::Sender<()>>>,
Json(input): Json<Value>,
) -> Json<Value> {
match inner_jrpc_call(sender, input, shutdown).await {
Ok(response) => Json(json!({
"result":response
})),
Err(err) => Json(json!({
"jsonrpc": "2.0",
"id": 1,
"error": {
"code": StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
"message": err
}
})),
}
}

async fn inner_jrpc_call(
sender: Arc<RequestSender>,
input: Value,
shutdown: Arc<mpsc::Sender<()>>,
) -> Result<Value, String> {
match input.get("request") {
Some(request) => {
let (one_sender, one_receiver) = oneshot::channel();
sender
.send((request.clone(), one_sender))
.await
.map_err(|e| e.to_string())?;
if let Ok(response) = one_receiver.await {
Ok(response)
} else {
Err("The request failed".to_string())
}
}
None => {
shutdown.send(()).await.unwrap();
Ok(serde_json::Value::Null)
}
}
}
1 change: 1 addition & 0 deletions atoma-inference/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod apis;
pub mod jrpc_server;
pub mod model_thread;
pub mod models;
pub mod service;
Expand Down
55 changes: 4 additions & 51 deletions atoma-inference/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use std::time::Duration;

use ed25519_consensus::SigningKey as PrivateKey;
use inference::{
models::{config::ModelsConfig, types::StableDiffusionRequest},
jrpc_server,
models::config::ModelsConfig,
service::{ModelService, ModelServiceError},
};

Expand All @@ -11,7 +10,6 @@ async fn main() -> Result<(), ModelServiceError> {
tracing_subscriber::fmt::init();

let (req_sender, req_receiver) = tokio::sync::mpsc::channel(32);
let (resp_sender, mut resp_receiver) = tokio::sync::mpsc::channel(32);

let model_config = ModelsConfig::from_file_path("../inference.toml".parse().unwrap());
let private_key_bytes =
Expand All @@ -21,60 +19,15 @@ async fn main() -> Result<(), ModelServiceError> {
.expect("Incorrect private key bytes length");

let private_key = PrivateKey::from(private_key_bytes);
let mut service = ModelService::start(model_config, private_key, req_receiver, resp_sender)
let mut service = ModelService::start(model_config, private_key, req_receiver)
.expect("Failed to start inference service");

let pk = service.public_key();

tokio::spawn(async move {
service.run().await?;
Ok::<(), ModelServiceError>(())
});

tokio::time::sleep(Duration::from_millis(5_000)).await;

// req_sender
// .send(serde_json::to_value(TextRequest {
// request_id: 0,
// prompt: "Leon, the professional is a movie".to_string(),
// model: "llama_tiny_llama_1_1b_chat".to_string(),
// max_tokens: 512,
// temperature: Some(0.0),
// random_seed: 42,
// repeat_last_n: 64,
// repeat_penalty: 1.1,
// sampled_nodes: vec![pk],
// top_p: Some(1.0),
// _top_k: 10,
// }).unwrap())
// .await
// .expect("Failed to send request");

req_sender
.send(
serde_json::to_value(StableDiffusionRequest {
request_id: 0,
prompt: "A depiction of Natalie Portman".to_string(),
uncond_prompt: "".to_string(),
height: Some(256),
width: Some(256),
num_samples: 1,
n_steps: None,
model: "stable_diffusion_v1-5".to_string(),
guidance_scale: None,
img2img: None,
img2img_strength: 0.8,
random_seed: Some(42),
sampled_nodes: vec![pk],
})
.unwrap(),
)
.await
.expect("Failed to send request");

if let Some(response) = resp_receiver.recv().await {
println!("Got a response: {:?}", response);
}
jrpc_server::run(req_sender).await;

Ok(())
}
17 changes: 6 additions & 11 deletions atoma-inference/src/model_thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::{
};

use ed25519_consensus::VerificationKey as PublicKey;
use futures::stream::FuturesUnordered;
use thiserror::Error;
use tokio::sync::oneshot::{self, error::RecvError};
use tracing::{debug, error, info, warn};
Expand Down Expand Up @@ -84,7 +83,6 @@ where
// error!("Current node, with verification key = {:?} is not authorized to run request with id = {}", public_key, request.request_id());
// continue;
// }

let model_input = serde_json::from_value(request)?;
let model_output = self.model.run(model_input)?;
let response = serde_json::to_value(model_output)?;
Expand All @@ -97,7 +95,6 @@ where

pub struct ModelThreadDispatcher {
model_senders: HashMap<ModelId, mpsc::Sender<ModelThreadCommand>>,
pub(crate) responses: FuturesUnordered<oneshot::Receiver<serde_json::Value>>,
}

impl ModelThreadDispatcher {
Expand Down Expand Up @@ -136,10 +133,7 @@ impl ModelThreadDispatcher {
});
}

let model_dispatcher = ModelThreadDispatcher {
model_senders,
responses: FuturesUnordered::new(),
};
let model_dispatcher = ModelThreadDispatcher { model_senders };

Ok((model_dispatcher, handles))
}
Expand All @@ -149,7 +143,7 @@ impl ModelThreadDispatcher {
let model_id = if let Some(model_id) = request.get("model") {
model_id.as_str().unwrap().to_string()
} else {
error!("Request malformed: Missing model_id from request");
error!("Request malformed: Missing 'model' from request");
return;
};

Expand All @@ -167,13 +161,14 @@ impl ModelThreadDispatcher {
}

impl ModelThreadDispatcher {
pub(crate) fn run_inference(&self, request: serde_json::Value) {
let (sender, receiver) = oneshot::channel();
pub(crate) fn run_inference(
&self,
(request, sender): (serde_json::Value, oneshot::Sender<serde_json::Value>),
) {
self.send(ModelThreadCommand {
request,
response_sender: sender,
});
self.responses.push(receiver);
}
}

Expand Down
6 changes: 4 additions & 2 deletions atoma-inference/src/models/candle/falcon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ impl ModelTrait for FalconModel {
let mut output = String::new();

let start_gen = Instant::now();
let mut generated_tokens = 0;
for index in 0..max_tokens {
let start_gen = Instant::now();
let context_size = if self.model.config().use_cache && index > 0 {
Expand All @@ -181,12 +182,13 @@ impl ModelTrait for FalconModel {
new_tokens.push(next_token);
debug!("> {:?}", start_gen);
output.push_str(&self.tokenizer.decode(&[next_token], true)?);
generated_tokens += 1;
}
let dt = start_gen.elapsed();

info!(
"{max_tokens} tokens generated ({} token/s)\n----\n{}\n----",
max_tokens as f64 / dt.as_secs_f64(),
"{generated_tokens} tokens generated ({} token/s)\n----\n{}\n----",
generated_tokens as f64 / dt.as_secs_f64(),
self.tokenizer.decode(&new_tokens, true)?,
);

Expand Down
12 changes: 12 additions & 0 deletions atoma-inference/src/models/candle/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ impl ModelTrait for LlamaModel {
);
let mut index_pos = 0;
let mut res = String::new();
let mut generated_tokens = 0;

let start_gen = Instant::now();
for index in 0..input.max_tokens {
let (context_size, context_index) = if self.cache.use_kv_cache && index > 0 {
(1, index_pos)
Expand Down Expand Up @@ -176,10 +179,19 @@ impl ModelTrait for LlamaModel {
if let Some(t) = tokenizer.next_token(next_token)? {
res += &t;
}

generated_tokens += 1;
}
if let Some(rest) = tokenizer.decode_rest()? {
res += &rest;
}

let dt = start_gen.elapsed();
info!(
"{generated_tokens} tokens generated ({} token/s)\n",
generated_tokens as f64 / dt.as_secs_f64(),
);

Ok(res)
}
}
Expand Down
8 changes: 5 additions & 3 deletions atoma-inference/src/models/candle/mamba.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,11 @@ impl ModelTrait for MambaModel {
..
} = input;

// clean tokenizer state
self.tokenizer.clear();

info!("Running inference on prompt: {:?}", prompt);

self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
Expand All @@ -160,7 +162,6 @@ impl ModelTrait for MambaModel {
let mut logits_processor =
LogitsProcessor::new(random_seed, Some(temperature), Some(top_p));

let mut generated_tokens = 0_usize;
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
Some(token) => token,
None => bail!("Invalid eos token"),
Expand Down Expand Up @@ -198,7 +199,6 @@ impl ModelTrait for MambaModel {

let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;

if next_token == eos_token {
break;
Expand All @@ -216,10 +216,12 @@ impl ModelTrait for MambaModel {
output.push_str(rest.as_str());
}

let generated_tokens = self.tokenizer.get_num_generated_tokens();
info!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);

Ok(output)
}
}
Expand Down
7 changes: 6 additions & 1 deletion atoma-inference/src/models/candle/stable_diffusion.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{path::PathBuf, str::FromStr};
use std::{path::PathBuf, str::FromStr, time::Instant};

use candle_transformers::models::stable_diffusion::{
self, clip::ClipTextTransformer, unet_2d::UNet2DConditionModel, vae::AutoEncoderKL,
Expand Down Expand Up @@ -243,6 +243,8 @@ impl ModelTrait for StableDiffusion {
)))?
}

let start_gen = Instant::now();

let height = input.height.unwrap_or(512);
let width = input.width.unwrap_or(512);

Expand Down Expand Up @@ -380,11 +382,14 @@ impl ModelTrait for StableDiffusion {
debug!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt);
}

let dt = start_gen.elapsed();
info!("Generated response in {:?}", dt);
debug!(
"Generating the final image for sample {}/{}.",
idx + 1,
input.num_samples
);

save_tensor_to_file(&latents, "tensor1")?;
let image = self.vae.decode(&(&latents / vae_scale)?)?;
save_tensor_to_file(&image, "tensor2")?;
Expand Down
4 changes: 4 additions & 0 deletions atoma-inference/src/models/token_output_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ impl TokenOutputStream {
&self.tokenizer
}

pub fn get_num_generated_tokens(&self) -> usize {
self.tokens.len()
}

pub fn clear(&mut self) {
self.tokens.clear();
self.prev_index = 0;
Expand Down
Loading

0 comments on commit 8b0ee5a

Please sign in to comment.