From d34a151424ab34fb16b59113b25a5a4298a3ecaa Mon Sep 17 00:00:00 2001 From: JordyRo1 Date: Wed, 17 Jul 2024 21:49:42 +0200 Subject: [PATCH] feat: threshold set as constructor for aggregation.cairo --- .../isms/aggregation/aggregation.cairo | 18 ++++-------- contracts/src/interfaces.cairo | 2 -- .../src/tests/isms/test_aggregation.cairo | 28 ++++++------------- contracts/src/tests/setup.cairo | 3 +- 4 files changed, 16 insertions(+), 35 deletions(-) diff --git a/contracts/src/contracts/isms/aggregation/aggregation.cairo b/contracts/src/contracts/isms/aggregation/aggregation.cairo index 6455cdc..56d5eae 100644 --- a/contracts/src/contracts/isms/aggregation/aggregation.cairo +++ b/contracts/src/contracts/isms/aggregation/aggregation.cairo @@ -46,11 +46,16 @@ pub mod aggregation { pub const THRESHOLD_NOT_SET: felt252 = 'Threshold not set'; pub const MODULES_ALREADY_STORED: felt252 = 'Modules already stored'; pub const NO_MODULES_PROVIDED: felt252 = 'No modules provided'; + pub const THRESHOLD_TOO_HIGH: felt252 = 'Threshold too high'; } #[constructor] - fn constructor(ref self: ContractState, _owner: ContractAddress, _modules: Span) { + fn constructor( + ref self: ContractState, _owner: ContractAddress, _modules: Span, _threshold: u8 + ) { self.ownable.initializer(_owner); + assert(_threshold <= 255, Errors::THRESHOLD_TOO_HIGH); + self.threshold.write(_threshold); self.set_modules(_modules); } @@ -130,17 +135,6 @@ pub mod aggregation { fn get_threshold(self: @ContractState) -> u8 { self.threshold.read() } - - /// Sets the threshold for validation - /// Dev: callable only by the owner - /// - /// # Arguments - /// - /// * - `_threshold` - The number of validator signatures needed - fn set_threshold(ref self: ContractState, _threshold: u8) { - self.ownable.assert_only_owner(); - self.threshold.write(_threshold); - } } #[generate_trait] impl InternalImpl of InternalTrait { diff --git a/contracts/src/interfaces.cairo b/contracts/src/interfaces.cairo index f1c9444..93c764a 100644 --- a/contracts/src/interfaces.cairo +++ b/contracts/src/interfaces.cairo @@ -276,8 +276,6 @@ pub trait IAggregation { fn get_modules(self: @TContractState) -> Span; fn get_threshold(self: @TContractState) -> u8; - - fn set_threshold(ref self: TContractState, _threshold: u8); } diff --git a/contracts/src/tests/isms/test_aggregation.cairo b/contracts/src/tests/isms/test_aggregation.cairo index f59ade0..9af797f 100644 --- a/contracts/src/tests/isms/test_aggregation.cairo +++ b/contracts/src/tests/isms/test_aggregation.cairo @@ -20,39 +20,27 @@ use starknet::ContractAddress; #[test] fn test_aggregation_module_type() { - let aggregation = setup_aggregation(MODULES()); + let threshold = 2; + let aggregation = setup_aggregation(MODULES(), threshold); assert( aggregation.module_type() == ModuleType::AGGREGATION(aggregation.contract_address), 'Aggregation: Wrong module type' ); } -#[test] -fn test_aggregation_set_threshold() { - let threshold = 3; - let aggregation = setup_aggregation(MODULES()); - let ownable = IOwnableDispatcher { contract_address: aggregation.contract_address }; - start_prank(CheatTarget::One(ownable.contract_address), OWNER().try_into().unwrap()); - aggregation.set_threshold(threshold); -} - -#[test] -#[should_panic(expected: ('Threshold not set',))] -fn test_aggregation_verify_fails_if_treshold_not_set() { - let aggregation = setup_aggregation(MODULES()); - aggregation.verify(BytesTrait::new(42, array![]), MessageTrait::default()); -} #[test] #[should_panic] fn test_setup_aggregation_with_null_module_address() { + let threshold = 2; let modules: Span = array![0, 'module_1'].span(); - setup_aggregation(modules); + setup_aggregation(modules, threshold); } #[test] fn test_get_modules() { - let aggregation = setup_aggregation(MODULES()); + let threshold = 2; + let aggregation = setup_aggregation(MODULES(), threshold); let ownable = IOwnableDispatcher { contract_address: aggregation.contract_address }; start_prank(CheatTarget::One(ownable.contract_address), OWNER().try_into().unwrap()); assert(aggregation.get_modules() == CONTRACT_MODULES(), 'set modules failed'); @@ -96,11 +84,11 @@ fn test_aggregation_verify() { // Noop ism let noop_ism = setup_noop_ism(); let aggregation = setup_aggregation( - array![messageid.contract_address.into(), noop_ism.contract_address.into(),].span() + array![messageid.contract_address.into(), noop_ism.contract_address.into(),].span(), + threshold ); let ownable = IOwnableDispatcher { contract_address: aggregation.contract_address }; start_prank(CheatTarget::One(ownable.contract_address), OWNER().try_into().unwrap()); - aggregation.set_threshold(threshold); let mut concat_metadata = BytesTrait::new_empty(); concat_metadata.append_u128(0x00000010000001A0000001A0000001A9); concat_metadata.concat(@message_id_metadata); diff --git a/contracts/src/tests/setup.cairo b/contracts/src/tests/setup.cairo index 5f224e4..81ff9d6 100644 --- a/contracts/src/tests/setup.cairo +++ b/contracts/src/tests/setup.cairo @@ -298,12 +298,13 @@ pub fn setup_mock_validator_announce( IMockValidatorAnnounceDispatcher { contract_address: validator_announce_addr } } -pub fn setup_aggregation(modules: Span) -> IAggregationDispatcher { +pub fn setup_aggregation(modules: Span, threshold: u8) -> IAggregationDispatcher { let aggregation_class = declare("aggregation").unwrap(); let mut parameters = Default::default(); let owner: felt252 = OWNER().try_into().unwrap(); Serde::serialize(@owner, ref parameters); Serde::serialize(@modules, ref parameters); + Serde::serialize(@threshold, ref parameters); let (aggregation_addr, _) = aggregation_class.deploy(@parameters).unwrap(); IAggregationDispatcher { contract_address: aggregation_addr } }