Skip to content

Commit

Permalink
add compatible code to subscriber and json rpc for inference node
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgeantonio21 committed Apr 12, 2024
1 parent d7255bd commit c125547
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 66 deletions.
11 changes: 8 additions & 3 deletions atoma-event-subscribe/sui/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,15 @@ async fn main() -> Result<(), SuiSubscriberError> {
let ws_url = args
.ws_socket_addr
.unwrap_or("wss://fullnode.devnet.sui.io:443".to_string());
let event_subscriber = SuiSubscriber::new(&http_url, Some(&ws_url), package_id).await?;

let (event_sender, mut event_receiver) = tokio::sync::mpsc::channel(32);
event_subscriber.subscribe(event_sender).await?;

let event_subscriber =
SuiSubscriber::new(&http_url, Some(&ws_url), package_id, event_sender).await?;

tokio::spawn(async move {
event_subscriber.subscribe().await?;
Ok::<_, SuiSubscriberError>(())
});

while let Some(event) = event_receiver.recv().await {
info!("Processed a new event: {event}")
Expand Down
29 changes: 7 additions & 22 deletions atoma-event-subscribe/sui/src/subscriber.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use sui_sdk::rpc_types::EventFilter;
use sui_sdk::types::base_types::{ObjectID, ObjectIDParseError};
use sui_sdk::{SuiClient, SuiClientBuilder};
use thiserror::Error;
use tokio::sync::{mpsc, oneshot};
use tokio::sync::mpsc;
use tracing::{error, info};

use crate::config::SuiSubscriberConfig;
Expand All @@ -15,17 +15,15 @@ use crate::TextPromptParams;
pub struct SuiSubscriber {
sui_client: SuiClient,
filter: EventFilter,
event_sender: mpsc::Sender<(Value, oneshot::Sender<Value>)>,
end_channel_sender: mpsc::Sender<oneshot::Receiver<Value>>,
event_sender: mpsc::Sender<Value>,
}

impl SuiSubscriber {
pub async fn new(
http_url: &str,
ws_url: Option<&str>,
object_id: ObjectID,
event_sender: mpsc::Sender<(Value, oneshot::Sender<Value>)>,
end_channel_sender: mpsc::Sender<oneshot::Receiver<Value>>,
event_sender: mpsc::Sender<Value>,
) -> Result<Self, SuiSubscriberError> {
let mut sui_client_builder = SuiClientBuilder::default();
if let Some(url) = ws_url {
Expand All @@ -38,27 +36,18 @@ impl SuiSubscriber {
sui_client,
filter,
event_sender,
end_channel_sender,
})
}

pub async fn new_from_config<P: AsRef<Path>>(
config_path: P,
event_sender: mpsc::Sender<(Value, oneshot::Sender<Value>)>,
end_channel_sender: mpsc::Sender<oneshot::Receiver<Value>>,
event_sender: mpsc::Sender<Value>,
) -> Result<Self, SuiSubscriberError> {
let config = SuiSubscriberConfig::from_file_path(config_path);
let http_url = config.http_url();
let ws_url = config.ws_url();
let object_id = config.object_id();
Self::new(
&http_url,
Some(&ws_url),
object_id,
event_sender,
end_channel_sender,
)
.await
Self::new(&http_url, Some(&ws_url), object_id, event_sender).await
}

pub async fn subscribe(self) -> Result<(), SuiSubscriberError> {
Expand All @@ -77,9 +66,7 @@ impl SuiSubscriber {
"The request = {:?} and sampled_nodes = {:?}",
request, sampled_nodes
);
let (oneshot_sender, oneshot_receiver) = oneshot::channel();
self.event_sender.send((event_data, oneshot_sender)).await?;
self.end_channel_sender.send(oneshot_receiver).await?;
self.event_sender.send(event_data).await?;
}
Err(e) => {
error!("Failed to get event with error: {e}");
Expand All @@ -99,7 +86,5 @@ pub enum SuiSubscriberError {
#[error("Object ID parse error: `{0}`")]
ObjectIDParseError(#[from] ObjectIDParseError),
#[error("Sender error: `{0}`")]
SendError(#[from] mpsc::error::SendError<(Value, oneshot::Sender<Value>)>),
#[error("End channel sender error: `{0}`")]
EndChannelSenderError(#[from] mpsc::error::SendError<oneshot::Receiver<Value>>),
SendError(#[from] mpsc::error::SendError<Value>),
}
13 changes: 11 additions & 2 deletions atoma-inference/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ async fn main() -> Result<(), ModelServiceError> {
tracing_subscriber::fmt::init();

let (req_sender, req_receiver) = tokio::sync::mpsc::channel(32);
let (_, subscriber_req_rx) = tokio::sync::mpsc::channel(32);
let (atoma_node_resp_tx, _) = tokio::sync::mpsc::channel(32);

let model_config = ModelsConfig::from_file_path("../inference.toml");
let private_key_bytes =
Expand All @@ -20,8 +22,15 @@ async fn main() -> Result<(), ModelServiceError> {

let private_key = PrivateKey::from(private_key_bytes);
let jrpc_port = model_config.jrpc_port();
let mut service = ModelService::start(model_config, private_key, req_receiver)
.expect("Failed to start inference service");

let mut service = ModelService::start(
model_config,
private_key,
req_receiver,
subscriber_req_rx,
atoma_node_resp_tx,
)
.expect("Failed to start inference service");

tokio::spawn(async move {
service.run().await?;
Expand Down
38 changes: 20 additions & 18 deletions atoma-inference/src/model_thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use std::{
};

use ed25519_consensus::VerificationKey as PublicKey;
use futures::stream::FuturesUnordered;
use serde_json::Value;
use thiserror::Error;
use tokio::sync::oneshot::{self, error::RecvError};
use tracing::{debug, error, info, warn};
Expand All @@ -21,8 +23,8 @@ use crate::{
};

pub struct ModelThreadCommand {
request: serde_json::Value,
response_sender: oneshot::Sender<serde_json::Value>,
request: Value,
sender: oneshot::Sender<Value>,
}

#[derive(Debug, Error)]
Expand Down Expand Up @@ -74,10 +76,7 @@ where
debug!("Start Model thread");

while let Ok(command) = self.receiver.recv() {
let ModelThreadCommand {
request,
response_sender,
} = command;
let ModelThreadCommand { request, sender } = command;

// if !request.is_node_authorized(&public_key) {
// error!("Current node, with verification key = {:?} is not authorized to run request with id = {}", public_key, request.request_id());
Expand All @@ -86,7 +85,7 @@ where
let model_input = serde_json::from_value(request)?;
let model_output = self.model.run(model_input)?;
let response = serde_json::to_value(model_output)?;
response_sender.send(response).ok();
sender.send(response).ok();
}

Ok(())
Expand All @@ -95,6 +94,7 @@ where

pub struct ModelThreadDispatcher {
model_senders: HashMap<ModelId, mpsc::Sender<ModelThreadCommand>>,
pub(crate) responses: FuturesUnordered<oneshot::Receiver<Value>>,
}

impl ModelThreadDispatcher {
Expand Down Expand Up @@ -133,14 +133,16 @@ impl ModelThreadDispatcher {
});
}

let model_dispatcher = ModelThreadDispatcher { model_senders };
let model_dispatcher = ModelThreadDispatcher {
model_senders,
responses: FuturesUnordered::new(),
};

Ok((model_dispatcher, handles))
}

fn send(&self, command: ModelThreadCommand) {
let request = command.request.clone();
let model_id = if let Some(model_id) = request.get("model") {
let model_id = if let Some(model_id) = command.request.get("model") {
model_id.as_str().unwrap().to_string()
} else {
error!("Request malformed: Missing 'model' from request");
Expand All @@ -161,14 +163,14 @@ impl ModelThreadDispatcher {
}

impl ModelThreadDispatcher {
pub(crate) fn run_inference(
&self,
(request, sender): (serde_json::Value, oneshot::Sender<serde_json::Value>),
) {
self.send(ModelThreadCommand {
request,
response_sender: sender,
});
pub(crate) fn run_json_inference(&self, (request, sender): (Value, oneshot::Sender<Value>)) {
self.send(ModelThreadCommand { request, sender });
}

pub(crate) fn run_subsbriber_inference(&self, request: Value) {
let (sender, receiver) = oneshot::channel();
self.send(ModelThreadCommand { request, sender });
self.responses.push(receiver);
}
}

Expand Down
51 changes: 43 additions & 8 deletions atoma-inference/src/service.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use candle::Error as CandleError;
use ed25519_consensus::{SigningKey as PrivateKey, VerificationKey as PublicKey};
use futures::StreamExt;
use serde_json::Value;
use std::fmt::Debug;
use std::{io, path::PathBuf, time::Instant};
use tokio::sync::mpsc::Receiver;
use tokio::sync::mpsc::{Receiver, Sender};
use tokio::sync::oneshot;
use tracing::{error, info};

Expand All @@ -21,14 +23,18 @@ pub struct ModelService {
flush_storage: bool,
public_key: PublicKey,
cache_dir: PathBuf,
request_receiver: Receiver<(serde_json::Value, oneshot::Sender<serde_json::Value>)>,
json_server_req_rx: Receiver<(Value, oneshot::Sender<Value>)>,
subscriber_req_rx: Receiver<Value>,
atoma_node_resp_tx: Sender<Value>,
}

impl ModelService {
pub fn start(
model_config: ModelsConfig,
private_key: PrivateKey,
request_receiver: Receiver<(serde_json::Value, oneshot::Sender<serde_json::Value>)>,
json_server_req_rx: Receiver<(Value, oneshot::Sender<Value>)>,
subscriber_req_rx: Receiver<Value>,
atoma_node_resp_tx: Sender<Value>,
) -> Result<Self, ModelServiceError> {
let public_key = private_key.verification_key();

Expand All @@ -47,16 +53,36 @@ impl ModelService {
flush_storage,
cache_dir,
public_key,
request_receiver,
json_server_req_rx,
subscriber_req_rx,
atoma_node_resp_tx,
})
}

pub async fn run(&mut self) -> Result<(), ModelServiceError> {
loop {
tokio::select! {
message = self.request_receiver.recv() => {
message = self.json_server_req_rx.recv() => {
if let Some(request) = message {
self.dispatcher.run_inference(request);
self.dispatcher.run_json_inference(request);
}
},
message = self.subscriber_req_rx.recv() => {
if let Some(request) = message {
self.dispatcher.run_subsbriber_inference(request);
}
}
response = self.dispatcher.responses.next() => {
if let Some(resp) = response {
match resp {
Ok(response) => {
info!("Received a new inference response: {:?}", response);
self.atoma_node_resp_tx.send(response).await.map_err(|e| ModelServiceError::SendError(e.to_string()))?;
}
Err(e) => {
error!("Found error in generating inference response: {e}");
}
}
}
}
}
Expand Down Expand Up @@ -201,11 +227,20 @@ mod tests {
file.write_all(toml_string.as_bytes())
.expect("Failed to write to file");

let (_, req_receiver) = tokio::sync::mpsc::channel(1);
let (_, json_server_req_rx) = tokio::sync::mpsc::channel(1);
let (_, subscriber_req_rx) = tokio::sync::mpsc::channel(1);
let (atoma_node_resp_tx, _) = tokio::sync::mpsc::channel(1);

let config = ModelsConfig::from_file_path(CONFIG_FILE_PATH);

let _ = ModelService::start(config, private_key, req_receiver).unwrap();
let _ = ModelService::start(
config,
private_key,
json_server_req_rx,
subscriber_req_rx,
atoma_node_resp_tx,
)
.unwrap();

std::fs::remove_file(CONFIG_FILE_PATH).unwrap();
}
Expand Down
2 changes: 2 additions & 0 deletions atoma-service/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ atoma-inference = { path = "../atoma-inference/" }
serde_json.workspace = true
thiserror.workspace = true
tokio.workspace = true
tracing.workspace = true
tracing-subscriber.workspace = true
Loading

0 comments on commit c125547

Please sign in to comment.