Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

handle different events #69

Merged
merged 6 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion atoma-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ impl AtomaSuiClient {
for event in events.data.iter() {
let event_value = &event.parsed_json;
if let Some(true) = event_value["is_first_submission"].as_bool() {
let _ = self.output_manager_tx.send((tx_digest, response)).await?;
self.output_manager_tx.send((tx_digest, response)).await?;
break; // we don't need to check other events, as at this point the node knows it has been selected for
}
}
Expand Down
94 changes: 66 additions & 28 deletions atoma-event-subscribe/sui/src/subscriber.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use std::{path::Path, time::Duration};
use std::{fmt::Write, path::Path, time::Duration};

use futures::StreamExt;
use sui_sdk::rpc_types::EventFilter;
use serde_json::Value;
use sui_sdk::rpc_types::{EventFilter, SuiEvent};
use sui_sdk::types::base_types::{ObjectID, ObjectIDParseError};
use sui_sdk::{SuiClient, SuiClientBuilder};
use thiserror::Error;
Expand All @@ -11,6 +12,8 @@ use tracing::{debug, error, info};
use crate::config::SuiSubscriberConfig;
use atoma_types::{Request, SmallId};

const REQUEST_ID_HEX_SIZE: usize = 64;

pub struct SuiSubscriber {
id: SmallId,
sui_client: SuiClient,
Expand Down Expand Up @@ -68,34 +71,11 @@ impl SuiSubscriber {

pub async fn subscribe(self) -> Result<(), SuiSubscriberError> {
let event_api = self.sui_client.event_api();
let mut subscribe_event = event_api.subscribe_event(self.filter).await?;
let mut subscribe_event = event_api.subscribe_event(self.filter.clone()).await?;
info!("Starting event while loop");
while let Some(event) = subscribe_event.next().await {
match event {
Ok(event) => {
let event_data = event.parsed_json;
if event_data["is_first_submission"].as_bool().is_some() {
continue;
}
debug!("event data: {}", event_data);
let request = Request::try_from(event_data)?;
info!("Received new request: {:?}", request);
let request_id = request
.id()
.iter()
.map(|b| format!("{:02x}", b))
.collect::<String>();
let sampled_nodes = request.sampled_nodes();
if sampled_nodes.contains(&self.id) {
info!(
"Current node has been sampled for request with id: {}",
request_id
);
self.event_sender.send(request).await?;
} else {
info!("Current node has not been sampled for request with id: {}, ignoring it..", request_id);
}
}
Ok(event) => self.handle_event(event).await?,
Err(e) => {
error!("Failed to get event with error: {e}");
}
Expand All @@ -105,6 +85,64 @@ impl SuiSubscriber {
}
}

impl SuiSubscriber {
async fn handle_event(&self, event: SuiEvent) -> Result<(), SuiSubscriberError> {
match event.type_.name.as_str() {
"DisputeEvent" => todo!(),
"FirstSubmission" | "NodeRegisteredEvent" | "NodeSubscribedToModelEvent" => {}
"Text2TextPromptEvent" | "NewlySampledNodesEvent" => {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's NewlySampledNodesEvent and why does it trigger text2text event?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is a node that times out for a request, that means it becomes unresponsive, we then resample new nodes.

let event_data = event.parsed_json;
self.handle_text2text_prompt_event(event_data).await?;
}
"Text2ImagePromptEvent" => {
let event_data = event.parsed_json;
self.handle_text2image_prompt_event(event_data).await?;
}
_ => panic!("Invalid Event type found!"),
}
Ok(())
}

async fn handle_text2image_prompt_event(
&self,
_event_data: Value,
) -> Result<(), SuiSubscriberError> {
Ok(())
}

async fn handle_text2text_prompt_event(
&self,
event_data: Value,
) -> Result<(), SuiSubscriberError> {
debug!("event data: {}", event_data);
let request = Request::try_from(event_data)?;
info!("Received new request: {:?}", request);
let request_id =
request
.id()
.iter()
.fold(String::with_capacity(REQUEST_ID_HEX_SIZE), |mut acc, &b| {
write!(acc, "{:02x}", b).expect("Failed to write to request_id");
acc
});
info!("request_id: {request_id}");
let sampled_nodes = request.sampled_nodes();
if sampled_nodes.contains(&self.id) {
info!(
"Current node has been sampled for request with id: {}",
request_id
);
self.event_sender.send(request).await.map_err(Box::new)?;
} else {
info!(
"Current node has not been sampled for request with id: {}, ignoring it..",
request_id
);
}
Ok(())
}
}

#[derive(Debug, Error)]
pub enum SuiSubscriberError {
#[error("Sui Builder error: `{0}`")]
Expand All @@ -114,7 +152,7 @@ pub enum SuiSubscriberError {
#[error("Object ID parse error: `{0}`")]
ObjectIDParseError(#[from] ObjectIDParseError),
#[error("Sender error: `{0}`")]
SendError(#[from] mpsc::error::SendError<Request>),
SendError(#[from] Box<mpsc::error::SendError<Request>>),
#[error("Type conversion error: `{0}`")]
TypeConversionError(#[from] anyhow::Error),
}
20 changes: 4 additions & 16 deletions atoma-inference/src/model_thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,27 +30,15 @@ pub struct ModelThreadCommand {
#[derive(Debug, Error)]
pub enum ModelThreadError {
#[error("Model thread shutdown: `{0}`")]
ApiError(ApiError),
ApiError(#[from] ApiError),
#[error("Model thread shutdown: `{0}`")]
ModelError(ModelError),
ModelError(#[from] ModelError),
#[error("Core thread shutdown: `{0}`")]
Shutdown(RecvError),
#[error("Serde error: `{0}`")]
SerdeError(#[from] serde_json::Error),
}

impl From<ModelError> for ModelThreadError {
fn from(error: ModelError) -> Self {
Self::ModelError(error)
}
}

impl From<ApiError> for ModelThreadError {
fn from(error: ApiError) -> Self {
Self::ApiError(error)
}
}

pub struct ModelThreadHandle {
sender: mpsc::Sender<ModelThreadCommand>,
join_handle: std::thread::JoinHandle<Result<(), ModelThreadError>>,
Expand Down Expand Up @@ -79,8 +67,8 @@ where
let ModelThreadCommand { request, sender } = command;
let request_id = request.id();
let sampled_nodes = request.sampled_nodes();
let body = request.body();
let model_input = serde_json::from_value(body)?;
let params = request.params();
let model_input = M::Input::try_from(params)?;
let model_output = self.model.run(model_input)?;
let output = serde_json::to_value(model_output)?;
let response = Response::new(request_id, sampled_nodes, output);
Expand Down
7 changes: 5 additions & 2 deletions atoma-inference/src/models/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use std::path::PathBuf;

use ::candle::{DTypeParseError, Error as CandleError};
use serde::{de::DeserializeOwned, Serialize};
use atoma_types::PromptParams;
use serde::Serialize;
use thiserror::Error;

use self::{config::ModelConfig, types::ModelType};
Expand All @@ -14,7 +15,7 @@ pub mod types;
pub type ModelId = String;

pub trait ModelTrait {
type Input: DeserializeOwned;
type Input: TryFrom<PromptParams, Error = ModelError>;
type Output: Serialize;
type LoadData;

Expand Down Expand Up @@ -66,6 +67,8 @@ pub enum ModelError {
DTypeParseError(#[from] DTypeParseError),
#[error("Invalid model type: `{0}`")]
InvalidModelType(String),
#[error("Invalid model input")]
InvalidModelInput,
}

#[macro_export]
Expand Down
44 changes: 44 additions & 0 deletions atoma-inference/src/models/types.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{fmt::Display, path::PathBuf, str::FromStr};

use atoma_types::PromptParams;
use candle::{DType, Device};
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -359,6 +360,26 @@ impl TextModelInput {
}
}

impl TryFrom<PromptParams> for TextModelInput {
type Error = ModelError;

fn try_from(value: PromptParams) -> Result<Self, Self::Error> {
match value {
PromptParams::Text2TextPromptParams(p) => Ok(Self {
prompt: p.prompt(),
temperature: p.temperature(),
random_seed: p.random_seed(),
repeat_penalty: p.repeat_penalty(),
repeat_last_n: p.repeat_last_n().try_into().unwrap(),
max_tokens: p.max_tokens().try_into().unwrap(),
top_k: p.top_k().map(|t| t.try_into().unwrap()),
top_p: p.top_p(),
}),
PromptParams::Text2ImagePromptParams(_) => Err(ModelError::InvalidModelInput),
}
}
}

#[derive(Serialize)]
pub struct TextModelOutput {
pub text: String,
Expand Down Expand Up @@ -455,6 +476,29 @@ impl Request for StableDiffusionRequest {
}
}

impl TryFrom<PromptParams> for StableDiffusionInput {
type Error = ModelError;

fn try_from(value: PromptParams) -> Result<Self, Self::Error> {
match value {
PromptParams::Text2ImagePromptParams(p) => Ok(Self {
prompt: p.prompt(),
uncond_prompt: p.uncond_prompt(),
height: p.height().map(|t| t.try_into().unwrap()),
width: p.width().map(|t| t.try_into().unwrap()),
n_steps: p.n_steps().map(|t| t.try_into().unwrap()),
num_samples: p.num_samples() as i64,
model: p.model(),
guidance_scale: p.guidance_scale(),
img2img: p.img2img(),
img2img_strength: p.img2img_strength(),
random_seed: p.random_seed(),
}),
_ => Err(ModelError::InvalidModelInput),
}
}
}

#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct StableDiffusionResponse {
pub output: Vec<(Vec<u8>, usize, usize)>,
Expand Down
15 changes: 13 additions & 2 deletions atoma-inference/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,11 @@ pub enum ModelServiceError {

#[cfg(test)]
mod tests {
use atoma_types::PromptParams;
use std::io::Write;
use toml::{toml, Value};

use crate::models::{config::ModelConfig, ModelTrait, Request, Response};
use crate::models::{config::ModelConfig, ModelError, ModelTrait, Request, Response};

use super::*;

Expand All @@ -150,8 +151,18 @@ mod tests {
#[derive(Clone)]
struct TestModelInstance {}

struct MockInput {}

impl TryFrom<PromptParams> for MockInput {
type Error = ModelError;

fn try_from(_: PromptParams) -> Result<Self, Self::Error> {
Ok(Self {})
}
}

impl ModelTrait for TestModelInstance {
type Input = ();
type Input = MockInput;
type Output = ();
type LoadData = ();

Expand Down
42 changes: 36 additions & 6 deletions atoma-inference/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ use crate::models::{config::ModelConfig, types::ModelType, ModelError, ModelTrai
use std::{path::PathBuf, time::Duration};

mod prompts;
use atoma_types::Text2TextPromptParams;
use prompts::PROMPTS;
use serde::Serialize;

use std::{collections::HashMap, sync::mpsc};

use atoma_types::Request;
use atoma_types::{PromptParams, Request};
use futures::{stream::FuturesUnordered, StreamExt};
use reqwest::Client;
use serde_json::json;
Expand All @@ -24,9 +26,24 @@ struct TestModel {
duration: Duration,
}

#[derive(Debug, Serialize)]
struct MockInputOutput {
id: u64,
}

impl TryFrom<PromptParams> for MockInputOutput {
type Error = ModelError;

fn try_from(value: PromptParams) -> Result<Self, Self::Error> {
Ok(Self {
id: value.into_text2text_prompt_params().unwrap().max_tokens(),
})
}
}

impl ModelTrait for TestModel {
type Input = Value;
type Output = Value;
type Input = MockInputOutput;
type Output = MockInputOutput;
type LoadData = Duration;

fn fetch(
Expand All @@ -51,7 +68,7 @@ impl ModelTrait for TestModel {
fn run(&mut self, input: Self::Input) -> Result<Self::Output, ModelError> {
std::thread::sleep(self.duration);
println!(
"Finished waiting time for {:?} and input = {}",
"Finished waiting time for {:?} and input = {:?}",
self.duration, input
);
Ok(input)
Expand Down Expand Up @@ -100,14 +117,27 @@ async fn test_mock_model_thread() {
for i in 0..NUM_REQUESTS {
for sender in model_thread_dispatcher.model_senders.values() {
let (response_sender, response_receiver) = oneshot::channel();
let request = Request::new(vec![0], vec![], json!(i));
let max_tokens = i as u64;
let prompt_params = PromptParams::Text2TextPromptParams(Text2TextPromptParams::new(
"".to_string(),
"".to_string(),
0.0,
1,
1.0,
0,
max_tokens,
Some(0),
Some(1.0),
));
let request = Request::new(vec![0], vec![], prompt_params);
let command = ModelThreadCommand {
request: request.clone(),
sender: response_sender,
};
sender.send(command).expect("Failed to send command");
responses.push(response_receiver);
should_be_received_responses.push(request.body().as_u64().unwrap());
should_be_received_responses
.push(MockInputOutput::try_from(request.params()).unwrap().id);
}
}

Expand Down
Loading