Skip to content

Commit

Permalink
resolve a set of bugs for stable diffusion
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgeantonio21 committed May 9, 2024
1 parent 7c13e93 commit 708fa72
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 92 deletions.
19 changes: 16 additions & 3 deletions atoma-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use sui_sdk::{
};
use thiserror::Error;
use tokio::sync::mpsc;
use tracing::{debug, info};
use tracing::{debug, error, info};

use crate::config::AtomaSuiClientConfig;

Expand Down Expand Up @@ -64,10 +64,21 @@ impl AtomaSuiClient {
Some(text) => text.as_bytes().to_owned(),
None => {
if let Some(array) = data.as_array() {
array
if !array.is_empty() {
let mut img = array[0].as_array().ok_or(AtomaSuiClientError::MissingOutputData)?
.iter()
.map(|b| b.as_u64().unwrap() as u8)
.collect::<Vec<_>>()
.collect::<Vec<_>>();
let height = data[1].as_u64().unwrap().to_le_bytes();
let width = data[2].as_u64().unwrap().to_le_bytes();
img.extend([height, width].concat());
img
}
else {
error!("Empty image generation");
return Err(AtomaSuiClientError::MissingOutputData);
}

} else {
return Err(AtomaSuiClientError::FailedResponseJsonParsing);
}
Expand Down Expand Up @@ -167,4 +178,6 @@ pub enum AtomaSuiClientError {
InvalidSampledNode,
#[error("Invalid request id")]
InvalidRequestId,
#[error("Missing output data")]
MissingOutputData,
}
48 changes: 48 additions & 0 deletions atoma-event-subscribe/sui/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,50 @@
use std::str::FromStr;

use subscriber::SuiSubscriberError;

pub mod config;
pub mod subscriber;

pub enum AtomaEvent {
DisputeEvent,
FirstSubmissionEvent,
NewlySampledNodesEvent,
NodeRegisteredEvent,
NodeSubscribedToModelEvent,
SettledEvent,
Text2ImagePromptEvent,
Text2TextPromptEvent,
}

impl FromStr for AtomaEvent {
type Err = SuiSubscriberError;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"DisputeEvent" => Ok(Self::DisputeEvent),
"FirstSubmissionEvent" => Ok(Self::FirstSubmissionEvent),
"NewlySampledNodesEvent" => Ok(Self::NewlySampledNodesEvent),
"NodeRegisteredEvent" => Ok(Self::NodeRegisteredEvent),
"NodeSubscribedToModelEvent" => Ok(Self::NodeSubscribedToModelEvent),
"SettledEvent" => Ok(Self::SettledEvent),
"Text2ImagePromptEvent" => Ok(Self::Text2ImagePromptEvent),
"Text2TextPromptEvent" => Ok(Self::Text2TextPromptEvent),
_ => panic!("Invalid `AtomaEvent` string"),
}
}
}

impl ToString for AtomaEvent {
fn to_string(&self) -> String {
match self {
Self::DisputeEvent => "DisputeEvent".into(),
Self::FirstSubmissionEvent => "FirstSubmissionEvent".into(),
Self::NewlySampledNodesEvent => "NewlySampledNodesEvent".into(),
Self::NodeRegisteredEvent => "NodeRegisteredEvent".into(),
Self::NodeSubscribedToModelEvent => "NodeSubscribedToModelEvent".into(),
Self::SettledEvent => "SettledEvent".into(),
Self::Text2ImagePromptEvent => "Text2ImagePromptEvent".into(),
Self::Text2TextPromptEvent => "Text2TextPromptEvent".into(),
}
}
}
142 changes: 69 additions & 73 deletions atoma-event-subscribe/sui/src/subscriber.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::str::FromStr;
use std::{fmt::Write, path::Path, time::Duration};

use futures::StreamExt;
Expand All @@ -10,6 +11,7 @@ use tokio::sync::mpsc;
use tracing::{debug, error, info};

use crate::config::SuiSubscriberConfig;
use crate::AtomaEvent;
use atoma_types::{Request, SmallId, NON_SAMPLED_NODE_ERR};

const REQUEST_ID_HEX_SIZE: usize = 64;
Expand Down Expand Up @@ -87,15 +89,16 @@ impl SuiSubscriber {

impl SuiSubscriber {
async fn handle_event(&self, event: SuiEvent) -> Result<(), SuiSubscriberError> {
match event.type_.name.as_str() {
"DisputeEvent" => todo!(),
"FirstSubmissionEvent"
| "NodeRegisteredEvent"
| "NodeSubscribedToModelEvent"
| "SettledEvent" => {
info!("Received event: {}", event.type_.name.as_str());
let event_type = event.type_.name.as_str();
match AtomaEvent::from_str(event_type)? {
AtomaEvent::DisputeEvent => todo!(),
AtomaEvent::FirstSubmissionEvent
| AtomaEvent::NodeRegisteredEvent
| AtomaEvent::NodeSubscribedToModelEvent
| AtomaEvent::SettledEvent => {
info!("Received event: {}", event_type);
}
"NewlySampledNodesEvent" => {
AtomaEvent::NewlySampledNodesEvent => {
let event_data = event.parsed_json;
match self.handle_newly_sampled_nodes_event(event_data).await {
Ok(()) => {}
Expand All @@ -104,42 +107,27 @@ impl SuiSubscriber {
}
}
}
"Text2TextPromptEvent" => {
AtomaEvent::Text2ImagePromptEvent | AtomaEvent::Text2TextPromptEvent => {
let event_data = event.parsed_json;
match self.handle_text2text_prompt_event(event_data).await {
match self.handle_prompt_event(event_data).await {
Ok(()) => {}
Err(SuiSubscriberError::TypeConversionError(err)) => {
if err.to_string().contains(NON_SAMPLED_NODE_ERR) {
info!("Node has not been sampled for current request");
info!("Node has not been sampled for current request")
} else {
error!("Failed to process request, with error: {err}")
}
}
Err(err) => {
error!("Failed to process request, with error: {err}");
error!("Failed to process request, with error: {err}")
}
}
}
"Text2ImagePromptEvent" => {
let event_data = event.parsed_json;
self.handle_text2image_prompt_event(event_data).await?;
}
_ => panic!("Invalid Event type found!"),
}
Ok(())
}

async fn handle_text2image_prompt_event(
&self,
_event_data: Value,
) -> Result<(), SuiSubscriberError> {
Ok(())
}

async fn handle_text2text_prompt_event(
&self,
event_data: Value,
) -> Result<(), SuiSubscriberError> {
async fn handle_prompt_event(&self, event_data: Value) -> Result<(), SuiSubscriberError> {
debug!("event data: {}", event_data);
let request = Request::try_from((self.id, event_data))?;
info!("Received new request: {:?}", request);
Expand All @@ -165,54 +153,11 @@ impl SuiSubscriber {
event_data: Value,
) -> Result<(), SuiSubscriberError> {
debug!("event data: {}", event_data);
let newly_sampled_nodes = event_data
.get("new_nodes")
.ok_or(SuiSubscriberError::MalformedEvent(
"missing `new_nodes` field".into(),
))?
.as_array()
.ok_or(SuiSubscriberError::MalformedEvent(
"invalid `new_nodes` field".into(),
))?
.iter()
.map(|n| {
let node_id = n
.get("node_id")
.ok_or(SuiSubscriberError::MalformedEvent(
"missing `node_id` field".into(),
))?
.get("inner")
.ok_or(SuiSubscriberError::MalformedEvent(
"invalid `inner` field".into(),
))?
.as_u64()
.ok_or(SuiSubscriberError::MalformedEvent(
"invalid `node_id` `inner` field".into(),
))?;
let index = n
.get("order")
.ok_or(SuiSubscriberError::MalformedEvent(
"missing `order` field".into(),
))?
.as_u64()
.ok_or(SuiSubscriberError::MalformedEvent(
"invalid `order` field".into(),
))?;
Ok::<_, SuiSubscriberError>((node_id, index))
})
.collect::<Result<Vec<_>, _>>()?;
let newly_sampled_nodes = extract_newly_sampled_nodes(&event_data)?;
if let Some((_, sampled_node_index)) =
newly_sampled_nodes.iter().find(|(id, _)| id == &self.id)
{
let ticket_id = event_data
.get("ticket_id")
.ok_or(SuiSubscriberError::MalformedEvent(
"missing `ticket_id` field".into(),
))?
.as_str()
.ok_or(SuiSubscriberError::MalformedEvent(
"invalid `ticket_id` field".into(),
))?;
let ticket_id = extract_ticket_id(&event_data)?;
let data = self
.sui_client
.event_api()
Expand Down Expand Up @@ -250,6 +195,57 @@ impl SuiSubscriber {
}
}

fn extract_newly_sampled_nodes(value: &Value) -> Result<Vec<(u64, u64)>, SuiSubscriberError> {
value
.get("new_nodes")
.ok_or(SuiSubscriberError::MalformedEvent(
"missing `new_nodes` field".into(),
))?
.as_array()
.ok_or(SuiSubscriberError::MalformedEvent(
"invalid `new_nodes` field".into(),
))?
.iter()
.map(|n| {
let node_id = n
.get("node_id")
.ok_or(SuiSubscriberError::MalformedEvent(
"missing `node_id` field".into(),
))?
.get("inner")
.ok_or(SuiSubscriberError::MalformedEvent(
"invalid `inner` field".into(),
))?
.as_u64()
.ok_or(SuiSubscriberError::MalformedEvent(
"invalid `node_id` `inner` field".into(),
))?;
let index = n
.get("order")
.ok_or(SuiSubscriberError::MalformedEvent(
"missing `order` field".into(),
))?
.as_u64()
.ok_or(SuiSubscriberError::MalformedEvent(
"invalid `order` field".into(),
))?;
Ok::<_, SuiSubscriberError>((node_id, index))
})
.collect::<Result<Vec<_>, _>>()
}

fn extract_ticket_id(value: &Value) -> Result<&str, SuiSubscriberError> {
value
.get("ticket_id")
.ok_or(SuiSubscriberError::MalformedEvent(
"missing `ticket_id` field".into(),
))?
.as_str()
.ok_or(SuiSubscriberError::MalformedEvent(
"invalid `ticket_id` field".into(),
))
}

#[derive(Debug, Error)]
pub enum SuiSubscriberError {
#[error("Sui Builder error: `{0}`")]
Expand Down
16 changes: 9 additions & 7 deletions atoma-inference/src/models/candle/stable_diffusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ pub struct StableDiffusionInput {
pub img2img_strength: f64,

/// The seed to use when generating random samples.
pub random_seed: Option<u64>,
pub random_seed: Option<u32>,
}

pub struct StableDiffusionLoadData {
Expand Down Expand Up @@ -76,7 +76,7 @@ pub struct StableDiffusion {

impl ModelTrait for StableDiffusion {
type Input = StableDiffusionInput;
type Output = Vec<(Vec<u8>, usize, usize)>;
type Output = (Vec<u8>, usize, usize);
type LoadData = StableDiffusionLoadData;

fn fetch(
Expand Down Expand Up @@ -271,7 +271,7 @@ impl ModelTrait for StableDiffusion {

let scheduler = self.config.build_scheduler(n_steps)?;
if let Some(seed) = input.random_seed {
self.device.set_seed(seed)?;
self.device.set_seed(seed as u64)?;
}
let use_guide_scale = guidance_scale > 1.0;

Expand Down Expand Up @@ -325,7 +325,7 @@ impl ModelTrait for StableDiffusion {
ModelType::StableDiffusionTurbo => 0.13025,
_ => bail!("Invalid stable diffusion model type"),
};
let mut res = Vec::new();
let mut res = (vec![], 0, 0);

for idx in 0..input.num_samples {
let timesteps = scheduler.timesteps();
Expand Down Expand Up @@ -400,8 +400,10 @@ impl ModelTrait for StableDiffusion {
save_image(&image, "./image.png").unwrap();
}
save_tensor_to_file(&image, "tensor4")?;
res.push(convert_to_image(&image)?);

res = convert_to_image(&image)?;
}

Ok(res)
}
}
Expand Down Expand Up @@ -694,8 +696,8 @@ mod tests {
let output = model.run(input).expect("Failed to run inference");
println!("{:?}", output);

assert_eq!(output[0].1, 512);
assert_eq!(output[0].2, 512);
assert_eq!(output.1, 512);
assert_eq!(output.2, 512);

std::fs::remove_dir_all(cache_dir).unwrap();
std::fs::remove_file("tensor1").unwrap();
Expand Down
2 changes: 1 addition & 1 deletion atoma-inference/src/models/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ pub struct StableDiffusionRequest {
pub img2img_strength: f64,

/// The seed to use when generating random samples.
pub random_seed: Option<u64>,
pub random_seed: Option<u32>,

pub sampled_nodes: Vec<Vec<u8>>,
}
Expand Down
Loading

0 comments on commit 708fa72

Please sign in to comment.