From c125547d2f5818c394edc6c89892ae326bcff1a9 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Fri, 12 Apr 2024 22:52:41 +0100 Subject: [PATCH] add compatible code to subscriber and json rpc for inference node --- atoma-event-subscribe/sui/src/main.rs | 11 +++-- atoma-event-subscribe/sui/src/subscriber.rs | 29 +++--------- atoma-inference/src/main.rs | 13 +++++- atoma-inference/src/model_thread.rs | 38 +++++++-------- atoma-inference/src/service.rs | 51 +++++++++++++++++---- atoma-service/Cargo.toml | 2 + atoma-service/src/lib.rs | 43 +++++++++++------ 7 files changed, 121 insertions(+), 66 deletions(-) diff --git a/atoma-event-subscribe/sui/src/main.rs b/atoma-event-subscribe/sui/src/main.rs index b8072e67..b40c265c 100644 --- a/atoma-event-subscribe/sui/src/main.rs +++ b/atoma-event-subscribe/sui/src/main.rs @@ -29,10 +29,15 @@ async fn main() -> Result<(), SuiSubscriberError> { let ws_url = args .ws_socket_addr .unwrap_or("wss://fullnode.devnet.sui.io:443".to_string()); - let event_subscriber = SuiSubscriber::new(&http_url, Some(&ws_url), package_id).await?; - let (event_sender, mut event_receiver) = tokio::sync::mpsc::channel(32); - event_subscriber.subscribe(event_sender).await?; + + let event_subscriber = + SuiSubscriber::new(&http_url, Some(&ws_url), package_id, event_sender).await?; + + tokio::spawn(async move { + event_subscriber.subscribe().await?; + Ok::<_, SuiSubscriberError>(()) + }); while let Some(event) = event_receiver.recv().await { info!("Processed a new event: {event}") diff --git a/atoma-event-subscribe/sui/src/subscriber.rs b/atoma-event-subscribe/sui/src/subscriber.rs index 4f508b41..6d62a925 100644 --- a/atoma-event-subscribe/sui/src/subscriber.rs +++ b/atoma-event-subscribe/sui/src/subscriber.rs @@ -6,7 +6,7 @@ use sui_sdk::rpc_types::EventFilter; use sui_sdk::types::base_types::{ObjectID, ObjectIDParseError}; use sui_sdk::{SuiClient, SuiClientBuilder}; use thiserror::Error; -use tokio::sync::{mpsc, oneshot}; +use tokio::sync::mpsc; use tracing::{error, info}; use crate::config::SuiSubscriberConfig; @@ -15,8 +15,7 @@ use crate::TextPromptParams; pub struct SuiSubscriber { sui_client: SuiClient, filter: EventFilter, - event_sender: mpsc::Sender<(Value, oneshot::Sender)>, - end_channel_sender: mpsc::Sender>, + event_sender: mpsc::Sender, } impl SuiSubscriber { @@ -24,8 +23,7 @@ impl SuiSubscriber { http_url: &str, ws_url: Option<&str>, object_id: ObjectID, - event_sender: mpsc::Sender<(Value, oneshot::Sender)>, - end_channel_sender: mpsc::Sender>, + event_sender: mpsc::Sender, ) -> Result { let mut sui_client_builder = SuiClientBuilder::default(); if let Some(url) = ws_url { @@ -38,27 +36,18 @@ impl SuiSubscriber { sui_client, filter, event_sender, - end_channel_sender, }) } pub async fn new_from_config>( config_path: P, - event_sender: mpsc::Sender<(Value, oneshot::Sender)>, - end_channel_sender: mpsc::Sender>, + event_sender: mpsc::Sender, ) -> Result { let config = SuiSubscriberConfig::from_file_path(config_path); let http_url = config.http_url(); let ws_url = config.ws_url(); let object_id = config.object_id(); - Self::new( - &http_url, - Some(&ws_url), - object_id, - event_sender, - end_channel_sender, - ) - .await + Self::new(&http_url, Some(&ws_url), object_id, event_sender).await } pub async fn subscribe(self) -> Result<(), SuiSubscriberError> { @@ -77,9 +66,7 @@ impl SuiSubscriber { "The request = {:?} and sampled_nodes = {:?}", request, sampled_nodes ); - let (oneshot_sender, oneshot_receiver) = oneshot::channel(); - self.event_sender.send((event_data, oneshot_sender)).await?; - self.end_channel_sender.send(oneshot_receiver).await?; + self.event_sender.send(event_data).await?; } Err(e) => { error!("Failed to get event with error: {e}"); @@ -99,7 +86,5 @@ pub enum SuiSubscriberError { #[error("Object ID parse error: `{0}`")] ObjectIDParseError(#[from] ObjectIDParseError), #[error("Sender error: `{0}`")] - SendError(#[from] mpsc::error::SendError<(Value, oneshot::Sender)>), - #[error("End channel sender error: `{0}`")] - EndChannelSenderError(#[from] mpsc::error::SendError>), + SendError(#[from] mpsc::error::SendError), } diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index e194bdd3..388ea546 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -10,6 +10,8 @@ async fn main() -> Result<(), ModelServiceError> { tracing_subscriber::fmt::init(); let (req_sender, req_receiver) = tokio::sync::mpsc::channel(32); + let (_, subscriber_req_rx) = tokio::sync::mpsc::channel(32); + let (atoma_node_resp_tx, _) = tokio::sync::mpsc::channel(32); let model_config = ModelsConfig::from_file_path("../inference.toml"); let private_key_bytes = @@ -20,8 +22,15 @@ async fn main() -> Result<(), ModelServiceError> { let private_key = PrivateKey::from(private_key_bytes); let jrpc_port = model_config.jrpc_port(); - let mut service = ModelService::start(model_config, private_key, req_receiver) - .expect("Failed to start inference service"); + + let mut service = ModelService::start( + model_config, + private_key, + req_receiver, + subscriber_req_rx, + atoma_node_resp_tx, + ) + .expect("Failed to start inference service"); tokio::spawn(async move { service.run().await?; diff --git a/atoma-inference/src/model_thread.rs b/atoma-inference/src/model_thread.rs index 3339c9a2..dd004ab6 100644 --- a/atoma-inference/src/model_thread.rs +++ b/atoma-inference/src/model_thread.rs @@ -3,6 +3,8 @@ use std::{ }; use ed25519_consensus::VerificationKey as PublicKey; +use futures::stream::FuturesUnordered; +use serde_json::Value; use thiserror::Error; use tokio::sync::oneshot::{self, error::RecvError}; use tracing::{debug, error, info, warn}; @@ -21,8 +23,8 @@ use crate::{ }; pub struct ModelThreadCommand { - request: serde_json::Value, - response_sender: oneshot::Sender, + request: Value, + sender: oneshot::Sender, } #[derive(Debug, Error)] @@ -74,10 +76,7 @@ where debug!("Start Model thread"); while let Ok(command) = self.receiver.recv() { - let ModelThreadCommand { - request, - response_sender, - } = command; + let ModelThreadCommand { request, sender } = command; // if !request.is_node_authorized(&public_key) { // error!("Current node, with verification key = {:?} is not authorized to run request with id = {}", public_key, request.request_id()); @@ -86,7 +85,7 @@ where let model_input = serde_json::from_value(request)?; let model_output = self.model.run(model_input)?; let response = serde_json::to_value(model_output)?; - response_sender.send(response).ok(); + sender.send(response).ok(); } Ok(()) @@ -95,6 +94,7 @@ where pub struct ModelThreadDispatcher { model_senders: HashMap>, + pub(crate) responses: FuturesUnordered>, } impl ModelThreadDispatcher { @@ -133,14 +133,16 @@ impl ModelThreadDispatcher { }); } - let model_dispatcher = ModelThreadDispatcher { model_senders }; + let model_dispatcher = ModelThreadDispatcher { + model_senders, + responses: FuturesUnordered::new(), + }; Ok((model_dispatcher, handles)) } fn send(&self, command: ModelThreadCommand) { - let request = command.request.clone(); - let model_id = if let Some(model_id) = request.get("model") { + let model_id = if let Some(model_id) = command.request.get("model") { model_id.as_str().unwrap().to_string() } else { error!("Request malformed: Missing 'model' from request"); @@ -161,14 +163,14 @@ impl ModelThreadDispatcher { } impl ModelThreadDispatcher { - pub(crate) fn run_inference( - &self, - (request, sender): (serde_json::Value, oneshot::Sender), - ) { - self.send(ModelThreadCommand { - request, - response_sender: sender, - }); + pub(crate) fn run_json_inference(&self, (request, sender): (Value, oneshot::Sender)) { + self.send(ModelThreadCommand { request, sender }); + } + + pub(crate) fn run_subsbriber_inference(&self, request: Value) { + let (sender, receiver) = oneshot::channel(); + self.send(ModelThreadCommand { request, sender }); + self.responses.push(receiver); } } diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index 657d7fc3..5f5570d3 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -1,8 +1,10 @@ use candle::Error as CandleError; use ed25519_consensus::{SigningKey as PrivateKey, VerificationKey as PublicKey}; +use futures::StreamExt; +use serde_json::Value; use std::fmt::Debug; use std::{io, path::PathBuf, time::Instant}; -use tokio::sync::mpsc::Receiver; +use tokio::sync::mpsc::{Receiver, Sender}; use tokio::sync::oneshot; use tracing::{error, info}; @@ -21,14 +23,18 @@ pub struct ModelService { flush_storage: bool, public_key: PublicKey, cache_dir: PathBuf, - request_receiver: Receiver<(serde_json::Value, oneshot::Sender)>, + json_server_req_rx: Receiver<(Value, oneshot::Sender)>, + subscriber_req_rx: Receiver, + atoma_node_resp_tx: Sender, } impl ModelService { pub fn start( model_config: ModelsConfig, private_key: PrivateKey, - request_receiver: Receiver<(serde_json::Value, oneshot::Sender)>, + json_server_req_rx: Receiver<(Value, oneshot::Sender)>, + subscriber_req_rx: Receiver, + atoma_node_resp_tx: Sender, ) -> Result { let public_key = private_key.verification_key(); @@ -47,16 +53,36 @@ impl ModelService { flush_storage, cache_dir, public_key, - request_receiver, + json_server_req_rx, + subscriber_req_rx, + atoma_node_resp_tx, }) } pub async fn run(&mut self) -> Result<(), ModelServiceError> { loop { tokio::select! { - message = self.request_receiver.recv() => { + message = self.json_server_req_rx.recv() => { if let Some(request) = message { - self.dispatcher.run_inference(request); + self.dispatcher.run_json_inference(request); + } + }, + message = self.subscriber_req_rx.recv() => { + if let Some(request) = message { + self.dispatcher.run_subsbriber_inference(request); + } + } + response = self.dispatcher.responses.next() => { + if let Some(resp) = response { + match resp { + Ok(response) => { + info!("Received a new inference response: {:?}", response); + self.atoma_node_resp_tx.send(response).await.map_err(|e| ModelServiceError::SendError(e.to_string()))?; + } + Err(e) => { + error!("Found error in generating inference response: {e}"); + } + } } } } @@ -201,11 +227,20 @@ mod tests { file.write_all(toml_string.as_bytes()) .expect("Failed to write to file"); - let (_, req_receiver) = tokio::sync::mpsc::channel(1); + let (_, json_server_req_rx) = tokio::sync::mpsc::channel(1); + let (_, subscriber_req_rx) = tokio::sync::mpsc::channel(1); + let (atoma_node_resp_tx, _) = tokio::sync::mpsc::channel(1); let config = ModelsConfig::from_file_path(CONFIG_FILE_PATH); - let _ = ModelService::start(config, private_key, req_receiver).unwrap(); + let _ = ModelService::start( + config, + private_key, + json_server_req_rx, + subscriber_req_rx, + atoma_node_resp_tx, + ) + .unwrap(); std::fs::remove_file(CONFIG_FILE_PATH).unwrap(); } diff --git a/atoma-service/Cargo.toml b/atoma-service/Cargo.toml index b0a02d96..ea1087ca 100644 --- a/atoma-service/Cargo.toml +++ b/atoma-service/Cargo.toml @@ -11,3 +11,5 @@ atoma-inference = { path = "../atoma-inference/" } serde_json.workspace = true thiserror.workspace = true tokio.workspace = true +tracing.workspace = true +tracing-subscriber.workspace = true diff --git a/atoma-service/src/lib.rs b/atoma-service/src/lib.rs index 0fa7bd89..bbcb108d 100644 --- a/atoma-service/src/lib.rs +++ b/atoma-service/src/lib.rs @@ -1,16 +1,18 @@ -use std::{ - io, - path::{Path, PathBuf}, -}; +use std::{io, path::Path}; use atoma_inference::{ - models::config::{ModelConfig, ModelsConfig}, + models::config::ModelsConfig, service::{ModelService, ModelServiceError}, PrivateKey, }; use atoma_sui::subscriber::{SuiSubscriber, SuiSubscriberError}; +use serde_json::Value; use thiserror::Error; -use tokio::{sync::mpsc, task::JoinHandle}; +use tokio::{ + sync::{mpsc, mpsc::Receiver, oneshot}, + task::JoinHandle, +}; +use tracing::info; const CHANNEL_SIZE: usize = 32; @@ -20,11 +22,15 @@ pub struct AtomaNode { } impl AtomaNode { - pub async fn start>( + pub async fn start

( model_config_path: P, private_key_path: P, sui_subscriber_path: P, - ) -> Result { + json_server_req_rx: Receiver<(Value, oneshot::Sender)>, + ) -> Result + where + P: AsRef + Send + 'static, + { let model_config = ModelsConfig::from_file_path(model_config_path); let private_key_bytes = std::fs::read(private_key_path)?; @@ -34,10 +40,17 @@ impl AtomaNode { let private_key = PrivateKey::from(private_key_bytes); - let (request_sender, request_receiver) = mpsc::channel(CHANNEL_SIZE); + let (subscriber_req_tx, subscriber_req_rx) = mpsc::channel(CHANNEL_SIZE); + let (atoma_node_resp_tx, mut atoma_node_resp_rx) = mpsc::channel(CHANNEL_SIZE); let inference_service_handle = tokio::spawn(async move { - let model_service = ModelService::start(model_config, private_key, request_receiver)?; + let mut model_service = ModelService::start( + model_config, + private_key, + json_server_req_rx, + subscriber_req_rx, + atoma_node_resp_tx, + )?; model_service .run() .await @@ -46,17 +59,21 @@ impl AtomaNode { let sui_subscriber_handle = tokio::spawn(async move { let sui_event_subscriber = - SuiSubscriber::new_from_config(sui_subscriber_path, request_sender).await?; + SuiSubscriber::new_from_config(sui_subscriber_path, subscriber_req_tx).await?; sui_event_subscriber .subscribe() .await .map_err(AtomaNodeError::SuiSubscriberError) }); - Self { + while let Some(response) = atoma_node_resp_rx.recv().await { + info!("Received new response: {response}"); + } + + Ok(Self { inference_service_handle, sui_subscriber_handle, - } + }) } }