Skip to content

Commit

Permalink
feat: refactor code for smart contract integration (#58)
Browse files Browse the repository at this point in the history
* first commit

* refactor atoma sui client on atoma-node

* integrate config directly into sui atoma client

* add small id to event subscriber config

* refactor deserialization of event data, in the case of text prompts

* fmt

* minor bug refactor

* simplify code

* address PR comments

* fmt

* address further PR comments
  • Loading branch information
jorgeantonio21 authored Apr 30, 2024
1 parent c72c8ab commit 32164b8
Show file tree
Hide file tree
Showing 15 changed files with 184 additions and 119 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ target/
/models/
inference.toml
sui_subscriber.toml
sui_client.toml
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ dotenv = "0.15.0"
ethers = "2.0.14"
futures = "0.3.30"
futures-util = "0.3.30"
hex = "0.4.3"
hf-hub = "0.3.2"
image = { version = "0.25.0", default-features = false, features = ["jpeg", "png"] }
rand = "0.8.5"
Expand Down
77 changes: 27 additions & 50 deletions atoma-client/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,91 +1,77 @@
use std::{path::Path, str::FromStr, time::Duration};
use std::path::Path;

use atoma_crypto::{calculate_commitment, Blake2b};
use atoma_types::Response;
use sui_keys::keystore::AccountKeystore;
use atoma_types::{Response, SmallId};
use sui_sdk::{
json::SuiJsonValue,
types::{
base_types::{ObjectID, ObjectIDParseError, SuiAddress},
crypto::Signature,
base_types::{ObjectIDParseError, SuiAddress},
digests::TransactionDigest,
},
wallet_context::WalletContext,
};
use thiserror::Error;
use tokio::sync::mpsc;
use tracing::info;
use tracing::{debug, info};

use crate::config::AtomaSuiClientConfig;

const GAS_BUDGET: u64 = 5_000_000; // 0.005 SUI

const PACKAGE_ID: &str = "<TODO>";
const MODULE_ID: &str = "";
const METHOD: &str = "command";
const MODULE_ID: &str = "settlement";
const METHOD: &str = "submit_commitment";

pub struct AtomaSuiClient {
node_id: u64,
address: SuiAddress,
config: AtomaSuiClientConfig,
wallet_ctx: WalletContext,
response_receiver: mpsc::Receiver<Response>,
}

impl AtomaSuiClient {
pub fn new<P: AsRef<Path>>(
node_id: u64,
config_path: P,
request_timeout: Option<Duration>,
max_concurrent_requests: Option<u64>,
pub fn new_from_config(
config: AtomaSuiClientConfig,
response_receiver: mpsc::Receiver<Response>,
) -> Result<Self, AtomaSuiClientError> {
info!("Initializing Sui wallet..");
let mut wallet_ctx = WalletContext::new(
config_path.as_ref(),
request_timeout,
max_concurrent_requests,
config.config_path().as_ref(),
Some(config.request_timeout()),
Some(config.max_concurrent_requests()),
)?;
let active_address = wallet_ctx.active_address()?;
info!("Set Sui client, with active address: {}", active_address);
Ok(Self {
node_id,
address: active_address,
config,
wallet_ctx,
response_receiver,
})
}

pub fn new_from_config<P: AsRef<Path>>(
node_id: u64,
pub fn new_from_config_file<P: AsRef<Path>>(
config_path: P,
response_receiver: mpsc::Receiver<Response>,
) -> Result<Self, AtomaSuiClientError> {
let config = AtomaSuiClientConfig::from_file_path(config_path);
let config_path = config.config_path();
let request_timeout = config.request_timeout();
let max_concurrent_requests = config.max_concurrent_requests();

Self::new(
node_id,
config_path,
Some(request_timeout),
Some(max_concurrent_requests),
response_receiver,
)
Self::new_from_config(config, response_receiver)
}

fn get_index(&self, sampled_nodes: Vec<u64>) -> Result<(usize, usize), AtomaSuiClientError> {
fn get_index(
&self,
sampled_nodes: Vec<SmallId>,
) -> Result<(usize, usize), AtomaSuiClientError> {
let num_leaves = sampled_nodes.len();
let index = sampled_nodes
.iter()
.position(|nid| nid == &self.node_id)
.position(|nid| nid == &self.config.small_id())
.ok_or(AtomaSuiClientError::InvalidSampledNode)?;
Ok((index, num_leaves))
}

fn get_data(&self, data: serde_json::Value) -> Result<Vec<u8>, AtomaSuiClientError> {
// TODO: rework this when responses get same structure
let data = match data.as_str() {
let data = match data["text"].as_str() {
Some(text) => text.as_bytes().to_owned(),
None => {
if let Some(array) = data.as_array() {
Expand All @@ -101,17 +87,6 @@ impl AtomaSuiClient {
Ok(data)
}

fn sign_root_commitment(
&self,
merkle_root: [u8; 32],
) -> Result<Signature, AtomaSuiClientError> {
self.wallet_ctx
.config
.keystore
.sign_hashed(&self.address, merkle_root.as_slice())
.map_err(|e| AtomaSuiClientError::FailedSignature(e.to_string()))
}

/// Upon receiving a response from the `AtomaNode` service, this method extracts
/// the output data and computes a cryptographic commitment. The commitment includes
/// the root of an n-ary Merkle Tree built from the output data, represented as a `Vec<u8>`,
Expand All @@ -131,20 +106,21 @@ impl AtomaSuiClient {
let data = self.get_data(response.response())?;
let (index, num_leaves) = self.get_index(response.sampled_nodes())?;
let (root, pre_image) = calculate_commitment::<Blake2b<_>, _>(data, index, num_leaves);
let signature = self.sign_root_commitment(root)?;

let client = self.wallet_ctx.get_client().await?;
let tx = client
.transaction_builder()
.move_call(
self.address,
ObjectID::from_str(PACKAGE_ID)?,
self.config.package_id(),
MODULE_ID,
METHOD,
vec![],
vec![
SuiJsonValue::from_object_id(self.config.atoma_db_id()),
SuiJsonValue::from_object_id(self.config.node_badge_id()),
SuiJsonValue::new(request_id.into())?,
SuiJsonValue::new(signature.as_ref().into())?,
SuiJsonValue::new(root.as_ref().into())?,
SuiJsonValue::new(pre_image.as_ref().into())?,
],
None,
Expand All @@ -155,12 +131,13 @@ impl AtomaSuiClient {

let tx = self.wallet_ctx.sign_transaction(&tx);
let resp = self.wallet_ctx.execute_transaction_must_succeed(tx).await;

debug!("Submitted transaction with response: {:?}", resp);
Ok(resp.digest)
}

pub async fn run(mut self) -> Result<(), AtomaSuiClientError> {
while let Some(response) = self.response_receiver.recv().await {
info!("Received new response: {:?}", response);
self.submit_response_commitment(response).await?;
}
Ok(())
Expand Down
28 changes: 25 additions & 3 deletions atoma-client/src/config.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
use std::{path::Path, time::Duration};

use atoma_types::SmallId;
use config::Config;
use serde::Deserialize;
use sui_sdk::types::base_types::ObjectID;

#[derive(Debug, Deserialize)]
pub struct AtomaSuiClientConfig {
config_path: String,
request_timeout: Duration,
node_badge_id: ObjectID,
small_id: SmallId,
package_id: ObjectID,
atoma_db_id: ObjectID,
max_concurrent_requests: u64,
request_timeout: Duration,
}

impl AtomaSuiClientConfig {
Expand All @@ -27,11 +33,27 @@ impl AtomaSuiClientConfig {
self.config_path.clone()
}

pub fn request_timeout(&self) -> Duration {
self.request_timeout
pub fn node_badge_id(&self) -> ObjectID {
self.node_badge_id
}

pub fn small_id(&self) -> SmallId {
self.small_id
}

pub fn package_id(&self) -> ObjectID {
self.package_id
}

pub fn atoma_db_id(&self) -> ObjectID {
self.atoma_db_id
}

pub fn max_concurrent_requests(&self) -> u64 {
self.max_concurrent_requests
}

pub fn request_timeout(&self) -> Duration {
self.request_timeout
}
}
21 changes: 15 additions & 6 deletions atoma-event-subscribe/sui/src/config.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{path::Path, time::Duration};

use atoma_types::SmallId;
use config::Config;
use serde::{Deserialize, Serialize};
use sui_sdk::types::base_types::ObjectID;
Expand All @@ -8,22 +9,25 @@ use sui_sdk::types::base_types::ObjectID;
pub struct SuiSubscriberConfig {
http_url: String,
ws_url: String,
object_id: ObjectID,
package_id: ObjectID,
request_timeout: Duration,
small_id: u64,
}

impl SuiSubscriberConfig {
pub fn new(
http_url: String,
ws_url: String,
object_id: ObjectID,
package_id: ObjectID,
request_timeout: Duration,
small_id: u64,
) -> Self {
Self {
http_url,
ws_url,
object_id,
package_id,
request_timeout,
small_id,
}
}

Expand All @@ -35,14 +39,18 @@ impl SuiSubscriberConfig {
self.ws_url.clone()
}

pub fn object_id(&self) -> ObjectID {
self.object_id
pub fn package_id(&self) -> ObjectID {
self.package_id
}

pub fn request_timeout(&self) -> Duration {
self.request_timeout
}

pub fn small_id(&self) -> SmallId {
self.small_id
}

pub fn from_file_path<P: AsRef<Path>>(config_file_path: P) -> Self {
let builder = Config::builder().add_source(config::File::with_name(
config_file_path.as_ref().to_str().unwrap(),
Expand All @@ -69,10 +77,11 @@ pub mod tests {
.parse()
.unwrap(),
Duration::from_secs(5 * 60),
0,
);

let toml_str = toml::to_string(&config).unwrap();
let should_be_toml_str = "http_url = \"\"\nws_url = \"\"\nobject_id = \"0x8d97f1cd6ac663735be08d1d2b6d02a159e711586461306ce60a2b7a6a565a9e\"\n\n[request_timeout]\nsecs = 300\nnanos = 0\n";
let should_be_toml_str = "http_url = \"\"\nws_url = \"\"\npackage_id = \"0x8d97f1cd6ac663735be08d1d2b6d02a159e711586461306ce60a2b7a6a565a9e\"\nsmall_id = 0\n\n[request_timeout]\nsecs = 300\nnanos = 0\n";
assert_eq!(toml_str, should_be_toml_str);
}
}
6 changes: 3 additions & 3 deletions atoma-event-subscribe/sui/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ struct Args {
#[arg(long)]
pub package_id: String,
/// HTTP node's address for Sui client
#[arg(long, default_value = "https://fullnode.mainnet.sui.io:443")]
#[arg(long, default_value = "https://fullnode.testnet.sui.io:443")]
pub http_addr: String,
/// RPC node's web socket address for Sui client
#[arg(long, default_value = "wss://fullnode.mainnet.sui.io:443")]
#[arg(long, default_value = "wss://fullnode.testnet.sui.io:443")]
pub ws_addr: String,
}

Expand All @@ -30,7 +30,7 @@ async fn main() -> Result<(), SuiSubscriberError> {
let (event_sender, mut event_receiver) = tokio::sync::mpsc::channel(32);

let event_subscriber = SuiSubscriber::new(
0,
1,
&http_url,
Some(&ws_url),
package_id,
Expand Down
Loading

0 comments on commit 32164b8

Please sign in to comment.