Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgeantonio21 committed May 8, 2024
1 parent 0afc273 commit 7c13e93
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 50 deletions.
16 changes: 2 additions & 14 deletions atoma-client/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::path::Path;

use atoma_crypto::{calculate_commitment, Blake2b};
use atoma_types::{Digest, Response, SmallId};
use atoma_types::{Digest, Response};
use sui_sdk::{
json::SuiJsonValue,
types::base_types::{ObjectIDParseError, SuiAddress},
Expand Down Expand Up @@ -58,18 +58,6 @@ impl AtomaSuiClient {
Self::new_from_config(config, response_rx, output_manager_tx)
}

fn get_index(
&self,
sampled_nodes: Vec<SmallId>,
) -> Result<(usize, usize), AtomaSuiClientError> {
let num_leaves = sampled_nodes.len();
let index = sampled_nodes
.iter()
.position(|nid| nid == &self.config.small_id())
.ok_or(AtomaSuiClientError::InvalidSampledNode)?;
Ok((index, num_leaves))
}

fn get_data(&self, data: serde_json::Value) -> Result<Vec<u8>, AtomaSuiClientError> {
// TODO: rework this when responses get same structure
let data = match data["text"].as_str() {
Expand Down Expand Up @@ -105,7 +93,7 @@ impl AtomaSuiClient {
) -> Result<Digest, AtomaSuiClientError> {
let request_id = response.id();
let data = self.get_data(response.response())?;
let (index, num_leaves) = self.get_index(response.sampled_nodes())?;
let (index, num_leaves) = (response.sampled_node_index(), response.num_sampled_nodes());
let (root, pre_image) = calculate_commitment::<Blake2b<_>, _>(data, index, num_leaves);

let client = self.wallet_ctx.get_client().await?;
Expand Down
132 changes: 118 additions & 14 deletions atoma-event-subscribe/sui/src/subscriber.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use tokio::sync::mpsc;
use tracing::{debug, error, info};

use crate::config::SuiSubscriberConfig;
use atoma_types::{Request, SmallId};
use atoma_types::{Request, SmallId, NON_SAMPLED_NODE_ERR};

const REQUEST_ID_HEX_SIZE: usize = 64;

Expand Down Expand Up @@ -95,9 +95,30 @@ impl SuiSubscriber {
| "SettledEvent" => {
info!("Received event: {}", event.type_.name.as_str());
}
"Text2TextPromptEvent" | "NewlySampledNodesEvent" => {
"NewlySampledNodesEvent" => {
let event_data = event.parsed_json;
self.handle_text2text_prompt_event(event_data).await?;
match self.handle_newly_sampled_nodes_event(event_data).await {
Ok(()) => {}
Err(err) => {
error!("Failed to process request, with error: {err}")
}
}
}
"Text2TextPromptEvent" => {
let event_data = event.parsed_json;
match self.handle_text2text_prompt_event(event_data).await {
Ok(()) => {}
Err(SuiSubscriberError::TypeConversionError(err)) => {
if err.to_string().contains(NON_SAMPLED_NODE_ERR) {
info!("Node has not been sampled for current request");
} else {
error!("Failed to process request, with error: {err}")
}
}
Err(err) => {
error!("Failed to process request, with error: {err}");
}
}
}
"Text2ImagePromptEvent" => {
let event_data = event.parsed_json;
Expand All @@ -120,7 +141,7 @@ impl SuiSubscriber {
event_data: Value,
) -> Result<(), SuiSubscriberError> {
debug!("event data: {}", event_data);
let request = Request::try_from(event_data)?;
let request = Request::try_from((self.id, event_data))?;
info!("Received new request: {:?}", request);
let request_id =
request
Expand All @@ -130,20 +151,101 @@ impl SuiSubscriber {
write!(acc, "{:02x}", b).expect("Failed to write to request_id");
acc
});
info!("request_id: {request_id}");
let sampled_nodes = request.sampled_nodes();
if sampled_nodes.contains(&self.id) {
info!(
"Current node has been sampled for request with id: {}",
request_id
);
self.event_sender.send(request).await.map_err(Box::new)?;

Ok(())
}

async fn handle_newly_sampled_nodes_event(
&self,
event_data: Value,
) -> Result<(), SuiSubscriberError> {
debug!("event data: {}", event_data);
let newly_sampled_nodes = event_data
.get("new_nodes")
.ok_or(SuiSubscriberError::MalformedEvent(
"missing `new_nodes` field".into(),
))?
.as_array()
.ok_or(SuiSubscriberError::MalformedEvent(
"invalid `new_nodes` field".into(),
))?
.iter()
.map(|n| {
let node_id = n
.get("node_id")
.ok_or(SuiSubscriberError::MalformedEvent(
"missing `node_id` field".into(),
))?
.get("inner")
.ok_or(SuiSubscriberError::MalformedEvent(
"invalid `inner` field".into(),
))?
.as_u64()
.ok_or(SuiSubscriberError::MalformedEvent(
"invalid `node_id` `inner` field".into(),
))?;
let index = n
.get("order")
.ok_or(SuiSubscriberError::MalformedEvent(
"missing `order` field".into(),
))?
.as_u64()
.ok_or(SuiSubscriberError::MalformedEvent(
"invalid `order` field".into(),
))?;
Ok::<_, SuiSubscriberError>((node_id, index))
})
.collect::<Result<Vec<_>, _>>()?;
if let Some((_, sampled_node_index)) =
newly_sampled_nodes.iter().find(|(id, _)| id == &self.id)
{
let ticket_id = event_data
.get("ticket_id")
.ok_or(SuiSubscriberError::MalformedEvent(
"missing `ticket_id` field".into(),
))?
.as_str()
.ok_or(SuiSubscriberError::MalformedEvent(
"invalid `ticket_id` field".into(),
))?;
let data = self
.sui_client
.event_api()
.query_events(
EventFilter::MoveEventField {
path: "ticket_id".to_string(),
value: serde_json::from_str(ticket_id)?,
},
None,
Some(1),
false,
)
.await?;
let event = data
.data
.first()
.ok_or(SuiSubscriberError::MalformedEvent(format!(
"Missing data from event with ticket id = {}",
ticket_id
)))?;
let request = Request::try_from((
ticket_id,
*sampled_node_index as usize,
event.parsed_json.clone(),
))?;
info!("Received new request: {:?}", request);
info!(
"Current node has been sampled for request with id: {}",
request_id
"Current node has been newly sampled for request with id: {}",
ticket_id
);
self.event_sender.send(request).await.map_err(Box::new)?;
} else {
info!(
"Current node has not been sampled for request with id: {}, ignoring it..",
request_id
);
}

Ok(())
}
}
Expand All @@ -160,4 +262,6 @@ pub enum SuiSubscriberError {
SendError(#[from] Box<mpsc::error::SendError<Request>>),
#[error("Type conversion error: `{0}`")]
TypeConversionError(#[from] anyhow::Error),
#[error("Malformed event: `{0}`")]
MalformedEvent(String),
}
5 changes: 3 additions & 2 deletions atoma-inference/src/model_thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,13 @@ where
while let Ok(command) = self.receiver.recv() {
let ModelThreadCommand { request, sender } = command;
let request_id = request.id();
let sampled_nodes = request.sampled_nodes();
let sampled_node_index = request.sampled_node_index();
let num_sampled_nodex = request.num_sampled_nodes();
let params = request.params();
let model_input = M::Input::try_from(params)?;
let model_output = self.model.run(model_input)?;
let output = serde_json::to_value(model_output)?;
let response = Response::new(request_id, sampled_nodes, output);
let response = Response::new(request_id, sampled_node_index, num_sampled_nodex, output);
sender.send(response).ok();
}

Expand Down
2 changes: 1 addition & 1 deletion atoma-inference/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ async fn test_mock_model_thread() {
Some(0),
Some(1.0),
));
let request = Request::new(vec![0], vec![], prompt_params);
let request = Request::new(vec![0], 0, 1, prompt_params);
let command = ModelThreadCommand {
request: request.clone(),
sender: response_sender,
Expand Down
2 changes: 1 addition & 1 deletion atoma-node/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ async fn main() -> Result<(), AtomaNodeError> {
let sui_subscriber_path = args.sui_subscriber_path;

let (_, json_rpc_server_rx) = mpsc::channel(CHANNEL_BUFFER);
let _atoma_node = AtomaNode::start(
AtomaNode::start(
atoma_sui_client_config_path,
model_config_path,
sui_subscriber_path,
Expand Down
Loading

0 comments on commit 7c13e93

Please sign in to comment.