From 67af2945f30905ecc52c1b2fddc273920b66d551 Mon Sep 17 00:00:00 2001 From: Jake Hartnell Date: Sat, 16 Sep 2023 20:58:35 -0700 Subject: [PATCH] Reuse validation logic, and fix bug that slipped through audit : ) --- .../dao-voting-cw721-staked/src/contract.rs | 56 ++++++++++--------- .../dao-voting-cw721-staked/src/error.rs | 17 ++---- .../src/testing/tests.rs | 16 ++++-- 3 files changed, 46 insertions(+), 43 deletions(-) diff --git a/contracts/voting/dao-voting-cw721-staked/src/contract.rs b/contracts/voting/dao-voting-cw721-staked/src/contract.rs index c3834f4a9..22f477573 100644 --- a/contracts/voting/dao-voting-cw721-staked/src/contract.rs +++ b/contracts/voting/dao-voting-cw721-staked/src/contract.rs @@ -1,8 +1,8 @@ #[cfg(not(feature = "library"))] use cosmwasm_std::entry_point; use cosmwasm_std::{ - from_binary, to_binary, Addr, Binary, CosmosMsg, Decimal, Deps, DepsMut, Empty, Env, - MessageInfo, Reply, Response, StdError, StdResult, SubMsg, Uint128, Uint256, WasmMsg, + from_binary, to_binary, Addr, Binary, CosmosMsg, Deps, DepsMut, Empty, Env, MessageInfo, Reply, + Response, StdError, StdResult, SubMsg, Uint128, Uint256, WasmMsg, }; use cw2::{get_contract_version, set_contract_version, ContractVersion}; use cw721::{Cw721QueryMsg, Cw721ReceiveMsg, NumTokensResponse}; @@ -10,7 +10,10 @@ use cw_storage_plus::Bound; use cw_utils::{parse_reply_execute_data, parse_reply_instantiate_data, Duration}; use dao_hooks::nft_stake::{stake_nft_hook_msgs, unstake_nft_hook_msgs}; use dao_interface::{nft::NftFactoryCallback, voting::IsActiveResponse}; -use dao_voting::threshold::{ActiveThreshold, ActiveThresholdResponse}; +use dao_voting::threshold::{ + assert_valid_absolute_count_threshold, assert_valid_percentage_threshold, ActiveThreshold, + ActiveThresholdResponse, +}; use crate::msg::{ExecuteMsg, InstantiateMsg, MigrateMsg, NftContract, QueryMsg}; use crate::state::{ @@ -74,24 +77,21 @@ pub fn instantiate( if let Some(active_threshold) = msg.active_threshold.as_ref() { match active_threshold { ActiveThreshold::Percentage { percent } => { - if percent > &Decimal::percent(100) || percent.is_zero() { - return Err(ContractError::InvalidActivePercentage {}); - } + assert_valid_percentage_threshold(*percent)?; } ActiveThreshold::AbsoluteCount { count } => { - // Check Absolute count is not zero - if count.is_zero() { - return Err(ContractError::ZeroActiveCount {}); - } - - // Check Absolute count is less than the supply of NFTs for existing NFT contracts + // Check Absolute count is less than the supply of NFTs for existing + // NFT contracts. For new NFT contracts, we will check this in the reply. if let NftContract::Existing { ref address } = msg.nft_contract { let nft_supply: NumTokensResponse = deps .querier .query_wasm_smart(address, &Cw721QueryMsg::NumTokens {})?; - if count > &Uint128::new(nft_supply.count.into()) { - return Err(ContractError::InvalidActiveCount {}); - } + // Check the absolute count is less than the supply of NFTs and + // greater than zero. + assert_valid_absolute_count_threshold( + *count, + Uint128::new(nft_supply.count.into()), + )?; } } } @@ -441,17 +441,20 @@ pub fn execute_update_active_threshold( return Err(ContractError::Unauthorized {}); } + let config = CONFIG.load(deps.storage)?; if let Some(active_threshold) = new_active_threshold { match active_threshold { ActiveThreshold::Percentage { percent } => { - if percent > Decimal::percent(100) || percent.is_zero() { - return Err(ContractError::InvalidActivePercentage {}); - } + assert_valid_percentage_threshold(percent)?; } ActiveThreshold::AbsoluteCount { count } => { - if count.is_zero() { - return Err(ContractError::ZeroActiveCount {}); - } + let nft_supply: NumTokensResponse = deps + .querier + .query_wasm_smart(config.nft_address, &Cw721QueryMsg::NumTokens {})?; + assert_valid_absolute_count_threshold( + count, + Uint128::new(nft_supply.count.into()), + )?; } } ACTIVE_THRESHOLD.save(deps.storage, &active_threshold)?; @@ -708,14 +711,15 @@ pub fn reply(deps: DepsMut, _env: Env, msg: Reply) -> Result Uint128::new(supply.count.into()) { - return Err(ContractError::InvalidActiveCount {}); - } + // Check the count is not greater than supply and is not zero + assert_valid_absolute_count_threshold( + count, + Uint128::new(nft_supply.count.into()), + )?; } Ok(Response::new()) } diff --git a/contracts/voting/dao-voting-cw721-staked/src/error.rs b/contracts/voting/dao-voting-cw721-staked/src/error.rs index 017041216..ec5af5eeb 100644 --- a/contracts/voting/dao-voting-cw721-staked/src/error.rs +++ b/contracts/voting/dao-voting-cw721-staked/src/error.rs @@ -1,5 +1,6 @@ use cosmwasm_std::{Addr, StdError}; use cw_utils::ParseReplyError; +use dao_voting::threshold::ActiveThresholdError; use thiserror::Error; #[derive(Error, Debug, PartialEq)] @@ -8,19 +9,16 @@ pub enum ContractError { Std(#[from] StdError), #[error(transparent)] - ParseReplyError(#[from] ParseReplyError), - - #[error("Can not stake that which has already been staked")] - AlreadyStaked {}, + ActiveThresholdError(#[from] ActiveThresholdError), #[error(transparent)] HookError(#[from] cw_hooks::HookError), - #[error("Active threshold count is greater than supply")] - InvalidActiveCount {}, + #[error(transparent)] + ParseReplyError(#[from] ParseReplyError), - #[error("Active threshold percentage must be greater than 0 and less than 1")] - InvalidActivePercentage {}, + #[error("Can not stake that which has already been staked")] + AlreadyStaked {}, #[error("Invalid token. Got ({received}), expected ({expected})")] InvalidToken { received: Addr, expected: Addr }, @@ -55,9 +53,6 @@ pub enum ContractError { #[error("Factory message must serialize to WasmMsg::Execute")] UnsupportedFactoryMsg {}, - #[error("Active threshold count must be greater than zero")] - ZeroActiveCount {}, - #[error("Can't unstake zero NFTs.")] ZeroUnstake {}, } diff --git a/contracts/voting/dao-voting-cw721-staked/src/testing/tests.rs b/contracts/voting/dao-voting-cw721-staked/src/testing/tests.rs index df3aeb22f..db7978a39 100644 --- a/contracts/voting/dao-voting-cw721-staked/src/testing/tests.rs +++ b/contracts/voting/dao-voting-cw721-staked/src/testing/tests.rs @@ -452,7 +452,7 @@ fn test_instantiate_zero_active_threshold_count() { } #[test] -#[should_panic(expected = "Active threshold count is greater than supply")] +#[should_panic(expected = "Absolute count threshold cannot be greater than the total token supply")] fn test_instantiate_invalid_active_threshold_count_new_nft() { let mut app = App::default(); let cw721_id = app.store_code(cw721_base_contract()); @@ -492,7 +492,7 @@ fn test_instantiate_invalid_active_threshold_count_new_nft() { } #[test] -#[should_panic(expected = "Active threshold count is greater than supply")] +#[should_panic(expected = "Absolute count threshold cannot be greater than the total token supply")] fn test_instantiate_invalid_active_threshold_count_existing_nft() { let mut app = App::default(); let module_id = app.store_code(voting_cw721_staked_contract()); @@ -805,7 +805,7 @@ fn test_update_active_threshold() { let msg = ExecuteMsg::UpdateActiveThreshold { new_threshold: Some(ActiveThreshold::AbsoluteCount { - count: Uint128::new(100), + count: Uint128::new(1), }), }; @@ -829,13 +829,15 @@ fn test_update_active_threshold() { assert_eq!( resp.active_threshold, Some(ActiveThreshold::AbsoluteCount { - count: Uint128::new(100) + count: Uint128::new(1) }) ); } #[test] -#[should_panic(expected = "Active threshold percentage must be greater than 0 and less than 1")] +#[should_panic( + expected = "Active threshold percentage must be greater than 0 and not greater than 1" +)] fn test_active_threshold_percentage_gt_100() { let mut app = App::default(); let cw721_id = app.store_code(cw721_base_contract()); @@ -875,7 +877,9 @@ fn test_active_threshold_percentage_gt_100() { } #[test] -#[should_panic(expected = "Active threshold percentage must be greater than 0 and less than 1")] +#[should_panic( + expected = "Active threshold percentage must be greater than 0 and not greater than 1" +)] fn test_active_threshold_percentage_lte_0() { let mut app = App::default(); let cw721_id = app.store_code(cw721_base_contract());