From 708fa72d1fc4317b82ce53a9601093a49a573ba8 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Thu, 9 May 2024 12:02:59 +0100 Subject: [PATCH] resolve a set of bugs for stable diffusion --- atoma-client/src/client.rs | 19 ++- atoma-event-subscribe/sui/src/lib.rs | 48 ++++++ atoma-event-subscribe/sui/src/subscriber.rs | 142 +++++++++--------- .../src/models/candle/stable_diffusion.rs | 16 +- atoma-inference/src/models/types.rs | 2 +- atoma-types/src/lib.rs | 32 +++- 6 files changed, 167 insertions(+), 92 deletions(-) diff --git a/atoma-client/src/client.rs b/atoma-client/src/client.rs index 606b6ea9..588d8f86 100644 --- a/atoma-client/src/client.rs +++ b/atoma-client/src/client.rs @@ -9,7 +9,7 @@ use sui_sdk::{ }; use thiserror::Error; use tokio::sync::mpsc; -use tracing::{debug, info}; +use tracing::{debug, error, info}; use crate::config::AtomaSuiClientConfig; @@ -64,10 +64,21 @@ impl AtomaSuiClient { Some(text) => text.as_bytes().to_owned(), None => { if let Some(array) = data.as_array() { - array + if !array.is_empty() { + let mut img = array[0].as_array().ok_or(AtomaSuiClientError::MissingOutputData)? .iter() .map(|b| b.as_u64().unwrap() as u8) - .collect::>() + .collect::>(); + let height = data[1].as_u64().unwrap().to_le_bytes(); + let width = data[2].as_u64().unwrap().to_le_bytes(); + img.extend([height, width].concat()); + img + } + else { + error!("Empty image generation"); + return Err(AtomaSuiClientError::MissingOutputData); + } + } else { return Err(AtomaSuiClientError::FailedResponseJsonParsing); } @@ -167,4 +178,6 @@ pub enum AtomaSuiClientError { InvalidSampledNode, #[error("Invalid request id")] InvalidRequestId, + #[error("Missing output data")] + MissingOutputData, } diff --git a/atoma-event-subscribe/sui/src/lib.rs b/atoma-event-subscribe/sui/src/lib.rs index 1b8b7a9a..11a1dc65 100644 --- a/atoma-event-subscribe/sui/src/lib.rs +++ b/atoma-event-subscribe/sui/src/lib.rs @@ -1,2 +1,50 @@ +use std::str::FromStr; + +use subscriber::SuiSubscriberError; + pub mod config; pub mod subscriber; + +pub enum AtomaEvent { + DisputeEvent, + FirstSubmissionEvent, + NewlySampledNodesEvent, + NodeRegisteredEvent, + NodeSubscribedToModelEvent, + SettledEvent, + Text2ImagePromptEvent, + Text2TextPromptEvent, +} + +impl FromStr for AtomaEvent { + type Err = SuiSubscriberError; + + fn from_str(s: &str) -> Result { + match s { + "DisputeEvent" => Ok(Self::DisputeEvent), + "FirstSubmissionEvent" => Ok(Self::FirstSubmissionEvent), + "NewlySampledNodesEvent" => Ok(Self::NewlySampledNodesEvent), + "NodeRegisteredEvent" => Ok(Self::NodeRegisteredEvent), + "NodeSubscribedToModelEvent" => Ok(Self::NodeSubscribedToModelEvent), + "SettledEvent" => Ok(Self::SettledEvent), + "Text2ImagePromptEvent" => Ok(Self::Text2ImagePromptEvent), + "Text2TextPromptEvent" => Ok(Self::Text2TextPromptEvent), + _ => panic!("Invalid `AtomaEvent` string"), + } + } +} + +impl ToString for AtomaEvent { + fn to_string(&self) -> String { + match self { + Self::DisputeEvent => "DisputeEvent".into(), + Self::FirstSubmissionEvent => "FirstSubmissionEvent".into(), + Self::NewlySampledNodesEvent => "NewlySampledNodesEvent".into(), + Self::NodeRegisteredEvent => "NodeRegisteredEvent".into(), + Self::NodeSubscribedToModelEvent => "NodeSubscribedToModelEvent".into(), + Self::SettledEvent => "SettledEvent".into(), + Self::Text2ImagePromptEvent => "Text2ImagePromptEvent".into(), + Self::Text2TextPromptEvent => "Text2TextPromptEvent".into(), + } + } +} diff --git a/atoma-event-subscribe/sui/src/subscriber.rs b/atoma-event-subscribe/sui/src/subscriber.rs index a234912c..96108993 100644 --- a/atoma-event-subscribe/sui/src/subscriber.rs +++ b/atoma-event-subscribe/sui/src/subscriber.rs @@ -1,3 +1,4 @@ +use std::str::FromStr; use std::{fmt::Write, path::Path, time::Duration}; use futures::StreamExt; @@ -10,6 +11,7 @@ use tokio::sync::mpsc; use tracing::{debug, error, info}; use crate::config::SuiSubscriberConfig; +use crate::AtomaEvent; use atoma_types::{Request, SmallId, NON_SAMPLED_NODE_ERR}; const REQUEST_ID_HEX_SIZE: usize = 64; @@ -87,15 +89,16 @@ impl SuiSubscriber { impl SuiSubscriber { async fn handle_event(&self, event: SuiEvent) -> Result<(), SuiSubscriberError> { - match event.type_.name.as_str() { - "DisputeEvent" => todo!(), - "FirstSubmissionEvent" - | "NodeRegisteredEvent" - | "NodeSubscribedToModelEvent" - | "SettledEvent" => { - info!("Received event: {}", event.type_.name.as_str()); + let event_type = event.type_.name.as_str(); + match AtomaEvent::from_str(event_type)? { + AtomaEvent::DisputeEvent => todo!(), + AtomaEvent::FirstSubmissionEvent + | AtomaEvent::NodeRegisteredEvent + | AtomaEvent::NodeSubscribedToModelEvent + | AtomaEvent::SettledEvent => { + info!("Received event: {}", event_type); } - "NewlySampledNodesEvent" => { + AtomaEvent::NewlySampledNodesEvent => { let event_data = event.parsed_json; match self.handle_newly_sampled_nodes_event(event_data).await { Ok(()) => {} @@ -104,42 +107,27 @@ impl SuiSubscriber { } } } - "Text2TextPromptEvent" => { + AtomaEvent::Text2ImagePromptEvent | AtomaEvent::Text2TextPromptEvent => { let event_data = event.parsed_json; - match self.handle_text2text_prompt_event(event_data).await { + match self.handle_prompt_event(event_data).await { Ok(()) => {} Err(SuiSubscriberError::TypeConversionError(err)) => { if err.to_string().contains(NON_SAMPLED_NODE_ERR) { - info!("Node has not been sampled for current request"); + info!("Node has not been sampled for current request") } else { error!("Failed to process request, with error: {err}") } } Err(err) => { - error!("Failed to process request, with error: {err}"); + error!("Failed to process request, with error: {err}") } } } - "Text2ImagePromptEvent" => { - let event_data = event.parsed_json; - self.handle_text2image_prompt_event(event_data).await?; - } - _ => panic!("Invalid Event type found!"), } Ok(()) } - async fn handle_text2image_prompt_event( - &self, - _event_data: Value, - ) -> Result<(), SuiSubscriberError> { - Ok(()) - } - - async fn handle_text2text_prompt_event( - &self, - event_data: Value, - ) -> Result<(), SuiSubscriberError> { + async fn handle_prompt_event(&self, event_data: Value) -> Result<(), SuiSubscriberError> { debug!("event data: {}", event_data); let request = Request::try_from((self.id, event_data))?; info!("Received new request: {:?}", request); @@ -165,54 +153,11 @@ impl SuiSubscriber { event_data: Value, ) -> Result<(), SuiSubscriberError> { debug!("event data: {}", event_data); - let newly_sampled_nodes = event_data - .get("new_nodes") - .ok_or(SuiSubscriberError::MalformedEvent( - "missing `new_nodes` field".into(), - ))? - .as_array() - .ok_or(SuiSubscriberError::MalformedEvent( - "invalid `new_nodes` field".into(), - ))? - .iter() - .map(|n| { - let node_id = n - .get("node_id") - .ok_or(SuiSubscriberError::MalformedEvent( - "missing `node_id` field".into(), - ))? - .get("inner") - .ok_or(SuiSubscriberError::MalformedEvent( - "invalid `inner` field".into(), - ))? - .as_u64() - .ok_or(SuiSubscriberError::MalformedEvent( - "invalid `node_id` `inner` field".into(), - ))?; - let index = n - .get("order") - .ok_or(SuiSubscriberError::MalformedEvent( - "missing `order` field".into(), - ))? - .as_u64() - .ok_or(SuiSubscriberError::MalformedEvent( - "invalid `order` field".into(), - ))?; - Ok::<_, SuiSubscriberError>((node_id, index)) - }) - .collect::, _>>()?; + let newly_sampled_nodes = extract_newly_sampled_nodes(&event_data)?; if let Some((_, sampled_node_index)) = newly_sampled_nodes.iter().find(|(id, _)| id == &self.id) { - let ticket_id = event_data - .get("ticket_id") - .ok_or(SuiSubscriberError::MalformedEvent( - "missing `ticket_id` field".into(), - ))? - .as_str() - .ok_or(SuiSubscriberError::MalformedEvent( - "invalid `ticket_id` field".into(), - ))?; + let ticket_id = extract_ticket_id(&event_data)?; let data = self .sui_client .event_api() @@ -250,6 +195,57 @@ impl SuiSubscriber { } } +fn extract_newly_sampled_nodes(value: &Value) -> Result, SuiSubscriberError> { + value + .get("new_nodes") + .ok_or(SuiSubscriberError::MalformedEvent( + "missing `new_nodes` field".into(), + ))? + .as_array() + .ok_or(SuiSubscriberError::MalformedEvent( + "invalid `new_nodes` field".into(), + ))? + .iter() + .map(|n| { + let node_id = n + .get("node_id") + .ok_or(SuiSubscriberError::MalformedEvent( + "missing `node_id` field".into(), + ))? + .get("inner") + .ok_or(SuiSubscriberError::MalformedEvent( + "invalid `inner` field".into(), + ))? + .as_u64() + .ok_or(SuiSubscriberError::MalformedEvent( + "invalid `node_id` `inner` field".into(), + ))?; + let index = n + .get("order") + .ok_or(SuiSubscriberError::MalformedEvent( + "missing `order` field".into(), + ))? + .as_u64() + .ok_or(SuiSubscriberError::MalformedEvent( + "invalid `order` field".into(), + ))?; + Ok::<_, SuiSubscriberError>((node_id, index)) + }) + .collect::, _>>() +} + +fn extract_ticket_id(value: &Value) -> Result<&str, SuiSubscriberError> { + value + .get("ticket_id") + .ok_or(SuiSubscriberError::MalformedEvent( + "missing `ticket_id` field".into(), + ))? + .as_str() + .ok_or(SuiSubscriberError::MalformedEvent( + "invalid `ticket_id` field".into(), + )) +} + #[derive(Debug, Error)] pub enum SuiSubscriberError { #[error("Sui Builder error: `{0}`")] diff --git a/atoma-inference/src/models/candle/stable_diffusion.rs b/atoma-inference/src/models/candle/stable_diffusion.rs index 98672f40..f6694353 100644 --- a/atoma-inference/src/models/candle/stable_diffusion.rs +++ b/atoma-inference/src/models/candle/stable_diffusion.rs @@ -46,7 +46,7 @@ pub struct StableDiffusionInput { pub img2img_strength: f64, /// The seed to use when generating random samples. - pub random_seed: Option, + pub random_seed: Option, } pub struct StableDiffusionLoadData { @@ -76,7 +76,7 @@ pub struct StableDiffusion { impl ModelTrait for StableDiffusion { type Input = StableDiffusionInput; - type Output = Vec<(Vec, usize, usize)>; + type Output = (Vec, usize, usize); type LoadData = StableDiffusionLoadData; fn fetch( @@ -271,7 +271,7 @@ impl ModelTrait for StableDiffusion { let scheduler = self.config.build_scheduler(n_steps)?; if let Some(seed) = input.random_seed { - self.device.set_seed(seed)?; + self.device.set_seed(seed as u64)?; } let use_guide_scale = guidance_scale > 1.0; @@ -325,7 +325,7 @@ impl ModelTrait for StableDiffusion { ModelType::StableDiffusionTurbo => 0.13025, _ => bail!("Invalid stable diffusion model type"), }; - let mut res = Vec::new(); + let mut res = (vec![], 0, 0); for idx in 0..input.num_samples { let timesteps = scheduler.timesteps(); @@ -400,8 +400,10 @@ impl ModelTrait for StableDiffusion { save_image(&image, "./image.png").unwrap(); } save_tensor_to_file(&image, "tensor4")?; - res.push(convert_to_image(&image)?); + + res = convert_to_image(&image)?; } + Ok(res) } } @@ -694,8 +696,8 @@ mod tests { let output = model.run(input).expect("Failed to run inference"); println!("{:?}", output); - assert_eq!(output[0].1, 512); - assert_eq!(output[0].2, 512); + assert_eq!(output.1, 512); + assert_eq!(output.2, 512); std::fs::remove_dir_all(cache_dir).unwrap(); std::fs::remove_file("tensor1").unwrap(); diff --git a/atoma-inference/src/models/types.rs b/atoma-inference/src/models/types.rs index 3d0dfc08..de44a565 100644 --- a/atoma-inference/src/models/types.rs +++ b/atoma-inference/src/models/types.rs @@ -443,7 +443,7 @@ pub struct StableDiffusionRequest { pub img2img_strength: f64, /// The seed to use when generating random samples. - pub random_seed: Option, + pub random_seed: Option, pub sampled_nodes: Vec>, } diff --git a/atoma-types/src/lib.rs b/atoma-types/src/lib.rs index 02220fca..79357f1b 100644 --- a/atoma-types/src/lib.rs +++ b/atoma-types/src/lib.rs @@ -299,7 +299,7 @@ pub struct Text2ImagePromptParams { guidance_scale: Option, img2img: Option, img2img_strength: f64, - random_seed: Option, + random_seed: Option, } impl Text2ImagePromptParams { @@ -315,7 +315,7 @@ impl Text2ImagePromptParams { guidance_scale: Option, img2img: Option, img2img_strength: f64, - random_seed: Option, + random_seed: Option, ) -> Self { Self { prompt, @@ -372,7 +372,7 @@ impl Text2ImagePromptParams { self.img2img_strength } - pub fn random_seed(&self) -> Option { + pub fn random_seed(&self) -> Option { self.random_seed } } @@ -381,18 +381,19 @@ impl TryFrom for Text2ImagePromptParams { type Error = Error; fn try_from(value: Value) -> Result { + println!("got value:{value}"); Ok(Self { prompt: utils::parse_str(&value["prompt"])?, model: utils::parse_str(&value["model"])?, uncond_prompt: utils::parse_str(&value["uncond_prompt"])?, - random_seed: Some(utils::parse_u64(&value["random_seed"])?), + random_seed: Some(utils::parse_u32(&value["random_seed"])?), height: Some(utils::parse_u64(&value["height"])?), width: Some(utils::parse_u64(&value["width"])?), n_steps: Some(utils::parse_u64(&value["n_steps"])?), num_samples: utils::parse_u64(&value["num_samples"])?, guidance_scale: Some(utils::parse_f32_from_le_bytes(&value["guidance_scale"])? as f64), - img2img: Some(utils::parse_str(&value["img2img"])?), - img2img_strength: utils::parse_f32_from_le_bytes(&value["img2img2_strength"])? as f64, + img2img: utils::parse_optional_str(&value["img2img"]), + img2img_strength: utils::parse_f32_from_le_bytes(&value["img2img_strength"])? as f64, }) } } @@ -452,6 +453,7 @@ mod utils { /// representing the extracted u32 value into little endian byte representation, and then applying `f32::from_le_bytes`. /// See https://github.com/atoma-network/atoma-contracts/blob/main/sui/packages/atoma/sources/gate.move#L26 pub(crate) fn parse_f32_from_le_bytes(value: &Value) -> Result { + println!("value got: {value}"); let u32_value: u32 = value .as_u64() .ok_or(anyhow!( @@ -462,14 +464,24 @@ mod utils { Ok(f32::from_le_bytes(f32_le_bytes)) } + /// Parses an appropriate JSON value, representing a `u32` number, from a Sui + /// `Text2ImagePromptEvent` or `Text2TextPromptEvent` `u32` fields. + pub(crate) fn parse_u32(value: &Value) -> Result { + value + .as_u64() + .ok_or(anyhow!("Failed to extract `u32` number"))? + .try_into() + .map_err(|e| anyhow!("Failed to parse `u32` from `u64`, with error: {e}")) + } + /// Parses an appropriate JSON value, representing a `u64` number, from a Sui - /// `Text2TextPromptEvent` `u64` fields. + /// `Text2ImagePromptEvent` or `Text2TextPromptEvent` `u64` fields. pub(crate) fn parse_u64(value: &Value) -> Result { value .as_str() .ok_or(anyhow!("Failed to extract `u64` number"))? .parse::() - .map_err(|e| anyhow!("Failed to parse `u64` from string, with error: {e}")) + .map_err(|e| anyhow!("Failed to parse `u64` from `String`, with error: {e}")) } /// Parses an appropriate JSON value, representing a `String` value, from a Sui @@ -480,4 +492,8 @@ mod utils { .ok_or(anyhow!("Failed to extract `String` from JSON value"))? .to_string()) } + + pub(crate) fn parse_optional_str(value: &Value) -> Option { + value.as_str().map(|s| s.to_string()) + } }