From 73fb892d24f715247cb24790a58945504818930b Mon Sep 17 00:00:00 2001 From: Philip Offtermatt <57488781+p-offtermatt@users.noreply.github.com> Date: Wed, 20 Dec 2023 10:27:26 +0100 Subject: [PATCH] test: Upgrade tests/mbt and testutil/simibc (#1527) Upgrade tests/mbt and testutil/simibc --- tests/mbt/driver/common.go | 4 ++- tests/mbt/driver/core.go | 34 +++++++++--------- tests/mbt/driver/mbt_test.go | 13 ++++--- tests/mbt/driver/setup.go | 49 +++++++++++-------------- testutil/simibc/chain_util.go | 68 ++++++++++++++++------------------- testutil/simibc/relay_util.go | 30 +++++++++------- 6 files changed, 98 insertions(+), 100 deletions(-) diff --git a/tests/mbt/driver/common.go b/tests/mbt/driver/common.go index 47d82833cc..1b648c6736 100644 --- a/tests/mbt/driver/common.go +++ b/tests/mbt/driver/common.go @@ -2,6 +2,8 @@ package main import ( sdk "github.com/cosmos/cosmos-sdk/types" + + "cosmossdk.io/math" ) const ( @@ -21,5 +23,5 @@ func getIndexOfString(s string, slice []string) int { func init() { // tokens === power - sdk.DefaultPowerReduction = sdk.NewInt(1) + sdk.DefaultPowerReduction = math.NewInt(1) } diff --git a/tests/mbt/driver/core.go b/tests/mbt/driver/core.go index 051f6c90ae..f86a1bf6c9 100644 --- a/tests/mbt/driver/core.go +++ b/tests/mbt/driver/core.go @@ -3,7 +3,6 @@ package main import ( "fmt" "log" - "math" "strings" "testing" "time" @@ -18,7 +17,6 @@ import ( stakingkeeper "github.com/cosmos/cosmos-sdk/x/staking/keeper" stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types" - abcitypes "github.com/cometbft/cometbft/abci/types" cmttypes "github.com/cometbft/cometbft/types" appConsumer "github.com/cosmos/interchain-security/v3/app/consumer" @@ -28,6 +26,10 @@ import ( consumertypes "github.com/cosmos/interchain-security/v3/x/ccv/consumer/types" providerkeeper "github.com/cosmos/interchain-security/v3/x/ccv/provider/keeper" providertypes "github.com/cosmos/interchain-security/v3/x/ccv/provider/types" + + "cosmossdk.io/math" + + gomath "math" ) // Define a new type for ChainIds to be more explicit @@ -126,15 +128,15 @@ func (s *Driver) consumerPower(i int64, chain ChainId) (int64, error) { // providerPower returns the power(=number of bonded tokens) of the i-th validator on the provider. func (s *Driver) providerPower(i int64) (int64, error) { - v, found := s.providerStakingKeeper().GetValidator(s.ctx(PROVIDER), s.validator(i)) - if !found { - return 0, fmt.Errorf("validator with id %v not found on provider", i) + v, err := s.providerStakingKeeper().GetValidator(s.ctx(PROVIDER), s.validator(i)) + if err != nil { + return 0, fmt.Errorf("validator with id %v not found on provider, error was %v", i, err) } else { return v.BondedTokens().Int64(), nil } } -func (s *Driver) providerValidatorSet() []stakingtypes.Validator { +func (s *Driver) providerValidatorSet() ([]stakingtypes.Validator, error) { return s.providerStakingKeeper().GetAllValidators(s.ctx(PROVIDER)) } @@ -146,9 +148,9 @@ func (s *Driver) consumerValidatorSet(chain ChainId) []consumertypes.CrossChainV func (s *Driver) delegate(val, amt int64) { providerStaking := s.providerStakingKeeper() server := stakingkeeper.NewMsgServerImpl(&providerStaking) - coin := sdk.NewCoin(sdk.DefaultBondDenom, sdk.NewInt(amt)) - d := s.delegator() - v := s.validator(val) + coin := sdk.NewCoin(sdk.DefaultBondDenom, math.NewInt(amt)) + d := s.delegator().String() + v := s.validator(val).String() msg := stakingtypes.NewMsgDelegate(d, v, coin) server.Delegate(sdk.WrapSDKContext(s.ctx(PROVIDER)), msg) } @@ -157,9 +159,9 @@ func (s *Driver) delegate(val, amt int64) { func (s *Driver) undelegate(val, amt int64) { providerStaking := s.providerStakingKeeper() server := stakingkeeper.NewMsgServerImpl(&providerStaking) - coin := sdk.NewCoin(sdk.DefaultBondDenom, sdk.NewInt(amt)) - d := s.delegator() - v := s.validator(val) + coin := sdk.NewCoin(sdk.DefaultBondDenom, math.NewInt(amt)) + d := s.delegator().String() + v := s.validator(val).String() msg := stakingtypes.NewMsgUndelegate(d, v, coin) server.Undelegate(sdk.WrapSDKContext(s.ctx(PROVIDER)), msg) } @@ -312,7 +314,7 @@ func (s *Driver) endAndBeginBlock(chain ChainId, timeAdvancement time.Duration) testChain, found := s.coordinator.Chains[string(chain)] require.True(s.t, found, "chain %s not found", chain) - header, packets := simibc.EndBlock(testChain, func() {}) + header, packets := simibc.FinalizeBlock(testChain, timeAdvancement) s.DriverStats.numSentPackets += len(packets) s.DriverStats.numBlocks += 1 @@ -339,7 +341,6 @@ func (s *Driver) endAndBeginBlock(chain ChainId, timeAdvancement time.Duration) } } - simibc.BeginBlock(testChain, timeAdvancement) return header } @@ -362,7 +363,6 @@ func (s *Driver) setTime(chain ChainId, newTime time.Time) { require.True(s.t, found, "chain %s not found", chain) testChain.CurrentHeader.Time = newTime - testChain.App.BeginBlock(abcitypes.RequestBeginBlock{Header: testChain.CurrentHeader}) } // DeliverPacketToConsumer delivers a packet from the provider to the given consumer recipient. @@ -384,8 +384,8 @@ func (s *Driver) DeliverPacketFromConsumer(sender ChainId, expectError bool) { func (s *Driver) DeliverAcks() { for _, chain := range s.runningConsumers() { path := s.path(ChainId(chain.ChainId)) - path.DeliverAcks(path.Path.EndpointA.Chain.ChainID, math.MaxInt) - path.DeliverAcks(path.Path.EndpointB.Chain.ChainID, math.MaxInt) + path.DeliverAcks(path.Path.EndpointA.Chain.ChainID, gomath.MaxInt) + path.DeliverAcks(path.Path.EndpointB.Chain.ChainID, gomath.MaxInt) } } diff --git a/tests/mbt/driver/mbt_test.go b/tests/mbt/driver/mbt_test.go index 9135e7cc0a..761aa60bca 100644 --- a/tests/mbt/driver/mbt_test.go +++ b/tests/mbt/driver/mbt_test.go @@ -387,7 +387,8 @@ func CompareValidatorSets(t *testing.T, driver *Driver, currentModelState map[st t.Helper() modelValSet := ValidatorSet(currentModelState, "provider") - rawActualValSet := driver.providerValidatorSet() + rawActualValSet, err := driver.providerValidatorSet() + require.NoError(t, err, "Error getting provider validator set") actualValSet := make(map[string]int64, len(rawActualValSet)) @@ -419,8 +420,8 @@ func CompareValidatorSets(t *testing.T, driver *Driver, currentModelState map[st } // get the validator for that address on the provider - providerVal, found := driver.providerStakingKeeper().GetValidatorByConsAddr(driver.providerCtx(), providerConsAddr.Address) - require.True(t, found, "Error getting provider validator") + providerVal, err := driver.providerStakingKeeper().GetValidatorByConsAddr(driver.providerCtx(), providerConsAddr.Address) + require.Nil(t, err, "Error getting provider validator") // use the moniker of that validator consumerCurValSet[providerVal.GetMoniker()] = val.Power @@ -571,7 +572,11 @@ func CompareSentPacketsOnProvider(driver *Driver, currentModelState map[string]i func (s *Stats) EnterStats(driver *Driver) { // highest observed voting power - for _, val := range driver.providerValidatorSet() { + valSet, err := driver.providerValidatorSet() + if err != nil { + log.Fatalf("error getting validator set on provider: %v", err) + } + for _, val := range valSet { if val.Tokens.Int64() > s.highestObservedValPower { s.highestObservedValPower = val.Tokens.Int64() } diff --git a/tests/mbt/driver/setup.go b/tests/mbt/driver/setup.go index 218fd2fd37..db9dd945b9 100644 --- a/tests/mbt/driver/setup.go +++ b/tests/mbt/driver/setup.go @@ -114,7 +114,7 @@ func getAppBytesAndSenders( bal := banktypes.Balance{ Address: acc.GetAddress().String(), Coins: sdk.NewCoins(sdk.NewCoin(sdk.DefaultBondDenom, - sdk.NewIntFromUint64(INITIAL_ACCOUNT_BALANCE))), + math.NewIntFromUint64(INITIAL_ACCOUNT_BALANCE))), } accounts = append(accounts, acc) @@ -137,21 +137,21 @@ func getAppBytesAndSenders( delegations := make([]stakingtypes.Delegation, 0, len(nodes)) // Sum bonded is needed for BondedPool account - sumBonded := sdk.NewInt(0) + sumBonded := math.NewInt(0) initValPowers := []abcitypes.ValidatorUpdate{} for i, val := range nodes { _, valSetVal := initialValSet.GetByAddress(val.Address.Bytes()) var tokens math.Int if valSetVal == nil { - tokens = sdk.NewInt(0) + tokens = math.NewInt(0) } else { - tokens = sdk.NewInt(valSetVal.VotingPower) + tokens = math.NewInt(valSetVal.VotingPower) } sumBonded = sumBonded.Add(tokens) - pk, err := cryptocodec.FromTmPubKeyInterface(val.PubKey) + pk, err := cryptocodec.FromCmtPubKeyInterface(val.PubKey) if err != nil { log.Panicf("error getting pubkey for val %v", val) } @@ -160,7 +160,7 @@ func getAppBytesAndSenders( log.Panicf("error getting pubkeyAny for val %v", val) } - delShares := sdk.NewDec(tokens.Int64()) // as many shares as tokens + delShares := math.LegacyNewDec(tokens.Int64()) // as many shares as tokens validator := stakingtypes.Validator{ OperatorAddress: sdk.ValAddress(val.Address).String(), @@ -174,14 +174,14 @@ func getAppBytesAndSenders( }, UnbondingHeight: int64(0), UnbondingTime: time.Unix(0, 0).UTC(), - Commission: stakingtypes.NewCommission(sdk.ZeroDec(), sdk.ZeroDec(), sdk.ZeroDec()), - MinSelfDelegation: sdk.ZeroInt(), + Commission: stakingtypes.NewCommission(math.LegacyZeroDec(), math.LegacyZeroDec(), math.LegacyZeroDec()), + MinSelfDelegation: math.ZeroInt(), } stakingValidators = append(stakingValidators, validator) // Store delegation from the model delegator account - delegations = append(delegations, stakingtypes.NewDelegation(senderAccounts[0].SenderAccount.GetAddress(), val.Address.Bytes(), delShares)) + delegations = append(delegations, stakingtypes.NewDelegation(senderAccounts[0].SenderAccount.GetAddress().String(), val.Address.String(), delShares)) // add initial validator powers so consumer InitGenesis runs correctly pub, _ := val.ToProto() @@ -224,7 +224,7 @@ func getAppBytesAndSenders( // add unbonded amount balances = append(balances, banktypes.Balance{ Address: authtypes.NewModuleAddress(stakingtypes.NotBondedPoolName).String(), - Coins: sdk.Coins{sdk.NewCoin(bondDenom, sdk.ZeroInt())}, + Coins: sdk.Coins{sdk.NewCoin(bondDenom, math.ZeroInt())}, }) // update total funds supply @@ -259,7 +259,7 @@ func newChain( protoConsParams := CONSENSUS_PARAMS.ToProto() app.InitChain( - abcitypes.RequestInitChain{ + &abcitypes.RequestInitChain{ ChainId: chainID, Validators: cmttypes.TM2PB.ValidatorUpdates(validators), ConsensusParams: &protoConsParams, @@ -269,20 +269,16 @@ func newChain( app.Commit() - app.BeginBlock( - abcitypes.RequestBeginBlock{ - Header: cmtproto.Header{ - ChainID: chainID, - Height: app.LastBlockHeight() + 1, - AppHash: app.LastCommitID().Hash, - ValidatorsHash: validators.Hash(), - NextValidatorsHash: validators.Hash(), - }, + app.FinalizeBlock( + &abcitypes.RequestFinalizeBlock{ + Hash: app.LastCommitID().Hash, + Height: app.LastBlockHeight() + 1, + NextValidatorsHash: validators.Hash(), }, ) chain := &ibctesting.TestChain{ - T: t, + TB: t, Coordinator: coord, ChainID: chainID, App: app, @@ -379,12 +375,8 @@ func (s *Driver) ConfigureNewPath(consumerChain, providerChain *ibctesting.TestC // Commit a block on both chains, giving us two committed headers from // the same time and height. This is the starting point for all our // data driven testing. - lastConsumerHeader, _ := simibc.EndBlock(consumerChain, func() {}) - lastProviderHeader, _ := simibc.EndBlock(providerChain, func() {}) - - // Get ready to update clients. - simibc.BeginBlock(providerChain, 5) - simibc.BeginBlock(consumerChain, 5) + lastConsumerHeader, _ := simibc.FinalizeBlock(consumerChain, 5) + lastProviderHeader, _ := simibc.FinalizeBlock(providerChain, 5) // Update clients to the latest header. err = simibc.UpdateReceiverClient(consumerEndPoint, providerEndPoint, lastConsumerHeader, false) @@ -420,8 +412,7 @@ func (s *Driver) setupProvider( s.providerKeeper().SetParams(s.ctx("provider"), providerParams) // produce a first block - simibc.EndBlock(providerChain, func() {}) - simibc.BeginBlock(providerChain, 0) + simibc.FinalizeBlock(providerChain, 0) } func (s *Driver) setupConsumer( diff --git a/testutil/simibc/chain_util.go b/testutil/simibc/chain_util.go index bfd45fa9ac..002ed0a8ab 100644 --- a/testutil/simibc/chain_util.go +++ b/testutil/simibc/chain_util.go @@ -6,6 +6,7 @@ import ( channeltypes "github.com/cosmos/ibc-go/v8/modules/core/04-channel/types" ibctmtypes "github.com/cosmos/ibc-go/v8/modules/light-clients/07-tendermint" ibctesting "github.com/cosmos/ibc-go/v8/testing" + "github.com/stretchr/testify/require" sdk "github.com/cosmos/cosmos-sdk/types" @@ -13,59 +14,52 @@ import ( tmproto "github.com/cometbft/cometbft/proto/tendermint/types" ) -// BeginBlock updates the current header and calls the app.BeginBlock method. -// The new block height is the previous block height + 1. -// The new block time is the previous block time + dt. +// FinalizeBlock calls app.FinalizeBlock and app.Commit. +// It sets the next block time to currentBlockTime + dt. +// This function returns the TMHeader of the block that was just ended, // // NOTE: this method may be used independently of the rest of simibc. -func BeginBlock(c *ibctesting.TestChain, dt time.Duration) { +func FinalizeBlock(c *ibctesting.TestChain, dt time.Duration) (*ibctmtypes.Header, []channeltypes.Packet) { + res, err := c.App.FinalizeBlock(&abci.RequestFinalizeBlock{ + Height: c.CurrentHeader.Height, + Time: c.CurrentHeader.GetTime(), + NextValidatorsHash: c.NextVals.Hash(), + }) + require.NoError(c.TB, err) + + _, err = c.App.Commit() + require.NoError(c.TB, err) + + // set the last header to the current header + // use nil trusted fields + c.LastHeader = c.CurrentTMClientHeader() + + // val set changes returned from previous block get applied to the next validators + // of this block. See tendermint spec for details. + c.Vals = c.NextVals + c.NextVals = ibctesting.ApplyValSetChanges(c, c.Vals, res.ValidatorUpdates) + + // increment the current header c.CurrentHeader = tmproto.Header{ ChainID: c.ChainID, Height: c.App.LastBlockHeight() + 1, AppHash: c.App.LastCommitID().Hash, - Time: c.CurrentHeader.Time.Add(dt), + Time: c.CurrentHeader.Time, ValidatorsHash: c.Vals.Hash(), NextValidatorsHash: c.NextVals.Hash(), + ProposerAddress: c.CurrentHeader.ProposerAddress, } - _ = c.App.BeginBlock(abci.RequestBeginBlock{Header: c.CurrentHeader}) -} - -// EndBlock calls app.EndBlock and executes preCommitCallback BEFORE calling app.Commit -// The callback is useful for testing purposes to execute arbitrary code before the -// chain sdk context is cleared in .Commit(). -// For example, app.EndBlock may lead to a new state, which you would like to query -// to check that it is correct. However, the sdk context is cleared after .Commit(), -// so you can query the state inside the callback. -// -// NOTE: this method may be used independently of the rest of simibc. -func EndBlock( - c *ibctesting.TestChain, - preCommitCallback func(), -) (*ibctmtypes.Header, []channeltypes.Packet) { - ebRes := c.App.EndBlock(abci.RequestEndBlock{Height: c.CurrentHeader.Height}) - - /* - It is useful to call arbitrary code after ending the block but before - committing the block because the sdk.Context is cleared after committing. - */ - - c.App.Commit() - - c.Vals = c.NextVals - - c.NextVals = ibctesting.ApplyValSetChanges(c.T, c.Vals, ebRes.ValidatorUpdates) - - c.LastHeader = c.CurrentTMClientHeader() + // set the new time + c.CurrentHeader.Time = c.CurrentHeader.Time.Add(dt) - sdkEvts := ABCIToSDKEvents(ebRes.Events) - packets := ParsePacketsFromEvents(sdkEvts) + packets := ParsePacketsFromEvents(res.Events) return c.LastHeader, packets } // ParsePacketsFromEvents returns all packets found in events. -func ParsePacketsFromEvents(events []sdk.Event) (packets []channeltypes.Packet) { +func ParsePacketsFromEvents(events []abci.Event) (packets []channeltypes.Packet) { for i, ev := range events { if ev.Type == channeltypes.EventTypeSendPacket { packet, err := ibctesting.ParsePacketFromEvents(events[i:]) diff --git a/testutil/simibc/relay_util.go b/testutil/simibc/relay_util.go index fefa2075ed..8fb93462c2 100644 --- a/testutil/simibc/relay_util.go +++ b/testutil/simibc/relay_util.go @@ -37,16 +37,18 @@ func UpdateReceiverClient(sender, receiver *ibctesting.Endpoint, header *ibctmty return err } - _, _, err = simapp.SignAndDeliver( - receiver.Chain.T, + _, err = simapp.SignAndDeliver( + receiver.Chain.TB, receiver.Chain.TxConfig, receiver.Chain.App.GetBaseApp(), - receiver.Chain.GetContext().BlockHeader(), []sdk.Msg{msg}, receiver.Chain.ChainID, []uint64{receiver.Chain.SenderAccount.GetAccountNumber()}, []uint64{receiver.Chain.SenderAccount.GetSequence()}, - true, !expectExpiration, receiver.Chain.SenderPrivKey, + !expectExpiration, + receiver.Chain.GetContext().BlockHeader().Time, + receiver.Chain.GetContext().BlockHeader().NextValidatorsHash, + receiver.Chain.SenderPrivKey, ) setSequenceErr := receiver.Chain.SenderAccount.SetSequence(receiver.Chain.SenderAccount.GetSequence() + 1) @@ -73,16 +75,18 @@ func TryRecvPacket(sender, receiver *ibctesting.Endpoint, packet channeltypes.Pa RPmsg := channeltypes.NewMsgRecvPacket(packet, proof, proofHeight, receiver.Chain.SenderAccount.GetAddress().String()) - _, resWithAck, err := simapp.SignAndDeliver( - receiver.Chain.T, + resWithAck, err := simapp.SignAndDeliver( + receiver.Chain.TB, receiver.Chain.TxConfig, receiver.Chain.App.GetBaseApp(), - receiver.Chain.GetContext().BlockHeader(), []sdk.Msg{RPmsg}, receiver.Chain.ChainID, []uint64{receiver.Chain.SenderAccount.GetAccountNumber()}, []uint64{receiver.Chain.SenderAccount.GetSequence()}, - true, !expectError, receiver.Chain.SenderPrivKey, + !expectError, + receiver.Chain.GetContext().BlockHeader().Time, + receiver.Chain.GetContext().BlockHeader().NextValidatorsHash, + receiver.Chain.SenderPrivKey, ) // need to set the sequence even if there was an error in delivery setSequenceErr := receiver.Chain.SenderAccount.SetSequence(receiver.Chain.SenderAccount.GetSequence() + 1) @@ -116,16 +120,18 @@ func TryRecvAck(sender, receiver *ibctesting.Endpoint, packet channeltypes.Packe ackMsg := channeltypes.NewMsgAcknowledgement(p, ack, proof, proofHeight, receiver.Chain.SenderAccount.GetAddress().String()) - _, _, err = simapp.SignAndDeliver( - receiver.Chain.T, + _, err = simapp.SignAndDeliver( + receiver.Chain.TB, receiver.Chain.TxConfig, receiver.Chain.App.GetBaseApp(), - receiver.Chain.GetContext().BlockHeader(), []sdk.Msg{ackMsg}, receiver.Chain.ChainID, []uint64{receiver.Chain.SenderAccount.GetAccountNumber()}, []uint64{receiver.Chain.SenderAccount.GetSequence()}, - true, true, receiver.Chain.SenderPrivKey, + true, + receiver.Chain.GetContext().BlockHeader().Time, + receiver.Chain.GetContext().BlockHeader().NextValidatorsHash, + receiver.Chain.SenderPrivKey, ) setSequenceErr := receiver.Chain.SenderAccount.SetSequence(receiver.Chain.SenderAccount.GetSequence() + 1)