diff --git a/atoma-client/src/client.rs b/atoma-client/src/client.rs index 606b6ea9..c1565822 100644 --- a/atoma-client/src/client.rs +++ b/atoma-client/src/client.rs @@ -9,7 +9,7 @@ use sui_sdk::{ }; use thiserror::Error; use tokio::sync::mpsc; -use tracing::{debug, info}; +use tracing::{debug, error, info}; use crate::config::AtomaSuiClientConfig; @@ -58,22 +58,56 @@ impl AtomaSuiClient { Self::new_from_config(config, response_rx, output_manager_tx) } + /// Extracts and processes data from a JSON response to generate a byte vector. + /// + /// This method handles two types of data structures within the JSON response: + /// - If the JSON contains a "text" field, it converts the text to a byte vector. + /// - If the JSON contains an array with at least three elements, it interprets: + /// - The first element as an array of bytes (representing an image byte content), + /// - The second element as the image height, + /// - The third element as the image width. + /// These are then combined into a single byte vector where the image data is followed by the height and width. fn get_data(&self, data: serde_json::Value) -> Result, AtomaSuiClientError> { // TODO: rework this when responses get same structure - let data = match data["text"].as_str() { - Some(text) => text.as_bytes().to_owned(), - None => { - if let Some(array) = data.as_array() { - array - .iter() - .map(|b| b.as_u64().unwrap() as u8) - .collect::>() - } else { - return Err(AtomaSuiClientError::FailedResponseJsonParsing); - } + if let Some(text) = data["text"].as_str() { + Ok(text.as_bytes().to_owned()) + } else if let Some(array) = data.as_array() { + if array.len() < 3 { + error!("Incomplete image data"); + return Err(AtomaSuiClientError::MissingOutputData); } - }; - Ok(data) + + let img_data = array + .get(0) + .and_then(|img| img.as_array()) + .ok_or(AtomaSuiClientError::MissingOutputData)?; + let img = img_data + .iter() + .map(|b| b.as_u64().ok_or(AtomaSuiClientError::MissingOutputData)) + .collect::, _>>()? + .into_iter() + .map(|b| b as u8) + .collect::>(); + let height = array + .get(1) + .and_then(|h| h.as_u64()) + .ok_or(AtomaSuiClientError::MissingOutputData)? + .to_le_bytes(); + let width = array + .get(2) + .and_then(|w| w.as_u64()) + .ok_or(AtomaSuiClientError::MissingOutputData)? + .to_le_bytes(); + + let mut result = img; + result.extend_from_slice(&height); + result.extend_from_slice(&width); + + Ok(result) + } else { + error!("Invalid JSON structure for data extraction"); + return Err(AtomaSuiClientError::FailedResponseJsonParsing); + } } /// Upon receiving a response from the `AtomaNode` service, this method extracts @@ -167,4 +201,6 @@ pub enum AtomaSuiClientError { InvalidSampledNode, #[error("Invalid request id")] InvalidRequestId, + #[error("Missing output data")] + MissingOutputData, } diff --git a/atoma-event-subscribe/sui/src/lib.rs b/atoma-event-subscribe/sui/src/lib.rs index 1b8b7a9a..11a1dc65 100644 --- a/atoma-event-subscribe/sui/src/lib.rs +++ b/atoma-event-subscribe/sui/src/lib.rs @@ -1,2 +1,50 @@ +use std::str::FromStr; + +use subscriber::SuiSubscriberError; + pub mod config; pub mod subscriber; + +pub enum AtomaEvent { + DisputeEvent, + FirstSubmissionEvent, + NewlySampledNodesEvent, + NodeRegisteredEvent, + NodeSubscribedToModelEvent, + SettledEvent, + Text2ImagePromptEvent, + Text2TextPromptEvent, +} + +impl FromStr for AtomaEvent { + type Err = SuiSubscriberError; + + fn from_str(s: &str) -> Result { + match s { + "DisputeEvent" => Ok(Self::DisputeEvent), + "FirstSubmissionEvent" => Ok(Self::FirstSubmissionEvent), + "NewlySampledNodesEvent" => Ok(Self::NewlySampledNodesEvent), + "NodeRegisteredEvent" => Ok(Self::NodeRegisteredEvent), + "NodeSubscribedToModelEvent" => Ok(Self::NodeSubscribedToModelEvent), + "SettledEvent" => Ok(Self::SettledEvent), + "Text2ImagePromptEvent" => Ok(Self::Text2ImagePromptEvent), + "Text2TextPromptEvent" => Ok(Self::Text2TextPromptEvent), + _ => panic!("Invalid `AtomaEvent` string"), + } + } +} + +impl ToString for AtomaEvent { + fn to_string(&self) -> String { + match self { + Self::DisputeEvent => "DisputeEvent".into(), + Self::FirstSubmissionEvent => "FirstSubmissionEvent".into(), + Self::NewlySampledNodesEvent => "NewlySampledNodesEvent".into(), + Self::NodeRegisteredEvent => "NodeRegisteredEvent".into(), + Self::NodeSubscribedToModelEvent => "NodeSubscribedToModelEvent".into(), + Self::SettledEvent => "SettledEvent".into(), + Self::Text2ImagePromptEvent => "Text2ImagePromptEvent".into(), + Self::Text2TextPromptEvent => "Text2TextPromptEvent".into(), + } + } +} diff --git a/atoma-event-subscribe/sui/src/subscriber.rs b/atoma-event-subscribe/sui/src/subscriber.rs index 3f44dc35..2323e9fd 100644 --- a/atoma-event-subscribe/sui/src/subscriber.rs +++ b/atoma-event-subscribe/sui/src/subscriber.rs @@ -1,3 +1,4 @@ +use std::str::FromStr; use std::{fmt::Write, path::Path, time::Duration}; use futures::StreamExt; @@ -10,6 +11,7 @@ use tokio::sync::mpsc; use tracing::{debug, error, info}; use crate::config::SuiSubscriberConfig; +use crate::AtomaEvent; use atoma_types::{Request, SmallId, NON_SAMPLED_NODE_ERR}; const REQUEST_ID_HEX_SIZE: usize = 64; @@ -87,15 +89,16 @@ impl SuiSubscriber { impl SuiSubscriber { async fn handle_event(&self, event: SuiEvent) -> Result<(), SuiSubscriberError> { - match event.type_.name.as_str() { - "DisputeEvent" => todo!(), - "FirstSubmissionEvent" - | "NodeRegisteredEvent" - | "NodeSubscribedToModelEvent" - | "SettledEvent" => { - info!("Received event: {}", event.type_.name.as_str()); + let event_type = event.type_.name.as_str(); + match AtomaEvent::from_str(event_type)? { + AtomaEvent::DisputeEvent => todo!(), + AtomaEvent::FirstSubmissionEvent + | AtomaEvent::NodeRegisteredEvent + | AtomaEvent::NodeSubscribedToModelEvent + | AtomaEvent::SettledEvent => { + info!("Received event: {}", event_type); } - "NewlySampledNodesEvent" => { + AtomaEvent::NewlySampledNodesEvent => { let event_data = event.parsed_json; match self.handle_newly_sampled_nodes_event(event_data).await { Ok(()) => {} @@ -104,42 +107,27 @@ impl SuiSubscriber { } } } - "Text2TextPromptEvent" => { + AtomaEvent::Text2ImagePromptEvent | AtomaEvent::Text2TextPromptEvent => { let event_data = event.parsed_json; - match self.handle_text2text_prompt_event(event_data).await { + match self.handle_prompt_event(event_data).await { Ok(()) => {} Err(SuiSubscriberError::TypeConversionError(err)) => { if err.to_string().contains(NON_SAMPLED_NODE_ERR) { - info!("Node has not been sampled for current request"); + info!("Node has not been sampled for current request") } else { error!("Failed to process request, with error: {err}") } } Err(err) => { - error!("Failed to process request, with error: {err}"); + error!("Failed to process request, with error: {err}") } } } - "Text2ImagePromptEvent" => { - let event_data = event.parsed_json; - self.handle_text2image_prompt_event(event_data).await?; - } - _ => panic!("Invalid Event type found!"), } Ok(()) } - async fn handle_text2image_prompt_event( - &self, - _event_data: Value, - ) -> Result<(), SuiSubscriberError> { - Ok(()) - } - - async fn handle_text2text_prompt_event( - &self, - event_data: Value, - ) -> Result<(), SuiSubscriberError> { + async fn handle_prompt_event(&self, event_data: Value) -> Result<(), SuiSubscriberError> { debug!("event data: {}", event_data); let request = Request::try_from((self.id, event_data))?; info!("Received new request: {:?}", request); @@ -165,47 +153,17 @@ impl SuiSubscriber { event_data: Value, ) -> Result<(), SuiSubscriberError> { debug!("event data: {}", event_data); - let newly_sampled_nodes = event_data - .get("new_nodes") - .ok_or(SuiSubscriberError::MalformedEvent( - "missing `new_nodes` field".into(), - ))? - .as_array() - .ok_or(SuiSubscriberError::MalformedEvent( - "invalid `new_nodes` field".into(), - ))? - .iter() - .find_map(|n| { - let node_id = n.get("node_id")?.get("inner")?.as_u64()?; - let index = n.get("order")?.as_u64()?; - if node_id == self.id { - Some(index) - } else { - None - } - }); + let newly_sampled_nodes = extract_sampled_node_index(self.id, &event_data)?; if let Some(sampled_node_index) = newly_sampled_nodes { - let ticket_id = event_data - .get("ticket_id") - .ok_or(SuiSubscriberError::MalformedEvent( - "missing `ticket_id` field".into(), - ))? - .as_str() - .ok_or(SuiSubscriberError::MalformedEvent( - "invalid `ticket_id` field".into(), - ))?; + let ticket_id = extract_ticket_id(&event_data)?; + let event_filter = EventFilter::MoveEventField { + path: "ticket_id".to_string(), + value: serde_json::from_str(ticket_id)?, + }; let data = self .sui_client .event_api() - .query_events( - EventFilter::MoveEventField { - path: "ticket_id".to_string(), - value: serde_json::from_str(ticket_id)?, - }, - None, - Some(1), - false, - ) + .query_events(event_filter, None, Some(1), false) .await?; let event = data .data @@ -231,6 +189,36 @@ impl SuiSubscriber { } } +fn extract_sampled_node_index(id: u64, value: &Value) -> Result, SuiSubscriberError> { + let new_nodes = value + .get("new_nodes") + .ok_or_else(|| SuiSubscriberError::MalformedEvent("missing `new_nodes` field".into()))? + .as_array() + .ok_or_else(|| SuiSubscriberError::MalformedEvent("invalid `new_nodes` field".into()))?; + + Ok(new_nodes.iter().find_map(|n| { + let node_id = n.get("node_id")?.get("inner")?.as_u64()?; + let index = n.get("order")?.as_u64()?; + if node_id == id { + Some(index) + } else { + None + } + })) +} + +fn extract_ticket_id(value: &Value) -> Result<&str, SuiSubscriberError> { + value + .get("ticket_id") + .ok_or(SuiSubscriberError::MalformedEvent( + "missing `ticket_id` field".into(), + ))? + .as_str() + .ok_or(SuiSubscriberError::MalformedEvent( + "invalid `ticket_id` field".into(), + )) +} + #[derive(Debug, Error)] pub enum SuiSubscriberError { #[error("Sui Builder error: `{0}`")] diff --git a/atoma-inference/src/models/candle/stable_diffusion.rs b/atoma-inference/src/models/candle/stable_diffusion.rs index 98672f40..f6694353 100644 --- a/atoma-inference/src/models/candle/stable_diffusion.rs +++ b/atoma-inference/src/models/candle/stable_diffusion.rs @@ -46,7 +46,7 @@ pub struct StableDiffusionInput { pub img2img_strength: f64, /// The seed to use when generating random samples. - pub random_seed: Option, + pub random_seed: Option, } pub struct StableDiffusionLoadData { @@ -76,7 +76,7 @@ pub struct StableDiffusion { impl ModelTrait for StableDiffusion { type Input = StableDiffusionInput; - type Output = Vec<(Vec, usize, usize)>; + type Output = (Vec, usize, usize); type LoadData = StableDiffusionLoadData; fn fetch( @@ -271,7 +271,7 @@ impl ModelTrait for StableDiffusion { let scheduler = self.config.build_scheduler(n_steps)?; if let Some(seed) = input.random_seed { - self.device.set_seed(seed)?; + self.device.set_seed(seed as u64)?; } let use_guide_scale = guidance_scale > 1.0; @@ -325,7 +325,7 @@ impl ModelTrait for StableDiffusion { ModelType::StableDiffusionTurbo => 0.13025, _ => bail!("Invalid stable diffusion model type"), }; - let mut res = Vec::new(); + let mut res = (vec![], 0, 0); for idx in 0..input.num_samples { let timesteps = scheduler.timesteps(); @@ -400,8 +400,10 @@ impl ModelTrait for StableDiffusion { save_image(&image, "./image.png").unwrap(); } save_tensor_to_file(&image, "tensor4")?; - res.push(convert_to_image(&image)?); + + res = convert_to_image(&image)?; } + Ok(res) } } @@ -694,8 +696,8 @@ mod tests { let output = model.run(input).expect("Failed to run inference"); println!("{:?}", output); - assert_eq!(output[0].1, 512); - assert_eq!(output[0].2, 512); + assert_eq!(output.1, 512); + assert_eq!(output.2, 512); std::fs::remove_dir_all(cache_dir).unwrap(); std::fs::remove_file("tensor1").unwrap(); diff --git a/atoma-inference/src/models/types.rs b/atoma-inference/src/models/types.rs index 3d0dfc08..de44a565 100644 --- a/atoma-inference/src/models/types.rs +++ b/atoma-inference/src/models/types.rs @@ -443,7 +443,7 @@ pub struct StableDiffusionRequest { pub img2img_strength: f64, /// The seed to use when generating random samples. - pub random_seed: Option, + pub random_seed: Option, pub sampled_nodes: Vec>, } diff --git a/atoma-types/src/lib.rs b/atoma-types/src/lib.rs index 02220fca..a5a9e428 100644 --- a/atoma-types/src/lib.rs +++ b/atoma-types/src/lib.rs @@ -299,7 +299,7 @@ pub struct Text2ImagePromptParams { guidance_scale: Option, img2img: Option, img2img_strength: f64, - random_seed: Option, + random_seed: Option, } impl Text2ImagePromptParams { @@ -315,7 +315,7 @@ impl Text2ImagePromptParams { guidance_scale: Option, img2img: Option, img2img_strength: f64, - random_seed: Option, + random_seed: Option, ) -> Self { Self { prompt, @@ -372,7 +372,7 @@ impl Text2ImagePromptParams { self.img2img_strength } - pub fn random_seed(&self) -> Option { + pub fn random_seed(&self) -> Option { self.random_seed } } @@ -385,14 +385,14 @@ impl TryFrom for Text2ImagePromptParams { prompt: utils::parse_str(&value["prompt"])?, model: utils::parse_str(&value["model"])?, uncond_prompt: utils::parse_str(&value["uncond_prompt"])?, - random_seed: Some(utils::parse_u64(&value["random_seed"])?), + random_seed: Some(utils::parse_u32(&value["random_seed"])?), height: Some(utils::parse_u64(&value["height"])?), width: Some(utils::parse_u64(&value["width"])?), n_steps: Some(utils::parse_u64(&value["n_steps"])?), num_samples: utils::parse_u64(&value["num_samples"])?, guidance_scale: Some(utils::parse_f32_from_le_bytes(&value["guidance_scale"])? as f64), - img2img: Some(utils::parse_str(&value["img2img"])?), - img2img_strength: utils::parse_f32_from_le_bytes(&value["img2img2_strength"])? as f64, + img2img: utils::parse_optional_str(&value["img2img"]), + img2img_strength: utils::parse_f32_from_le_bytes(&value["img2img_strength"])? as f64, }) } } @@ -454,30 +454,46 @@ mod utils { pub(crate) fn parse_f32_from_le_bytes(value: &Value) -> Result { let u32_value: u32 = value .as_u64() - .ok_or(anyhow!( - "Failed to extract `f32` little endian bytes representation" - ))? + .ok_or_else(|| anyhow!("Expected a u64 for f32 conversion, found none"))? .try_into()?; let f32_le_bytes = u32_value.to_le_bytes(); Ok(f32::from_le_bytes(f32_le_bytes)) } + /// Parses an appropriate JSON value, representing a `u32` number, from a Sui + /// `Text2ImagePromptEvent` or `Text2TextPromptEvent` `u32` fields. + pub(crate) fn parse_u32(value: &Value) -> Result { + value + .as_u64() + .ok_or_else(|| anyhow!("Expected a u64 for u32 parsing, found none")) + .and_then(|v| { + v.try_into() + .map_err(|_| anyhow!("u64 to u32 conversion failed")) + }) + } + /// Parses an appropriate JSON value, representing a `u64` number, from a Sui - /// `Text2TextPromptEvent` `u64` fields. + /// `Text2ImagePromptEvent` or `Text2TextPromptEvent` `u64` fields. pub(crate) fn parse_u64(value: &Value) -> Result { value .as_str() - .ok_or(anyhow!("Failed to extract `u64` number"))? - .parse::() - .map_err(|e| anyhow!("Failed to parse `u64` from string, with error: {e}")) + .ok_or_else(|| anyhow!("Expected a string for u64 parsing, found none")) + .and_then(|s| { + s.parse::() + .map_err(|e| anyhow!("Failed to parse u64: {}", e)) + }) } /// Parses an appropriate JSON value, representing a `String` value, from a Sui /// `Text2TextPromptEvent` `String` fields. pub(crate) fn parse_str(value: &Value) -> Result { - Ok(value + value .as_str() - .ok_or(anyhow!("Failed to extract `String` from JSON value"))? - .to_string()) + .ok_or_else(|| anyhow!("Expected a string, found none")) + .map(|s| s.to_string()) + } + + pub(crate) fn parse_optional_str(value: &Value) -> Option { + value.as_str().map(|s| s.to_string()) } }