Skip to content

Commit

Permalink
feat: address stable diffusion events (#76)
Browse files Browse the repository at this point in the history
* first commit

* resolve a set of bugs for stable diffusion

* fmt

* remove unnecessary print lines, and improve code readability

* address PR comments
  • Loading branch information
jorgeantonio21 authored May 10, 2024
1 parent 9d740f8 commit dd7829d
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 103 deletions.
64 changes: 50 additions & 14 deletions atoma-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use sui_sdk::{
};
use thiserror::Error;
use tokio::sync::mpsc;
use tracing::{debug, info};
use tracing::{debug, error, info};

use crate::config::AtomaSuiClientConfig;

Expand Down Expand Up @@ -58,22 +58,56 @@ impl AtomaSuiClient {
Self::new_from_config(config, response_rx, output_manager_tx)
}

/// Extracts and processes data from a JSON response to generate a byte vector.
///
/// This method handles two types of data structures within the JSON response:
/// - If the JSON contains a "text" field, it converts the text to a byte vector.
/// - If the JSON contains an array with at least three elements, it interprets:
/// - The first element as an array of bytes (representing an image byte content),
/// - The second element as the image height,
/// - The third element as the image width.
/// These are then combined into a single byte vector where the image data is followed by the height and width.
fn get_data(&self, data: serde_json::Value) -> Result<Vec<u8>, AtomaSuiClientError> {
// TODO: rework this when responses get same structure
let data = match data["text"].as_str() {
Some(text) => text.as_bytes().to_owned(),
None => {
if let Some(array) = data.as_array() {
array
.iter()
.map(|b| b.as_u64().unwrap() as u8)
.collect::<Vec<_>>()
} else {
return Err(AtomaSuiClientError::FailedResponseJsonParsing);
}
if let Some(text) = data["text"].as_str() {
Ok(text.as_bytes().to_owned())
} else if let Some(array) = data.as_array() {
if array.len() < 3 {
error!("Incomplete image data");
return Err(AtomaSuiClientError::MissingOutputData);
}
};
Ok(data)

let img_data = array
.get(0)
.and_then(|img| img.as_array())
.ok_or(AtomaSuiClientError::MissingOutputData)?;
let img = img_data
.iter()
.map(|b| b.as_u64().ok_or(AtomaSuiClientError::MissingOutputData))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.map(|b| b as u8)
.collect::<Vec<_>>();
let height = array
.get(1)
.and_then(|h| h.as_u64())
.ok_or(AtomaSuiClientError::MissingOutputData)?
.to_le_bytes();
let width = array
.get(2)
.and_then(|w| w.as_u64())
.ok_or(AtomaSuiClientError::MissingOutputData)?
.to_le_bytes();

let mut result = img;
result.extend_from_slice(&height);
result.extend_from_slice(&width);

Ok(result)
} else {
error!("Invalid JSON structure for data extraction");
return Err(AtomaSuiClientError::FailedResponseJsonParsing);
}
}

/// Upon receiving a response from the `AtomaNode` service, this method extracts
Expand Down Expand Up @@ -167,4 +201,6 @@ pub enum AtomaSuiClientError {
InvalidSampledNode,
#[error("Invalid request id")]
InvalidRequestId,
#[error("Missing output data")]
MissingOutputData,
}
48 changes: 48 additions & 0 deletions atoma-event-subscribe/sui/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,50 @@
use std::str::FromStr;

use subscriber::SuiSubscriberError;

pub mod config;
pub mod subscriber;

pub enum AtomaEvent {
DisputeEvent,
FirstSubmissionEvent,
NewlySampledNodesEvent,
NodeRegisteredEvent,
NodeSubscribedToModelEvent,
SettledEvent,
Text2ImagePromptEvent,
Text2TextPromptEvent,
}

impl FromStr for AtomaEvent {
type Err = SuiSubscriberError;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"DisputeEvent" => Ok(Self::DisputeEvent),
"FirstSubmissionEvent" => Ok(Self::FirstSubmissionEvent),
"NewlySampledNodesEvent" => Ok(Self::NewlySampledNodesEvent),
"NodeRegisteredEvent" => Ok(Self::NodeRegisteredEvent),
"NodeSubscribedToModelEvent" => Ok(Self::NodeSubscribedToModelEvent),
"SettledEvent" => Ok(Self::SettledEvent),
"Text2ImagePromptEvent" => Ok(Self::Text2ImagePromptEvent),
"Text2TextPromptEvent" => Ok(Self::Text2TextPromptEvent),
_ => panic!("Invalid `AtomaEvent` string"),
}
}
}

impl ToString for AtomaEvent {
fn to_string(&self) -> String {
match self {
Self::DisputeEvent => "DisputeEvent".into(),
Self::FirstSubmissionEvent => "FirstSubmissionEvent".into(),
Self::NewlySampledNodesEvent => "NewlySampledNodesEvent".into(),
Self::NodeRegisteredEvent => "NodeRegisteredEvent".into(),
Self::NodeSubscribedToModelEvent => "NodeSubscribedToModelEvent".into(),
Self::SettledEvent => "SettledEvent".into(),
Self::Text2ImagePromptEvent => "Text2ImagePromptEvent".into(),
Self::Text2TextPromptEvent => "Text2TextPromptEvent".into(),
}
}
}
118 changes: 53 additions & 65 deletions atoma-event-subscribe/sui/src/subscriber.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::str::FromStr;
use std::{fmt::Write, path::Path, time::Duration};

use futures::StreamExt;
Expand All @@ -10,6 +11,7 @@ use tokio::sync::mpsc;
use tracing::{debug, error, info};

use crate::config::SuiSubscriberConfig;
use crate::AtomaEvent;
use atoma_types::{Request, SmallId, NON_SAMPLED_NODE_ERR};

const REQUEST_ID_HEX_SIZE: usize = 64;
Expand Down Expand Up @@ -87,15 +89,16 @@ impl SuiSubscriber {

impl SuiSubscriber {
async fn handle_event(&self, event: SuiEvent) -> Result<(), SuiSubscriberError> {
match event.type_.name.as_str() {
"DisputeEvent" => todo!(),
"FirstSubmissionEvent"
| "NodeRegisteredEvent"
| "NodeSubscribedToModelEvent"
| "SettledEvent" => {
info!("Received event: {}", event.type_.name.as_str());
let event_type = event.type_.name.as_str();
match AtomaEvent::from_str(event_type)? {
AtomaEvent::DisputeEvent => todo!(),
AtomaEvent::FirstSubmissionEvent
| AtomaEvent::NodeRegisteredEvent
| AtomaEvent::NodeSubscribedToModelEvent
| AtomaEvent::SettledEvent => {
info!("Received event: {}", event_type);
}
"NewlySampledNodesEvent" => {
AtomaEvent::NewlySampledNodesEvent => {
let event_data = event.parsed_json;
match self.handle_newly_sampled_nodes_event(event_data).await {
Ok(()) => {}
Expand All @@ -104,42 +107,27 @@ impl SuiSubscriber {
}
}
}
"Text2TextPromptEvent" => {
AtomaEvent::Text2ImagePromptEvent | AtomaEvent::Text2TextPromptEvent => {
let event_data = event.parsed_json;
match self.handle_text2text_prompt_event(event_data).await {
match self.handle_prompt_event(event_data).await {
Ok(()) => {}
Err(SuiSubscriberError::TypeConversionError(err)) => {
if err.to_string().contains(NON_SAMPLED_NODE_ERR) {
info!("Node has not been sampled for current request");
info!("Node has not been sampled for current request")
} else {
error!("Failed to process request, with error: {err}")
}
}
Err(err) => {
error!("Failed to process request, with error: {err}");
error!("Failed to process request, with error: {err}")
}
}
}
"Text2ImagePromptEvent" => {
let event_data = event.parsed_json;
self.handle_text2image_prompt_event(event_data).await?;
}
_ => panic!("Invalid Event type found!"),
}
Ok(())
}

async fn handle_text2image_prompt_event(
&self,
_event_data: Value,
) -> Result<(), SuiSubscriberError> {
Ok(())
}

async fn handle_text2text_prompt_event(
&self,
event_data: Value,
) -> Result<(), SuiSubscriberError> {
async fn handle_prompt_event(&self, event_data: Value) -> Result<(), SuiSubscriberError> {
debug!("event data: {}", event_data);
let request = Request::try_from((self.id, event_data))?;
info!("Received new request: {:?}", request);
Expand All @@ -165,47 +153,17 @@ impl SuiSubscriber {
event_data: Value,
) -> Result<(), SuiSubscriberError> {
debug!("event data: {}", event_data);
let newly_sampled_nodes = event_data
.get("new_nodes")
.ok_or(SuiSubscriberError::MalformedEvent(
"missing `new_nodes` field".into(),
))?
.as_array()
.ok_or(SuiSubscriberError::MalformedEvent(
"invalid `new_nodes` field".into(),
))?
.iter()
.find_map(|n| {
let node_id = n.get("node_id")?.get("inner")?.as_u64()?;
let index = n.get("order")?.as_u64()?;
if node_id == self.id {
Some(index)
} else {
None
}
});
let newly_sampled_nodes = extract_sampled_node_index(self.id, &event_data)?;
if let Some(sampled_node_index) = newly_sampled_nodes {
let ticket_id = event_data
.get("ticket_id")
.ok_or(SuiSubscriberError::MalformedEvent(
"missing `ticket_id` field".into(),
))?
.as_str()
.ok_or(SuiSubscriberError::MalformedEvent(
"invalid `ticket_id` field".into(),
))?;
let ticket_id = extract_ticket_id(&event_data)?;
let event_filter = EventFilter::MoveEventField {
path: "ticket_id".to_string(),
value: serde_json::from_str(ticket_id)?,
};
let data = self
.sui_client
.event_api()
.query_events(
EventFilter::MoveEventField {
path: "ticket_id".to_string(),
value: serde_json::from_str(ticket_id)?,
},
None,
Some(1),
false,
)
.query_events(event_filter, None, Some(1), false)
.await?;
let event = data
.data
Expand All @@ -231,6 +189,36 @@ impl SuiSubscriber {
}
}

fn extract_sampled_node_index(id: u64, value: &Value) -> Result<Option<u64>, SuiSubscriberError> {
let new_nodes = value
.get("new_nodes")
.ok_or_else(|| SuiSubscriberError::MalformedEvent("missing `new_nodes` field".into()))?
.as_array()
.ok_or_else(|| SuiSubscriberError::MalformedEvent("invalid `new_nodes` field".into()))?;

Ok(new_nodes.iter().find_map(|n| {
let node_id = n.get("node_id")?.get("inner")?.as_u64()?;
let index = n.get("order")?.as_u64()?;
if node_id == id {
Some(index)
} else {
None
}
}))
}

fn extract_ticket_id(value: &Value) -> Result<&str, SuiSubscriberError> {
value
.get("ticket_id")
.ok_or(SuiSubscriberError::MalformedEvent(
"missing `ticket_id` field".into(),
))?
.as_str()
.ok_or(SuiSubscriberError::MalformedEvent(
"invalid `ticket_id` field".into(),
))
}

#[derive(Debug, Error)]
pub enum SuiSubscriberError {
#[error("Sui Builder error: `{0}`")]
Expand Down
16 changes: 9 additions & 7 deletions atoma-inference/src/models/candle/stable_diffusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ pub struct StableDiffusionInput {
pub img2img_strength: f64,

/// The seed to use when generating random samples.
pub random_seed: Option<u64>,
pub random_seed: Option<u32>,
}

pub struct StableDiffusionLoadData {
Expand Down Expand Up @@ -76,7 +76,7 @@ pub struct StableDiffusion {

impl ModelTrait for StableDiffusion {
type Input = StableDiffusionInput;
type Output = Vec<(Vec<u8>, usize, usize)>;
type Output = (Vec<u8>, usize, usize);
type LoadData = StableDiffusionLoadData;

fn fetch(
Expand Down Expand Up @@ -271,7 +271,7 @@ impl ModelTrait for StableDiffusion {

let scheduler = self.config.build_scheduler(n_steps)?;
if let Some(seed) = input.random_seed {
self.device.set_seed(seed)?;
self.device.set_seed(seed as u64)?;
}
let use_guide_scale = guidance_scale > 1.0;

Expand Down Expand Up @@ -325,7 +325,7 @@ impl ModelTrait for StableDiffusion {
ModelType::StableDiffusionTurbo => 0.13025,
_ => bail!("Invalid stable diffusion model type"),
};
let mut res = Vec::new();
let mut res = (vec![], 0, 0);

for idx in 0..input.num_samples {
let timesteps = scheduler.timesteps();
Expand Down Expand Up @@ -400,8 +400,10 @@ impl ModelTrait for StableDiffusion {
save_image(&image, "./image.png").unwrap();
}
save_tensor_to_file(&image, "tensor4")?;
res.push(convert_to_image(&image)?);

res = convert_to_image(&image)?;
}

Ok(res)
}
}
Expand Down Expand Up @@ -694,8 +696,8 @@ mod tests {
let output = model.run(input).expect("Failed to run inference");
println!("{:?}", output);

assert_eq!(output[0].1, 512);
assert_eq!(output[0].2, 512);
assert_eq!(output.1, 512);
assert_eq!(output.2, 512);

std::fs::remove_dir_all(cache_dir).unwrap();
std::fs::remove_file("tensor1").unwrap();
Expand Down
2 changes: 1 addition & 1 deletion atoma-inference/src/models/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ pub struct StableDiffusionRequest {
pub img2img_strength: f64,

/// The seed to use when generating random samples.
pub random_seed: Option<u64>,
pub random_seed: Option<u32>,

pub sampled_nodes: Vec<Vec<u8>>,
}
Expand Down
Loading

0 comments on commit dd7829d

Please sign in to comment.