Skip to content

Commit

Permalink
fix!: improve message validation (#1460)
Browse files Browse the repository at this point in the history
* validate ValidatorSetChangePacketData

* update ValidateBasic for ValidatorSetChangePacketData

* update ConsumerPacketData validation

* fix TestConsumerPacketSendExpiredClient

* update TestOnRecvSlashPacketErrors

* fix TestQueueAndSendSlashPacket

* remove TODO

* nit: validate MsgAssignConsumerKey

* add changelog entries

* fix linter

* fix gosec

* rename ValidateBasic to Validate (IBC packets)

* Update x/ccv/types/wire.go

Co-authored-by: Simon Noetzlin <[email protected]>

* revert SlashPacketData address validation

---------

Co-authored-by: Simon Noetzlin <[email protected]>
  • Loading branch information
2 people authored and MSalopek committed Dec 1, 2023
1 parent b6c8f31 commit 0d82030
Show file tree
Hide file tree
Showing 15 changed files with 372 additions and 221 deletions.
3 changes: 3 additions & 0 deletions .changelog/unreleased/state-breaking/1460-msg-validation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
- Improve validation of IBC packet data and provider messages. Also,
enable the provider to validate consumer packets before handling them.
([\#1460](https://github.com/cosmos/interchain-security/pull/1460))
2 changes: 1 addition & 1 deletion tests/integration/expired_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ func (s *CCVTestSuite) TestConsumerPacketSendExpiredClient() {

// try to send slash packet for downtime infraction
addr := ed25519.GenPrivKey().PubKey().Address()
val := abci.Validator{Address: addr}
val := abci.Validator{Address: addr, Power: 1}
consumerKeeper.QueueSlashPacket(s.consumerCtx(), val, 2, stakingtypes.Infraction_INFRACTION_DOWNTIME)
// try to send slash packet for the same downtime infraction
consumerKeeper.QueueSlashPacket(s.consumerCtx(), val, 3, stakingtypes.Infraction_INFRACTION_DOWNTIME)
Expand Down
169 changes: 73 additions & 96 deletions tests/integration/slashing.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

cryptocodec "github.com/cosmos/cosmos-sdk/crypto/codec"
sdk "github.com/cosmos/cosmos-sdk/types"
sdkaddress "github.com/cosmos/cosmos-sdk/types/address"
evidencetypes "github.com/cosmos/cosmos-sdk/x/evidence/types"
slashingtypes "github.com/cosmos/cosmos-sdk/x/slashing/types"
stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types"
Expand Down Expand Up @@ -269,12 +270,14 @@ func (s *CCVTestSuite) TestSlashPacketAcknowledgement() {
// Map infraction height on provider so validation passes and provider returns valid ack result
providerKeeper.SetValsetUpdateBlockHeight(s.providerCtx(), spd.ValsetUpdateId, 47923)

exportedAck := providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, spd)
s.Require().NotNil(exportedAck)
ackResult, err := providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, spd)
s.Require().NotNil(ackResult)
s.Require().NoError(err)
exportedAck := channeltypes.NewResultAcknowledgement(ackResult)

// Unmarshal ack to struct that's compatible with consumer. IBC does this automatically
ack := channeltypes.Acknowledgement{}
err := channeltypes.SubModuleCdc.UnmarshalJSON(exportedAck.Acknowledgement(), &ack)
err = channeltypes.SubModuleCdc.UnmarshalJSON(exportedAck.Acknowledgement(), &ack)
s.Require().NoError(err)

err = consumerKeeper.OnAcknowledgementPacket(s.consumerCtx(), packet, ack)
Expand Down Expand Up @@ -325,12 +328,11 @@ func (suite *CCVTestSuite) TestHandleSlashPacketDowntime() {
suite.Require().Equal(suite.providerCtx().BlockTime().Add(jailDuration), signingInfo.JailedUntil)
}

// TODO: had conflicts
// TestOnRecvSlashPacketErrors tests errors for the OnRecvSlashPacket method in an integration testing setting
func (suite *CCVTestSuite) TestOnRecvSlashPacketErrors() {
providerKeeper := suite.providerApp.GetProviderKeeper()
providerSlashingKeeper := suite.providerApp.GetTestSlashingKeeper()
firstBundle := suite.getFirstBundle()
consumerChainID := firstBundle.Chain.ChainID

suite.SetupAllCCVChannels()

Expand All @@ -339,106 +341,80 @@ func (suite *CCVTestSuite) TestOnRecvSlashPacketErrors() {

// Expect panic if ccv channel is not established via dest channel of packet
suite.Panics(func() {
providerKeeper.OnRecvSlashPacket(ctx, channeltypes.Packet{}, ccv.SlashPacketData{})
_, _ = providerKeeper.OnRecvSlashPacket(ctx, channeltypes.Packet{}, ccv.SlashPacketData{})
})

// Add correct channelID to packet. Now we will not panic anymore.
packet := channeltypes.Packet{DestinationChannel: firstBundle.Path.EndpointB.ChannelID}
suite.NotPanics(func() {
_, _ = providerKeeper.OnRecvSlashPacket(ctx, packet, ccv.SlashPacketData{})
})

// Init chain height is set by established CCV channel
// Delete init chain height and confirm expected error
initChainHeight, found := providerKeeper.GetInitChainHeight(ctx, consumerChainID)
suite.Require().True(found)
providerKeeper.DeleteInitChainHeight(ctx, consumerChainID)

packetData := ccv.SlashPacketData{ValsetUpdateId: 0}
errAck := providerKeeper.OnRecvSlashPacket(ctx, packet, packetData)
suite.Require().False(errAck.Success())
errAckCast := errAck.(channeltypes.Acknowledgement)
// Error strings in err acks are now thrown out by IBC core to prevent app hash.
// Hence a generic error string is expected.
suite.Require().Equal("ABCI code: 1: error handling packet: see events for details", errAckCast.GetError())

// Restore init chain height
providerKeeper.SetInitChainHeight(ctx, consumerChainID, initChainHeight)

// now the method will fail at infraction height check.
packetData.Infraction = stakingtypes.Infraction_INFRACTION_UNSPECIFIED
errAck = providerKeeper.OnRecvSlashPacket(ctx, packet, packetData)
suite.Require().False(errAck.Success())
errAckCast = errAck.(channeltypes.Acknowledgement)
suite.Require().Equal("ABCI code: 1: error handling packet: see events for details", errAckCast.GetError())

// save current VSC ID
vscID := providerKeeper.GetValidatorSetUpdateId(ctx)

// remove block height value mapped to current VSC ID
providerKeeper.DeleteValsetUpdateBlockHeight(ctx, vscID)

// Instantiate packet data with current VSC ID
packetData = ccv.SlashPacketData{ValsetUpdateId: vscID}

// expect an error if mapped block height is not found
errAck = providerKeeper.OnRecvSlashPacket(ctx, packet, packetData)
suite.Require().False(errAck.Success())
errAckCast = errAck.(channeltypes.Acknowledgement)
suite.Require().Equal("ABCI code: 1: error handling packet: see events for details", errAckCast.GetError())

// construct slashing packet with non existing validator
slashingPkt := ccv.NewSlashPacketData(
// Check Validate for SlashPacket data
validAddress := ed25519.GenPrivKey().PubKey().Address()
slashPacketData := ccv.NewSlashPacketData(
abci.Validator{
Address: ed25519.GenPrivKey().PubKey().Address(),
Power: int64(0),
Address: validAddress,
Power: int64(1),
}, uint64(0), stakingtypes.Infraction_INFRACTION_DOWNTIME,
)

// Set initial block height for consumer chain
providerKeeper.SetInitChainHeight(ctx, consumerChainID, uint64(ctx.BlockHeight()))

// Expect no error ack if validator does not exist
// TODO: this behavior should be changed to return an error ack,
// see: https://github.com/cosmos/interchain-security/issues/546
ack := providerKeeper.OnRecvSlashPacket(ctx, packet, *slashingPkt)
suite.Require().True(ack.Success())

val := suite.providerChain.Vals.Validators[0]

// commit block to set VSC ID
suite.coordinator.CommitBlock(suite.providerChain)
// Update suite.ctx bc CommitBlock updates only providerChain's current header block height
ctx = suite.providerChain.GetContext()
suite.Require().NotZero(providerKeeper.GetValsetUpdateBlockHeight(ctx, vscID))

// create validator signing info
valInfo := slashingtypes.NewValidatorSigningInfo(sdk.ConsAddress(val.Address), ctx.BlockHeight(),
ctx.BlockHeight()-1, time.Time{}.UTC(), false, int64(0))
providerSlashingKeeper.SetValidatorSigningInfo(ctx, sdk.ConsAddress(val.Address), valInfo)

// update validator address and VSC ID
slashingPkt.Validator.Address = val.Address
slashingPkt.ValsetUpdateId = vscID

// expect error ack when infraction type in unspecified
tmAddr := suite.providerChain.Vals.Validators[1].Address
slashingPkt.Validator.Address = tmAddr
slashingPkt.Infraction = stakingtypes.Infraction_INFRACTION_UNSPECIFIED

valInfo.Address = sdk.ConsAddress(tmAddr).String()
providerSlashingKeeper.SetValidatorSigningInfo(ctx, sdk.ConsAddress(tmAddr), valInfo)

errAck = providerKeeper.OnRecvSlashPacket(ctx, packet, *slashingPkt)
suite.Require().False(errAck.Success())

// Expect nothing was queued
suite.Require().Equal(0, len(providerKeeper.GetAllGlobalSlashEntries(ctx)))
suite.Require().Equal(uint64(0), (providerKeeper.GetThrottledPacketDataSize(ctx, consumerChainID)))

// expect to queue entries for the slash request
slashingPkt.Infraction = stakingtypes.Infraction_INFRACTION_DOWNTIME
ack = providerKeeper.OnRecvSlashPacket(ctx, packet, *slashingPkt)
suite.Require().True(ack.Success())
suite.Require().Equal(1, len(providerKeeper.GetAllGlobalSlashEntries(ctx)))
suite.Require().Equal(uint64(1), (providerKeeper.GetThrottledPacketDataSize(ctx, consumerChainID)))
// Expect an error if validator address is too long
slashPacketData.Validator.Address = make([]byte, sdkaddress.MaxAddrLen+1)
_, err := providerKeeper.OnRecvSlashPacket(ctx, packet, *slashPacketData)
suite.Require().Error(err, "validating SlashPacket data should fail - invalid validator address")

// Expect an error if validator power is zero
slashPacketData.Validator.Address = validAddress
slashPacketData.Validator.Power = 0
_, err = providerKeeper.OnRecvSlashPacket(ctx, packet, *slashPacketData)
suite.Require().Error(err, "validating SlashPacket data should fail - invalid validator power")

// Expect an error if the infraction type is unspecified
slashPacketData.Validator.Power = 1
slashPacketData.Infraction = stakingtypes.Infraction_INFRACTION_UNSPECIFIED
_, err = providerKeeper.OnRecvSlashPacket(ctx, packet, *slashPacketData)
suite.Require().Error(err, "validating SlashPacket data should fail - invalid infraction type")

// Restore slashPacketData to be valid
slashPacketData.Infraction = stakingtypes.Infraction_INFRACTION_DOWNTIME

// Check ValidateSlashPacket
// Expect an error if a mapping of the infraction height cannot be found;
// just set the vscID of the slash packet to the latest mapped vscID +1
valsetUpdateBlockHeights := providerKeeper.GetAllValsetUpdateBlockHeights(ctx)
latestMappedValsetUpdateId := valsetUpdateBlockHeights[len(valsetUpdateBlockHeights)-1].ValsetUpdateId
slashPacketData.ValsetUpdateId = latestMappedValsetUpdateId + 1
_, err = providerKeeper.OnRecvSlashPacket(ctx, packet, *slashPacketData)
suite.Require().Error(err, "ValidateSlashPacket should fail - no infraction height mapping")

// Restore slashPacketData to be valid
slashPacketData.ValsetUpdateId = latestMappedValsetUpdateId

// Expect no error if validator does not exist
_, err = providerKeeper.OnRecvSlashPacket(ctx, packet, *slashPacketData)
suite.Require().NoError(err, "no error expected")

// Check expected behavior for handling SlashPackets for double signing infractions
slashPacketData.Infraction = stakingtypes.Infraction_INFRACTION_DOUBLE_SIGN
ackResult, err := providerKeeper.OnRecvSlashPacket(ctx, packet, *slashPacketData)
suite.Require().NoError(err, "no error expected")
suite.Require().Equal(ccv.V1Result, ackResult, "expected successful ack")

// Check expected behavior for handling SlashPackets for downtime infractions
slashPacketData.Infraction = stakingtypes.Infraction_INFRACTION_DOWNTIME

// Expect the packet to bounce if the slash meter is negative
providerKeeper.SetSlashMeter(ctx, sdk.NewInt(-1))
ackResult, err = providerKeeper.OnRecvSlashPacket(ctx, packet, *slashPacketData)
suite.Require().NoError(err, "no error expected")
suite.Require().Equal(ccv.SlashPacketBouncedResult, ackResult, "expected successful ack")

// Expect the packet to be handled if the slash meter is positive
providerKeeper.SetSlashMeter(ctx, sdk.NewInt(0))
ackResult, err = providerKeeper.OnRecvSlashPacket(ctx, packet, *slashPacketData)
suite.Require().NoError(err, "no error expected")
suite.Require().Equal(ccv.SlashPacketHandledResult, ackResult, "expected successful ack")
}

// TestValidatorDowntime tests if a slash packet is sent
Expand Down Expand Up @@ -654,6 +630,7 @@ func (suite *CCVTestSuite) TestQueueAndSendSlashPacket() {
addr := ed25519.GenPrivKey().PubKey().Address()
val := abci.Validator{
Address: addr,
Power: int64(1),
}
consumerKeeper.QueueSlashPacket(ctx, val, 0, infraction)
slashedVals = append(slashedVals, slashedVal{validator: val, infraction: infraction})
Expand Down
9 changes: 6 additions & 3 deletions tests/integration/throttle.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,8 @@ func (s *CCVTestSuite) TestPacketSpam() {
consumerPacketData, err := provider.UnmarshalConsumerPacketData(data) // Same func used by provider's OnRecvPacket
s.Require().NoError(err)
packet := s.newPacketFromConsumer(data, uint64(sequence), firstBundle.Path, timeoutHeight, timeoutTimestamp)
providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketData.GetSlashPacketData())
_, err = providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketData.GetSlashPacketData())
s.Require().NoError(err)
}

// Execute block to handle packets in endblock
Expand Down Expand Up @@ -416,7 +417,8 @@ func (s *CCVTestSuite) TestDoubleSignDoesNotAffectThrottling() {
consumerPacketData, err := provider.UnmarshalConsumerPacketData(data) // Same func used by provider's OnRecvPacket
s.Require().NoError(err)
packet := s.newPacketFromConsumer(data, uint64(sequence), firstBundle.Path, timeoutHeight, timeoutTimestamp)
providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketData.GetSlashPacketData())
_, err = providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketData.GetSlashPacketData())
s.Require().NoError(err)
}

// Execute block to handle packets in endblock
Expand Down Expand Up @@ -806,7 +808,8 @@ func (s CCVTestSuite) TestSlashAllValidators() { //nolint:govet // this is a tes
consumerPacketData, err := provider.UnmarshalConsumerPacketData(data) // Same func used by provider's OnRecvPacket
s.Require().NoError(err)
packet := s.newPacketFromConsumer(data, ibcSeqNum, s.getFirstBundle().Path, timeoutHeight, timeoutTimestamp)
providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketData.GetSlashPacketData())
_, err = providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketData.GetSlashPacketData())
s.Require().NoError(err)
}

// We should have 24 pending slash packet entries queued.
Expand Down
15 changes: 6 additions & 9 deletions tests/integration/valset_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,8 @@ func (suite *CCVTestSuite) TestQueueAndSendVSCMaturedPackets() {

// send first packet
packet := suite.newPacketFromProvider(pd.GetBytes(), 1, suite.path, clienttypes.NewHeight(1, 0), 0)
ack := consumerKeeper.OnRecvVSCPacket(suite.consumerChain.GetContext(), packet, pd)
suite.Require().NotNil(ack, "OnRecvVSCPacket did not return ack")
suite.Require().True(ack.Success(), "OnRecvVSCPacket did not return a Success Acknowledgment")
err = consumerKeeper.OnRecvVSCPacket(suite.consumerChain.GetContext(), packet, pd)
suite.Require().Nil(err, "OnRecvVSCPacket did return non-nil error")

// increase time
incrementTime(suite, time.Hour)
Expand All @@ -83,9 +82,8 @@ func (suite *CCVTestSuite) TestQueueAndSendVSCMaturedPackets() {
pd.ValsetUpdateId = 2
packet.Data = pd.GetBytes()
packet.Sequence = 2
ack = consumerKeeper.OnRecvVSCPacket(suite.consumerChain.GetContext(), packet, pd)
suite.Require().NotNil(ack, "OnRecvVSCPacket did not return ack")
suite.Require().True(ack.Success(), "OnRecvVSCPacket did not return a Success Acknowledgment")
err = consumerKeeper.OnRecvVSCPacket(suite.consumerChain.GetContext(), packet, pd)
suite.Require().Nil(err, "OnRecvVSCPacket did return non-nil error")

// increase time
incrementTime(suite, 24*time.Hour)
Expand All @@ -95,9 +93,8 @@ func (suite *CCVTestSuite) TestQueueAndSendVSCMaturedPackets() {
pd.ValsetUpdateId = 3
packet.Data = pd.GetBytes()
packet.Sequence = 3
ack = consumerKeeper.OnRecvVSCPacket(suite.consumerChain.GetContext(), packet, pd)
suite.Require().NotNil(ack, "OnRecvVSCPacket did not return ack")
suite.Require().True(ack.Success(), "OnRecvVSCPacket did not return a Success Acknowledgment")
err = consumerKeeper.OnRecvVSCPacket(suite.consumerChain.GetContext(), packet, pd)
suite.Require().Nil(err, "OnRecvVSCPacket did return non-nil error")

packetMaturities := consumerKeeper.GetAllPacketMaturityTimes(suite.consumerChain.GetContext())

Expand Down
44 changes: 34 additions & 10 deletions x/ccv/consumer/ibc_module.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package consumer

import (
"fmt"
"strconv"
"strings"

transfertypes "github.com/cosmos/ibc-go/v7/modules/apps/transfer/types"
Expand Down Expand Up @@ -225,25 +226,48 @@ func (am AppModule) OnRecvPacket(
packet channeltypes.Packet,
_ sdk.AccAddress,
) ibcexported.Acknowledgement {
var (
ack ibcexported.Acknowledgement
data types.ValidatorSetChangePacketData
)
logger := am.keeper.Logger(ctx)
ack := channeltypes.NewResultAcknowledgement([]byte{byte(1)})

var data types.ValidatorSetChangePacketData
var ackErr error
if err := types.ModuleCdc.UnmarshalJSON(packet.GetData(), &data); err != nil {
errAck := types.NewErrorAcknowledgementWithLog(ctx, fmt.Errorf("cannot unmarshal CCV packet data"))
ack = &errAck
} else {
ack = am.keeper.OnRecvVSCPacket(ctx, packet, data)
ackErr = errorsmod.Wrapf(sdkerrors.ErrInvalidType, "cannot unmarshal VSCPacket data")
logger.Error(fmt.Sprintf("%s sequence %d", ackErr.Error(), packet.Sequence))
ack = channeltypes.NewErrorAcknowledgement(ackErr)
}

// only attempt the application logic if the packet data
// was successfully decoded
if ack.Success() {
err := am.keeper.OnRecvVSCPacket(ctx, packet, data)
if err != nil {
ack = channeltypes.NewErrorAcknowledgement(err)
ackErr = err
logger.Error(fmt.Sprintf("%s sequence %d", ackErr.Error(), packet.Sequence))
} else {
logger.Info("successfully handled VSCPacket sequence: %d", packet.Sequence)
}
}

eventAttributes := []sdk.Attribute{
sdk.NewAttribute(sdk.AttributeKeyModule, types.ModuleName),
sdk.NewAttribute(types.AttributeValSetUpdateID, strconv.Itoa(int(data.ValsetUpdateId))),
sdk.NewAttribute(types.AttributeKeyAckSuccess, fmt.Sprintf("%t", ack.Success())),
}

if ackErr != nil {
eventAttributes = append(eventAttributes, sdk.NewAttribute(types.AttributeKeyAckError, ackErr.Error()))
}

ctx.EventManager().EmitEvent(
sdk.NewEvent(
types.EventTypePacket,
sdk.NewAttribute(sdk.AttributeKeyModule, consumertypes.ModuleName),
sdk.NewAttribute(types.AttributeKeyAckSuccess, fmt.Sprintf("%t", ack != nil)),
eventAttributes...,
),
)

// NOTE: acknowledgement will be written synchronously during IBC handler execution.
return ack
}

Expand Down
11 changes: 7 additions & 4 deletions x/ccv/consumer/keeper/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (

clienttypes "github.com/cosmos/ibc-go/v7/modules/core/02-client/types"
channeltypes "github.com/cosmos/ibc-go/v7/modules/core/04-channel/types"
"github.com/cosmos/ibc-go/v7/modules/core/exported"

errorsmod "cosmossdk.io/errors"

Expand All @@ -25,7 +24,12 @@ import (
//
// Note: CCV uses an ordered IBC channel, meaning VSC packet changes will be accumulated (and later
// processed by ApplyCCValidatorChanges) s.t. more recent val power changes overwrite older ones.
func (k Keeper) OnRecvVSCPacket(ctx sdk.Context, packet channeltypes.Packet, newChanges ccv.ValidatorSetChangePacketData) exported.Acknowledgement {
func (k Keeper) OnRecvVSCPacket(ctx sdk.Context, packet channeltypes.Packet, newChanges ccv.ValidatorSetChangePacketData) error {
// validate packet data upon receiving
if err := newChanges.Validate(); err != nil {
return errorsmod.Wrapf(err, "error validating VSCPacket data")
}

// get the provider channel
providerChannel, found := k.GetProviderChannel(ctx)
if found && providerChannel != packet.DestinationChannel {
Expand Down Expand Up @@ -87,8 +91,7 @@ func (k Keeper) OnRecvVSCPacket(ctx sdk.Context, packet channeltypes.Packet, new
"len updates", len(newChanges.ValidatorUpdates),
"len slash acks", len(newChanges.SlashAcks),
)
ack := channeltypes.NewResultAcknowledgement([]byte{byte(1)})
return ack
return nil
}

// QueueVSCMaturedPackets appends matured VSCs to an internal queue.
Expand Down
Loading

0 comments on commit 0d82030

Please sign in to comment.