diff --git a/atoma-client/src/client.rs b/atoma-client/src/client.rs index fa996ce3..eacc4a0f 100644 --- a/atoma-client/src/client.rs +++ b/atoma-client/src/client.rs @@ -140,7 +140,7 @@ impl AtomaSuiClient { for event in events.data.iter() { let event_value = &event.parsed_json; if let Some(true) = event_value["is_first_submission"].as_bool() { - let _ = self.output_manager_tx.send((tx_digest, response)).await?; + self.output_manager_tx.send((tx_digest, response)).await?; break; // we don't need to check other events, as at this point the node knows it has been selected for } } diff --git a/atoma-event-subscribe/sui/src/subscriber.rs b/atoma-event-subscribe/sui/src/subscriber.rs index 55a2a258..7a91e7a2 100644 --- a/atoma-event-subscribe/sui/src/subscriber.rs +++ b/atoma-event-subscribe/sui/src/subscriber.rs @@ -1,7 +1,8 @@ -use std::{path::Path, time::Duration}; +use std::{fmt::Write, path::Path, time::Duration}; use futures::StreamExt; -use sui_sdk::rpc_types::EventFilter; +use serde_json::Value; +use sui_sdk::rpc_types::{EventFilter, SuiEvent}; use sui_sdk::types::base_types::{ObjectID, ObjectIDParseError}; use sui_sdk::{SuiClient, SuiClientBuilder}; use thiserror::Error; @@ -11,6 +12,8 @@ use tracing::{debug, error, info}; use crate::config::SuiSubscriberConfig; use atoma_types::{Request, SmallId}; +const REQUEST_ID_HEX_SIZE: usize = 64; + pub struct SuiSubscriber { id: SmallId, sui_client: SuiClient, @@ -68,34 +71,11 @@ impl SuiSubscriber { pub async fn subscribe(self) -> Result<(), SuiSubscriberError> { let event_api = self.sui_client.event_api(); - let mut subscribe_event = event_api.subscribe_event(self.filter).await?; + let mut subscribe_event = event_api.subscribe_event(self.filter.clone()).await?; info!("Starting event while loop"); while let Some(event) = subscribe_event.next().await { match event { - Ok(event) => { - let event_data = event.parsed_json; - if event_data["is_first_submission"].as_bool().is_some() { - continue; - } - debug!("event data: {}", event_data); - let request = Request::try_from(event_data)?; - info!("Received new request: {:?}", request); - let request_id = request - .id() - .iter() - .map(|b| format!("{:02x}", b)) - .collect::(); - let sampled_nodes = request.sampled_nodes(); - if sampled_nodes.contains(&self.id) { - info!( - "Current node has been sampled for request with id: {}", - request_id - ); - self.event_sender.send(request).await?; - } else { - info!("Current node has not been sampled for request with id: {}, ignoring it..", request_id); - } - } + Ok(event) => self.handle_event(event).await?, Err(e) => { error!("Failed to get event with error: {e}"); } @@ -105,6 +85,64 @@ impl SuiSubscriber { } } +impl SuiSubscriber { + async fn handle_event(&self, event: SuiEvent) -> Result<(), SuiSubscriberError> { + match event.type_.name.as_str() { + "DisputeEvent" => todo!(), + "FirstSubmission" | "NodeRegisteredEvent" | "NodeSubscribedToModelEvent" => {} + "Text2TextPromptEvent" | "NewlySampledNodesEvent" => { + let event_data = event.parsed_json; + self.handle_text2text_prompt_event(event_data).await?; + } + "Text2ImagePromptEvent" => { + let event_data = event.parsed_json; + self.handle_text2image_prompt_event(event_data).await?; + } + _ => panic!("Invalid Event type found!"), + } + Ok(()) + } + + async fn handle_text2image_prompt_event( + &self, + _event_data: Value, + ) -> Result<(), SuiSubscriberError> { + Ok(()) + } + + async fn handle_text2text_prompt_event( + &self, + event_data: Value, + ) -> Result<(), SuiSubscriberError> { + debug!("event data: {}", event_data); + let request = Request::try_from(event_data)?; + info!("Received new request: {:?}", request); + let request_id = + request + .id() + .iter() + .fold(String::with_capacity(REQUEST_ID_HEX_SIZE), |mut acc, &b| { + write!(acc, "{:02x}", b).expect("Failed to write to request_id"); + acc + }); + info!("request_id: {request_id}"); + let sampled_nodes = request.sampled_nodes(); + if sampled_nodes.contains(&self.id) { + info!( + "Current node has been sampled for request with id: {}", + request_id + ); + self.event_sender.send(request).await.map_err(Box::new)?; + } else { + info!( + "Current node has not been sampled for request with id: {}, ignoring it..", + request_id + ); + } + Ok(()) + } +} + #[derive(Debug, Error)] pub enum SuiSubscriberError { #[error("Sui Builder error: `{0}`")] @@ -114,7 +152,7 @@ pub enum SuiSubscriberError { #[error("Object ID parse error: `{0}`")] ObjectIDParseError(#[from] ObjectIDParseError), #[error("Sender error: `{0}`")] - SendError(#[from] mpsc::error::SendError), + SendError(#[from] Box>), #[error("Type conversion error: `{0}`")] TypeConversionError(#[from] anyhow::Error), } diff --git a/atoma-inference/src/model_thread.rs b/atoma-inference/src/model_thread.rs index 242225c8..25580ad4 100644 --- a/atoma-inference/src/model_thread.rs +++ b/atoma-inference/src/model_thread.rs @@ -30,27 +30,15 @@ pub struct ModelThreadCommand { #[derive(Debug, Error)] pub enum ModelThreadError { #[error("Model thread shutdown: `{0}`")] - ApiError(ApiError), + ApiError(#[from] ApiError), #[error("Model thread shutdown: `{0}`")] - ModelError(ModelError), + ModelError(#[from] ModelError), #[error("Core thread shutdown: `{0}`")] Shutdown(RecvError), #[error("Serde error: `{0}`")] SerdeError(#[from] serde_json::Error), } -impl From for ModelThreadError { - fn from(error: ModelError) -> Self { - Self::ModelError(error) - } -} - -impl From for ModelThreadError { - fn from(error: ApiError) -> Self { - Self::ApiError(error) - } -} - pub struct ModelThreadHandle { sender: mpsc::Sender, join_handle: std::thread::JoinHandle>, @@ -79,8 +67,8 @@ where let ModelThreadCommand { request, sender } = command; let request_id = request.id(); let sampled_nodes = request.sampled_nodes(); - let body = request.body(); - let model_input = serde_json::from_value(body)?; + let params = request.params(); + let model_input = M::Input::try_from(params)?; let model_output = self.model.run(model_input)?; let output = serde_json::to_value(model_output)?; let response = Response::new(request_id, sampled_nodes, output); diff --git a/atoma-inference/src/models/mod.rs b/atoma-inference/src/models/mod.rs index 99f81538..19e4cdef 100644 --- a/atoma-inference/src/models/mod.rs +++ b/atoma-inference/src/models/mod.rs @@ -1,7 +1,8 @@ use std::path::PathBuf; use ::candle::{DTypeParseError, Error as CandleError}; -use serde::{de::DeserializeOwned, Serialize}; +use atoma_types::PromptParams; +use serde::Serialize; use thiserror::Error; use self::{config::ModelConfig, types::ModelType}; @@ -14,7 +15,7 @@ pub mod types; pub type ModelId = String; pub trait ModelTrait { - type Input: DeserializeOwned; + type Input: TryFrom; type Output: Serialize; type LoadData; @@ -66,6 +67,8 @@ pub enum ModelError { DTypeParseError(#[from] DTypeParseError), #[error("Invalid model type: `{0}`")] InvalidModelType(String), + #[error("Invalid model input")] + InvalidModelInput, } #[macro_export] diff --git a/atoma-inference/src/models/types.rs b/atoma-inference/src/models/types.rs index 7085dc47..3d0dfc08 100644 --- a/atoma-inference/src/models/types.rs +++ b/atoma-inference/src/models/types.rs @@ -1,5 +1,6 @@ use std::{fmt::Display, path::PathBuf, str::FromStr}; +use atoma_types::PromptParams; use candle::{DType, Device}; use serde::{Deserialize, Serialize}; @@ -359,6 +360,26 @@ impl TextModelInput { } } +impl TryFrom for TextModelInput { + type Error = ModelError; + + fn try_from(value: PromptParams) -> Result { + match value { + PromptParams::Text2TextPromptParams(p) => Ok(Self { + prompt: p.prompt(), + temperature: p.temperature(), + random_seed: p.random_seed(), + repeat_penalty: p.repeat_penalty(), + repeat_last_n: p.repeat_last_n().try_into().unwrap(), + max_tokens: p.max_tokens().try_into().unwrap(), + top_k: p.top_k().map(|t| t.try_into().unwrap()), + top_p: p.top_p(), + }), + PromptParams::Text2ImagePromptParams(_) => Err(ModelError::InvalidModelInput), + } + } +} + #[derive(Serialize)] pub struct TextModelOutput { pub text: String, @@ -455,6 +476,29 @@ impl Request for StableDiffusionRequest { } } +impl TryFrom for StableDiffusionInput { + type Error = ModelError; + + fn try_from(value: PromptParams) -> Result { + match value { + PromptParams::Text2ImagePromptParams(p) => Ok(Self { + prompt: p.prompt(), + uncond_prompt: p.uncond_prompt(), + height: p.height().map(|t| t.try_into().unwrap()), + width: p.width().map(|t| t.try_into().unwrap()), + n_steps: p.n_steps().map(|t| t.try_into().unwrap()), + num_samples: p.num_samples() as i64, + model: p.model(), + guidance_scale: p.guidance_scale(), + img2img: p.img2img(), + img2img_strength: p.img2img_strength(), + random_seed: p.random_seed(), + }), + _ => Err(ModelError::InvalidModelInput), + } + } +} + #[derive(Clone, Debug, Deserialize, Serialize)] pub struct StableDiffusionResponse { pub output: Vec<(Vec, usize, usize)>, diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index 8f099f4f..0f9dbfa6 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -120,10 +120,11 @@ pub enum ModelServiceError { #[cfg(test)] mod tests { + use atoma_types::PromptParams; use std::io::Write; use toml::{toml, Value}; - use crate::models::{config::ModelConfig, ModelTrait, Request, Response}; + use crate::models::{config::ModelConfig, ModelError, ModelTrait, Request, Response}; use super::*; @@ -150,8 +151,18 @@ mod tests { #[derive(Clone)] struct TestModelInstance {} + struct MockInput {} + + impl TryFrom for MockInput { + type Error = ModelError; + + fn try_from(_: PromptParams) -> Result { + Ok(Self {}) + } + } + impl ModelTrait for TestModelInstance { - type Input = (); + type Input = MockInput; type Output = (); type LoadData = (); diff --git a/atoma-inference/src/tests/mod.rs b/atoma-inference/src/tests/mod.rs index b38ba8f6..ff2eedae 100644 --- a/atoma-inference/src/tests/mod.rs +++ b/atoma-inference/src/tests/mod.rs @@ -2,11 +2,13 @@ use crate::models::{config::ModelConfig, types::ModelType, ModelError, ModelTrai use std::{path::PathBuf, time::Duration}; mod prompts; +use atoma_types::Text2TextPromptParams; use prompts::PROMPTS; +use serde::Serialize; use std::{collections::HashMap, sync::mpsc}; -use atoma_types::Request; +use atoma_types::{PromptParams, Request}; use futures::{stream::FuturesUnordered, StreamExt}; use reqwest::Client; use serde_json::json; @@ -24,9 +26,24 @@ struct TestModel { duration: Duration, } +#[derive(Debug, Serialize)] +struct MockInputOutput { + id: u64, +} + +impl TryFrom for MockInputOutput { + type Error = ModelError; + + fn try_from(value: PromptParams) -> Result { + Ok(Self { + id: value.into_text2text_prompt_params().unwrap().max_tokens(), + }) + } +} + impl ModelTrait for TestModel { - type Input = Value; - type Output = Value; + type Input = MockInputOutput; + type Output = MockInputOutput; type LoadData = Duration; fn fetch( @@ -51,7 +68,7 @@ impl ModelTrait for TestModel { fn run(&mut self, input: Self::Input) -> Result { std::thread::sleep(self.duration); println!( - "Finished waiting time for {:?} and input = {}", + "Finished waiting time for {:?} and input = {:?}", self.duration, input ); Ok(input) @@ -100,14 +117,27 @@ async fn test_mock_model_thread() { for i in 0..NUM_REQUESTS { for sender in model_thread_dispatcher.model_senders.values() { let (response_sender, response_receiver) = oneshot::channel(); - let request = Request::new(vec![0], vec![], json!(i)); + let max_tokens = i as u64; + let prompt_params = PromptParams::Text2TextPromptParams(Text2TextPromptParams::new( + "".to_string(), + "".to_string(), + 0.0, + 1, + 1.0, + 0, + max_tokens, + Some(0), + Some(1.0), + )); + let request = Request::new(vec![0], vec![], prompt_params); let command = ModelThreadCommand { request: request.clone(), sender: response_sender, }; sender.send(command).expect("Failed to send command"); responses.push(response_receiver); - should_be_received_responses.push(request.body().as_u64().unwrap()); + should_be_received_responses + .push(MockInputOutput::try_from(request.params()).unwrap().id); } } diff --git a/atoma-node/src/atoma_node.rs b/atoma-node/src/atoma_node.rs index e30154d5..df095242 100644 --- a/atoma-node/src/atoma_node.rs +++ b/atoma-node/src/atoma_node.rs @@ -10,7 +10,10 @@ use atoma_sui::subscriber::{SuiSubscriber, SuiSubscriberError}; use atoma_types::{Request, Response}; use thiserror::Error; use tokio::{ - sync::{mpsc, mpsc::Receiver, oneshot}, + sync::{ + mpsc::{self, Receiver}, + oneshot, + }, task::JoinHandle, }; use tracing::info; @@ -19,10 +22,10 @@ const ATOMA_OUTPUT_MANAGER_FIREBASE_URL: &str = "https://atoma-demo-default-rtdb const CHANNEL_SIZE: usize = 32; pub struct AtomaNode { - pub model_service_handle: JoinHandle>, - pub sui_subscriber_handle: JoinHandle>, pub atoma_sui_client_handle: JoinHandle>, pub atoma_output_manager_handle: JoinHandle>, + pub model_service_handle: JoinHandle>, + pub sui_subscriber_handle: JoinHandle>, } impl AtomaNode { diff --git a/atoma-types/src/lib.rs b/atoma-types/src/lib.rs index 61ae1eee..894b2ad4 100644 --- a/atoma-types/src/lib.rs +++ b/atoma-types/src/lib.rs @@ -1,6 +1,6 @@ use anyhow::{anyhow, Error, Result}; use serde::{Deserialize, Serialize}; -use serde_json::{json, Value}; +use serde_json::Value; pub type Digest = [u8; 32]; pub type SmallId = u64; @@ -19,16 +19,15 @@ pub struct Request { id: Vec, #[serde(rename(deserialize = "nodes"))] sampled_nodes: Vec, - #[serde(rename(deserialize = "params"))] - body: Value, + params: PromptParams, } impl Request { - pub fn new(id: Vec, sampled_nodes: Vec, body: Value) -> Self { + pub fn new(id: Vec, sampled_nodes: Vec, params: PromptParams) -> Self { Self { id, sampled_nodes, - body, + params, } } @@ -37,15 +36,15 @@ impl Request { } pub fn model(&self) -> String { - self.body["model"].as_str().unwrap().to_string() + self.params.model() } pub fn sampled_nodes(&self) -> Vec { self.sampled_nodes.clone() } - pub fn body(&self) -> Value { - self.body.clone() + pub fn params(&self) -> PromptParams { + self.params.clone() } } @@ -63,13 +62,293 @@ impl TryFrom for Request { .as_array() .unwrap() .iter() - .map(|v| parse_u64(&v["inner"])) + .map(|v| utils::parse_u64(&v["inner"])) .collect::>>()?; - let body = parse_body(value["params"].clone())?; + let body = PromptParams::try_from(value["params"].clone())?; Ok(Request::new(id, sampled_nodes, body)) } } +/// Enum encapsulating possible modal prompt params. Including both +/// - Text to text prompt parameters; +/// - Text to image prompt parameters. +#[derive(Clone, Debug, Deserialize)] +pub enum PromptParams { + Text2TextPromptParams(Text2TextPromptParams), + Text2ImagePromptParams(Text2ImagePromptParams), +} + +impl PromptParams { + pub fn model(&self) -> String { + match self { + Self::Text2ImagePromptParams(p) => p.model(), + Self::Text2TextPromptParams(p) => p.model(), + } + } + + /// Extracts a `Text2TextPromptParams` from a `PromptParams` enum, or None + /// if `PromptParams` does not correspond to `PromptParams::Text2TextPromptParams` + pub fn into_text2text_prompt_params(self) -> Option { + match self { + Self::Text2TextPromptParams(p) => Some(p), + Self::Text2ImagePromptParams(_) => None, + } + } + + // Extracts a `Text2ImagePromptParams` from a `PromptParams` enum, or None + /// if `PromptParams` does not correspond to `PromptParams::Text2ImagePromptParams` + pub fn into_text2image_prompt_params(self) -> Option { + match self { + Self::Text2ImagePromptParams(p) => Some(p), + Self::Text2TextPromptParams(_) => None, + } + } +} + +impl TryFrom for PromptParams { + type Error = Error; + + fn try_from(value: Value) -> Result { + if value["temperature"].is_null() { + Ok(Self::Text2ImagePromptParams( + Text2ImagePromptParams::try_from(value)?, + )) + } else { + Ok(Self::Text2TextPromptParams( + Text2TextPromptParams::try_from(value)?, + )) + } + } +} + +/// Text to text prompt parameters. It includes: +/// - prompt: Prompt to be passed to model as input; +/// - model: Name of the model; +/// - temperature: parameter to control creativity of model +/// - random_seed: seed parameter for sampling +/// - repeat penalty: parameter to penalize token repetition (it should be >= 1.0) +/// - repeat last n: parameter to penalize last `n` token repetition +/// - top_k: parameter controlling `k` top tokens for sampling +/// - top_p: parameter controlling probabilities for top tokens +#[derive(Clone, Debug, Deserialize)] +pub struct Text2TextPromptParams { + prompt: String, + model: String, + temperature: f64, + random_seed: u64, + repeat_penalty: f32, + repeat_last_n: u64, + max_tokens: u64, + top_k: Option, + top_p: Option, +} + +impl Text2TextPromptParams { + #[allow(clippy::too_many_arguments)] + pub fn new( + prompt: String, + model: String, + temperature: f64, + random_seed: u64, + repeat_penalty: f32, + repeat_last_n: u64, + max_tokens: u64, + top_k: Option, + top_p: Option, + ) -> Self { + Self { + prompt, + model, + temperature, + random_seed, + repeat_penalty, + repeat_last_n, + max_tokens, + top_k, + top_p, + } + } + + pub fn prompt(&self) -> String { + self.prompt.clone() + } + + pub fn model(&self) -> String { + self.model.clone() + } + + pub fn temperature(&self) -> f64 { + self.temperature + } + + pub fn random_seed(&self) -> u64 { + self.random_seed + } + + pub fn repeat_penalty(&self) -> f32 { + self.repeat_penalty + } + + pub fn repeat_last_n(&self) -> u64 { + self.repeat_last_n + } + + pub fn max_tokens(&self) -> u64 { + self.max_tokens + } + + pub fn top_k(&self) -> Option { + self.top_k + } + + pub fn top_p(&self) -> Option { + self.top_p + } +} + +impl TryFrom for Text2TextPromptParams { + type Error = Error; + + fn try_from(value: Value) -> Result { + Ok(Self { + prompt: utils::parse_str(&value["prompt"])?, + model: utils::parse_str(&value["model"])?, + temperature: utils::parse_f32_from_le_bytes(&value["temperature"])? as f64, + random_seed: utils::parse_u64(&value["random_seed"])?, + repeat_penalty: utils::parse_f32_from_le_bytes(&value["repeat_penalty"])?, + repeat_last_n: utils::parse_u64(&value["repeat_last_n"])?, + max_tokens: utils::parse_u64(&value["max_tokens"])?, + top_k: Some(utils::parse_u64(&value["top_k"])?), + top_p: Some(utils::parse_f32_from_le_bytes(&value["top_p"])? as f64), + }) + } +} + +#[derive(Clone, Debug, Deserialize)] +/// Text to image prompt parameters. It includes: +/// - prompt: prompt to be passed to model as input; +/// - model: name of the model; +/// - uncond_prompt: unconditional prompt; +/// - height: output image height; +/// - width: output image width; +/// - n_steps: The number of steps to run the diffusion for; +/// - num_samples: number of samples to generate; +/// - guidance_scale: +/// - img2img: generate new AI images from an input image and text prompt. +/// The output image will follow the color and composition of the input image; +/// - img2img_strength: the strength, indicates how much to transform the initial image. The +/// value must be between 0 and 1, a value of 1 discards the initial image +/// information; +/// - random_seed: the seed to use when generating random samples. +pub struct Text2ImagePromptParams { + prompt: String, + model: String, + uncond_prompt: String, + height: Option, + width: Option, + n_steps: Option, + num_samples: u64, + guidance_scale: Option, + img2img: Option, + img2img_strength: f64, + random_seed: Option, +} + +impl Text2ImagePromptParams { + #[allow(clippy::too_many_arguments)] + pub fn new( + prompt: String, + model: String, + uncond_prompt: String, + height: Option, + width: Option, + n_steps: Option, + num_samples: u64, + guidance_scale: Option, + img2img: Option, + img2img_strength: f64, + random_seed: Option, + ) -> Self { + Self { + prompt, + model, + uncond_prompt, + height, + width, + n_steps, + num_samples, + guidance_scale, + img2img, + img2img_strength, + random_seed, + } + } + + pub fn prompt(&self) -> String { + self.prompt.clone() + } + + pub fn model(&self) -> String { + self.model.clone() + } + + pub fn uncond_prompt(&self) -> String { + self.uncond_prompt.clone() + } + + pub fn height(&self) -> Option { + self.height + } + + pub fn width(&self) -> Option { + self.width + } + + pub fn n_steps(&self) -> Option { + self.n_steps + } + + pub fn num_samples(&self) -> u64 { + self.num_samples + } + + pub fn guidance_scale(&self) -> Option { + self.guidance_scale + } + + pub fn img2img(&self) -> Option { + self.img2img.clone() + } + + pub fn img2img_strength(&self) -> f64 { + self.img2img_strength + } + + pub fn random_seed(&self) -> Option { + self.random_seed + } +} + +impl TryFrom for Text2ImagePromptParams { + type Error = Error; + + fn try_from(value: Value) -> Result { + Ok(Self { + prompt: utils::parse_str(&value["prompt"])?, + model: utils::parse_str(&value["model"])?, + uncond_prompt: utils::parse_str(&value["uncond_prompt"])?, + random_seed: Some(utils::parse_u64(&value["random_seed"])?), + height: Some(utils::parse_u64(&value["height"])?), + width: Some(utils::parse_u64(&value["width"])?), + n_steps: Some(utils::parse_u64(&value["n_steps"])?), + num_samples: utils::parse_u64(&value["num_samples"])?, + guidance_scale: Some(utils::parse_f32_from_le_bytes(&value["guidance_scale"])? as f64), + img2img: Some(utils::parse_str(&value["img2img"])?), + img2img_strength: utils::parse_f32_from_le_bytes(&value["img2img2_strength"])? as f64, + }) + } +} + /// Represents a response object containing information about a response, including an ID, sampled nodes, and the response data. /// /// Fields: @@ -105,44 +384,39 @@ impl Response { } } -/// Parses the body of a JSON value. This JSON value is supposed to be obtained -/// from a Sui `Text2TextPromptEvent`, -/// see https://github.com/atoma-network/atoma-contracts/blob/main/sui/packages/atoma/sources/gate.move#L28 -fn parse_body(json: Value) -> Result { - let output = json!({ - "max_tokens": parse_u64(&json["max_tokens"])?, - "model": json["model"], - "prompt": json["prompt"], - "random_seed": parse_u64(&json["random_seed"])?, - "repeat_last_n": parse_u64(&json["repeat_last_n"])?, - "repeat_penalty": parse_f32_from_le_bytes(&json["repeat_penalty"])?, - "temperature": parse_f32_from_le_bytes(&json["temperature"])?, - "top_k": parse_u64(&json["top_k"])?, - "top_p": parse_f32_from_le_bytes(&json["top_p"])?, - }); - Ok(output) -} +mod utils { + use super::*; -/// Parses an appropriate JSON value, from a number (represented as a `u32`) to a `f32` type, by -/// representing the extracted u32 value into little endian byte representation, and then applying `f32::from_le_bytes`. -/// See https://github.com/atoma-network/atoma-contracts/blob/main/sui/packages/atoma/sources/gate.move#L26 -fn parse_f32_from_le_bytes(value: &Value) -> Result { - let u32_value: u32 = value - .as_u64() - .ok_or(anyhow!( - "Failed to extract `f32` little endian bytes representation" - ))? - .try_into()?; - let f32_le_bytes = u32_value.to_le_bytes(); - Ok(f32::from_le_bytes(f32_le_bytes)) -} + /// Parses an appropriate JSON value, from a number (represented as a `u32`) to a `f32` type, by + /// representing the extracted u32 value into little endian byte representation, and then applying `f32::from_le_bytes`. + /// See https://github.com/atoma-network/atoma-contracts/blob/main/sui/packages/atoma/sources/gate.move#L26 + pub(crate) fn parse_f32_from_le_bytes(value: &Value) -> Result { + let u32_value: u32 = value + .as_u64() + .ok_or(anyhow!( + "Failed to extract `f32` little endian bytes representation" + ))? + .try_into()?; + let f32_le_bytes = u32_value.to_le_bytes(); + Ok(f32::from_le_bytes(f32_le_bytes)) + } -/// Parses an appropriate JSON value, representing a `u64` number, from a Sui -/// `Text2TextPromptEvent` `u64` fields. -fn parse_u64(value: &Value) -> Result { - value - .as_str() - .ok_or(anyhow!("Failed to extract `u64` number"))? - .parse::() - .map_err(|e| anyhow!("Failed to parse `u64` from string, with error: {e}")) + /// Parses an appropriate JSON value, representing a `u64` number, from a Sui + /// `Text2TextPromptEvent` `u64` fields. + pub(crate) fn parse_u64(value: &Value) -> Result { + value + .as_str() + .ok_or(anyhow!("Failed to extract `u64` number"))? + .parse::() + .map_err(|e| anyhow!("Failed to parse `u64` from string, with error: {e}")) + } + + /// Parses an appropriate JSON value, representing a `String` value, from a Sui + /// `Text2TextPromptEvent` `String` fields. + pub(crate) fn parse_str(value: &Value) -> Result { + Ok(value + .as_str() + .ok_or(anyhow!("Failed to extract `String` from JSON value"))? + .to_string()) + } }