diff --git a/.gitignore b/.gitignore index 000dd564..a25c2384 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ target/ /models/ inference.toml sui_subscriber.toml +sui_client.toml diff --git a/Cargo.toml b/Cargo.toml index 0e23548a..66d4c9b9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,7 @@ dotenv = "0.15.0" ethers = "2.0.14" futures = "0.3.30" futures-util = "0.3.30" +hex = "0.4.3" hf-hub = "0.3.2" image = { version = "0.25.0", default-features = false, features = ["jpeg", "png"] } rand = "0.8.5" diff --git a/atoma-client/src/client.rs b/atoma-client/src/client.rs index ba28f58f..3779eb3d 100644 --- a/atoma-client/src/client.rs +++ b/atoma-client/src/client.rs @@ -1,91 +1,77 @@ -use std::{path::Path, str::FromStr, time::Duration}; +use std::path::Path; use atoma_crypto::{calculate_commitment, Blake2b}; -use atoma_types::Response; -use sui_keys::keystore::AccountKeystore; +use atoma_types::{Response, SmallId}; use sui_sdk::{ json::SuiJsonValue, types::{ - base_types::{ObjectID, ObjectIDParseError, SuiAddress}, - crypto::Signature, + base_types::{ObjectIDParseError, SuiAddress}, digests::TransactionDigest, }, wallet_context::WalletContext, }; use thiserror::Error; use tokio::sync::mpsc; -use tracing::info; +use tracing::{debug, info}; use crate::config::AtomaSuiClientConfig; const GAS_BUDGET: u64 = 5_000_000; // 0.005 SUI -const PACKAGE_ID: &str = ""; -const MODULE_ID: &str = ""; -const METHOD: &str = "command"; +const MODULE_ID: &str = "settlement"; +const METHOD: &str = "submit_commitment"; pub struct AtomaSuiClient { - node_id: u64, address: SuiAddress, + config: AtomaSuiClientConfig, wallet_ctx: WalletContext, response_receiver: mpsc::Receiver, } impl AtomaSuiClient { - pub fn new>( - node_id: u64, - config_path: P, - request_timeout: Option, - max_concurrent_requests: Option, + pub fn new_from_config( + config: AtomaSuiClientConfig, response_receiver: mpsc::Receiver, ) -> Result { info!("Initializing Sui wallet.."); let mut wallet_ctx = WalletContext::new( - config_path.as_ref(), - request_timeout, - max_concurrent_requests, + config.config_path().as_ref(), + Some(config.request_timeout()), + Some(config.max_concurrent_requests()), )?; let active_address = wallet_ctx.active_address()?; info!("Set Sui client, with active address: {}", active_address); Ok(Self { - node_id, address: active_address, + config, wallet_ctx, response_receiver, }) } - pub fn new_from_config>( - node_id: u64, + pub fn new_from_config_file>( config_path: P, response_receiver: mpsc::Receiver, ) -> Result { let config = AtomaSuiClientConfig::from_file_path(config_path); - let config_path = config.config_path(); - let request_timeout = config.request_timeout(); - let max_concurrent_requests = config.max_concurrent_requests(); - - Self::new( - node_id, - config_path, - Some(request_timeout), - Some(max_concurrent_requests), - response_receiver, - ) + Self::new_from_config(config, response_receiver) } - fn get_index(&self, sampled_nodes: Vec) -> Result<(usize, usize), AtomaSuiClientError> { + fn get_index( + &self, + sampled_nodes: Vec, + ) -> Result<(usize, usize), AtomaSuiClientError> { let num_leaves = sampled_nodes.len(); let index = sampled_nodes .iter() - .position(|nid| nid == &self.node_id) + .position(|nid| nid == &self.config.small_id()) .ok_or(AtomaSuiClientError::InvalidSampledNode)?; Ok((index, num_leaves)) } fn get_data(&self, data: serde_json::Value) -> Result, AtomaSuiClientError> { // TODO: rework this when responses get same structure - let data = match data.as_str() { + let data = match data["text"].as_str() { Some(text) => text.as_bytes().to_owned(), None => { if let Some(array) = data.as_array() { @@ -101,17 +87,6 @@ impl AtomaSuiClient { Ok(data) } - fn sign_root_commitment( - &self, - merkle_root: [u8; 32], - ) -> Result { - self.wallet_ctx - .config - .keystore - .sign_hashed(&self.address, merkle_root.as_slice()) - .map_err(|e| AtomaSuiClientError::FailedSignature(e.to_string())) - } - /// Upon receiving a response from the `AtomaNode` service, this method extracts /// the output data and computes a cryptographic commitment. The commitment includes /// the root of an n-ary Merkle Tree built from the output data, represented as a `Vec`, @@ -131,20 +106,21 @@ impl AtomaSuiClient { let data = self.get_data(response.response())?; let (index, num_leaves) = self.get_index(response.sampled_nodes())?; let (root, pre_image) = calculate_commitment::, _>(data, index, num_leaves); - let signature = self.sign_root_commitment(root)?; let client = self.wallet_ctx.get_client().await?; let tx = client .transaction_builder() .move_call( self.address, - ObjectID::from_str(PACKAGE_ID)?, + self.config.package_id(), MODULE_ID, METHOD, vec![], vec![ + SuiJsonValue::from_object_id(self.config.atoma_db_id()), + SuiJsonValue::from_object_id(self.config.node_badge_id()), SuiJsonValue::new(request_id.into())?, - SuiJsonValue::new(signature.as_ref().into())?, + SuiJsonValue::new(root.as_ref().into())?, SuiJsonValue::new(pre_image.as_ref().into())?, ], None, @@ -155,12 +131,13 @@ impl AtomaSuiClient { let tx = self.wallet_ctx.sign_transaction(&tx); let resp = self.wallet_ctx.execute_transaction_must_succeed(tx).await; - + debug!("Submitted transaction with response: {:?}", resp); Ok(resp.digest) } pub async fn run(mut self) -> Result<(), AtomaSuiClientError> { while let Some(response) = self.response_receiver.recv().await { + info!("Received new response: {:?}", response); self.submit_response_commitment(response).await?; } Ok(()) diff --git a/atoma-client/src/config.rs b/atoma-client/src/config.rs index ecb73f75..b2299eeb 100644 --- a/atoma-client/src/config.rs +++ b/atoma-client/src/config.rs @@ -1,13 +1,19 @@ use std::{path::Path, time::Duration}; +use atoma_types::SmallId; use config::Config; use serde::Deserialize; +use sui_sdk::types::base_types::ObjectID; #[derive(Debug, Deserialize)] pub struct AtomaSuiClientConfig { config_path: String, - request_timeout: Duration, + node_badge_id: ObjectID, + small_id: SmallId, + package_id: ObjectID, + atoma_db_id: ObjectID, max_concurrent_requests: u64, + request_timeout: Duration, } impl AtomaSuiClientConfig { @@ -27,11 +33,27 @@ impl AtomaSuiClientConfig { self.config_path.clone() } - pub fn request_timeout(&self) -> Duration { - self.request_timeout + pub fn node_badge_id(&self) -> ObjectID { + self.node_badge_id + } + + pub fn small_id(&self) -> SmallId { + self.small_id + } + + pub fn package_id(&self) -> ObjectID { + self.package_id + } + + pub fn atoma_db_id(&self) -> ObjectID { + self.atoma_db_id } pub fn max_concurrent_requests(&self) -> u64 { self.max_concurrent_requests } + + pub fn request_timeout(&self) -> Duration { + self.request_timeout + } } diff --git a/atoma-event-subscribe/sui/src/config.rs b/atoma-event-subscribe/sui/src/config.rs index 851ea7a9..603c64b4 100644 --- a/atoma-event-subscribe/sui/src/config.rs +++ b/atoma-event-subscribe/sui/src/config.rs @@ -1,5 +1,6 @@ use std::{path::Path, time::Duration}; +use atoma_types::SmallId; use config::Config; use serde::{Deserialize, Serialize}; use sui_sdk::types::base_types::ObjectID; @@ -8,22 +9,25 @@ use sui_sdk::types::base_types::ObjectID; pub struct SuiSubscriberConfig { http_url: String, ws_url: String, - object_id: ObjectID, + package_id: ObjectID, request_timeout: Duration, + small_id: u64, } impl SuiSubscriberConfig { pub fn new( http_url: String, ws_url: String, - object_id: ObjectID, + package_id: ObjectID, request_timeout: Duration, + small_id: u64, ) -> Self { Self { http_url, ws_url, - object_id, + package_id, request_timeout, + small_id, } } @@ -35,14 +39,18 @@ impl SuiSubscriberConfig { self.ws_url.clone() } - pub fn object_id(&self) -> ObjectID { - self.object_id + pub fn package_id(&self) -> ObjectID { + self.package_id } pub fn request_timeout(&self) -> Duration { self.request_timeout } + pub fn small_id(&self) -> SmallId { + self.small_id + } + pub fn from_file_path>(config_file_path: P) -> Self { let builder = Config::builder().add_source(config::File::with_name( config_file_path.as_ref().to_str().unwrap(), @@ -69,10 +77,11 @@ pub mod tests { .parse() .unwrap(), Duration::from_secs(5 * 60), + 0, ); let toml_str = toml::to_string(&config).unwrap(); - let should_be_toml_str = "http_url = \"\"\nws_url = \"\"\nobject_id = \"0x8d97f1cd6ac663735be08d1d2b6d02a159e711586461306ce60a2b7a6a565a9e\"\n\n[request_timeout]\nsecs = 300\nnanos = 0\n"; + let should_be_toml_str = "http_url = \"\"\nws_url = \"\"\npackage_id = \"0x8d97f1cd6ac663735be08d1d2b6d02a159e711586461306ce60a2b7a6a565a9e\"\nsmall_id = 0\n\n[request_timeout]\nsecs = 300\nnanos = 0\n"; assert_eq!(toml_str, should_be_toml_str); } } diff --git a/atoma-event-subscribe/sui/src/main.rs b/atoma-event-subscribe/sui/src/main.rs index c82f3890..cf52591e 100644 --- a/atoma-event-subscribe/sui/src/main.rs +++ b/atoma-event-subscribe/sui/src/main.rs @@ -11,10 +11,10 @@ struct Args { #[arg(long)] pub package_id: String, /// HTTP node's address for Sui client - #[arg(long, default_value = "https://fullnode.mainnet.sui.io:443")] + #[arg(long, default_value = "https://fullnode.testnet.sui.io:443")] pub http_addr: String, /// RPC node's web socket address for Sui client - #[arg(long, default_value = "wss://fullnode.mainnet.sui.io:443")] + #[arg(long, default_value = "wss://fullnode.testnet.sui.io:443")] pub ws_addr: String, } @@ -30,7 +30,7 @@ async fn main() -> Result<(), SuiSubscriberError> { let (event_sender, mut event_receiver) = tokio::sync::mpsc::channel(32); let event_subscriber = SuiSubscriber::new( - 0, + 1, &http_url, Some(&ws_url), package_id, diff --git a/atoma-event-subscribe/sui/src/subscriber.rs b/atoma-event-subscribe/sui/src/subscriber.rs index 950c20c9..92120e08 100644 --- a/atoma-event-subscribe/sui/src/subscriber.rs +++ b/atoma-event-subscribe/sui/src/subscriber.rs @@ -6,13 +6,13 @@ use sui_sdk::types::base_types::{ObjectID, ObjectIDParseError}; use sui_sdk::{SuiClient, SuiClientBuilder}; use thiserror::Error; use tokio::sync::mpsc; -use tracing::{error, info}; +use tracing::{debug, error, info}; use crate::config::SuiSubscriberConfig; -use atoma_types::Request; +use atoma_types::{Request, SmallId}; pub struct SuiSubscriber { - id: u64, + id: SmallId, sui_client: SuiClient, filter: EventFilter, event_sender: mpsc::Sender, @@ -20,10 +20,10 @@ pub struct SuiSubscriber { impl SuiSubscriber { pub async fn new( - id: u64, + id: SmallId, http_url: &str, ws_url: Option<&str>, - object_id: ObjectID, + package_id: ObjectID, event_sender: mpsc::Sender, request_timeout: Option, ) -> Result { @@ -36,7 +36,7 @@ impl SuiSubscriber { } info!("Starting sui client.."); let sui_client = sui_client_builder.build(http_url).await?; - let filter = EventFilter::Package(object_id); + let filter = EventFilter::Package(package_id); Ok(Self { id, sui_client, @@ -46,20 +46,20 @@ impl SuiSubscriber { } pub async fn new_from_config>( - id: u64, config_path: P, event_sender: mpsc::Sender, ) -> Result { let config = SuiSubscriberConfig::from_file_path(config_path); + let small_id = config.small_id(); let http_url = config.http_url(); let ws_url = config.ws_url(); - let object_id = config.object_id(); + let package_id = config.package_id(); let request_timeout = config.request_timeout(); Self::new( - id, + small_id, &http_url, Some(&ws_url), - object_id, + package_id, event_sender, Some(request_timeout), ) @@ -74,9 +74,14 @@ impl SuiSubscriber { match event { Ok(event) => { let event_data = event.parsed_json; - let request = serde_json::from_value::(event_data)?; + debug!("event data: {}", event_data); + let request = Request::try_from(event_data)?; info!("Received new request: {:?}", request); - let request_id = request.id(); + let request_id = request + .id() + .iter() + .map(|b| format!("{:02x}", b)) + .collect::(); let sampled_nodes = request.sampled_nodes(); if sampled_nodes.contains(&self.id) { info!( @@ -107,4 +112,6 @@ pub enum SuiSubscriberError { ObjectIDParseError(#[from] ObjectIDParseError), #[error("Sender error: `{0}`")] SendError(#[from] mpsc::error::SendError), + #[error("Type conversion error: `{0}`")] + TypeConversionError(#[from] anyhow::Error), } diff --git a/atoma-inference/src/models/candle/phi3.rs b/atoma-inference/src/models/candle/phi3.rs index dd5f3ed0..f8bc8746 100644 --- a/atoma-inference/src/models/candle/phi3.rs +++ b/atoma-inference/src/models/candle/phi3.rs @@ -105,7 +105,10 @@ impl ModelTrait for Phi3Model { } fn run(&mut self, input: Self::Input) -> Result { - info!("Running inference on prompt: {}", input.prompt); + info!( + "Running inference on prompt: {}, with inputs = {:?}", + input.prompt, input + ); // clean tokenizer state self.tokenizer.clear(); diff --git a/atoma-inference/src/models/types.rs b/atoma-inference/src/models/types.rs index 36c8b26c..7085dc47 100644 --- a/atoma-inference/src/models/types.rs +++ b/atoma-inference/src/models/types.rs @@ -322,7 +322,7 @@ impl Request for TextRequest { } } -#[derive(Deserialize)] +#[derive(Debug, Deserialize)] pub struct TextModelInput { pub(crate) prompt: String, pub(crate) temperature: f64, diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index 6cc857dd..8f099f4f 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -111,25 +111,13 @@ pub enum ModelServiceError { #[error("Core error: `{0}`")] ModelThreadError(ModelThreadError), #[error("Api error: `{0}`")] - ApiError(ApiError), + ApiError(#[from] ApiError), #[error("Candle error: `{0}`")] - CandleError(CandleError), + CandleError(#[from] CandleError), #[error("Sender error: `{0}`")] SendError(String), } -impl From for ModelServiceError { - fn from(error: ApiError) -> Self { - Self::ApiError(error) - } -} - -impl From for ModelServiceError { - fn from(error: CandleError) -> Self { - Self::CandleError(error) - } -} - #[cfg(test)] mod tests { use std::io::Write; diff --git a/atoma-inference/src/tests/mod.rs b/atoma-inference/src/tests/mod.rs index 54899c0a..b38ba8f6 100644 --- a/atoma-inference/src/tests/mod.rs +++ b/atoma-inference/src/tests/mod.rs @@ -100,7 +100,7 @@ async fn test_mock_model_thread() { for i in 0..NUM_REQUESTS { for sender in model_thread_dispatcher.model_senders.values() { let (response_sender, response_receiver) = oneshot::channel(); - let request = Request::new(0, vec![], "".to_string(), json!(i)); + let request = Request::new(vec![0], vec![], json!(i)); let command = ModelThreadCommand { request: request.clone(), sender: response_sender, diff --git a/atoma-node/src/atoma_node.rs b/atoma-node/src/atoma_node.rs index 3d4fb2bb..98018a41 100644 --- a/atoma-node/src/atoma_node.rs +++ b/atoma-node/src/atoma_node.rs @@ -24,7 +24,6 @@ pub struct AtomaNode { impl AtomaNode { pub async fn start

( - node_id: u64, atoma_sui_client_config_path: P, model_config_path: P, sui_subscriber_path: P, @@ -55,8 +54,7 @@ impl AtomaNode { let sui_subscriber_handle = tokio::spawn(async move { info!("Starting Sui subscriber service.."); let sui_event_subscriber = - SuiSubscriber::new_from_config(node_id, sui_subscriber_path, subscriber_req_tx) - .await?; + SuiSubscriber::new_from_config(sui_subscriber_path, subscriber_req_tx).await?; sui_event_subscriber .subscribe() .await @@ -65,8 +63,7 @@ impl AtomaNode { let atoma_sui_client_handle = tokio::spawn(async move { info!("Starting Atoma Sui client service.."); - let atoma_sui_client = AtomaSuiClient::new_from_config( - node_id, + let atoma_sui_client = AtomaSuiClient::new_from_config_file( atoma_sui_client_config_path, atoma_node_resp_rx, )?; diff --git a/atoma-node/src/main.rs b/atoma-node/src/main.rs index 7dae873c..650defcb 100644 --- a/atoma-node/src/main.rs +++ b/atoma-node/src/main.rs @@ -9,8 +9,6 @@ struct Args { #[arg(long)] atoma_sui_client_config_path: String, #[arg(long)] - node_id: u64, - #[arg(long)] model_config_path: String, #[arg(long)] sui_subscriber_path: String, @@ -22,13 +20,11 @@ async fn main() -> Result<(), AtomaNodeError> { let args = Args::parse(); let atoma_sui_client_config_path = args.atoma_sui_client_config_path; - let node_id = args.node_id; let model_config_path = args.model_config_path; let sui_subscriber_path = args.sui_subscriber_path; let (_, json_rpc_server_rx) = mpsc::channel(CHANNEL_BUFFER); let _atoma_node = AtomaNode::start( - node_id, atoma_sui_client_config_path, model_config_path, sui_subscriber_path, @@ -36,5 +32,5 @@ async fn main() -> Result<(), AtomaNodeError> { ) .await?; - Ok(()) + loop {} } diff --git a/atoma-types/Cargo.toml b/atoma-types/Cargo.toml index 7361c834..cd2127a8 100644 --- a/atoma-types/Cargo.toml +++ b/atoma-types/Cargo.toml @@ -6,5 +6,7 @@ edition.workspace = true # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +anyhow.workspace = true +hex.workspace = true serde.workspace = true serde_json.workspace = true diff --git a/atoma-types/src/lib.rs b/atoma-types/src/lib.rs index 0dc63279..895e340f 100644 --- a/atoma-types/src/lib.rs +++ b/atoma-types/src/lib.rs @@ -1,34 +1,38 @@ +use anyhow::{anyhow, Error, Result}; use serde::Deserialize; -use serde_json::Value; +use serde_json::{json, Value}; + +pub type SmallId = u64; #[derive(Clone, Debug, Deserialize)] pub struct Request { - id: usize, - sampled_nodes: Vec, - model: String, + #[serde(rename(deserialize = "ticket_id"))] + id: Vec, + #[serde(rename(deserialize = "nodes"))] + sampled_nodes: Vec, + #[serde(rename(deserialize = "params"))] body: Value, } impl Request { - pub fn new(id: usize, sampled_nodes: Vec, model: String, body: Value) -> Self { + pub fn new(id: Vec, sampled_nodes: Vec, body: Value) -> Self { Self { id, sampled_nodes, - model, body, } } - pub fn id(&self) -> usize { - self.id + pub fn id(&self) -> Vec { + self.id.clone() } - pub fn sampled_nodes(&self) -> Vec { - self.sampled_nodes.clone() + pub fn model(&self) -> String { + self.body["model"].as_str().unwrap().to_string() } - pub fn model(&self) -> String { - self.model.clone() + pub fn sampled_nodes(&self) -> Vec { + self.sampled_nodes.clone() } pub fn body(&self) -> Value { @@ -36,15 +40,73 @@ impl Request { } } +impl TryFrom for Request { + type Error = Error; + + fn try_from(value: Value) -> Result { + let id = hex::decode( + value["ticket_id"] + .as_str() + .ok_or(anyhow!("Failed to decode hex string for request ticket_id"))? + .replace("0x", ""), + )?; + let sampled_nodes = value["nodes"] + .as_array() + .unwrap() + .iter() + .map(|v| parse_u64(&v["inner"])) + .collect::>>()?; + let body = parse_body(value["params"].clone())?; + Ok(Request::new(id, sampled_nodes, body)) + } +} + +fn parse_body(json: Value) -> Result { + let output = json!({ + "max_tokens": parse_u64(&json["max_tokens"])?, + "model": json["model"], + "prompt": json["prompt"], + "random_seed": parse_u64(&json["random_seed"])?, + "repeat_last_n": parse_u64(&json["repeat_last_n"])?, + "repeat_penalty": parse_f32_from_le_bytes(&json["repeat_penalty"])?, + "temperature": parse_f32_from_le_bytes(&json["temperature"])?, + "top_k": parse_u64(&json["top_k"])?, + "top_p": parse_f32_from_le_bytes(&json["top_p"])?, + }); + Ok(output) +} + +/// Parses an appropriate JSON value, from a number (represented as a `u32`) to a `f32` type, by +/// representing the extracted u32 value into little endian byte representation, and then applying `f32::from_le_bytes`. +/// See https://github.com/atoma-network/atoma-contracts/blob/main/sui/packages/atoma/sources/gate.move#L26 +fn parse_f32_from_le_bytes(value: &Value) -> Result { + let u32_value: u32 = value + .as_u64() + .ok_or(anyhow!( + "Failed to extract `f32` little endian bytes representation" + ))? + .try_into()?; + let f32_le_bytes = u32_value.to_le_bytes(); + Ok(f32::from_le_bytes(f32_le_bytes)) +} + +fn parse_u64(value: &Value) -> Result { + value + .as_str() + .ok_or(anyhow!("Failed to extract `u64` number"))? + .parse::() + .map_err(|e| anyhow!("Failed to parse `u64` from string, with error: {e}")) +} + #[derive(Debug)] pub struct Response { - id: usize, - sampled_nodes: Vec, + id: Vec, + sampled_nodes: Vec, response: Value, } impl Response { - pub fn new(id: usize, sampled_nodes: Vec, response: Value) -> Self { + pub fn new(id: Vec, sampled_nodes: Vec, response: Value) -> Self { Self { id, sampled_nodes, @@ -52,11 +114,11 @@ impl Response { } } - pub fn id(&self) -> usize { - self.id + pub fn id(&self) -> Vec { + self.id.clone() } - pub fn sampled_nodes(&self) -> Vec { + pub fn sampled_nodes(&self) -> Vec { self.sampled_nodes.clone() }