Skip to content

Commit

Permalink
feat: threshold set as constructor for aggregation.cairo
Browse files Browse the repository at this point in the history
  • Loading branch information
JordyRo1 committed Jul 17, 2024
1 parent 8902554 commit d34a151
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 35 deletions.
18 changes: 6 additions & 12 deletions contracts/src/contracts/isms/aggregation/aggregation.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,16 @@ pub mod aggregation {
pub const THRESHOLD_NOT_SET: felt252 = 'Threshold not set';
pub const MODULES_ALREADY_STORED: felt252 = 'Modules already stored';
pub const NO_MODULES_PROVIDED: felt252 = 'No modules provided';
pub const THRESHOLD_TOO_HIGH: felt252 = 'Threshold too high';
}

#[constructor]
fn constructor(ref self: ContractState, _owner: ContractAddress, _modules: Span<felt252>) {
fn constructor(
ref self: ContractState, _owner: ContractAddress, _modules: Span<felt252>, _threshold: u8
) {
self.ownable.initializer(_owner);
assert(_threshold <= 255, Errors::THRESHOLD_TOO_HIGH);
self.threshold.write(_threshold);
self.set_modules(_modules);
}

Expand Down Expand Up @@ -130,17 +135,6 @@ pub mod aggregation {
fn get_threshold(self: @ContractState) -> u8 {
self.threshold.read()
}

/// Sets the threshold for validation
/// Dev: callable only by the owner
///
/// # Arguments
///
/// * - `_threshold` - The number of validator signatures needed
fn set_threshold(ref self: ContractState, _threshold: u8) {
self.ownable.assert_only_owner();
self.threshold.write(_threshold);
}
}
#[generate_trait]
impl InternalImpl of InternalTrait {
Expand Down
2 changes: 0 additions & 2 deletions contracts/src/interfaces.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,6 @@ pub trait IAggregation<TContractState> {
fn get_modules(self: @TContractState) -> Span<ContractAddress>;

fn get_threshold(self: @TContractState) -> u8;

fn set_threshold(ref self: TContractState, _threshold: u8);
}


Expand Down
28 changes: 8 additions & 20 deletions contracts/src/tests/isms/test_aggregation.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -20,39 +20,27 @@ use starknet::ContractAddress;

#[test]
fn test_aggregation_module_type() {
let aggregation = setup_aggregation(MODULES());
let threshold = 2;
let aggregation = setup_aggregation(MODULES(), threshold);
assert(
aggregation.module_type() == ModuleType::AGGREGATION(aggregation.contract_address),
'Aggregation: Wrong module type'
);
}

#[test]
fn test_aggregation_set_threshold() {
let threshold = 3;
let aggregation = setup_aggregation(MODULES());
let ownable = IOwnableDispatcher { contract_address: aggregation.contract_address };
start_prank(CheatTarget::One(ownable.contract_address), OWNER().try_into().unwrap());
aggregation.set_threshold(threshold);
}

#[test]
#[should_panic(expected: ('Threshold not set',))]
fn test_aggregation_verify_fails_if_treshold_not_set() {
let aggregation = setup_aggregation(MODULES());
aggregation.verify(BytesTrait::new(42, array![]), MessageTrait::default());
}

#[test]
#[should_panic]
fn test_setup_aggregation_with_null_module_address() {
let threshold = 2;
let modules: Span<felt252> = array![0, 'module_1'].span();
setup_aggregation(modules);
setup_aggregation(modules, threshold);
}

#[test]
fn test_get_modules() {
let aggregation = setup_aggregation(MODULES());
let threshold = 2;
let aggregation = setup_aggregation(MODULES(), threshold);
let ownable = IOwnableDispatcher { contract_address: aggregation.contract_address };
start_prank(CheatTarget::One(ownable.contract_address), OWNER().try_into().unwrap());
assert(aggregation.get_modules() == CONTRACT_MODULES(), 'set modules failed');
Expand Down Expand Up @@ -96,11 +84,11 @@ fn test_aggregation_verify() {
// Noop ism
let noop_ism = setup_noop_ism();
let aggregation = setup_aggregation(
array![messageid.contract_address.into(), noop_ism.contract_address.into(),].span()
array![messageid.contract_address.into(), noop_ism.contract_address.into(),].span(),
threshold
);
let ownable = IOwnableDispatcher { contract_address: aggregation.contract_address };
start_prank(CheatTarget::One(ownable.contract_address), OWNER().try_into().unwrap());
aggregation.set_threshold(threshold);
let mut concat_metadata = BytesTrait::new_empty();
concat_metadata.append_u128(0x00000010000001A0000001A0000001A9);
concat_metadata.concat(@message_id_metadata);
Expand Down
3 changes: 2 additions & 1 deletion contracts/src/tests/setup.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -298,12 +298,13 @@ pub fn setup_mock_validator_announce(
IMockValidatorAnnounceDispatcher { contract_address: validator_announce_addr }
}

pub fn setup_aggregation(modules: Span<felt252>) -> IAggregationDispatcher {
pub fn setup_aggregation(modules: Span<felt252>, threshold: u8) -> IAggregationDispatcher {
let aggregation_class = declare("aggregation").unwrap();
let mut parameters = Default::default();
let owner: felt252 = OWNER().try_into().unwrap();
Serde::serialize(@owner, ref parameters);
Serde::serialize(@modules, ref parameters);
Serde::serialize(@threshold, ref parameters);
let (aggregation_addr, _) = aggregation_class.deploy(@parameters).unwrap();
IAggregationDispatcher { contract_address: aggregation_addr }
}
Expand Down

0 comments on commit d34a151

Please sign in to comment.