diff --git a/pallets/liquidity-pools-gateway/src/lib.rs b/pallets/liquidity-pools-gateway/src/lib.rs index 169c7ab601..7190b36424 100644 --- a/pallets/liquidity-pools-gateway/src/lib.rs +++ b/pallets/liquidity-pools-gateway/src/lib.rs @@ -73,6 +73,14 @@ pub enum InboundEntry { }, } +#[derive(Clone)] +pub struct InboundProcessingInfo { + domain_address: DomainAddress, + inbound_routers: BoundedVec, + current_session_id: T::SessionId, + expected_proof_count_per_message: u32, +} + #[frame_support::pallet] pub mod pallet { use super::*; @@ -141,9 +149,9 @@ pub mod pallet { Message = GatewayMessage, >; - /// Number of routers for a domain. + /// Maximum number of routers allowed for a domain. #[pallet::constant] - type MultiRouterCount: Get; + type MaxRouterCount: Get; /// Type for identifying sessions of inbound routers. type SessionId: Parameter @@ -177,13 +185,13 @@ pub mod pallet { /// The outbound routers for a given domain were set. OutboundRoutersSet { domain: Domain, - routers: BoundedVec, + routers: BoundedVec, }, /// Inbound routers were set. InboundRoutersSet { domain: Domain, - router_hashes: BoundedVec, + router_hashes: BoundedVec, }, } @@ -232,7 +240,7 @@ pub mod pallet { #[pallet::storage] #[pallet::getter(fn outbound_domain_routers)] pub type OutboundDomainRouters = - StorageMap<_, Blake2_128Concat, Domain, BoundedVec>; + StorageMap<_, Blake2_128Concat, Domain, BoundedVec>; /// Storage for pending inbound messages. #[pallet::storage] @@ -252,7 +260,7 @@ pub mod pallet { #[pallet::storage] #[pallet::getter(fn inbound_routers)] pub type InboundRouters = - StorageMap<_, Blake2_128Concat, Domain, BoundedVec>; + StorageMap<_, Blake2_128Concat, Domain, BoundedVec>; /// Storage for the session ID of an inbound domain. #[pallet::storage] @@ -462,13 +470,13 @@ pub mod pallet { pub fn set_outbound_routers( origin: OriginFor, domain: Domain, - routers: BoundedVec, + routers: BoundedVec, ) -> DispatchResult { T::AdminOrigin::ensure_origin(origin)?; ensure!(domain != Domain::Centrifuge, Error::::DomainNotSupported); ensure!( - routers.len() == T::MultiRouterCount::get() as usize, + routers.len() == T::MaxRouterCount::get() as usize, Error::::InvalidMultiRouter ); @@ -500,12 +508,12 @@ pub mod pallet { pub fn set_inbound_routers( origin: OriginFor, domain: Domain, - router_hashes: BoundedVec, + router_hashes: BoundedVec, ) -> DispatchResult { T::AdminOrigin::ensure_origin(origin)?; ensure!( - router_hashes.len() == T::MultiRouterCount::get() as usize, + router_hashes.len() == T::MaxRouterCount::get() as usize, Error::::InvalidMultiRouter ); @@ -543,8 +551,13 @@ pub mod pallet { impl Pallet { //TODO(cdamian): Use safe math - fn get_expected_message_proof_count() -> u32 { - T::MultiRouterCount::get() - 1 + fn get_expected_proof_count(domain: &Domain) -> Result { + let routers = + InboundRouters::::get(domain).ok_or(Error::::MultiRouterNotFound)?; + + let expected_proof_count = routers.len().ensure_sub(1)?; + + Ok(expected_proof_count as u32) } fn get_message_proof(message: T::Message) -> Proof { @@ -560,12 +573,13 @@ pub mod pallet { fn create_inbound_entry( domain_address: DomainAddress, message: T::Message, + expected_proof_count: u32, ) -> InboundEntry { match message.get_message_proof() { None => InboundEntry::Message { domain_address, message, - expected_proof_count: Self::get_expected_message_proof_count(), + expected_proof_count, }, Some(_) => InboundEntry::Proof { current_count: 1 }, } @@ -614,15 +628,21 @@ pub mod pallet { } } - fn update_storage_entry(old: &mut InboundEntry, new: InboundEntry) -> DispatchResult { + fn update_storage_entry( + domain: Domain, + old: &mut InboundEntry, + new: InboundEntry, + ) -> DispatchResult { match old { InboundEntry::Message { - expected_proof_count, + expected_proof_count: stored_expected_proof_count, .. } => match new { InboundEntry::Message { .. } => { - expected_proof_count - .ensure_add_assign(Self::get_expected_message_proof_count())?; + let expected_message_proof_count = Self::get_expected_proof_count(&domain)?; + + stored_expected_proof_count + .ensure_add_assign(expected_message_proof_count)?; Ok(()) } @@ -688,24 +708,26 @@ pub mod pallet { } fn validate_and_update_pending_entries( - session_id: T::SessionId, + inbound_processing_info: &InboundProcessingInfo, + message: T::Message, message_proof: Proof, router_hash: T::Hash, - domain_address: DomainAddress, - message: T::Message, weight: &mut Weight, ) -> DispatchResult { - let session_id = InboundDomainSessions::::get(domain_address.domain()) - .ok_or(Error::::InvalidMultiRouter)?; - - let message_proof = Self::get_message_proof(message.clone()); - - let inbound_entry = Self::create_inbound_entry(domain_address.clone(), message); + let inbound_entry = Self::create_inbound_entry( + inbound_processing_info.domain_address.clone(), + message, + inbound_processing_info.expected_proof_count_per_message, + ); - Self::validate_inbound_entry(domain_address.domain(), router_hash, &inbound_entry)?; + Self::validate_inbound_entry( + inbound_processing_info.domain_address.domain(), + router_hash, + &inbound_entry, + )?; Self::update_pending_entry( - session_id, + inbound_processing_info.current_session_id, message_proof, router_hash, inbound_entry, @@ -716,15 +738,17 @@ pub mod pallet { } fn get_executable_message( - inbound_routers: BoundedVec, - session_id: T::SessionId, + inbound_processing_info: &InboundProcessingInfo, message_proof: Proof, ) -> Option { let mut message = None; - let mut proof_count = 0; + let mut votes = 0; - for inbound_router in inbound_routers { - match PendingInboundEntries::::get(session_id, (message_proof, inbound_router)) { + for inbound_router in &inbound_processing_info.inbound_routers { + match PendingInboundEntries::::get( + inbound_processing_info.current_session_id, + (message_proof, inbound_router), + ) { // We expected one InboundEntry for each router, if that's not the case, // we can return. None => return None, @@ -735,14 +759,14 @@ pub mod pallet { } => message = Some(stored_message), InboundEntry::Proof { current_count } => { if current_count > 0 { - proof_count += 1; + votes += 1; } } }, }; } - if proof_count == Self::get_expected_message_proof_count() { + if votes == inbound_processing_info.expected_proof_count_per_message { return message; } @@ -750,13 +774,12 @@ pub mod pallet { } fn decrease_pending_entries_counts( - inbound_routers: BoundedVec, - session_id: T::SessionId, + inbound_processing_info: &InboundProcessingInfo, message_proof: Proof, ) -> DispatchResult { - for inbound_router in inbound_routers { + for inbound_router in &inbound_processing_info.inbound_routers { match PendingInboundEntries::::try_mutate( - session_id, + inbound_processing_info.current_session_id, (message_proof, inbound_router), |storage_entry| match storage_entry { // TODO(cdamian): Add new error @@ -766,8 +789,9 @@ pub mod pallet { expected_proof_count, .. } => { - let updated_count = (*expected_proof_count) - .ensure_sub(Self::get_expected_message_proof_count())?; + let updated_count = (*expected_proof_count).ensure_sub( + inbound_processing_info.expected_proof_count_per_message, + )?; if updated_count == 0 { *storage_entry = None; @@ -799,27 +823,47 @@ pub mod pallet { Ok(()) } + fn get_inbound_processing_info( + domain_address: DomainAddress, + weight: &mut Weight, + ) -> Result, DispatchError> { + let inbound_routers = + //TODO(cdamian): Add new error + InboundRouters::::get(domain_address.domain()).ok_or(Error::::InvalidMultiRouter)?; + + weight.saturating_accrue(T::DbWeight::get().reads(1)); + + let current_session_id = + //TODO(cdamian): Add new error + InboundDomainSessions::::get(domain_address.domain()).ok_or(Error::::InvalidMultiRouter)?; + + weight.saturating_accrue(T::DbWeight::get().reads(1)); + + let expected_proof_count = Self::get_expected_proof_count(&domain_address.domain())?; + + weight.saturating_accrue(T::DbWeight::get().reads(1)); + + Ok(InboundProcessingInfo { + domain_address, + inbound_routers, + current_session_id, + expected_proof_count_per_message: expected_proof_count, + }) + } + /// Give the message to the `InboundMessageHandler` to be processed. fn process_inbound_message( domain_address: DomainAddress, message: T::Message, router_hash: T::Hash, ) -> (DispatchResult, Weight) { - let mut weight = T::DbWeight::get().reads(1); + let mut weight = Default::default(); - let Some(inbound_routers) = InboundRouters::::get(domain_address.domain()) else { - //TODO(cdamian): Add new error - return (Err(Error::::InvalidMultiRouter.into()), weight); - }; - - if inbound_routers.len() == 0 {} - - let Some(session_id) = InboundDomainSessions::::get(domain_address.domain()) else { - //TODO(cdamian): Add error - return (Err(Error::::InvalidMultiRouter.into()), weight); - }; - - let message_proof = Self::get_message_proof(message.clone()); + let inbound_processing_info = + match Self::get_inbound_processing_info(domain_address.clone(), &mut weight) { + Ok(i) => i, + Err(e) => return (Err(e), weight), + }; weight.saturating_accrue( Weight::from_parts(0, T::Message::max_encoded_len() as u64) @@ -831,26 +875,22 @@ pub mod pallet { for submessage in message.submessages() { count += 1; + let message_proof = Self::get_message_proof(message.clone()); + if let Err(e) = Self::validate_and_update_pending_entries( - session_id, + &inbound_processing_info, + submessage.clone(), message_proof, router_hash, - domain_address.clone(), - submessage.clone(), &mut weight, ) { return (Err(e), weight); } - match Self::get_executable_message( - inbound_routers.clone(), - session_id, - message_proof, - ) { + match Self::get_executable_message(&inbound_processing_info, message_proof) { Some(m) => { if let Err(e) = Self::decrease_pending_entries_counts( - inbound_routers.clone(), - session_id, + &inbound_processing_info, message_proof, ) { return (Err(e), weight.saturating_mul(count)); diff --git a/pallets/liquidity-pools-gateway/src/mock.rs b/pallets/liquidity-pools-gateway/src/mock.rs index 6406b673f8..428d21427d 100644 --- a/pallets/liquidity-pools-gateway/src/mock.rs +++ b/pallets/liquidity-pools-gateway/src/mock.rs @@ -138,7 +138,7 @@ frame_support::parameter_types! { pub Sender: AccountId32 = AccountId32::from(H256::from_low_u64_be(1).to_fixed_bytes()); pub const MaxIncomingMessageSize: u32 = 1024; pub const LpAdminAccount: AccountId32 = LP_ADMIN_ACCOUNT; - pub const MultiRouterCount: u32 = 3; + pub const MaxRouterCount: u32 = 8; } impl pallet_liquidity_pools_gateway::Config for Runtime { @@ -146,6 +146,7 @@ impl pallet_liquidity_pools_gateway::Config for Runtime { type InboundMessageHandler = MockLiquidityPools; type LocalEVMOrigin = EnsureLocal; type MaxIncomingMessageSize = MaxIncomingMessageSize; + type MaxRouterCount = MaxRouterCount; type Message = Message; type MessageQueue = MockLiquidityPoolsGatewayQueue; type MultiRouterCount = MultiRouterCount; diff --git a/pallets/liquidity-pools-gateway/src/tests.rs b/pallets/liquidity-pools-gateway/src/tests.rs index 3c9f056071..a66cc89115 100644 --- a/pallets/liquidity-pools-gateway/src/tests.rs +++ b/pallets/liquidity-pools-gateway/src/tests.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; use cfg_mocks::*; +use cfg_primitives::LP_DEFENSIVE_WEIGHT; use cfg_traits::liquidity_pools::{LPEncoding, MessageProcessor, OutboundMessageHandler, Proof}; use cfg_types::domain_address::*; use frame_support::{