Skip to content

Commit

Permalink
refactor: use IterateLastValidatorPowers instead of GetLastValidators (
Browse files Browse the repository at this point in the history
…#1953)

* Add skeleton for GetLastValidators wrapper

* Fix unit tests

* Correct comment

* Log error messages if validators are not found

* Change AnyTimes to more specific Times(1)

* Instantiate slices with their max length and truncate

* Remove GetLastValidators from expectation

* Remove GetLastValidators call in consumer

* Move GetLastBondedValidators to validator_set_updates

* Add comment on iteration loop
  • Loading branch information
p-offtermatt authored Jun 12, 2024
1 parent 477cce2 commit 8955fcb
Show file tree
Hide file tree
Showing 20 changed files with 143 additions and 86 deletions.
11 changes: 6 additions & 5 deletions tests/integration/unbonding.go
Original file line number Diff line number Diff line change
Expand Up @@ -467,13 +467,14 @@ func (s *CCVTestSuite) TestRedelegationProviderFirst() {
// during the staking module EndBlock.
func (s *CCVTestSuite) TestTooManyLastValidators() {
sk := s.providerApp.GetTestStakingKeeper()
pk := s.providerApp.GetProviderKeeper()

// get current staking params
p := sk.GetParams(s.providerCtx())

// get validators, which are all active at the moment
vals := sk.GetAllValidators(s.providerCtx())
s.Require().Equal(len(vals), len(sk.GetLastValidators(s.providerCtx())))
s.Require().Equal(len(vals), len(pk.GetLastBondedValidators(s.providerCtx())))

// jail a validator
val := vals[0]
Expand All @@ -482,17 +483,17 @@ func (s *CCVTestSuite) TestTooManyLastValidators() {
sk.Jail(s.providerCtx(), consAddr)

// save the current number of bonded vals
lastVals := sk.GetLastValidators(s.providerCtx())
lastVals := pk.GetLastBondedValidators(s.providerCtx())

// pass one block to apply the validator set changes
// (calls ApplyAndReturnValidatorSetUpdates in the the staking module EndBlock)
s.providerChain.NextBlock()

// verify that the number of bonded validators is decreased by one
s.Require().Equal(len(lastVals)-1, len(sk.GetLastValidators(s.providerCtx())))
s.Require().Equal(len(lastVals)-1, len(pk.GetLastBondedValidators(s.providerCtx())))

// update maximum validator to equal the number of bonded validators
p.MaxValidators = uint32(len(sk.GetLastValidators(s.providerCtx())))
p.MaxValidators = uint32(len(pk.GetLastBondedValidators(s.providerCtx())))
sk.SetParams(s.providerCtx(), p)

// pass one block to apply validator set changes
Expand All @@ -508,5 +509,5 @@ func (s *CCVTestSuite) TestTooManyLastValidators() {
// ApplyAndReturnValidatorSetUpdates where the staking module has a inconsistent state
s.Require().NotPanics(s.providerChain.NextBlock)
s.Require().NotPanics(func() { sk.ApplyAndReturnValidatorSetUpdates(s.providerCtx()) })
s.Require().NotPanics(func() { sk.GetLastValidators(s.providerCtx()) })
s.Require().NotPanics(func() { pk.GetLastBondedValidators(s.providerCtx()) })
}
2 changes: 1 addition & 1 deletion testutil/ibc_testing/generic_setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func AddConsumer[Tp testutil.ProviderApp, Tc testutil.ConsumerApp](
prop.Top_N = consumerTopNParams[index] // isn't used in CreateConsumerClient

// opt-in all validators
for _, v := range providerApp.GetTestStakingKeeper().GetLastValidators(providerChain.GetContext()) {
for _, v := range providerKeeper.GetLastBondedValidators(providerChain.GetContext()) {
consAddr, _ := v.GetConsAddr()
providerKeeper.SetOptedIn(providerChain.GetContext(), chainID, providertypes.NewProviderConsAddress(consAddr))
}
Expand Down
33 changes: 33 additions & 0 deletions testutil/keeper/expectations.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,36 @@ func GetMocksForSlashValidator(
Times(1),
}
}

// SetupMocksForLastBondedValidatorsExpectation sets up the expectation for the `IterateLastValidatorPowers` `MaxValidators`, and `GetValidator` methods of the `mockStakingKeeper` object.
// These are needed in particular when calling `GetLastBondedValidators` from the provider keeper.
// Times is the number of times the expectation should be called. Provide -1 for `AnyTimes“.
func SetupMocksForLastBondedValidatorsExpectation(mockStakingKeeper *MockStakingKeeper, maxValidators uint32, vals []stakingtypes.Validator, powers []int64, times int) {
iteratorCall := mockStakingKeeper.EXPECT().IterateLastValidatorPowers(gomock.Any(), gomock.Any()).DoAndReturn(
func(ctx sdk.Context, cb func(sdk.ValAddress, int64) bool) {
for i, val := range vals {
if stop := cb(sdk.ValAddress(val.OperatorAddress), powers[i]); stop {
break
}
}
})
maxValidatorsCall := mockStakingKeeper.EXPECT().MaxValidators(gomock.Any()).Return(maxValidators)

if times == -1 {
iteratorCall.AnyTimes()
maxValidatorsCall.AnyTimes()
} else {
iteratorCall.Times(times)
maxValidatorsCall.Times(times)
}

// set up mocks for GetValidator calls
for _, val := range vals {
getValCall := mockStakingKeeper.EXPECT().GetValidator(gomock.Any(), sdk.ValAddress(val.OperatorAddress)).Return(val, true)
if times == -1 {
getValCall.AnyTimes()
} else {
getValCall.Times(times)
}
}
}
42 changes: 0 additions & 42 deletions testutil/keeper/mocks.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion testutil/keeper/unit_test_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,9 @@ func SetupForStoppingConsumerChain(t *testing.T, ctx sdk.Context,
providerKeeper *providerkeeper.Keeper, mocks MockedKeepers,
) {
t.Helper()
mocks.MockStakingKeeper.EXPECT().GetLastValidators(gomock.Any()).Times(1)

SetupMocksForLastBondedValidatorsExpectation(mocks.MockStakingKeeper, 1, []stakingtypes.Validator{}, []int64{}, 1)

expectations := GetMocksForCreateConsumerClient(ctx, &mocks,
"chainID", clienttypes.NewHeight(4, 5))
expectations = append(expectations, GetMocksForSetConsumerChain(ctx, &mocks, "chainID")...)
Expand Down
16 changes: 10 additions & 6 deletions x/ccv/consumer/keeper/changeover_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package keeper_test
import (
"testing"

"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"

sdkcryptocodec "github.com/cosmos/cosmos-sdk/crypto/codec"
Expand All @@ -30,6 +29,8 @@ func TestChangeoverToConsumer(t *testing.T) {
cIds[4].SDKStakingValidator(),
}

powers := []int64{55, 87324, 2, 42389479, 9089080}

// Instantiate 5 ics val updates for use in test
initialValUpdates := []abci.ValidatorUpdate{
{Power: 55, PubKey: cIds[5].TMProtoCryptoPublicKey()},
Expand All @@ -41,7 +42,7 @@ func TestChangeoverToConsumer(t *testing.T) {

testCases := []struct {
name string
// Last standalone validators that will be mock returned from stakingKeeper.GetLastValidators()
// Last standalone validators that will be mock returned from consumerKeeper.GetLastBondedValidators()
lastSovVals []stakingtypes.Validator
// Val updates corresponding to initial valset set for ccv set initGenesis
initialValUpdates []abci.ValidatorUpdate
Expand Down Expand Up @@ -100,10 +101,13 @@ func TestChangeoverToConsumer(t *testing.T) {
// Set initial valset, as would be done in InitGenesis
consumerKeeper.SetInitialValSet(ctx, tc.initialValUpdates)

// Setup mocked return value for stakingKeeper.GetLastValidators()
gomock.InOrder(
mocks.MockStakingKeeper.EXPECT().GetLastValidators(ctx).Return(tc.lastSovVals),
)
// Setup mocked return value for consumerkeeper.GetLastBondedValidators()
uthelpers.SetupMocksForLastBondedValidatorsExpectation(
mocks.MockStakingKeeper,
180, // max validators
tc.lastSovVals,
powers,
-1) // any times

// Add ref to standalone staking keeper
consumerKeeper.SetStandaloneStakingKeeper(mocks.MockStakingKeeper)
Expand Down
8 changes: 7 additions & 1 deletion x/ccv/consumer/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ func (k Keeper) GetLastStandaloneValidators(ctx sdk.Context) []stakingtypes.Vali
if !k.IsPreCCV(ctx) || k.standaloneStakingKeeper == nil {
panic("cannot get last standalone validators if not in pre-ccv state, or if standalone staking keeper is nil")
}
return k.standaloneStakingKeeper.GetLastValidators(ctx)
return k.GetLastBondedValidators(ctx)
}

// GetElapsedPacketMaturityTimes returns a slice of already elapsed PacketMaturityTimes, sorted by maturity times,
Expand Down Expand Up @@ -690,3 +690,9 @@ func (k Keeper) IsPrevStandaloneChain(ctx sdk.Context) bool {
store := ctx.KVStore(k.storeKey)
return store.Has(types.PrevStandaloneChainKey())
}

// GetLastBondedValidators iterates the last validator powers in the staking module
// and returns the first MaxValidators many validators with the largest powers.
func (k Keeper) GetLastBondedValidators(ctx sdk.Context) []stakingtypes.Validator {
return ccv.GetLastBondedValidatorsUtil(ctx, k.standaloneStakingKeeper, k.Logger(ctx))
}
11 changes: 7 additions & 4 deletions x/ccv/consumer/keeper/keeper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,14 @@ func TestGetLastSovereignValidators(t *testing.T) {
cId1 := crypto.NewCryptoIdentityFromIntSeed(11)
val := cId1.SDKStakingValidator()
val.Description.Moniker = "sanity check this is the correctly serialized val"
gomock.InOrder(
mocks.MockStakingKeeper.EXPECT().GetLastValidators(ctx).Return([]stakingtypes.Validator{
val,
}),
testkeeper.SetupMocksForLastBondedValidatorsExpectation(
mocks.MockStakingKeeper,
180,
[]stakingtypes.Validator{val},
[]int64{1000},
1,
)

lastSovVals := ck.GetLastStandaloneValidators(ctx)
require.Equal(t, []stakingtypes.Validator{val}, lastSovVals)
require.Equal(t, "sanity check this is the correctly serialized val",
Expand Down
6 changes: 3 additions & 3 deletions x/ccv/provider/keeper/grpc_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func (k Keeper) GetConsumerChain(ctx sdk.Context, chainID string) (types.Chain,
// Get MinPowerInTop_N
var minPowerInTopN int64
if found && topN > 0 {
res, err := k.ComputeMinPowerToOptIn(ctx, k.stakingKeeper.GetLastValidators(ctx), topN)
res, err := k.ComputeMinPowerToOptIn(ctx, k.GetLastBondedValidators(ctx), topN)
if err != nil {
return types.Chain{}, fmt.Errorf("failed to compute min power to opt in for chain (%s): %w", chainID, err)
}
Expand Down Expand Up @@ -381,7 +381,7 @@ func (k Keeper) hasToValidate(
}

// if the validator was not part of the last epoch, check if the validator is going to be part of te next epoch
bondedValidators := k.stakingKeeper.GetLastValidators(ctx)
bondedValidators := k.GetLastBondedValidators(ctx)
if topN, found := k.GetTopN(ctx, chainID); found && topN > 0 {
// in a Top-N chain, we automatically opt in all validators that belong to the top N
minPower, err := k.ComputeMinPowerToOptIn(ctx, bondedValidators, topN)
Expand All @@ -395,7 +395,7 @@ func (k Keeper) hasToValidate(
// if the validator is opted in and belongs to the validators of the next epoch, then if nothing changes
// the validator would have to validate in the next epoch
if k.IsOptedIn(ctx, chainID, provAddr) {
nextValidators := k.ComputeNextValidators(ctx, chainID, k.stakingKeeper.GetLastValidators(ctx))
nextValidators := k.ComputeNextValidators(ctx, chainID, bondedValidators)
for _, v := range nextValidators {
consAddr := sdk.ConsAddress(v.ProviderConsAddr)
if provAddr.ToSdkConsAddr().Equals(consAddr) {
Expand Down
12 changes: 7 additions & 5 deletions x/ccv/provider/keeper/grpc_query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ func TestQueryConsumerChainsValidatorHasToValidate(t *testing.T) {
valConsAddr, _ := val.GetConsAddr()
providerAddr := types.NewProviderConsAddress(valConsAddr)
mocks.MockStakingKeeper.EXPECT().GetValidatorByConsAddr(ctx, valConsAddr).Return(val, true).AnyTimes()
mocks.MockStakingKeeper.EXPECT().GetLastValidators(ctx).Return([]stakingtypes.Validator{val}).AnyTimes()
testkeeper.SetupMocksForLastBondedValidatorsExpectation(mocks.MockStakingKeeper, 1, []stakingtypes.Validator{val}, []int64{1}, -1) // -1 to allow the calls "AnyTimes"

req := types.QueryConsumerChainsValidatorHasToValidateRequest{
ProviderAddress: providerAddr.String(),
Expand All @@ -203,14 +203,15 @@ func TestQueryConsumerChainsValidatorHasToValidate(t *testing.T) {
ConsumerPublicKey: &crypto.PublicKey{
Sum: &crypto.PublicKey_Ed25519{
Ed25519: []byte{1},
}}})
},
},
})

// set `providerAddr` as an opted-in validator on "chain3"
pk.SetOptedIn(ctx, "chain3", providerAddr)

// `providerAddr` has to validate "chain1" because it is a consumer validator in this chain, as well as "chain3"
// because it opted in, in "chain3" and `providerAddr` belongs to the bonded validators (see the mocking of `GetLastValidators`
// above)
// because it opted in, in "chain3" and `providerAddr` belongs to the bonded validators
expectedChains := []string{"chain1", "chain3"}

res, err := pk.QueryConsumerChainsValidatorHasToValidate(ctx, &req)
Expand Down Expand Up @@ -268,7 +269,8 @@ func TestGetConsumerChain(t *testing.T) {
{OperatorAddress: "cosmosvaloper1tflk30mq5vgqjdly92kkhhq3raev2hnz6eete3"}, // 500 power
}
powers := []int64{50, 150, 300, 500} // sum = 1000
mocks.MockStakingKeeper.EXPECT().GetLastValidators(gomock.Any()).Return(vals).AnyTimes()
maxValidators := uint32(180)
testkeeper.SetupMocksForLastBondedValidatorsExpectation(mocks.MockStakingKeeper, maxValidators, vals, powers, -1) // -1 to allow the calls "AnyTimes"

for i, val := range vals {
mocks.MockStakingKeeper.EXPECT().GetLastValidatorPower(gomock.Any(), val.GetOperator()).Return(powers[i]).AnyTimes()
Expand Down
3 changes: 1 addition & 2 deletions x/ccv/provider/keeper/partial_set_security.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ func (k Keeper) HandleOptOut(ctx sdk.Context, chainID string, providerAddr types
"validator with consensus address %s could not be found", providerAddr.ToSdkConsAddr())
}
power := k.stakingKeeper.GetLastValidatorPower(ctx, validator.GetOperator())
minPowerToOptIn, err := k.ComputeMinPowerToOptIn(ctx, k.stakingKeeper.GetLastValidators(ctx), topN)

minPowerToOptIn, err := k.ComputeMinPowerToOptIn(ctx, k.GetLastBondedValidators(ctx), topN)
if err != nil {
k.Logger(ctx).Error("failed to compute min power to opt in for chain", "chain", chainID, "error", err)
return errorsmod.Wrapf(
Expand Down
2 changes: 1 addition & 1 deletion x/ccv/provider/keeper/partial_set_security_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ func TestHandleOptOutFromTopNChain(t *testing.T) {
valDConsAddr, _ := valD.GetConsAddr()
mocks.MockStakingKeeper.EXPECT().GetValidatorByConsAddr(ctx, valDConsAddr).Return(valD, true).AnyTimes()

mocks.MockStakingKeeper.EXPECT().GetLastValidators(ctx).Return([]stakingtypes.Validator{valA, valB, valC, valD}).AnyTimes()
testkeeper.SetupMocksForLastBondedValidatorsExpectation(mocks.MockStakingKeeper, 4, []stakingtypes.Validator{valA, valB, valC, valD}, []int64{1, 2, 3, 4}, -1) // -1 to allow mocks AnyTimes

// opt in all validators
providerKeeper.SetOptedIn(ctx, chainID, types.NewProviderConsAddress(valAConsAddr))
Expand Down
2 changes: 1 addition & 1 deletion x/ccv/provider/keeper/proposal.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ func (k Keeper) MakeConsumerGenesis(
}

// get the bonded validators from the staking module
bondedValidators := k.stakingKeeper.GetLastValidators(ctx)
bondedValidators := k.GetLastBondedValidators(ctx)

if topN, found := k.GetTopN(ctx, chainID); found && topN > 0 {
// in a Top-N chain, we automatically opt in all validators that belong to the top N
Expand Down
Loading

0 comments on commit 8955fcb

Please sign in to comment.