From 569c3b72c35500e182ea0603a4c3db0f8315f312 Mon Sep 17 00:00:00 2001
From: Philip Offtermatt
Date: Fri, 17 May 2024 11:16:57 +0200
Subject: [PATCH] Add mock iterator and tests
---
testutil/store/mock_iterator.go | 48 +++++++++++++++++++++++++++++
x/ccv/provider/keeper/relay.go | 18 ++++++-----
x/ccv/provider/keeper/relay_test.go | 29 ++++++++++++++++-
3 files changed, 87 insertions(+), 8 deletions(-)
create mode 100644 testutil/store/mock_iterator.go
diff --git a/testutil/store/mock_iterator.go b/testutil/store/mock_iterator.go
new file mode 100644
index 0000000000..3f18bca64c
--- /dev/null
+++ b/testutil/store/mock_iterator.go
@@ -0,0 +1,48 @@
+package store
+
+// MockIterator is a mock implementation of Iterator that simply iterates an underlying slice.
+type MockIterator struct {
+ keys [][]byte
+ values [][]byte
+
+ curIndex int
+}
+
+func (MockIterator) Domain() ([]byte, []byte) {
+ panic("not implemented")
+}
+
+func (m MockIterator) Valid() bool {
+ return m.curIndex < len(m.keys)
+}
+
+func (m *MockIterator) Next() {
+ m.curIndex++
+}
+
+func (m MockIterator) Key() []byte {
+ return m.keys[m.curIndex]
+}
+
+func (m MockIterator) Value() []byte {
+ return m.values[m.curIndex]
+}
+
+func (m MockIterator) Error() error {
+ panic("not implemented")
+}
+
+func (m MockIterator) Close() error {
+ return nil
+}
+
+// NewMockIterator creates a new MockIterator.
+func NewMockIterator(
+ keys [][]byte, values [][]byte,
+) *MockIterator {
+ return &MockIterator{
+ keys: keys,
+ values: values,
+ curIndex: 0,
+ }
+}
diff --git a/x/ccv/provider/keeper/relay.go b/x/ccv/provider/keeper/relay.go
index 15dcfcbcb7..c5772a9687 100644
--- a/x/ccv/provider/keeper/relay.go
+++ b/x/ccv/provider/keeper/relay.go
@@ -224,24 +224,28 @@ func (k Keeper) QueueVSCPackets(ctx sdk.Context) {
// TODO make this a param
maxTotalValidators := 500
- validators := make([]stakingtypes.Validator, 0, maxTotalValidators)
+ allValidators := make([]stakingtypes.Validator, maxTotalValidators)
defer validatorIterator.Close()
i := 0
for ; validatorIterator.Valid() && i < int(maxTotalValidators); validatorIterator.Next() {
- address := validatorIterator.Value()
+ address := sdk.ConsAddress(validatorIterator.Value())
validator, found := k.stakingKeeper.GetValidatorByConsAddr(ctx, address)
if !found {
k.Logger(ctx).Error("validator not found", "address", address.String())
continue
}
- if validator.IsBonded() {
- validators[i] = validator
- i++
- }
+ allValidators[i] = validator
+ i++
}
+ // truncate all validators
+ allValidators = allValidators[:i]
+
+ // get the bonded validators to compute the top N to opt in
+ bondedValidators := k.stakingKeeper.GetLastValidators(ctx)
+
for _, chain := range k.GetAllConsumerChains(ctx) {
currentValidators := k.GetConsumerValSet(ctx, chain.ChainId)
@@ -253,7 +257,7 @@ func (k Keeper) QueueVSCPackets(ctx sdk.Context) {
}
}
- nextValidators := k.ComputeNextValidators(ctx, chain.ChainId, bondedValidators)
+ nextValidators := k.ComputeNextValidators(ctx, chain.ChainId, allValidators)
valUpdates := DiffValidators(currentValidators, nextValidators)
k.SetConsumerValSet(ctx, chain.ChainId, nextValidators)
diff --git a/x/ccv/provider/keeper/relay_test.go b/x/ccv/provider/keeper/relay_test.go
index 5f74d93424..1b155efd0b 100644
--- a/x/ccv/provider/keeper/relay_test.go
+++ b/x/ccv/provider/keeper/relay_test.go
@@ -23,6 +23,7 @@ import (
"github.com/cosmos/interchain-security/v4/testutil/crypto"
cryptotestutil "github.com/cosmos/interchain-security/v4/testutil/crypto"
testkeeper "github.com/cosmos/interchain-security/v4/testutil/keeper"
+ teststore "github.com/cosmos/interchain-security/v4/testutil/store"
"github.com/cosmos/interchain-security/v4/x/ccv/provider/keeper"
"github.com/cosmos/interchain-security/v4/x/ccv/provider/types"
providertypes "github.com/cosmos/interchain-security/v4/x/ccv/provider/types"
@@ -71,6 +72,7 @@ func TestQueueVSCPackets(t *testing.T) {
defer ctrl.Finish()
mocks := testkeeper.NewMockedKeepers(ctrl)
mocks.MockStakingKeeper.EXPECT().GetLastValidators(ctx).Times(1)
+ mocks.MockStakingKeeper.EXPECT().ValidatorsPowerStoreIterator(ctx).Return(teststore.NewMockIterator([][]byte{}, [][]byte{})).Times(1)
pk := testkeeper.NewInMemProviderKeeper(keeperParams, mocks)
// no-op if tc.packets is empty
@@ -769,6 +771,9 @@ func TestEndBlockVSU(t *testing.T) {
mocks.MockStakingKeeper.EXPECT().GetLastValidators(gomock.Any()).Return(lastValidators).AnyTimes()
+ // set the mock to return the last validators we built also as the set of all validators
+ MockAllValidatorsAsLastValidators(mocks, lastValidators)
+
// set a sample client for a consumer chain so that `GetAllConsumerChains` in `QueueVSCPackets` iterates at least once
providerKeeper.SetConsumerClientId(ctx, chainID, "clientID")
@@ -795,6 +800,24 @@ func TestEndBlockVSU(t *testing.T) {
require.Equal(t, 1, len(providerKeeper.GetPendingVSCPackets(ctx, chainID)))
}
+// MockAllValidatorsAsLastValidators mocks the staking keeper to return the given validators
+// when prompted for the whole validator set.
+// The mocks are set up to mock the calls made during the `EndBlockVSU` method.
+func MockAllValidatorsAsLastValidators(mocks testkeeper.MockedKeepers, validators []stakingtypes.Validator) {
+ keySlice := make([][]byte, len(validators))
+ valsSlice := make([][]byte, len(validators))
+ for i, val := range validators {
+ keySlice[i] = []byte{}
+ valsSlice[i] = val.GetOperator()
+ }
+
+ mocks.MockStakingKeeper.EXPECT().ValidatorsPowerStoreIterator(gomock.Any()).Return(teststore.NewMockIterator(keySlice, valsSlice)).AnyTimes()
+
+ for i, val := range validators {
+ mocks.MockStakingKeeper.EXPECT().GetValidatorByConsAddr(gomock.Any(), sdk.ConsAddress(val.GetOperator())).Return(validators[i], true).AnyTimes()
+ }
+}
+
// TestQueueVSCPacketsWithPowerCapping tests queueing validator set updates with power capping
func TestQueueVSCPacketsWithPowerCapping(t *testing.T) {
providerKeeper, ctx, ctrl, mocks := testkeeper.GetProviderKeeperAndCtx(t, testkeeper.NewInMemKeeperParams(t))
@@ -822,7 +845,11 @@ func TestQueueVSCPacketsWithPowerCapping(t *testing.T) {
valEPubKey, _ := valE.TmConsPublicKey()
mocks.MockStakingKeeper.EXPECT().GetValidatorByConsAddr(ctx, valEConsAddr).Return(valE, true).AnyTimes()
- mocks.MockStakingKeeper.EXPECT().GetLastValidators(ctx).Return([]stakingtypes.Validator{valA, valB, valC, valD, valE}).AnyTimes()
+ vals := []stakingtypes.Validator{valA, valB, valC, valD, valE}
+ mocks.MockStakingKeeper.EXPECT().GetLastValidators(ctx).Return(vals).AnyTimes()
+
+ // set the mock to return the last validators we built also as the set of all validators
+ MockAllValidatorsAsLastValidators(mocks, vals)
// add a consumer chain
providerKeeper.SetConsumerClientId(ctx, "chainID", "clientID")