Skip to content

Commit

Permalink
address new PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgeantonio21 committed Mar 27, 2024
1 parent 54e0abd commit e60b586
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 26 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ candle-nn = { git = "https://github.com/jorgeantonio21/candle/", package = "cand
candle-transformers = { git = "https://github.com/jorgeantonio21/candle/", package = "candle-transformers", branch = "ja-send-sync-sd-scheduler" }
config = "0.14.0"
ed25519-consensus = "2.1.0"
futures = "0.3.30"
hf-hub = "0.3.2"
serde = "1.0.197"
serde_json = "1.0.114"
Expand Down
1 change: 1 addition & 0 deletions atoma-inference/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ candle-nn.workspace = true
candle-transformers.workspace = true
config.true = true
ed25519-consensus.workspace = true
futures.workspace = true
hf-hub.workspace = true
reqwest = { workspace = true, features = ["json"] }
serde = { workspace = true, features = ["derive"] }
Expand Down
14 changes: 7 additions & 7 deletions atoma-inference/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ use inference::service::InferenceService;
async fn main() {
tracing_subscriber::fmt::init();

let (_, receiver) = tokio::sync::mpsc::channel(32);
// let (_, receiver) = tokio::sync::mpsc::channel(32);

let _ = InferenceService::start::<Model>(
"../inference.toml".parse().unwrap(),
"../private_key".parse().unwrap(),
receiver,
)
.expect("Failed to start inference service");
// let _ = InferenceService::start::<Model>(
// "../inference.toml".parse().unwrap(),
// "../private_key".parse().unwrap(),
// receiver,
// )
// .expect("Failed to start inference service");

// inference_service
// .run_inference(InferenceRequest {
Expand Down
15 changes: 8 additions & 7 deletions atoma-inference/src/model_thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::collections::HashMap;

use candle_nn::VarBuilder;
use ed25519_consensus::VerificationKey as PublicKey;
use futures::stream::FuturesUnordered;
use thiserror::Error;
use tokio::sync::oneshot::{self, error::RecvError};
use tracing::{debug, error, warn};
Expand Down Expand Up @@ -87,9 +88,9 @@ where
}
}

#[derive(Clone)]
pub struct ModelThreadDispatcher {
model_senders: HashMap<ModelType, std::sync::mpsc::Sender<ModelThreadCommand>>,
pub(crate) responses: FuturesUnordered<oneshot::Receiver<InferenceResponse>>,
}

impl ModelThreadDispatcher {
Expand Down Expand Up @@ -125,7 +126,10 @@ impl ModelThreadDispatcher {
model_senders.insert(model_type, model_sender);
}

let model_dispatcher = ModelThreadDispatcher { model_senders };
let model_dispatcher = ModelThreadDispatcher {
model_senders,
responses: FuturesUnordered::new(),
};

Ok((model_dispatcher, handles))
}
Expand All @@ -146,12 +150,9 @@ impl ModelThreadDispatcher {
}

impl ModelThreadDispatcher {
pub(crate) async fn run_inference(
&self,
request: InferenceRequest,
) -> Result<InferenceResponse, ModelThreadError> {
pub(crate) fn run_inference(&self, request: InferenceRequest) {
let (sender, receiver) = oneshot::channel();
self.send(ModelThreadCommand(request, sender));
receiver.await.map_err(ModelThreadError::Shutdown)
self.responses.push(receiver);
}
}
39 changes: 27 additions & 12 deletions atoma-inference/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ use candle::{Device, Error as CandleError};
use candle_nn::var_builder::VarBuilder;
use candle_transformers::models::llama::Cache as LlamaCache;
use ed25519_consensus::SigningKey as PrivateKey;
use futures::StreamExt;
use hf_hub::api::sync::Api;
use std::{io, path::PathBuf, time::Instant};
use tokenizers::Tokenizer;
use tokio::sync::mpsc::{error::SendError, Receiver};
use tracing::info;
use tracing::{error, info};

use thiserror::Error;

Expand All @@ -22,14 +23,14 @@ pub struct InferenceService {
model_thread_handle: Vec<ModelThreadHandle>,
dispatcher: ModelThreadDispatcher,
start_time: Instant,
_request_receiver: Receiver<InferenceRequest>,
request_receiver: Receiver<InferenceRequest>,
}

impl InferenceService {
pub fn start<T>(
config_file_path: PathBuf,
private_key_path: PathBuf,
_request_receiver: Receiver<InferenceRequest>,
request_receiver: Receiver<InferenceRequest>,
) -> Result<Self, InferenceServiceError>
where
T: ModelApi + Send + 'static,
Expand Down Expand Up @@ -120,18 +121,32 @@ impl InferenceService {
dispatcher,
model_thread_handle,
start_time,
_request_receiver,
request_receiver,
})
}

pub async fn run_inference(
&self,
inference_request: InferenceRequest,
) -> Result<InferenceResponse, InferenceServiceError> {
self.dispatcher
.run_inference(inference_request)
.await
.map_err(InferenceServiceError::ModelThreadError)
pub async fn run(&mut self) -> Result<InferenceResponse, InferenceServiceError> {
loop {
tokio::select! {
message = self.request_receiver.recv() => {
if let Some(request) = message {
self.dispatcher.run_inference(request);
}
}
response = self.dispatcher.responses.next() => {
if let Some(resp) = response {
match resp {
Ok(response) => {
info!("Received a new inference response: {:?}", response);
}
Err(e) => {
error!("Found error in generating inference response: {e}");
}
}
}
}
}
}
}
}

Expand Down
1 change: 1 addition & 0 deletions atoma-inference/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub type Temperature = f32;

#[derive(Clone, Debug)]
pub struct InferenceRequest {
pub request_id: u128,
pub prompt: String,
pub model: ModelType,
pub max_tokens: usize,
Expand Down

0 comments on commit e60b586

Please sign in to comment.