Skip to content

Commit

Permalink
Add mock iterator and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
p-offtermatt committed May 17, 2024
1 parent fbcbd72 commit 569c3b7
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 8 deletions.
48 changes: 48 additions & 0 deletions testutil/store/mock_iterator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package store

// MockIterator is a mock implementation of Iterator that simply iterates an underlying slice.
type MockIterator struct {
keys [][]byte
values [][]byte

curIndex int
}

func (MockIterator) Domain() ([]byte, []byte) {
panic("not implemented")
}

func (m MockIterator) Valid() bool {
return m.curIndex < len(m.keys)
}

func (m *MockIterator) Next() {
m.curIndex++
}

func (m MockIterator) Key() []byte {
return m.keys[m.curIndex]
}

func (m MockIterator) Value() []byte {
return m.values[m.curIndex]
}

func (m MockIterator) Error() error {
panic("not implemented")
}

func (m MockIterator) Close() error {
return nil
}

// NewMockIterator creates a new MockIterator.
func NewMockIterator(
keys [][]byte, values [][]byte,
) *MockIterator {
return &MockIterator{
keys: keys,
values: values,
curIndex: 0,
}
}
18 changes: 11 additions & 7 deletions x/ccv/provider/keeper/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,24 +224,28 @@ func (k Keeper) QueueVSCPackets(ctx sdk.Context) {
// TODO make this a param
maxTotalValidators := 500

validators := make([]stakingtypes.Validator, 0, maxTotalValidators)
allValidators := make([]stakingtypes.Validator, maxTotalValidators)
defer validatorIterator.Close()

i := 0
for ; validatorIterator.Valid() && i < int(maxTotalValidators); validatorIterator.Next() {
address := validatorIterator.Value()
address := sdk.ConsAddress(validatorIterator.Value())
validator, found := k.stakingKeeper.GetValidatorByConsAddr(ctx, address)
if !found {
k.Logger(ctx).Error("validator not found", "address", address.String())
continue
}

if validator.IsBonded() {
validators[i] = validator
i++
}
allValidators[i] = validator
i++
}

// truncate all validators
allValidators = allValidators[:i]

// get the bonded validators to compute the top N to opt in
bondedValidators := k.stakingKeeper.GetLastValidators(ctx)

for _, chain := range k.GetAllConsumerChains(ctx) {
currentValidators := k.GetConsumerValSet(ctx, chain.ChainId)

Expand All @@ -253,7 +257,7 @@ func (k Keeper) QueueVSCPackets(ctx sdk.Context) {
}
}

nextValidators := k.ComputeNextValidators(ctx, chain.ChainId, bondedValidators)
nextValidators := k.ComputeNextValidators(ctx, chain.ChainId, allValidators)

valUpdates := DiffValidators(currentValidators, nextValidators)
k.SetConsumerValSet(ctx, chain.ChainId, nextValidators)
Expand Down
29 changes: 28 additions & 1 deletion x/ccv/provider/keeper/relay_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/cosmos/interchain-security/v4/testutil/crypto"
cryptotestutil "github.com/cosmos/interchain-security/v4/testutil/crypto"
testkeeper "github.com/cosmos/interchain-security/v4/testutil/keeper"
teststore "github.com/cosmos/interchain-security/v4/testutil/store"
"github.com/cosmos/interchain-security/v4/x/ccv/provider/keeper"
"github.com/cosmos/interchain-security/v4/x/ccv/provider/types"
providertypes "github.com/cosmos/interchain-security/v4/x/ccv/provider/types"
Expand Down Expand Up @@ -71,6 +72,7 @@ func TestQueueVSCPackets(t *testing.T) {
defer ctrl.Finish()
mocks := testkeeper.NewMockedKeepers(ctrl)
mocks.MockStakingKeeper.EXPECT().GetLastValidators(ctx).Times(1)
mocks.MockStakingKeeper.EXPECT().ValidatorsPowerStoreIterator(ctx).Return(teststore.NewMockIterator([][]byte{}, [][]byte{})).Times(1)

pk := testkeeper.NewInMemProviderKeeper(keeperParams, mocks)
// no-op if tc.packets is empty
Expand Down Expand Up @@ -769,6 +771,9 @@ func TestEndBlockVSU(t *testing.T) {

mocks.MockStakingKeeper.EXPECT().GetLastValidators(gomock.Any()).Return(lastValidators).AnyTimes()

// set the mock to return the last validators we built also as the set of all validators
MockAllValidatorsAsLastValidators(mocks, lastValidators)

// set a sample client for a consumer chain so that `GetAllConsumerChains` in `QueueVSCPackets` iterates at least once
providerKeeper.SetConsumerClientId(ctx, chainID, "clientID")

Expand All @@ -795,6 +800,24 @@ func TestEndBlockVSU(t *testing.T) {
require.Equal(t, 1, len(providerKeeper.GetPendingVSCPackets(ctx, chainID)))
}

// MockAllValidatorsAsLastValidators mocks the staking keeper to return the given validators
// when prompted for the whole validator set.
// The mocks are set up to mock the calls made during the `EndBlockVSU` method.
func MockAllValidatorsAsLastValidators(mocks testkeeper.MockedKeepers, validators []stakingtypes.Validator) {
keySlice := make([][]byte, len(validators))
valsSlice := make([][]byte, len(validators))
for i, val := range validators {
keySlice[i] = []byte{}
valsSlice[i] = val.GetOperator()
}

mocks.MockStakingKeeper.EXPECT().ValidatorsPowerStoreIterator(gomock.Any()).Return(teststore.NewMockIterator(keySlice, valsSlice)).AnyTimes()

for i, val := range validators {
mocks.MockStakingKeeper.EXPECT().GetValidatorByConsAddr(gomock.Any(), sdk.ConsAddress(val.GetOperator())).Return(validators[i], true).AnyTimes()
}
}

// TestQueueVSCPacketsWithPowerCapping tests queueing validator set updates with power capping
func TestQueueVSCPacketsWithPowerCapping(t *testing.T) {
providerKeeper, ctx, ctrl, mocks := testkeeper.GetProviderKeeperAndCtx(t, testkeeper.NewInMemKeeperParams(t))
Expand Down Expand Up @@ -822,7 +845,11 @@ func TestQueueVSCPacketsWithPowerCapping(t *testing.T) {
valEPubKey, _ := valE.TmConsPublicKey()
mocks.MockStakingKeeper.EXPECT().GetValidatorByConsAddr(ctx, valEConsAddr).Return(valE, true).AnyTimes()

mocks.MockStakingKeeper.EXPECT().GetLastValidators(ctx).Return([]stakingtypes.Validator{valA, valB, valC, valD, valE}).AnyTimes()
vals := []stakingtypes.Validator{valA, valB, valC, valD, valE}
mocks.MockStakingKeeper.EXPECT().GetLastValidators(ctx).Return(vals).AnyTimes()

// set the mock to return the last validators we built also as the set of all validators
MockAllValidatorsAsLastValidators(mocks, vals)

// add a consumer chain
providerKeeper.SetConsumerClientId(ctx, "chainID", "clientID")
Expand Down

0 comments on commit 569c3b7

Please sign in to comment.