diff --git a/Cargo.lock b/Cargo.lock index e2a5abe..ab5c0c9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9451,6 +9451,7 @@ dependencies = [ "test-cluster", "tokio", "tokio-retry", + "tokio-util 0.7.11", "tracing", ] diff --git a/Cargo.toml b/Cargo.toml index 2d33c7c..a9b0432 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,7 @@ telemetry-subscribers = { git = "https://github.com/MystenLabs/sui", branch = "t anyhow = "1.0.75" async-trait = "0.1.51" -axum = {version = "0.6.6", features = ["headers"]} +axum = { version = "0.6.6", features = ["headers"] } bcs = "0.1.6" clap = "4.4.10" chrono = "0.4.19" @@ -53,6 +53,7 @@ tracing = "0.1.40" tokio = { version = "1.36.0", features = ["full"] } tokio-retry = "0.3.0" serde_json = "1.0.108" +tokio-util = "0.7.10" [dev-dependencies] rand = "0.8.5" diff --git a/src/benchmarks/kms_stress.rs b/src/benchmarks/kms_stress.rs index 28d8705..509d6b2 100644 --- a/src/benchmarks/kms_stress.rs +++ b/src/benchmarks/kms_stress.rs @@ -1,13 +1,15 @@ // Copyright (c) Mysten Labs, Inc. // SPDX-License-Identifier: Apache-2.0 -use crate::tx_signer::{SidecarTxSigner, TxSigner}; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::Arc; use std::time::{Duration, Instant}; use sui_types::base_types::{random_object_ref, SuiAddress}; use sui_types::transaction::{ProgrammableTransaction, TransactionData, TransactionKind}; +use crate::tx_signer::sidecar_signer::SidecarTxSigner; +use crate::tx_signer::TxSignerTrait; + pub async fn run_kms_stress_test(kms_url: String, num_tasks: usize) { let signer = SidecarTxSigner::new(kms_url).await; let test_tx_data = TransactionData::new( diff --git a/src/command.rs b/src/command.rs index 1bf957c..001440b 100644 --- a/src/command.rs +++ b/src/command.rs @@ -53,9 +53,9 @@ impl Command { let signer = signer_config.new_signer().await; let storage_metrics = StorageMetrics::new(&prometheus_registry); - let sponsor_address = signer.get_address(); - info!("Sponsor address: {:?}", sponsor_address); - let storage = connect_storage(&gas_pool_config, sponsor_address, storage_metrics).await; + let sponsor_addresses = signer.get_all_addresses(); + info!("Sponsor addresses: {:?}", sponsor_addresses); + let storage = connect_storage(&gas_pool_config, sponsor_addresses, storage_metrics).await; let sui_client = SuiClient::new(&fullnode_url, fullnode_basic_auth).await; let _coin_init_task = if let Some(coin_init_config) = coin_init_config { let task = GasPoolInitializer::start( diff --git a/src/config.rs b/src/config.rs index 88b42c2..46a0168 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,7 +1,9 @@ // Copyright (c) Mysten Labs, Inc. // SPDX-License-Identifier: Apache-2.0 -use crate::tx_signer::{SidecarTxSigner, TestTxSigner, TxSigner}; +use crate::tx_signer::in_memory_signer::InMemoryTxSigner; +use crate::tx_signer::sidecar_signer::SidecarTxSigner; +use crate::tx_signer::{TxSigner, TxSignerTrait}; use serde::{Deserialize, Serialize}; use serde_with::serde_as; use std::net::Ipv4Addr; @@ -82,6 +84,15 @@ impl Default for GasPoolStorageConfig { pub enum TxSignerConfig { Local { keypair: SuiKeyPair }, Sidecar { sidecar_url: String }, + MultiSigner { signers: Vec }, +} + +#[serde_as] +#[derive(Debug, Deserialize, Serialize)] +#[serde(rename_all = "kebab-case")] +pub enum SingleSignerType { + Local { keypair: SuiKeyPair }, + Sidecar { sidecar_url: String }, } impl Default for TxSignerConfig { @@ -94,11 +105,28 @@ impl Default for TxSignerConfig { } impl TxSignerConfig { - pub async fn new_signer(self) -> Arc { - match self { - TxSignerConfig::Local { keypair } => TestTxSigner::new(keypair), - TxSignerConfig::Sidecar { sidecar_url } => SidecarTxSigner::new(sidecar_url).await, - } + pub async fn new_signer(self) -> Arc { + let all_signers: Vec> = match self { + TxSignerConfig::Local { keypair } => vec![InMemoryTxSigner::new(keypair)], + TxSignerConfig::Sidecar { sidecar_url } => { + vec![SidecarTxSigner::new(sidecar_url).await] + } + TxSignerConfig::MultiSigner { signers } => { + let mut all_signers: Vec> = Vec::new(); + for signer_config in signers { + match signer_config { + SingleSignerType::Local { keypair } => { + all_signers.push(InMemoryTxSigner::new(keypair)) + } + SingleSignerType::Sidecar { sidecar_url } => { + all_signers.push(SidecarTxSigner::new(sidecar_url).await) + } + } + } + all_signers + } + }; + TxSigner::new(all_signers) } } diff --git a/src/gas_pool/gas_pool_core.rs b/src/gas_pool/gas_pool_core.rs index a142bc9..df21438 100644 --- a/src/gas_pool/gas_pool_core.rs +++ b/src/gas_pool/gas_pool_core.rs @@ -20,6 +20,7 @@ use sui_types::transaction::{ }; use tap::TapFallible; use tokio::task::JoinHandle; +use tokio_util::sync::CancellationToken; use tracing::{debug, error, info}; use super::gas_usage_cap::GasUsageCap; @@ -28,13 +29,12 @@ const EXPIRATION_JOB_INTERVAL: Duration = Duration::from_secs(1); pub struct GasPoolContainer { inner: Arc, - _coin_unlocker_task: JoinHandle<()>, - // This is always Some. It is None only after the drop method is called. - cancel_sender: Option>, + _coin_unlocker_tasks: Vec>, + cancel: CancellationToken, } pub struct GasPool { - signer: Arc, + signer: Arc, gas_pool_store: Arc, sui_client: SuiClient, metrics: Arc, @@ -43,7 +43,7 @@ pub struct GasPool { impl GasPool { pub async fn new( - signer: Arc, + signer: Arc, gas_pool_store: Arc, sui_client: SuiClient, metrics: Arc, @@ -66,10 +66,10 @@ impl GasPool { ) -> anyhow::Result<(SuiAddress, ReservationID, Vec)> { let cur_time = std::time::Instant::now(); self.gas_usage_cap.check_usage().await?; - let sponsor = self.signer.get_address(); + let sponsor = self.signer.get_one_address(); let (reservation_id, gas_coins) = self .gas_pool_store - .reserve_gas_coins(gas_budget, duration.as_millis() as u64) + .reserve_gas_coins(sponsor, gas_budget, duration.as_millis() as u64) .await?; let elapsed = cur_time.elapsed().as_millis(); self.metrics.reserve_gas_latency_ms.observe(elapsed as u64); @@ -106,7 +106,7 @@ impl GasPool { "Payment coins in transaction: {:?}", payment ); self.gas_pool_store - .ready_for_execution(reservation_id) + .ready_for_execution(sponsor, reservation_id) .await?; debug!(?reservation_id, "Reservation is ready for execution"); @@ -161,7 +161,7 @@ impl GasPool { // Regardless of whether the transaction succeeded, we need to release the coins. // Otherwise, we lose track of them. This is because `ready_for_execution` already takes // the coins out of the pool and will not be covered by the auto-release mechanism. - self.release_gas_coins(updated_coins).await; + self.release_gas_coins(sponsor, updated_coins).await; if smashed_coin_count > 0 { info!( ?reservation_id, @@ -260,11 +260,11 @@ impl GasPool { } /// Release gas coins back to the gas pool, by adding them to the storage. - async fn release_gas_coins(&self, gas_coins: Vec) { + async fn release_gas_coins(&self, sponsor: SuiAddress, gas_coins: Vec) { debug!("Trying to release gas coins: {:?}", gas_coins); retry_forever!(async { self.gas_pool_store - .add_new_coins(gas_coins.clone()) + .add_new_coins(sponsor, gas_coins.clone()) .await .tap_err(|err| error!("Failed to call update_gas_coins on storage: {:?}", err)) }) @@ -274,30 +274,26 @@ impl GasPool { /// Performs an end-to-end flow of reserving gas, signing a transaction, and releasing the gas coins. pub async fn debug_check_health(&self) -> anyhow::Result<()> { let gas_budget = MIST_PER_SUI / 10; - let (_address, _reservation_id, gas_coins) = + let (sender, _reservation_id, gas_coins) = self.reserve_gas(gas_budget, Duration::from_secs(3)).await?; let tx_kind = TransactionKind::ProgrammableTransaction( ProgrammableTransactionBuilder::new().finish(), ); // Since we just want to check the health of the signer, we don't need to actually execute the transaction. - let tx_data = TransactionData::new_with_gas_coins( - tx_kind, - SuiAddress::default(), - gas_coins, - gas_budget, - 0, - ); + let tx_data = + TransactionData::new_with_gas_coins(tx_kind, sender, gas_coins, gas_budget, 0); self.signer.sign_transaction(&tx_data).await?; Ok(()) } async fn start_coin_unlock_task( self: Arc, - mut cancel_receiver: tokio::sync::oneshot::Receiver<()>, + sponsor: SuiAddress, + cancel: CancellationToken, ) -> JoinHandle<()> { tokio::task::spawn(async move { loop { - let expire_results = self.gas_pool_store.expire_coins().await; + let expire_results = self.gas_pool_store.expire_coins(sponsor).await; let unlocked_coins = expire_results.unwrap_or_else(|err| { error!("Failed to call expire_coins to the storage: {:?}", err); vec![] @@ -312,12 +308,12 @@ impl GasPool { .flatten() .collect(); let count = latest_coins.len(); - self.release_gas_coins(latest_coins).await; + self.release_gas_coins(sponsor, latest_coins).await; info!("Released {:?} coins after expiration", count); } tokio::select! { _ = tokio::time::sleep(EXPIRATION_JOB_INTERVAL) => {} - _ = &mut cancel_receiver => { + _ = cancel.cancelled() => { info!("Coin unlocker task is cancelled"); break; } @@ -326,9 +322,9 @@ impl GasPool { }) } - pub async fn query_pool_available_coin_count(&self) -> usize { + pub async fn query_pool_available_coin_count(&self, sponsor: SuiAddress) -> usize { self.gas_pool_store - .get_available_coin_count() + .get_available_coin_count(sponsor) .await .unwrap() } @@ -336,12 +332,13 @@ impl GasPool { impl GasPoolContainer { pub async fn new( - signer: Arc, + signer: Arc, gas_pool_store: Arc, sui_client: SuiClient, gas_usage_daily_cap: u64, metrics: Arc, ) -> Self { + let sponsor_addresses = signer.get_all_addresses(); let inner = GasPool::new( signer, gas_pool_store, @@ -350,13 +347,19 @@ impl GasPoolContainer { Arc::new(GasUsageCap::new(gas_usage_daily_cap)), ) .await; - let (cancel_sender, cancel_receiver) = tokio::sync::oneshot::channel(); - let _coin_unlocker_task = inner.clone().start_coin_unlock_task(cancel_receiver).await; + let cancel = CancellationToken::new(); + + let mut _coin_unlocker_tasks = vec![]; + for sponsor in sponsor_addresses { + let inner = inner.clone(); + let task = inner.start_coin_unlock_task(sponsor, cancel.clone()).await; + _coin_unlocker_tasks.push(task); + } Self { inner, - _coin_unlocker_task, - cancel_sender: Some(cancel_sender), + _coin_unlocker_tasks, + cancel, } } @@ -367,6 +370,6 @@ impl GasPoolContainer { impl Drop for GasPoolContainer { fn drop(&mut self) { - self.cancel_sender.take().unwrap().send(()).unwrap(); + self.cancel.cancel(); } } diff --git a/src/gas_pool/mod.rs b/src/gas_pool/mod.rs index 0b887aa..5769eb0 100644 --- a/src/gas_pool/mod.rs +++ b/src/gas_pool/mod.rs @@ -27,14 +27,14 @@ mod tests { .await .unwrap(); assert_eq!(gas_coins.len(), 3); - assert_eq!(station.query_pool_available_coin_count().await, 7); + assert_eq!(station.query_pool_available_coin_count(sponsor1).await, 7); let (sponsor2, _res_id2, gas_coins) = station .reserve_gas(MIST_PER_SUI * 7, Duration::from_secs(10)) .await .unwrap(); assert_eq!(gas_coins.len(), 7); assert_eq!(sponsor1, sponsor2); - assert_eq!(station.query_pool_available_coin_count().await, 0); + assert_eq!(station.query_pool_available_coin_count(sponsor2).await, 0); assert!(station .reserve_gas(1, Duration::from_secs(10)) .await @@ -55,7 +55,7 @@ mod tests { .await .unwrap(); assert_eq!(gas_coins.len(), 1); - assert_eq!(station.query_pool_available_coin_count().await, 0); + assert_eq!(station.query_pool_available_coin_count(sponsor).await, 0); assert!(station .reserve_gas(1, Duration::from_secs(10)) .await @@ -67,7 +67,7 @@ mod tests { .await .unwrap(); assert!(effects.status().is_ok()); - assert_eq!(station.query_pool_available_coin_count().await, 1); + assert_eq!(station.query_pool_available_coin_count(sponsor).await, 1); } #[tokio::test] @@ -93,7 +93,7 @@ mod tests { .await; println!("{:?}", result); assert!(result.is_err()); - assert_eq!(station.query_pool_available_coin_count().await, 1); + assert_eq!(station.query_pool_available_coin_count(sponsor).await, 1); } #[tokio::test] @@ -106,14 +106,14 @@ mod tests { .await .unwrap(); assert_eq!(gas_coins.len(), 1); - assert_eq!(station.query_pool_available_coin_count().await, 0); + assert_eq!(station.query_pool_available_coin_count(sponsor).await, 0); assert!(station .reserve_gas(1, Duration::from_secs(1)) .await .is_err()); // Sleep a little longer to give it enough time to expire. tokio::time::sleep(Duration::from_secs(5)).await; - assert_eq!(station.query_pool_available_coin_count().await, 1); + assert_eq!(station.query_pool_available_coin_count(sponsor).await, 1); let (tx_data, user_sig) = create_test_transaction(&test_cluster, sponsor, gas_coins).await; assert!(station .execute_transaction(reservation_id, tx_data, user_sig) diff --git a/src/gas_pool_initializer.rs b/src/gas_pool_initializer.rs index a2d5702..a9124f6 100644 --- a/src/gas_pool_initializer.rs +++ b/src/gas_pool_initializer.rs @@ -36,7 +36,7 @@ const MAX_INIT_DURATION_SEC: u64 = 60 * 60 * 12; struct CoinSplitEnv { target_init_coin_balance: u64, gas_cost_per_object: u64, - signer: Arc, + signer: Arc, sponsor_address: SuiAddress, sui_client: SuiClient, task_queue: Arc>>>>, @@ -161,7 +161,7 @@ enum RunMode { } pub struct GasPoolInitializer { - _task_handle: JoinHandle<()>, + _fund_task_handle: JoinHandle<()>, // This is always Some. It is None only after the drop method is called. cancel_sender: Option>, } @@ -177,21 +177,24 @@ impl GasPoolInitializer { sui_client: SuiClient, storage: Arc, coin_init_config: CoinInitConfig, - signer: Arc, + signer: Arc, ) -> Self { - if !storage.is_initialized().await.unwrap() { - // If the pool has never been initialized, always run once at the beginning to make sure we have enough coins. - Self::run_once( - sui_client.clone(), - &storage, - RunMode::Init, - coin_init_config.target_init_balance, - &signer, - ) - .await; + for address in signer.get_all_addresses() { + if !storage.is_initialized(address).await.unwrap() { + // If the pool has never been initialized, always run once at the beginning to make sure we have enough coins. + Self::run_once( + address, + sui_client.clone(), + &storage, + RunMode::Init, + coin_init_config.target_init_balance, + &signer, + ) + .await; + } } let (cancel_sender, cancel_receiver) = tokio::sync::oneshot::channel(); - let _task_handle = tokio::spawn(Self::run( + let _fund_task_handle = tokio::spawn(Self::run( sui_client, storage, coin_init_config, @@ -199,7 +202,7 @@ impl GasPoolInitializer { cancel_receiver, )); Self { - _task_handle, + _fund_task_handle, cancel_sender: Some(cancel_sender), } } @@ -208,7 +211,7 @@ impl GasPoolInitializer { sui_client: SuiClient, storage: Arc, coin_init_config: CoinInitConfig, - signer: Arc, + signer: Arc, mut cancel_receiver: tokio::sync::oneshot::Receiver<()>, ) { loop { @@ -220,38 +223,50 @@ impl GasPoolInitializer { } } info!("Coin init task waking up and looking for new coins to initialize"); - Self::run_once( - sui_client.clone(), - &storage, - RunMode::Refresh, - coin_init_config.target_init_balance, - &signer, - ) - .await; + for address in signer.get_all_addresses() { + Self::run_once( + address, + sui_client.clone(), + &storage, + RunMode::Refresh, + coin_init_config.target_init_balance, + &signer, + ) + .await; + } } } async fn run_once( + sponsor_address: SuiAddress, sui_client: SuiClient, storage: &Arc, mode: RunMode, target_init_coin_balance: u64, - signer: &Arc, + signer: &Arc, ) { - let sponsor_address = signer.get_address(); if storage - .acquire_init_lock(MAX_INIT_DURATION_SEC) + .acquire_init_lock(sponsor_address, MAX_INIT_DURATION_SEC) .await .unwrap() { - info!("Acquired init lock. Starting new coin initialization"); + info!( + ?sponsor_address, + "Acquired init lock. Starting new coin initialization" + ); } else { - info!("Another task is already initializing the pool. Skipping this round"); + info!( + ?sponsor_address, + "Another task is already initializing the pool. Skipping this round" + ); return; } let start = Instant::now(); let balance_threshold = if matches!(mode, RunMode::Init) { - info!("The pool has never been initialized. Initializing it for the first time"); + info!( + ?sponsor_address, + "The pool has never been initialized. Initializing it for the first time" + ); 0 } else { target_init_coin_balance * NEW_COIN_BALANCE_FACTOR_THRESHOLD @@ -261,10 +276,11 @@ impl GasPoolInitializer { .await; if coins.is_empty() { info!( + ?sponsor_address, "No coins with balance above {} found. Skipping new coin initialization", balance_threshold ); - storage.release_init_lock().await.unwrap(); + storage.release_init_lock(sponsor_address).await.unwrap(); return; } let total_coin_count = Arc::new(AtomicUsize::new(coins.len())); @@ -288,10 +304,14 @@ impl GasPoolInitializer { ) .await; for chunk in result.chunks(5000) { - storage.add_new_coins(chunk.to_vec()).await.unwrap(); + storage + .add_new_coins(sponsor_address, chunk.to_vec()) + .await + .unwrap(); } - storage.release_init_lock().await.unwrap(); + storage.release_init_lock(sponsor_address).await.unwrap(); info!( + ?sponsor_address, "New coin initialization took {:?}s", start.elapsed().as_secs() ); @@ -343,7 +363,8 @@ mod tests { telemetry_subscribers::init_for_testing(); let (cluster, signer) = start_sui_cluster(vec![1000 * MIST_PER_SUI]).await; let fullnode_url = cluster.fullnode_handle.rpc_url; - let storage = connect_storage_for_testing(signer.get_address()).await; + let sponsor = signer.get_one_address(); + let storage = connect_storage_for_testing(sponsor).await; let sui_client = SuiClient::new(&fullnode_url, None).await; let _ = GasPoolInitializer::start( sui_client, @@ -355,15 +376,16 @@ mod tests { signer, ) .await; - assert!(storage.get_available_coin_count().await.unwrap() > 900); + assert!(storage.get_available_coin_count(sponsor).await.unwrap() > 900); } #[tokio::test] async fn test_init_non_even_split() { telemetry_subscribers::init_for_testing(); let (cluster, signer) = start_sui_cluster(vec![10000000 * MIST_PER_SUI]).await; + let sponsor = signer.get_one_address(); let fullnode_url = cluster.fullnode_handle.rpc_url; - let storage = connect_storage_for_testing(signer.get_address()).await; + let storage = connect_storage_for_testing(sponsor).await; let target_init_balance = 12345 * MIST_PER_SUI; let sui_client = SuiClient::new(&fullnode_url, None).await; let _ = GasPoolInitializer::start( @@ -376,16 +398,16 @@ mod tests { signer, ) .await; - assert!(storage.get_available_coin_count().await.unwrap() > 800); + assert!(storage.get_available_coin_count(sponsor).await.unwrap() > 800); } #[tokio::test] async fn test_add_new_funds_to_pool() { telemetry_subscribers::init_for_testing(); let (cluster, signer) = start_sui_cluster(vec![1000 * MIST_PER_SUI]).await; - let sponsor = signer.get_address(); + let sponsor = signer.get_one_address(); let fullnode_url = cluster.fullnode_handle.rpc_url.clone(); - let storage = connect_storage_for_testing(signer.get_address()).await; + let storage = connect_storage_for_testing(sponsor).await; let sui_client = SuiClient::new(&fullnode_url, None).await; let _init_task = GasPoolInitializer::start( sui_client, @@ -397,8 +419,8 @@ mod tests { signer, ) .await; - assert!(storage.is_initialized().await.unwrap()); - let available_coin_count = storage.get_available_coin_count().await.unwrap(); + assert!(storage.is_initialized(sponsor).await.unwrap()); + let available_coin_count = storage.get_available_coin_count(sponsor).await.unwrap(); tracing::debug!("Available coin count: {}", available_coin_count); // Transfer some new SUI into the sponsor account. @@ -420,7 +442,7 @@ mod tests { // Give it some time for the task to pick up the new coin and split it. tokio::time::sleep(std::time::Duration::from_secs(30)).await; - let new_available_coin_count = storage.get_available_coin_count().await.unwrap(); + let new_available_coin_count = storage.get_available_coin_count(sponsor).await.unwrap(); assert!( // In an ideal world we should have NEW_COIN_BALANCE_FACTOR_THRESHOLD more coins // since we just send a new coin with balance NEW_COIN_BALANCE_FACTOR_THRESHOLD and split diff --git a/src/storage/mod.rs b/src/storage/mod.rs index 7918179..8314a23 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -26,15 +26,24 @@ pub trait Storage: Sync + Send { /// 3. It should never return more than 256 coins at a time since that's the upper bound of gas. async fn reserve_gas_coins( &self, + sponsor_address: SuiAddress, target_budget: u64, reserved_duration_ms: u64, ) -> anyhow::Result<(ReservationID, Vec)>; - async fn ready_for_execution(&self, reservation_id: ReservationID) -> anyhow::Result<()>; + async fn ready_for_execution( + &self, + sponsor_address: SuiAddress, + reservation_id: ReservationID, + ) -> anyhow::Result<()>; - async fn add_new_coins(&self, new_coins: Vec) -> anyhow::Result<()>; + async fn add_new_coins( + &self, + sponsor_address: SuiAddress, + new_coins: Vec, + ) -> anyhow::Result<()>; - async fn expire_coins(&self) -> anyhow::Result>; + async fn expire_coins(&self, sponsor_address: SuiAddress) -> anyhow::Result>; /// Initialize some of the gas pool statistics at the startup. /// Such as the total number of gas coins and the total balance. @@ -43,48 +52,57 @@ pub trait Storage: Sync + Send { /// We only need this once ever though. /// 2. To make sure we start reporting the correct metrics from the beginning. /// Returns the total number of gas coins and the total balance. - async fn init_coin_stats_at_startup(&self) -> anyhow::Result<(u64, u64)>; + async fn init_coin_stats_at_startup( + &self, + sponsor_address: SuiAddress, + ) -> anyhow::Result<(u64, u64)>; /// Whether the gas pool for the given sponsor address is initialized. - async fn is_initialized(&self) -> anyhow::Result; + async fn is_initialized(&self, sponsor_address: SuiAddress) -> anyhow::Result; /// Acquire a lock to initialize the gas pool for the given sponsor address for a certain duration. /// Returns true if the lock is acquired, false otherwise. /// Once the lock is acquired, until it expires, no other caller can acquire the lock. /// The reason we use a lock duration is such that in case the server crashed while holding the lock, /// the lock will be automatically considered as released after the lock duration. - async fn acquire_init_lock(&self, lock_duration_sec: u64) -> anyhow::Result; + async fn acquire_init_lock( + &self, + sponsor_address: SuiAddress, + lock_duration_sec: u64, + ) -> anyhow::Result; - async fn release_init_lock(&self) -> anyhow::Result<()>; + async fn release_init_lock(&self, sponsor_address: SuiAddress) -> anyhow::Result<()>; async fn check_health(&self) -> anyhow::Result<()>; #[cfg(test)] async fn flush_db(&self); - async fn get_available_coin_count(&self) -> anyhow::Result; + async fn get_available_coin_count(&self, sponsor_address: SuiAddress) -> anyhow::Result; - async fn get_available_coin_total_balance(&self) -> u64; + async fn get_available_coin_total_balance(&self, sponsor_address: SuiAddress) -> u64; #[cfg(test)] - async fn get_reserved_coin_count(&self) -> usize; + async fn get_reserved_coin_count(&self, sponsor_address: SuiAddress) -> usize; } pub async fn connect_storage( config: &GasPoolStorageConfig, - sponsor_address: SuiAddress, + sponsor_addresses: Vec, metrics: Arc, ) -> Arc { let storage: Arc = match config { GasPoolStorageConfig::Redis { redis_url } => { - Arc::new(RedisStorage::new(redis_url, sponsor_address, metrics).await) + Arc::new(RedisStorage::new(redis_url, metrics).await) } }; storage .check_health() .await .expect("Unable to connect to the storage layer"); - storage.init_coin_stats_at_startup().await.unwrap(); + for address in sponsor_addresses { + storage.init_coin_stats_at_startup(address).await.unwrap(); + } storage } @@ -98,12 +116,20 @@ pub async fn connect_storage_for_testing_with_config( static IS_FIRST_CALL: AtomicBool = AtomicBool::new(true); let is_first_call = IS_FIRST_CALL.fetch_and(false, Ordering::SeqCst); - let storage = connect_storage(config, sponsor_address, StorageMetrics::new_for_testing()).await; + let storage = connect_storage( + config, + vec![sponsor_address], + StorageMetrics::new_for_testing(), + ) + .await; if is_first_call { // Make sure that we only flush the DB once at the beginning of each test run. storage.flush_db().await; // Re-init coin stats again since we just flushed. - storage.init_coin_stats_at_startup().await.unwrap(); + storage + .init_coin_stats_at_startup(sponsor_address) + .await + .unwrap(); } storage } @@ -124,9 +150,17 @@ mod tests { use sui_types::base_types::{random_object_ref, ObjectID, SequenceNumber, SuiAddress}; use sui_types::digests::ObjectDigest; - async fn assert_coin_count(storage: &Arc, available: usize, reserved: usize) { - assert_eq!(storage.get_available_coin_count().await.unwrap(), available); - assert_eq!(storage.get_reserved_coin_count().await, reserved); + async fn assert_coin_count( + sponsor: SuiAddress, + storage: &Arc, + available: usize, + reserved: usize, + ) { + assert_eq!( + storage.get_available_coin_count(sponsor).await.unwrap(), + available + ); + assert_eq!(storage.get_reserved_coin_count(sponsor).await, reserved); } async fn setup(sponsor: SuiAddress, init_balances: Vec) -> Arc { @@ -143,7 +177,10 @@ mod tests { }) .collect::>(); for chunk in gas_coins.chunks(5000) { - storage.add_new_coins(chunk.to_vec()).await.unwrap(); + storage + .add_new_coins(sponsor, chunk.to_vec()) + .await + .unwrap(); } storage } @@ -152,18 +189,21 @@ mod tests { async fn test_gas_pool_init() { let sponsor = SuiAddress::random_for_testing_only(); let storage = connect_storage_for_testing(sponsor).await; - assert!(!storage.is_initialized().await.unwrap()); - storage.add_new_coins(vec![]).await.unwrap(); + assert!(!storage.is_initialized(sponsor).await.unwrap()); + storage.add_new_coins(sponsor, vec![]).await.unwrap(); // Still not initialized because we are not adding any coins. - assert!(!storage.is_initialized().await.unwrap()); + assert!(!storage.is_initialized(sponsor).await.unwrap()); storage - .add_new_coins(vec![GasCoin { - object_ref: random_object_ref(), - balance: 1, - }]) + .add_new_coins( + sponsor, + vec![GasCoin { + object_ref: random_object_ref(), + balance: 1, + }], + ) .await .unwrap(); - assert!(storage.is_initialized().await.unwrap()); + assert!(storage.is_initialized(sponsor).await.unwrap()); } #[tokio::test] @@ -171,18 +211,20 @@ mod tests { // Create a gas pool of 100000 coins, each with balance of 1. let sponsor = SuiAddress::random_for_testing_only(); let storage = setup(sponsor, vec![1; 100000]).await; - assert_coin_count(&storage, 100000, 0).await; + assert_coin_count(sponsor, &storage, 100000, 0).await; let mut cur_available = 100000; let mut expected_res_id = 1; for i in 1..=MAX_GAS_PER_QUERY { - let (res_id, reserved_gas_coins) = - storage.reserve_gas_coins(i as u64, 1000).await.unwrap(); + let (res_id, reserved_gas_coins) = storage + .reserve_gas_coins(sponsor, i as u64, 1000) + .await + .unwrap(); assert_eq!(expected_res_id, res_id); assert_eq!(reserved_gas_coins.len(), i); expected_res_id += 1; cur_available -= i; } - assert_coin_count(&storage, cur_available, 100000 - cur_available).await; + assert_coin_count(sponsor, &storage, cur_available, 100000 - cur_available).await; } #[tokio::test] @@ -190,18 +232,18 @@ mod tests { let sponsor = SuiAddress::random_for_testing_only(); let storage = setup(sponsor, vec![1; MAX_GAS_PER_QUERY + 1]).await; assert!(storage - .reserve_gas_coins((MAX_GAS_PER_QUERY + 1) as u64, 1000) + .reserve_gas_coins(sponsor, (MAX_GAS_PER_QUERY + 1) as u64, 1000) .await .is_err()); - assert_coin_count(&storage, MAX_GAS_PER_QUERY + 1, 0).await; + assert_coin_count(sponsor, &storage, MAX_GAS_PER_QUERY + 1, 0).await; } #[tokio::test] async fn test_insufficient_pool_budget() { let sponsor = SuiAddress::random_for_testing_only(); let storage = setup(sponsor, vec![1; 100]).await; - assert!(storage.reserve_gas_coins(101, 1000).await.is_err()); - assert_coin_count(&storage, 100, 0).await; + assert!(storage.reserve_gas_coins(sponsor, 101, 1000).await.is_err()); + assert_coin_count(sponsor, &storage, 100, 0).await; } #[tokio::test] @@ -211,12 +253,16 @@ mod tests { for _ in 0..100 { // Keep reserving and putting them back. // Should be able to repeat this process indefinitely if balance are not changed. - let (res_id, reserved_gas_coins) = storage.reserve_gas_coins(99, 1000).await.unwrap(); + let (res_id, reserved_gas_coins) = + storage.reserve_gas_coins(sponsor, 99, 1000).await.unwrap(); assert_eq!(reserved_gas_coins.len(), 99); - assert_coin_count(&storage, 1, 99).await; - storage.ready_for_execution(res_id).await.unwrap(); - storage.add_new_coins(reserved_gas_coins).await.unwrap(); - assert_coin_count(&storage, 100, 0).await; + assert_coin_count(sponsor, &storage, 1, 99).await; + storage.ready_for_execution(sponsor, res_id).await.unwrap(); + storage + .add_new_coins(sponsor, reserved_gas_coins) + .await + .unwrap(); + assert_coin_count(sponsor, &storage, 100, 0).await; } } @@ -226,7 +272,7 @@ mod tests { let storage = setup(sponsor, vec![1; 100]).await; for _ in 0..10 { let (res_id, mut reserved_gas_coins) = - storage.reserve_gas_coins(10, 1000).await.unwrap(); + storage.reserve_gas_coins(sponsor, 10, 1000).await.unwrap(); assert_eq!( reserved_gas_coins.iter().map(|c| c.balance).sum::(), 10 @@ -236,46 +282,56 @@ mod tests { reserved_gas_coin.balance -= 1; } } - storage.ready_for_execution(res_id).await.unwrap(); - storage.add_new_coins(reserved_gas_coins).await.unwrap(); + storage.ready_for_execution(sponsor, res_id).await.unwrap(); + storage + .add_new_coins(sponsor, reserved_gas_coins) + .await + .unwrap(); } - assert_coin_count(&storage, 100, 0).await; - assert_eq!(storage.get_available_coin_total_balance().await, 0); - assert!(storage.reserve_gas_coins(1, 1000).await.is_err()); + assert_coin_count(sponsor, &storage, 100, 0).await; + assert_eq!(storage.get_available_coin_total_balance(sponsor).await, 0); + assert!(storage.reserve_gas_coins(sponsor, 1, 1000).await.is_err()); } #[tokio::test] async fn test_deleted_objects() { let sponsor = SuiAddress::random_for_testing_only(); let storage = setup(sponsor, vec![1; 100]).await; - let (res_id, mut reserved_gas_coins) = storage.reserve_gas_coins(100, 1000).await.unwrap(); + let (res_id, mut reserved_gas_coins) = + storage.reserve_gas_coins(sponsor, 100, 1000).await.unwrap(); assert_eq!(reserved_gas_coins.len(), 100); - storage.ready_for_execution(res_id).await.unwrap(); + storage.ready_for_execution(sponsor, res_id).await.unwrap(); reserved_gas_coins.drain(0..50); - storage.add_new_coins(reserved_gas_coins).await.unwrap(); - assert_coin_count(&storage, 50, 0).await; + storage + .add_new_coins(sponsor, reserved_gas_coins) + .await + .unwrap(); + assert_coin_count(sponsor, &storage, 50, 0).await; } #[tokio::test] async fn test_coin_expiration() { let sponsor = SuiAddress::random_for_testing_only(); let storage = setup(sponsor, vec![1; 100]).await; - let (_res_id1, reserved_gas_coins1) = storage.reserve_gas_coins(10, 900).await.unwrap(); + let (_res_id1, reserved_gas_coins1) = + storage.reserve_gas_coins(sponsor, 10, 900).await.unwrap(); assert_eq!(reserved_gas_coins1.len(), 10); - let (_res_id2, reserved_gas_coins2) = storage.reserve_gas_coins(30, 1900).await.unwrap(); + let (_res_id2, reserved_gas_coins2) = + storage.reserve_gas_coins(sponsor, 30, 1900).await.unwrap(); assert_eq!(reserved_gas_coins2.len(), 30); // Just to make sure these two reservations will have a different expiration timestamp. tokio::time::sleep(Duration::from_millis(1)).await; - let (_res_id3, reserved_gas_coins3) = storage.reserve_gas_coins(50, 1900).await.unwrap(); + let (_res_id3, reserved_gas_coins3) = + storage.reserve_gas_coins(sponsor, 50, 1900).await.unwrap(); assert_eq!(reserved_gas_coins3.len(), 50); - assert_coin_count(&storage, 10, 90).await; + assert_coin_count(sponsor, &storage, 10, 90).await; - assert!(storage.expire_coins().await.unwrap().is_empty()); - assert_coin_count(&storage, 10, 90).await; + assert!(storage.expire_coins(sponsor).await.unwrap().is_empty()); + assert_coin_count(sponsor, &storage, 10, 90).await; tokio::time::sleep(Duration::from_secs(1)).await; - let expired1 = storage.expire_coins().await.unwrap(); + let expired1 = storage.expire_coins(sponsor).await.unwrap(); assert_eq!(expired1.len(), 10); assert_eq!( expired1.iter().cloned().collect::>(), @@ -284,13 +340,13 @@ mod tests { .map(|coin| coin.object_ref.0) .collect::>() ); - assert_coin_count(&storage, 10, 80).await; + assert_coin_count(sponsor, &storage, 10, 80).await; - assert!(storage.expire_coins().await.unwrap().is_empty()); - assert_coin_count(&storage, 10, 80).await; + assert!(storage.expire_coins(sponsor).await.unwrap().is_empty()); + assert_coin_count(sponsor, &storage, 10, 80).await; tokio::time::sleep(Duration::from_secs(1)).await; - let expired2 = storage.expire_coins().await.unwrap(); + let expired2 = storage.expire_coins(sponsor).await.unwrap(); assert_eq!(expired2.len(), 80); assert_eq!( expired2.iter().cloned().collect::>(), @@ -300,7 +356,7 @@ mod tests { .map(|coin| coin.object_ref.0) .collect::>() ); - assert_coin_count(&storage, 10, 0).await; + assert_coin_count(sponsor, &storage, 10, 0).await; } #[tokio::test] @@ -309,13 +365,13 @@ mod tests { .map(|_| SuiAddress::random_for_testing_only()) .collect::>(); let mut storages = vec![]; - for sponsor in sponsors { - storages.push(setup(sponsor, vec![1; 100]).await); + for sponsor in &sponsors { + storages.push(setup(*sponsor, vec![1; 100]).await); } - for storage in storages { - let (_, gas_coins) = storage.reserve_gas_coins(50, 1000).await.unwrap(); + for (storage, sponsor) in storages.into_iter().zip(sponsors) { + let (_, gas_coins) = storage.reserve_gas_coins(sponsor, 50, 1000).await.unwrap(); assert_eq!(gas_coins.len(), 50); - assert_coin_count(&storage, 50, 50).await; + assert_coin_count(sponsor, &storage, 50, 50).await; } } @@ -329,7 +385,8 @@ mod tests { handles.push(tokio::spawn(async move { let mut reserved_gas_coins = vec![]; for _ in 0..100 { - let (_, newly_reserved) = storage.reserve_gas_coins(3, 1000).await.unwrap(); + let (_, newly_reserved) = + storage.reserve_gas_coins(sponsor, 3, 1000).await.unwrap(); reserved_gas_coins.extend(newly_reserved); } reserved_gas_coins @@ -344,17 +401,17 @@ mod tests { reserved_gas_coins.sort_by_key(|c| c.object_ref.0); reserved_gas_coins.dedup_by_key(|c| c.object_ref.0); assert_eq!(reserved_gas_coins.len(), count); - assert_coin_count(&storage, 100000 - count, count).await; + assert_coin_count(sponsor, &storage, 100000 - count, count).await; } #[tokio::test] async fn test_acquire_init_lock() { let sponsor = SuiAddress::random_for_testing_only(); let storage = setup(sponsor, vec![1; 100]).await; - assert!(storage.acquire_init_lock(5).await.unwrap()); - assert!(!storage.acquire_init_lock(1).await.unwrap()); + assert!(storage.acquire_init_lock(sponsor, 5).await.unwrap()); + assert!(!storage.acquire_init_lock(sponsor, 1).await.unwrap()); tokio::time::sleep(Duration::from_secs(6)).await; - assert!(storage.acquire_init_lock(5).await.unwrap()); + assert!(storage.acquire_init_lock(sponsor, 5).await.unwrap()); } #[tokio::test] @@ -363,7 +420,8 @@ mod tests { let storage = setup(sponsor, vec![1; 100]).await; // init_coin_stats_at_startup has already been called in setup. // Calling it again should not change anything. - let (coin_count, total_balance) = storage.init_coin_stats_at_startup().await.unwrap(); + let (coin_count, total_balance) = + storage.init_coin_stats_at_startup(sponsor).await.unwrap(); assert_eq!(coin_count, 100); assert_eq!(total_balance, 100); } diff --git a/src/storage/redis/mod.rs b/src/storage/redis/mod.rs index fa9f15a..904d034 100644 --- a/src/storage/redis/mod.rs +++ b/src/storage/redis/mod.rs @@ -18,22 +18,15 @@ use tracing::{debug, info}; pub struct RedisStorage { conn_manager: ConnectionManager, - // String format of the sponsor address to avoid converting it to string multiple times. - sponsor_str: String, metrics: Arc, } impl RedisStorage { - pub async fn new( - redis_url: &str, - sponsor_address: SuiAddress, - metrics: Arc, - ) -> Self { + pub async fn new(redis_url: &str, metrics: Arc) -> Self { let client = redis::Client::open(redis_url).unwrap(); let conn_manager = ConnectionManager::new(client).await.unwrap(); Self { conn_manager, - sponsor_str: sponsor_address.to_string(), metrics, } } @@ -43,10 +36,12 @@ impl RedisStorage { impl Storage for RedisStorage { async fn reserve_gas_coins( &self, + sponsor_address: SuiAddress, target_budget: u64, reserved_duration_ms: u64, ) -> anyhow::Result<(ReservationID, Vec)> { self.metrics.num_reserve_gas_coins_requests.inc(); + let sponsor_str = sponsor_address.to_string(); let expiration_time = Utc::now() .add(Duration::from_millis(reserved_duration_ms)) @@ -58,7 +53,7 @@ impl Storage for RedisStorage { i64, i64, ) = ScriptManager::reserve_gas_coins_script() - .arg(self.sponsor_str.clone()) + .arg(sponsor_str.clone()) .arg(target_budget) .arg(expiration_time) .invoke_async(&mut conn) @@ -89,22 +84,27 @@ impl Storage for RedisStorage { self.metrics .gas_pool_available_gas_coin_count - .with_label_values(&[&self.sponsor_str]) + .with_label_values(&[&sponsor_str]) .set(new_coin_count); self.metrics .gas_pool_available_gas_total_balance - .with_label_values(&[&self.sponsor_str]) + .with_label_values(&[&sponsor_str]) .set(new_total_balance); self.metrics.num_successful_reserve_gas_coins_requests.inc(); Ok((reservation_id, gas_coins)) } - async fn ready_for_execution(&self, reservation_id: ReservationID) -> anyhow::Result<()> { + async fn ready_for_execution( + &self, + sponsor_address: SuiAddress, + reservation_id: ReservationID, + ) -> anyhow::Result<()> { self.metrics.num_ready_for_execution_requests.inc(); + let sponsor_str = sponsor_address.to_string(); let mut conn = self.conn_manager.clone(); ScriptManager::ready_for_execution_script() - .arg(self.sponsor_str.clone()) + .arg(sponsor_str) .arg(reservation_id) .invoke_async::<_, ()>(&mut conn) .await?; @@ -115,8 +115,13 @@ impl Storage for RedisStorage { Ok(()) } - async fn add_new_coins(&self, new_coins: Vec) -> anyhow::Result<()> { + async fn add_new_coins( + &self, + sponsor_address: SuiAddress, + new_coins: Vec, + ) -> anyhow::Result<()> { self.metrics.num_add_new_coins_requests.inc(); + let sponsor_str = sponsor_address.to_string(); let formatted_coins = new_coins .iter() .map(|c| { @@ -135,7 +140,7 @@ impl Storage for RedisStorage { let mut conn = self.conn_manager.clone(); let (new_total_balance, new_coin_count): (i64, i64) = ScriptManager::add_new_coins_script() - .arg(self.sponsor_str.clone()) + .arg(sponsor_str.clone()) .arg(serde_json::to_string(&formatted_coins)?) .invoke_async(&mut conn) .await?; @@ -146,23 +151,24 @@ impl Storage for RedisStorage { ); self.metrics .gas_pool_available_gas_coin_count - .with_label_values(&[&self.sponsor_str]) + .with_label_values(&[&sponsor_str]) .set(new_coin_count); self.metrics .gas_pool_available_gas_total_balance - .with_label_values(&[&self.sponsor_str]) + .with_label_values(&[&sponsor_str]) .set(new_total_balance); self.metrics.num_successful_add_new_coins_requests.inc(); Ok(()) } - async fn expire_coins(&self) -> anyhow::Result> { + async fn expire_coins(&self, sponsor_address: SuiAddress) -> anyhow::Result> { self.metrics.num_expire_coins_requests.inc(); + let sponsor_str = sponsor_address.to_string(); let now = Utc::now().timestamp_millis() as u64; let mut conn = self.conn_manager.clone(); let expired_coin_strings: Vec = ScriptManager::expire_coins_script() - .arg(self.sponsor_str.clone()) + .arg(sponsor_str) .arg(now) .invoke_async(&mut conn) .await?; @@ -176,26 +182,30 @@ impl Storage for RedisStorage { Ok(expired_coin_ids) } - async fn init_coin_stats_at_startup(&self) -> anyhow::Result<(u64, u64)> { + async fn init_coin_stats_at_startup( + &self, + sponsor_address: SuiAddress, + ) -> anyhow::Result<(u64, u64)> { + let sponsor_str = sponsor_address.to_string(); let mut conn = self.conn_manager.clone(); let (available_coin_count, available_coin_total_balance): (i64, i64) = ScriptManager::init_coin_stats_at_startup_script() - .arg(self.sponsor_str.clone()) + .arg(sponsor_str.clone()) .invoke_async(&mut conn) .await?; info!( - sponsor_address=?self.sponsor_str, + sponsor_address=?sponsor_str, "Number of available gas coins in the pool: {}, total balance: {}", available_coin_count, available_coin_total_balance ); self.metrics .gas_pool_available_gas_coin_count - .with_label_values(&[&self.sponsor_str]) + .with_label_values(&[&sponsor_str]) .set(available_coin_count); self.metrics .gas_pool_available_gas_total_balance - .with_label_values(&[&self.sponsor_str]) + .with_label_values(&[&sponsor_str]) .set(available_coin_total_balance); Ok(( available_coin_count as u64, @@ -203,16 +213,22 @@ impl Storage for RedisStorage { )) } - async fn is_initialized(&self) -> anyhow::Result { + async fn is_initialized(&self, sponsor_address: SuiAddress) -> anyhow::Result { + let sponsor_str = sponsor_address.to_string(); let mut conn = self.conn_manager.clone(); let result = ScriptManager::get_is_initialized_script() - .arg(self.sponsor_str.clone()) + .arg(sponsor_str) .invoke_async::<_, bool>(&mut conn) .await?; Ok(result) } - async fn acquire_init_lock(&self, lock_duration_sec: u64) -> anyhow::Result { + async fn acquire_init_lock( + &self, + sponsor_address: SuiAddress, + lock_duration_sec: u64, + ) -> anyhow::Result { + let sponsor_str = sponsor_address.to_string(); let mut conn = self.conn_manager.clone(); let cur_timestamp = Utc::now().timestamp() as u64; debug!( @@ -220,7 +236,7 @@ impl Storage for RedisStorage { cur_timestamp, lock_duration_sec ); let result = ScriptManager::acquire_init_lock_script() - .arg(self.sponsor_str.clone()) + .arg(sponsor_str) .arg(cur_timestamp) .arg(lock_duration_sec) .invoke_async::<_, bool>(&mut conn) @@ -228,11 +244,12 @@ impl Storage for RedisStorage { Ok(result) } - async fn release_init_lock(&self) -> anyhow::Result<()> { + async fn release_init_lock(&self, sponsor_address: SuiAddress) -> anyhow::Result<()> { debug!("Releasing the init lock."); + let sponsor_str = sponsor_address.to_string(); let mut conn = self.conn_manager.clone(); ScriptManager::release_init_lock_script() - .arg(self.sponsor_str.clone()) + .arg(sponsor_str) .invoke_async::<_, ()>(&mut conn) .await?; Ok(()) @@ -253,29 +270,32 @@ impl Storage for RedisStorage { .unwrap(); } - async fn get_available_coin_count(&self) -> anyhow::Result { + async fn get_available_coin_count(&self, sponsor_address: SuiAddress) -> anyhow::Result { + let sponsor_str = sponsor_address.to_string(); let mut conn = self.conn_manager.clone(); let count = ScriptManager::get_available_coin_count_script() - .arg(self.sponsor_str.clone()) + .arg(sponsor_str) .invoke_async::<_, usize>(&mut conn) .await?; Ok(count) } - async fn get_available_coin_total_balance(&self) -> u64 { + async fn get_available_coin_total_balance(&self, sponsor_address: SuiAddress) -> u64 { + let sponsor_str = sponsor_address.to_string(); let mut conn = self.conn_manager.clone(); ScriptManager::get_available_coin_total_balance_script() - .arg(self.sponsor_str.clone()) + .arg(sponsor_str) .invoke_async::<_, u64>(&mut conn) .await .unwrap() } #[cfg(test)] - async fn get_reserved_coin_count(&self) -> usize { + async fn get_reserved_coin_count(&self, sponsor_address: SuiAddress) -> usize { + let sponsor_str = sponsor_address.to_string(); let mut conn = self.conn_manager.clone(); ScriptManager::get_reserved_coin_count_script() - .arg(self.sponsor_str.clone()) + .arg(sponsor_str) .invoke_async::<_, usize>(&mut conn) .await .unwrap() @@ -295,72 +315,83 @@ mod tests { #[tokio::test] async fn test_init_coin_stats_at_startup() { let storage = setup_storage().await; + let sponsor = SuiAddress::ZERO; storage - .add_new_coins(vec![ - GasCoin { - balance: 100, - object_ref: random_object_ref(), - }, - GasCoin { - balance: 200, - object_ref: random_object_ref(), - }, - ]) + .add_new_coins( + sponsor, + vec![ + GasCoin { + balance: 100, + object_ref: random_object_ref(), + }, + GasCoin { + balance: 200, + object_ref: random_object_ref(), + }, + ], + ) .await .unwrap(); - let (coin_count, total_balance) = storage.init_coin_stats_at_startup().await.unwrap(); + let (coin_count, total_balance) = + storage.init_coin_stats_at_startup(sponsor).await.unwrap(); assert_eq!(coin_count, 2); assert_eq!(total_balance, 300); } #[tokio::test] async fn test_add_new_coins() { + let sponsor = SuiAddress::ZERO; let storage = setup_storage().await; storage - .add_new_coins(vec![ - GasCoin { - balance: 100, - object_ref: random_object_ref(), - }, - GasCoin { - balance: 200, - object_ref: random_object_ref(), - }, - ]) + .add_new_coins( + sponsor, + vec![ + GasCoin { + balance: 100, + object_ref: random_object_ref(), + }, + GasCoin { + balance: 200, + object_ref: random_object_ref(), + }, + ], + ) .await .unwrap(); - let coin_count = storage.get_available_coin_count().await.unwrap(); + let coin_count = storage.get_available_coin_count(sponsor).await.unwrap(); assert_eq!(coin_count, 2); - let total_balance = storage.get_available_coin_total_balance().await; + let total_balance = storage.get_available_coin_total_balance(sponsor).await; assert_eq!(total_balance, 300); storage - .add_new_coins(vec![ - GasCoin { - balance: 300, - object_ref: random_object_ref(), - }, - GasCoin { - balance: 400, - object_ref: random_object_ref(), - }, - ]) + .add_new_coins( + sponsor, + vec![ + GasCoin { + balance: 300, + object_ref: random_object_ref(), + }, + GasCoin { + balance: 400, + object_ref: random_object_ref(), + }, + ], + ) .await .unwrap(); - let coin_count = storage.get_available_coin_count().await.unwrap(); + let coin_count = storage.get_available_coin_count(sponsor).await.unwrap(); assert_eq!(coin_count, 4); - let total_balance = storage.get_available_coin_total_balance().await; + let total_balance = storage.get_available_coin_total_balance(sponsor).await; assert_eq!(total_balance, 1000); } async fn setup_storage() -> RedisStorage { - let storage = RedisStorage::new( - "redis://127.0.0.1:6379", - SuiAddress::ZERO, - StorageMetrics::new_for_testing(), - ) - .await; + let storage = + RedisStorage::new("redis://127.0.0.1:6379", StorageMetrics::new_for_testing()).await; storage.flush_db().await; - let (coin_count, total_balance) = storage.init_coin_stats_at_startup().await.unwrap(); + let (coin_count, total_balance) = storage + .init_coin_stats_at_startup(SuiAddress::ZERO) + .await + .unwrap(); assert_eq!(coin_count, 0); assert_eq!(total_balance, 0); storage diff --git a/src/test_env.rs b/src/test_env.rs index a6d4bc6..7cab63e 100644 --- a/src/test_env.rs +++ b/src/test_env.rs @@ -8,7 +8,8 @@ use crate::metrics::{GasPoolCoreMetrics, GasPoolRpcMetrics}; use crate::rpc::GasPoolServer; use crate::storage::connect_storage_for_testing; use crate::sui_client::SuiClient; -use crate::tx_signer::{TestTxSigner, TxSigner}; +use crate::tx_signer::in_memory_signer::InMemoryTxSigner; +use crate::tx_signer::TxSigner; use crate::AUTH_ENV_NAME; use std::sync::Arc; use sui_config::local_ip_utils::{get_available_port, localhost_for_testing}; @@ -21,7 +22,7 @@ use sui_types::transaction::{TransactionData, TransactionDataAPI}; use test_cluster::{TestCluster, TestClusterBuilder}; use tracing::debug; -pub async fn start_sui_cluster(init_gas_amounts: Vec) -> (TestCluster, Arc) { +pub async fn start_sui_cluster(init_gas_amounts: Vec) -> (TestCluster, Arc) { let (sponsor, keypair) = get_account_key_pair(); let cluster = TestClusterBuilder::new() .with_accounts(vec![ @@ -37,7 +38,10 @@ pub async fn start_sui_cluster(init_gas_amounts: Vec) -> (TestCluster, Arc< ]) .build() .await; - (cluster, TestTxSigner::new(keypair.into())) + ( + cluster, + TxSigner::new(vec![InMemoryTxSigner::new(keypair.into())]), + ) } pub async fn start_gas_station( @@ -47,7 +51,7 @@ pub async fn start_gas_station( debug!("Starting Sui cluster.."); let (test_cluster, signer) = start_sui_cluster(init_gas_amounts).await; let fullnode_url = test_cluster.fullnode_handle.rpc_url.clone(); - let sponsor_address = signer.get_address(); + let sponsor_address = signer.get_one_address(); debug!("Starting storage. Sponsor address: {:?}", sponsor_address); let storage = connect_storage_for_testing(sponsor_address).await; let sui_client = SuiClient::new(&fullnode_url, None).await; diff --git a/src/tx_signer/in_memory_signer.rs b/src/tx_signer/in_memory_signer.rs new file mode 100644 index 0000000..160a32f --- /dev/null +++ b/src/tx_signer/in_memory_signer.rs @@ -0,0 +1,38 @@ +// Copyright (c) Mysten Labs, Inc. +// SPDX-License-Identifier: Apache-2.0 + +use std::sync::Arc; + +use shared_crypto::intent::{Intent, IntentMessage}; +use sui_types::base_types::SuiAddress; +use sui_types::crypto::{Signature, SuiKeyPair}; +use sui_types::signature::GenericSignature; +use sui_types::transaction::TransactionData; + +use super::TxSignerTrait; + +pub struct InMemoryTxSigner { + keypair: SuiKeyPair, +} + +impl InMemoryTxSigner { + pub fn new(keypair: SuiKeyPair) -> Arc { + Arc::new(Self { keypair }) + } +} + +#[async_trait::async_trait] +impl TxSignerTrait for InMemoryTxSigner { + async fn sign_transaction( + &self, + tx_data: &TransactionData, + ) -> anyhow::Result { + let intent_msg = IntentMessage::new(Intent::sui_transaction(), tx_data); + let sponsor_sig = Signature::new_secure(&intent_msg, &self.keypair).into(); + Ok(sponsor_sig) + } + + fn sui_address(&self) -> SuiAddress { + (&self.keypair.public()).into() + } +} diff --git a/src/tx_signer/mod.rs b/src/tx_signer/mod.rs new file mode 100644 index 0000000..055bf1a --- /dev/null +++ b/src/tx_signer/mod.rs @@ -0,0 +1,112 @@ +// Copyright (c) Mysten Labs, Inc. +// SPDX-License-Identifier: Apache-2.0 + +use std::collections::HashMap; +use std::sync::atomic::{self, AtomicUsize}; +use std::sync::Arc; +use sui_types::base_types::SuiAddress; +use sui_types::signature::GenericSignature; +use sui_types::transaction::{TransactionData, TransactionDataAPI}; + +pub mod in_memory_signer; +pub mod sidecar_signer; + +#[async_trait::async_trait] +pub trait TxSignerTrait: Send + Sync { + async fn sign_transaction(&self, tx_data: &TransactionData) + -> anyhow::Result; + fn sui_address(&self) -> SuiAddress; +} + +pub struct TxSigner { + signers: Vec>, + next_signer_idx: AtomicUsize, + address_index_map: HashMap, +} + +impl TxSigner { + pub fn new(signers: Vec>) -> Arc { + let address_index_map: HashMap<_, _> = signers + .iter() + .enumerate() + .map(|(i, s)| (s.sui_address(), i)) + .collect(); + Arc::new(Self { + signers, + next_signer_idx: AtomicUsize::new(0), + address_index_map, + }) + } + + pub fn get_all_addresses(&self) -> Vec { + self.signers.iter().map(|s| s.sui_address()).collect() + } + + pub fn is_valid_address(&self, address: &SuiAddress) -> bool { + self.address_index_map.contains_key(address) + } + + pub fn get_one_address(&self) -> SuiAddress { + let idx = self.next_signer_idx.fetch_add(1, atomic::Ordering::Relaxed); + self.signers[idx % self.signers.len()].sui_address() + } + + pub async fn sign_transaction( + &self, + tx_data: &TransactionData, + ) -> anyhow::Result { + let sponsor_address = tx_data.gas_data().owner; + let idx = *self + .address_index_map + .get(&sponsor_address) + .ok_or_else(|| anyhow::anyhow!("No signer found for address: {}", sponsor_address))?; + self.signers[idx].sign_transaction(tx_data).await + } +} + +#[cfg(test)] +mod tests { + use in_memory_signer::InMemoryTxSigner; + use sui_types::{ + crypto::get_account_key_pair, + programmable_transaction_builder::ProgrammableTransactionBuilder, + transaction::TransactionKind, + }; + + use super::*; + + #[tokio::test] + async fn test_multi_tx_signer() { + let (sender1, key1) = get_account_key_pair(); + let (sender2, key2) = get_account_key_pair(); + let (sender3, key3) = get_account_key_pair(); + let senders = vec![sender1, sender2, sender3]; + let signer1 = InMemoryTxSigner::new(key1.into()); + let signer2 = InMemoryTxSigner::new(key2.into()); + let signer3 = InMemoryTxSigner::new(key3.into()); + let tx_signer = TxSigner::new(vec![signer1, signer2, signer3]); + for sender in senders { + let tx_data = TransactionData::new_with_gas_coins( + TransactionKind::ProgrammableTransaction( + ProgrammableTransactionBuilder::new().finish(), + ), + sender, + vec![], + 0, + 0, + ); + tx_signer.sign_transaction(&tx_data).await.unwrap(); + } + let (sender4, _) = get_account_key_pair(); + let tx_data = TransactionData::new_with_gas_coins( + TransactionKind::ProgrammableTransaction( + ProgrammableTransactionBuilder::new().finish(), + ), + sender4, + vec![], + 0, + 0, + ); + assert!(tx_signer.sign_transaction(&tx_data).await.is_err()); + } +} diff --git a/src/tx_signer.rs b/src/tx_signer/sidecar_signer.rs similarity index 62% rename from src/tx_signer.rs rename to src/tx_signer/sidecar_signer.rs index 1c3ba8f..d5b08e1 100644 --- a/src/tx_signer.rs +++ b/src/tx_signer/sidecar_signer.rs @@ -6,23 +6,13 @@ use fastcrypto::encoding::{Base64, Encoding}; use reqwest::Client; use serde::Deserialize; use serde_json::json; -use shared_crypto::intent::{Intent, IntentMessage}; use std::str::FromStr; use std::sync::Arc; use sui_types::base_types::SuiAddress; -use sui_types::crypto::{Signature, SuiKeyPair}; use sui_types::signature::GenericSignature; use sui_types::transaction::TransactionData; -#[async_trait::async_trait] -pub trait TxSigner: Send + Sync { - async fn sign_transaction(&self, tx_data: &TransactionData) - -> anyhow::Result; - fn get_address(&self) -> SuiAddress; - fn is_valid_address(&self, address: &SuiAddress) -> bool { - self.get_address() == *address - } -} +use super::TxSignerTrait; #[derive(Deserialize)] #[serde(rename_all = "camelCase")] @@ -37,8 +27,8 @@ struct SuiAddressResponse { } pub struct SidecarTxSigner { - sidecar_url: String, client: Client, + sidecar_url: String, sui_address: SuiAddress, } @@ -46,7 +36,7 @@ impl SidecarTxSigner { pub async fn new(sidecar_url: String) -> Arc { let client = Client::new(); let resp = client - .get(format!("{}/{}", sidecar_url, "get-pubkey-address")) + .get(format!("{}/{}", &sidecar_url, "get-pubkey-address")) .send() .await .unwrap_or_else(|err| panic!("Failed to get pubkey address: {}", err)); @@ -56,15 +46,15 @@ impl SidecarTxSigner { .unwrap_or_else(|err| panic!("Failed to parse address response: {}", err)) .sui_pubkey_address; Arc::new(Self { - sidecar_url, client, + sidecar_url, sui_address, }) } } #[async_trait::async_trait] -impl TxSigner for SidecarTxSigner { +impl TxSignerTrait for SidecarTxSigner { async fn sign_transaction( &self, tx_data: &TransactionData, @@ -83,33 +73,7 @@ impl TxSigner for SidecarTxSigner { Ok(sig) } - fn get_address(&self) -> SuiAddress { + fn sui_address(&self) -> SuiAddress { self.sui_address } } - -pub struct TestTxSigner { - keypair: SuiKeyPair, -} - -impl TestTxSigner { - pub fn new(keypair: SuiKeyPair) -> Arc { - Arc::new(Self { keypair }) - } -} - -#[async_trait::async_trait] -impl TxSigner for TestTxSigner { - async fn sign_transaction( - &self, - tx_data: &TransactionData, - ) -> anyhow::Result { - let intent_msg = IntentMessage::new(Intent::sui_transaction(), tx_data); - let sponsor_sig = Signature::new_secure(&intent_msg, &self.keypair).into(); - Ok(sponsor_sig) - } - - fn get_address(&self) -> SuiAddress { - (&self.keypair.public()).into() - } -}