diff --git a/testutil/store/mock_iterator.go b/testutil/store/mock_iterator.go new file mode 100644 index 0000000000..3f18bca64c --- /dev/null +++ b/testutil/store/mock_iterator.go @@ -0,0 +1,48 @@ +package store + +// MockIterator is a mock implementation of Iterator that simply iterates an underlying slice. +type MockIterator struct { + keys [][]byte + values [][]byte + + curIndex int +} + +func (MockIterator) Domain() ([]byte, []byte) { + panic("not implemented") +} + +func (m MockIterator) Valid() bool { + return m.curIndex < len(m.keys) +} + +func (m *MockIterator) Next() { + m.curIndex++ +} + +func (m MockIterator) Key() []byte { + return m.keys[m.curIndex] +} + +func (m MockIterator) Value() []byte { + return m.values[m.curIndex] +} + +func (m MockIterator) Error() error { + panic("not implemented") +} + +func (m MockIterator) Close() error { + return nil +} + +// NewMockIterator creates a new MockIterator. +func NewMockIterator( + keys [][]byte, values [][]byte, +) *MockIterator { + return &MockIterator{ + keys: keys, + values: values, + curIndex: 0, + } +} diff --git a/x/ccv/provider/keeper/relay.go b/x/ccv/provider/keeper/relay.go index 15dcfcbcb7..c5772a9687 100644 --- a/x/ccv/provider/keeper/relay.go +++ b/x/ccv/provider/keeper/relay.go @@ -224,24 +224,28 @@ func (k Keeper) QueueVSCPackets(ctx sdk.Context) { // TODO make this a param maxTotalValidators := 500 - validators := make([]stakingtypes.Validator, 0, maxTotalValidators) + allValidators := make([]stakingtypes.Validator, maxTotalValidators) defer validatorIterator.Close() i := 0 for ; validatorIterator.Valid() && i < int(maxTotalValidators); validatorIterator.Next() { - address := validatorIterator.Value() + address := sdk.ConsAddress(validatorIterator.Value()) validator, found := k.stakingKeeper.GetValidatorByConsAddr(ctx, address) if !found { k.Logger(ctx).Error("validator not found", "address", address.String()) continue } - if validator.IsBonded() { - validators[i] = validator - i++ - } + allValidators[i] = validator + i++ } + // truncate all validators + allValidators = allValidators[:i] + + // get the bonded validators to compute the top N to opt in + bondedValidators := k.stakingKeeper.GetLastValidators(ctx) + for _, chain := range k.GetAllConsumerChains(ctx) { currentValidators := k.GetConsumerValSet(ctx, chain.ChainId) @@ -253,7 +257,7 @@ func (k Keeper) QueueVSCPackets(ctx sdk.Context) { } } - nextValidators := k.ComputeNextValidators(ctx, chain.ChainId, bondedValidators) + nextValidators := k.ComputeNextValidators(ctx, chain.ChainId, allValidators) valUpdates := DiffValidators(currentValidators, nextValidators) k.SetConsumerValSet(ctx, chain.ChainId, nextValidators) diff --git a/x/ccv/provider/keeper/relay_test.go b/x/ccv/provider/keeper/relay_test.go index 5f74d93424..1b155efd0b 100644 --- a/x/ccv/provider/keeper/relay_test.go +++ b/x/ccv/provider/keeper/relay_test.go @@ -23,6 +23,7 @@ import ( "github.com/cosmos/interchain-security/v4/testutil/crypto" cryptotestutil "github.com/cosmos/interchain-security/v4/testutil/crypto" testkeeper "github.com/cosmos/interchain-security/v4/testutil/keeper" + teststore "github.com/cosmos/interchain-security/v4/testutil/store" "github.com/cosmos/interchain-security/v4/x/ccv/provider/keeper" "github.com/cosmos/interchain-security/v4/x/ccv/provider/types" providertypes "github.com/cosmos/interchain-security/v4/x/ccv/provider/types" @@ -71,6 +72,7 @@ func TestQueueVSCPackets(t *testing.T) { defer ctrl.Finish() mocks := testkeeper.NewMockedKeepers(ctrl) mocks.MockStakingKeeper.EXPECT().GetLastValidators(ctx).Times(1) + mocks.MockStakingKeeper.EXPECT().ValidatorsPowerStoreIterator(ctx).Return(teststore.NewMockIterator([][]byte{}, [][]byte{})).Times(1) pk := testkeeper.NewInMemProviderKeeper(keeperParams, mocks) // no-op if tc.packets is empty @@ -769,6 +771,9 @@ func TestEndBlockVSU(t *testing.T) { mocks.MockStakingKeeper.EXPECT().GetLastValidators(gomock.Any()).Return(lastValidators).AnyTimes() + // set the mock to return the last validators we built also as the set of all validators + MockAllValidatorsAsLastValidators(mocks, lastValidators) + // set a sample client for a consumer chain so that `GetAllConsumerChains` in `QueueVSCPackets` iterates at least once providerKeeper.SetConsumerClientId(ctx, chainID, "clientID") @@ -795,6 +800,24 @@ func TestEndBlockVSU(t *testing.T) { require.Equal(t, 1, len(providerKeeper.GetPendingVSCPackets(ctx, chainID))) } +// MockAllValidatorsAsLastValidators mocks the staking keeper to return the given validators +// when prompted for the whole validator set. +// The mocks are set up to mock the calls made during the `EndBlockVSU` method. +func MockAllValidatorsAsLastValidators(mocks testkeeper.MockedKeepers, validators []stakingtypes.Validator) { + keySlice := make([][]byte, len(validators)) + valsSlice := make([][]byte, len(validators)) + for i, val := range validators { + keySlice[i] = []byte{} + valsSlice[i] = val.GetOperator() + } + + mocks.MockStakingKeeper.EXPECT().ValidatorsPowerStoreIterator(gomock.Any()).Return(teststore.NewMockIterator(keySlice, valsSlice)).AnyTimes() + + for i, val := range validators { + mocks.MockStakingKeeper.EXPECT().GetValidatorByConsAddr(gomock.Any(), sdk.ConsAddress(val.GetOperator())).Return(validators[i], true).AnyTimes() + } +} + // TestQueueVSCPacketsWithPowerCapping tests queueing validator set updates with power capping func TestQueueVSCPacketsWithPowerCapping(t *testing.T) { providerKeeper, ctx, ctrl, mocks := testkeeper.GetProviderKeeperAndCtx(t, testkeeper.NewInMemKeeperParams(t)) @@ -822,7 +845,11 @@ func TestQueueVSCPacketsWithPowerCapping(t *testing.T) { valEPubKey, _ := valE.TmConsPublicKey() mocks.MockStakingKeeper.EXPECT().GetValidatorByConsAddr(ctx, valEConsAddr).Return(valE, true).AnyTimes() - mocks.MockStakingKeeper.EXPECT().GetLastValidators(ctx).Return([]stakingtypes.Validator{valA, valB, valC, valD, valE}).AnyTimes() + vals := []stakingtypes.Validator{valA, valB, valC, valD, valE} + mocks.MockStakingKeeper.EXPECT().GetLastValidators(ctx).Return(vals).AnyTimes() + + // set the mock to return the last validators we built also as the set of all validators + MockAllValidatorsAsLastValidators(mocks, vals) // add a consumer chain providerKeeper.SetConsumerClientId(ctx, "chainID", "clientID")