From 0d82030c4300a213d66433ace85151201dd5026c Mon Sep 17 00:00:00 2001 From: Marius Poke Date: Fri, 1 Dec 2023 12:26:39 +0100 Subject: [PATCH] fix!: improve message validation (#1460) * validate ValidatorSetChangePacketData * update ValidateBasic for ValidatorSetChangePacketData * update ConsumerPacketData validation * fix TestConsumerPacketSendExpiredClient * update TestOnRecvSlashPacketErrors * fix TestQueueAndSendSlashPacket * remove TODO * nit: validate MsgAssignConsumerKey * add changelog entries * fix linter * fix gosec * rename ValidateBasic to Validate (IBC packets) * Update x/ccv/types/wire.go Co-authored-by: Simon Noetzlin * revert SlashPacketData address validation --------- Co-authored-by: Simon Noetzlin --- .../state-breaking/1460-msg-validation.md | 3 + tests/integration/expired_client.go | 2 +- tests/integration/slashing.go | 169 ++++++++---------- tests/integration/throttle.go | 9 +- tests/integration/valset_update.go | 15 +- x/ccv/consumer/ibc_module.go | 44 +++-- x/ccv/consumer/keeper/relay.go | 11 +- x/ccv/consumer/keeper/relay_test.go | 55 ++++-- x/ccv/provider/ibc_module.go | 68 +++++-- x/ccv/provider/keeper/relay.go | 38 ++-- x/ccv/provider/keeper/relay_test.go | 54 +++--- x/ccv/provider/types/errors.go | 2 +- x/ccv/provider/types/msg.go | 5 +- x/ccv/types/wire.go | 50 ++++-- x/ccv/types/wire_test.go | 68 ++++++- 15 files changed, 372 insertions(+), 221 deletions(-) create mode 100644 .changelog/unreleased/state-breaking/1460-msg-validation.md diff --git a/.changelog/unreleased/state-breaking/1460-msg-validation.md b/.changelog/unreleased/state-breaking/1460-msg-validation.md new file mode 100644 index 0000000000..46d18bd4c9 --- /dev/null +++ b/.changelog/unreleased/state-breaking/1460-msg-validation.md @@ -0,0 +1,3 @@ +- Improve validation of IBC packet data and provider messages. Also, + enable the provider to validate consumer packets before handling them. + ([\#1460](https://github.com/cosmos/interchain-security/pull/1460)) \ No newline at end of file diff --git a/tests/integration/expired_client.go b/tests/integration/expired_client.go index f196f16ce1..81654895eb 100644 --- a/tests/integration/expired_client.go +++ b/tests/integration/expired_client.go @@ -129,7 +129,7 @@ func (s *CCVTestSuite) TestConsumerPacketSendExpiredClient() { // try to send slash packet for downtime infraction addr := ed25519.GenPrivKey().PubKey().Address() - val := abci.Validator{Address: addr} + val := abci.Validator{Address: addr, Power: 1} consumerKeeper.QueueSlashPacket(s.consumerCtx(), val, 2, stakingtypes.Infraction_INFRACTION_DOWNTIME) // try to send slash packet for the same downtime infraction consumerKeeper.QueueSlashPacket(s.consumerCtx(), val, 3, stakingtypes.Infraction_INFRACTION_DOWNTIME) diff --git a/tests/integration/slashing.go b/tests/integration/slashing.go index f1be51d33c..bb5206d94e 100644 --- a/tests/integration/slashing.go +++ b/tests/integration/slashing.go @@ -9,6 +9,7 @@ import ( cryptocodec "github.com/cosmos/cosmos-sdk/crypto/codec" sdk "github.com/cosmos/cosmos-sdk/types" + sdkaddress "github.com/cosmos/cosmos-sdk/types/address" evidencetypes "github.com/cosmos/cosmos-sdk/x/evidence/types" slashingtypes "github.com/cosmos/cosmos-sdk/x/slashing/types" stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types" @@ -269,12 +270,14 @@ func (s *CCVTestSuite) TestSlashPacketAcknowledgement() { // Map infraction height on provider so validation passes and provider returns valid ack result providerKeeper.SetValsetUpdateBlockHeight(s.providerCtx(), spd.ValsetUpdateId, 47923) - exportedAck := providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, spd) - s.Require().NotNil(exportedAck) + ackResult, err := providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, spd) + s.Require().NotNil(ackResult) + s.Require().NoError(err) + exportedAck := channeltypes.NewResultAcknowledgement(ackResult) // Unmarshal ack to struct that's compatible with consumer. IBC does this automatically ack := channeltypes.Acknowledgement{} - err := channeltypes.SubModuleCdc.UnmarshalJSON(exportedAck.Acknowledgement(), &ack) + err = channeltypes.SubModuleCdc.UnmarshalJSON(exportedAck.Acknowledgement(), &ack) s.Require().NoError(err) err = consumerKeeper.OnAcknowledgementPacket(s.consumerCtx(), packet, ack) @@ -325,12 +328,11 @@ func (suite *CCVTestSuite) TestHandleSlashPacketDowntime() { suite.Require().Equal(suite.providerCtx().BlockTime().Add(jailDuration), signingInfo.JailedUntil) } +// TODO: had conflicts // TestOnRecvSlashPacketErrors tests errors for the OnRecvSlashPacket method in an integration testing setting func (suite *CCVTestSuite) TestOnRecvSlashPacketErrors() { providerKeeper := suite.providerApp.GetProviderKeeper() - providerSlashingKeeper := suite.providerApp.GetTestSlashingKeeper() firstBundle := suite.getFirstBundle() - consumerChainID := firstBundle.Chain.ChainID suite.SetupAllCCVChannels() @@ -339,106 +341,80 @@ func (suite *CCVTestSuite) TestOnRecvSlashPacketErrors() { // Expect panic if ccv channel is not established via dest channel of packet suite.Panics(func() { - providerKeeper.OnRecvSlashPacket(ctx, channeltypes.Packet{}, ccv.SlashPacketData{}) + _, _ = providerKeeper.OnRecvSlashPacket(ctx, channeltypes.Packet{}, ccv.SlashPacketData{}) }) // Add correct channelID to packet. Now we will not panic anymore. packet := channeltypes.Packet{DestinationChannel: firstBundle.Path.EndpointB.ChannelID} + suite.NotPanics(func() { + _, _ = providerKeeper.OnRecvSlashPacket(ctx, packet, ccv.SlashPacketData{}) + }) - // Init chain height is set by established CCV channel - // Delete init chain height and confirm expected error - initChainHeight, found := providerKeeper.GetInitChainHeight(ctx, consumerChainID) - suite.Require().True(found) - providerKeeper.DeleteInitChainHeight(ctx, consumerChainID) - - packetData := ccv.SlashPacketData{ValsetUpdateId: 0} - errAck := providerKeeper.OnRecvSlashPacket(ctx, packet, packetData) - suite.Require().False(errAck.Success()) - errAckCast := errAck.(channeltypes.Acknowledgement) - // Error strings in err acks are now thrown out by IBC core to prevent app hash. - // Hence a generic error string is expected. - suite.Require().Equal("ABCI code: 1: error handling packet: see events for details", errAckCast.GetError()) - - // Restore init chain height - providerKeeper.SetInitChainHeight(ctx, consumerChainID, initChainHeight) - - // now the method will fail at infraction height check. - packetData.Infraction = stakingtypes.Infraction_INFRACTION_UNSPECIFIED - errAck = providerKeeper.OnRecvSlashPacket(ctx, packet, packetData) - suite.Require().False(errAck.Success()) - errAckCast = errAck.(channeltypes.Acknowledgement) - suite.Require().Equal("ABCI code: 1: error handling packet: see events for details", errAckCast.GetError()) - - // save current VSC ID - vscID := providerKeeper.GetValidatorSetUpdateId(ctx) - - // remove block height value mapped to current VSC ID - providerKeeper.DeleteValsetUpdateBlockHeight(ctx, vscID) - - // Instantiate packet data with current VSC ID - packetData = ccv.SlashPacketData{ValsetUpdateId: vscID} - - // expect an error if mapped block height is not found - errAck = providerKeeper.OnRecvSlashPacket(ctx, packet, packetData) - suite.Require().False(errAck.Success()) - errAckCast = errAck.(channeltypes.Acknowledgement) - suite.Require().Equal("ABCI code: 1: error handling packet: see events for details", errAckCast.GetError()) - - // construct slashing packet with non existing validator - slashingPkt := ccv.NewSlashPacketData( + // Check Validate for SlashPacket data + validAddress := ed25519.GenPrivKey().PubKey().Address() + slashPacketData := ccv.NewSlashPacketData( abci.Validator{ - Address: ed25519.GenPrivKey().PubKey().Address(), - Power: int64(0), + Address: validAddress, + Power: int64(1), }, uint64(0), stakingtypes.Infraction_INFRACTION_DOWNTIME, ) - // Set initial block height for consumer chain - providerKeeper.SetInitChainHeight(ctx, consumerChainID, uint64(ctx.BlockHeight())) - - // Expect no error ack if validator does not exist - // TODO: this behavior should be changed to return an error ack, - // see: https://github.com/cosmos/interchain-security/issues/546 - ack := providerKeeper.OnRecvSlashPacket(ctx, packet, *slashingPkt) - suite.Require().True(ack.Success()) - - val := suite.providerChain.Vals.Validators[0] - - // commit block to set VSC ID - suite.coordinator.CommitBlock(suite.providerChain) - // Update suite.ctx bc CommitBlock updates only providerChain's current header block height - ctx = suite.providerChain.GetContext() - suite.Require().NotZero(providerKeeper.GetValsetUpdateBlockHeight(ctx, vscID)) - - // create validator signing info - valInfo := slashingtypes.NewValidatorSigningInfo(sdk.ConsAddress(val.Address), ctx.BlockHeight(), - ctx.BlockHeight()-1, time.Time{}.UTC(), false, int64(0)) - providerSlashingKeeper.SetValidatorSigningInfo(ctx, sdk.ConsAddress(val.Address), valInfo) - - // update validator address and VSC ID - slashingPkt.Validator.Address = val.Address - slashingPkt.ValsetUpdateId = vscID - - // expect error ack when infraction type in unspecified - tmAddr := suite.providerChain.Vals.Validators[1].Address - slashingPkt.Validator.Address = tmAddr - slashingPkt.Infraction = stakingtypes.Infraction_INFRACTION_UNSPECIFIED - - valInfo.Address = sdk.ConsAddress(tmAddr).String() - providerSlashingKeeper.SetValidatorSigningInfo(ctx, sdk.ConsAddress(tmAddr), valInfo) - - errAck = providerKeeper.OnRecvSlashPacket(ctx, packet, *slashingPkt) - suite.Require().False(errAck.Success()) - - // Expect nothing was queued - suite.Require().Equal(0, len(providerKeeper.GetAllGlobalSlashEntries(ctx))) - suite.Require().Equal(uint64(0), (providerKeeper.GetThrottledPacketDataSize(ctx, consumerChainID))) - - // expect to queue entries for the slash request - slashingPkt.Infraction = stakingtypes.Infraction_INFRACTION_DOWNTIME - ack = providerKeeper.OnRecvSlashPacket(ctx, packet, *slashingPkt) - suite.Require().True(ack.Success()) - suite.Require().Equal(1, len(providerKeeper.GetAllGlobalSlashEntries(ctx))) - suite.Require().Equal(uint64(1), (providerKeeper.GetThrottledPacketDataSize(ctx, consumerChainID))) + // Expect an error if validator address is too long + slashPacketData.Validator.Address = make([]byte, sdkaddress.MaxAddrLen+1) + _, err := providerKeeper.OnRecvSlashPacket(ctx, packet, *slashPacketData) + suite.Require().Error(err, "validating SlashPacket data should fail - invalid validator address") + + // Expect an error if validator power is zero + slashPacketData.Validator.Address = validAddress + slashPacketData.Validator.Power = 0 + _, err = providerKeeper.OnRecvSlashPacket(ctx, packet, *slashPacketData) + suite.Require().Error(err, "validating SlashPacket data should fail - invalid validator power") + + // Expect an error if the infraction type is unspecified + slashPacketData.Validator.Power = 1 + slashPacketData.Infraction = stakingtypes.Infraction_INFRACTION_UNSPECIFIED + _, err = providerKeeper.OnRecvSlashPacket(ctx, packet, *slashPacketData) + suite.Require().Error(err, "validating SlashPacket data should fail - invalid infraction type") + + // Restore slashPacketData to be valid + slashPacketData.Infraction = stakingtypes.Infraction_INFRACTION_DOWNTIME + + // Check ValidateSlashPacket + // Expect an error if a mapping of the infraction height cannot be found; + // just set the vscID of the slash packet to the latest mapped vscID +1 + valsetUpdateBlockHeights := providerKeeper.GetAllValsetUpdateBlockHeights(ctx) + latestMappedValsetUpdateId := valsetUpdateBlockHeights[len(valsetUpdateBlockHeights)-1].ValsetUpdateId + slashPacketData.ValsetUpdateId = latestMappedValsetUpdateId + 1 + _, err = providerKeeper.OnRecvSlashPacket(ctx, packet, *slashPacketData) + suite.Require().Error(err, "ValidateSlashPacket should fail - no infraction height mapping") + + // Restore slashPacketData to be valid + slashPacketData.ValsetUpdateId = latestMappedValsetUpdateId + + // Expect no error if validator does not exist + _, err = providerKeeper.OnRecvSlashPacket(ctx, packet, *slashPacketData) + suite.Require().NoError(err, "no error expected") + + // Check expected behavior for handling SlashPackets for double signing infractions + slashPacketData.Infraction = stakingtypes.Infraction_INFRACTION_DOUBLE_SIGN + ackResult, err := providerKeeper.OnRecvSlashPacket(ctx, packet, *slashPacketData) + suite.Require().NoError(err, "no error expected") + suite.Require().Equal(ccv.V1Result, ackResult, "expected successful ack") + + // Check expected behavior for handling SlashPackets for downtime infractions + slashPacketData.Infraction = stakingtypes.Infraction_INFRACTION_DOWNTIME + + // Expect the packet to bounce if the slash meter is negative + providerKeeper.SetSlashMeter(ctx, sdk.NewInt(-1)) + ackResult, err = providerKeeper.OnRecvSlashPacket(ctx, packet, *slashPacketData) + suite.Require().NoError(err, "no error expected") + suite.Require().Equal(ccv.SlashPacketBouncedResult, ackResult, "expected successful ack") + + // Expect the packet to be handled if the slash meter is positive + providerKeeper.SetSlashMeter(ctx, sdk.NewInt(0)) + ackResult, err = providerKeeper.OnRecvSlashPacket(ctx, packet, *slashPacketData) + suite.Require().NoError(err, "no error expected") + suite.Require().Equal(ccv.SlashPacketHandledResult, ackResult, "expected successful ack") } // TestValidatorDowntime tests if a slash packet is sent @@ -654,6 +630,7 @@ func (suite *CCVTestSuite) TestQueueAndSendSlashPacket() { addr := ed25519.GenPrivKey().PubKey().Address() val := abci.Validator{ Address: addr, + Power: int64(1), } consumerKeeper.QueueSlashPacket(ctx, val, 0, infraction) slashedVals = append(slashedVals, slashedVal{validator: val, infraction: infraction}) diff --git a/tests/integration/throttle.go b/tests/integration/throttle.go index a63b8e9ba7..6da70f16ca 100644 --- a/tests/integration/throttle.go +++ b/tests/integration/throttle.go @@ -356,7 +356,8 @@ func (s *CCVTestSuite) TestPacketSpam() { consumerPacketData, err := provider.UnmarshalConsumerPacketData(data) // Same func used by provider's OnRecvPacket s.Require().NoError(err) packet := s.newPacketFromConsumer(data, uint64(sequence), firstBundle.Path, timeoutHeight, timeoutTimestamp) - providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketData.GetSlashPacketData()) + _, err = providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketData.GetSlashPacketData()) + s.Require().NoError(err) } // Execute block to handle packets in endblock @@ -416,7 +417,8 @@ func (s *CCVTestSuite) TestDoubleSignDoesNotAffectThrottling() { consumerPacketData, err := provider.UnmarshalConsumerPacketData(data) // Same func used by provider's OnRecvPacket s.Require().NoError(err) packet := s.newPacketFromConsumer(data, uint64(sequence), firstBundle.Path, timeoutHeight, timeoutTimestamp) - providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketData.GetSlashPacketData()) + _, err = providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketData.GetSlashPacketData()) + s.Require().NoError(err) } // Execute block to handle packets in endblock @@ -806,7 +808,8 @@ func (s CCVTestSuite) TestSlashAllValidators() { //nolint:govet // this is a tes consumerPacketData, err := provider.UnmarshalConsumerPacketData(data) // Same func used by provider's OnRecvPacket s.Require().NoError(err) packet := s.newPacketFromConsumer(data, ibcSeqNum, s.getFirstBundle().Path, timeoutHeight, timeoutTimestamp) - providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketData.GetSlashPacketData()) + _, err = providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketData.GetSlashPacketData()) + s.Require().NoError(err) } // We should have 24 pending slash packet entries queued. diff --git a/tests/integration/valset_update.go b/tests/integration/valset_update.go index b12066afa8..89236dbd0c 100644 --- a/tests/integration/valset_update.go +++ b/tests/integration/valset_update.go @@ -71,9 +71,8 @@ func (suite *CCVTestSuite) TestQueueAndSendVSCMaturedPackets() { // send first packet packet := suite.newPacketFromProvider(pd.GetBytes(), 1, suite.path, clienttypes.NewHeight(1, 0), 0) - ack := consumerKeeper.OnRecvVSCPacket(suite.consumerChain.GetContext(), packet, pd) - suite.Require().NotNil(ack, "OnRecvVSCPacket did not return ack") - suite.Require().True(ack.Success(), "OnRecvVSCPacket did not return a Success Acknowledgment") + err = consumerKeeper.OnRecvVSCPacket(suite.consumerChain.GetContext(), packet, pd) + suite.Require().Nil(err, "OnRecvVSCPacket did return non-nil error") // increase time incrementTime(suite, time.Hour) @@ -83,9 +82,8 @@ func (suite *CCVTestSuite) TestQueueAndSendVSCMaturedPackets() { pd.ValsetUpdateId = 2 packet.Data = pd.GetBytes() packet.Sequence = 2 - ack = consumerKeeper.OnRecvVSCPacket(suite.consumerChain.GetContext(), packet, pd) - suite.Require().NotNil(ack, "OnRecvVSCPacket did not return ack") - suite.Require().True(ack.Success(), "OnRecvVSCPacket did not return a Success Acknowledgment") + err = consumerKeeper.OnRecvVSCPacket(suite.consumerChain.GetContext(), packet, pd) + suite.Require().Nil(err, "OnRecvVSCPacket did return non-nil error") // increase time incrementTime(suite, 24*time.Hour) @@ -95,9 +93,8 @@ func (suite *CCVTestSuite) TestQueueAndSendVSCMaturedPackets() { pd.ValsetUpdateId = 3 packet.Data = pd.GetBytes() packet.Sequence = 3 - ack = consumerKeeper.OnRecvVSCPacket(suite.consumerChain.GetContext(), packet, pd) - suite.Require().NotNil(ack, "OnRecvVSCPacket did not return ack") - suite.Require().True(ack.Success(), "OnRecvVSCPacket did not return a Success Acknowledgment") + err = consumerKeeper.OnRecvVSCPacket(suite.consumerChain.GetContext(), packet, pd) + suite.Require().Nil(err, "OnRecvVSCPacket did return non-nil error") packetMaturities := consumerKeeper.GetAllPacketMaturityTimes(suite.consumerChain.GetContext()) diff --git a/x/ccv/consumer/ibc_module.go b/x/ccv/consumer/ibc_module.go index 99f0141e64..d8f25b2173 100644 --- a/x/ccv/consumer/ibc_module.go +++ b/x/ccv/consumer/ibc_module.go @@ -2,6 +2,7 @@ package consumer import ( "fmt" + "strconv" "strings" transfertypes "github.com/cosmos/ibc-go/v7/modules/apps/transfer/types" @@ -225,25 +226,48 @@ func (am AppModule) OnRecvPacket( packet channeltypes.Packet, _ sdk.AccAddress, ) ibcexported.Acknowledgement { - var ( - ack ibcexported.Acknowledgement - data types.ValidatorSetChangePacketData - ) + logger := am.keeper.Logger(ctx) + ack := channeltypes.NewResultAcknowledgement([]byte{byte(1)}) + + var data types.ValidatorSetChangePacketData + var ackErr error if err := types.ModuleCdc.UnmarshalJSON(packet.GetData(), &data); err != nil { - errAck := types.NewErrorAcknowledgementWithLog(ctx, fmt.Errorf("cannot unmarshal CCV packet data")) - ack = &errAck - } else { - ack = am.keeper.OnRecvVSCPacket(ctx, packet, data) + ackErr = errorsmod.Wrapf(sdkerrors.ErrInvalidType, "cannot unmarshal VSCPacket data") + logger.Error(fmt.Sprintf("%s sequence %d", ackErr.Error(), packet.Sequence)) + ack = channeltypes.NewErrorAcknowledgement(ackErr) + } + + // only attempt the application logic if the packet data + // was successfully decoded + if ack.Success() { + err := am.keeper.OnRecvVSCPacket(ctx, packet, data) + if err != nil { + ack = channeltypes.NewErrorAcknowledgement(err) + ackErr = err + logger.Error(fmt.Sprintf("%s sequence %d", ackErr.Error(), packet.Sequence)) + } else { + logger.Info("successfully handled VSCPacket sequence: %d", packet.Sequence) + } + } + + eventAttributes := []sdk.Attribute{ + sdk.NewAttribute(sdk.AttributeKeyModule, types.ModuleName), + sdk.NewAttribute(types.AttributeValSetUpdateID, strconv.Itoa(int(data.ValsetUpdateId))), + sdk.NewAttribute(types.AttributeKeyAckSuccess, fmt.Sprintf("%t", ack.Success())), + } + + if ackErr != nil { + eventAttributes = append(eventAttributes, sdk.NewAttribute(types.AttributeKeyAckError, ackErr.Error())) } ctx.EventManager().EmitEvent( sdk.NewEvent( types.EventTypePacket, - sdk.NewAttribute(sdk.AttributeKeyModule, consumertypes.ModuleName), - sdk.NewAttribute(types.AttributeKeyAckSuccess, fmt.Sprintf("%t", ack != nil)), + eventAttributes..., ), ) + // NOTE: acknowledgement will be written synchronously during IBC handler execution. return ack } diff --git a/x/ccv/consumer/keeper/relay.go b/x/ccv/consumer/keeper/relay.go index 3afeb33188..68c911f255 100644 --- a/x/ccv/consumer/keeper/relay.go +++ b/x/ccv/consumer/keeper/relay.go @@ -6,7 +6,6 @@ import ( clienttypes "github.com/cosmos/ibc-go/v7/modules/core/02-client/types" channeltypes "github.com/cosmos/ibc-go/v7/modules/core/04-channel/types" - "github.com/cosmos/ibc-go/v7/modules/core/exported" errorsmod "cosmossdk.io/errors" @@ -25,7 +24,12 @@ import ( // // Note: CCV uses an ordered IBC channel, meaning VSC packet changes will be accumulated (and later // processed by ApplyCCValidatorChanges) s.t. more recent val power changes overwrite older ones. -func (k Keeper) OnRecvVSCPacket(ctx sdk.Context, packet channeltypes.Packet, newChanges ccv.ValidatorSetChangePacketData) exported.Acknowledgement { +func (k Keeper) OnRecvVSCPacket(ctx sdk.Context, packet channeltypes.Packet, newChanges ccv.ValidatorSetChangePacketData) error { + // validate packet data upon receiving + if err := newChanges.Validate(); err != nil { + return errorsmod.Wrapf(err, "error validating VSCPacket data") + } + // get the provider channel providerChannel, found := k.GetProviderChannel(ctx) if found && providerChannel != packet.DestinationChannel { @@ -87,8 +91,7 @@ func (k Keeper) OnRecvVSCPacket(ctx sdk.Context, packet channeltypes.Packet, new "len updates", len(newChanges.ValidatorUpdates), "len slash acks", len(newChanges.SlashAcks), ) - ack := channeltypes.NewResultAcknowledgement([]byte{byte(1)}) - return ack + return nil } // QueueVSCMaturedPackets appends matured VSCs to an internal queue. diff --git a/x/ccv/consumer/keeper/relay_test.go b/x/ccv/consumer/keeper/relay_test.go index d681a2fdc4..f9d03c68f0 100644 --- a/x/ccv/consumer/keeper/relay_test.go +++ b/x/ccv/consumer/keeper/relay_test.go @@ -74,31 +74,59 @@ func TestOnRecvVSCPacket(t *testing.T) { nil, ) + pd3 := types.NewValidatorSetChangePacketData( + []abci.ValidatorUpdate{}, + 3, + []string{ + "invalid_slash_ack", + }, + ) + testCases := []struct { name string + expError bool packet channeltypes.Packet - newChanges types.ValidatorSetChangePacketData expectedPendingChanges types.ValidatorSetChangePacketData }{ { "success on first packet", + false, channeltypes.NewPacket(pd.GetBytes(), 1, types.ProviderPortID, providerCCVChannelID, types.ConsumerPortID, consumerCCVChannelID, clienttypes.NewHeight(1, 0), 0), types.ValidatorSetChangePacketData{ValidatorUpdates: changes1}, - types.ValidatorSetChangePacketData{ValidatorUpdates: changes1}, }, { "success on subsequent packet", + false, channeltypes.NewPacket(pd.GetBytes(), 2, types.ProviderPortID, providerCCVChannelID, types.ConsumerPortID, consumerCCVChannelID, clienttypes.NewHeight(1, 0), 0), types.ValidatorSetChangePacketData{ValidatorUpdates: changes1}, - types.ValidatorSetChangePacketData{ValidatorUpdates: changes1}, }, { "success on packet with more changes", + false, channeltypes.NewPacket(pd2.GetBytes(), 3, types.ProviderPortID, providerCCVChannelID, types.ConsumerPortID, consumerCCVChannelID, clienttypes.NewHeight(1, 0), 0), - types.ValidatorSetChangePacketData{ValidatorUpdates: changes2}, + types.ValidatorSetChangePacketData{ValidatorUpdates: []abci.ValidatorUpdate{ + { + PubKey: pk1, + Power: 30, + }, + { + PubKey: pk2, + Power: 40, + }, + { + PubKey: pk3, + Power: 10, + }, + }}, + }, + { + "failure on packet with invalid slash acks", + true, + channeltypes.NewPacket(pd3.GetBytes(), 4, types.ProviderPortID, providerCCVChannelID, types.ConsumerPortID, consumerCCVChannelID, + clienttypes.NewHeight(1, 0), 0), types.ValidatorSetChangePacketData{ValidatorUpdates: []abci.ValidatorUpdate{ { PubKey: pk1, @@ -128,10 +156,16 @@ func TestOnRecvVSCPacket(t *testing.T) { consumerKeeper.SetParams(ctx, moduleParams) for _, tc := range testCases { - ack := consumerKeeper.OnRecvVSCPacket(ctx, tc.packet, tc.newChanges) + var newChanges types.ValidatorSetChangePacketData + err := types.ModuleCdc.UnmarshalJSON(tc.packet.GetData(), &newChanges) + require.Nil(t, err, "invalid test case: %s - cannot unmarshal VSCPacket data", tc.name) + err = consumerKeeper.OnRecvVSCPacket(ctx, tc.packet, newChanges) + if tc.expError { + require.Error(t, err, "%s - invalid but OnRecvVSCPacket did not return error", tc.name) + continue + } + require.NoError(t, err, "%s - valid but OnRecvVSCPacket returned error: %w", tc.name, err) - require.NotNil(t, ack, "invalid test case: %s did not return ack", tc.name) - require.True(t, ack.Success(), "invalid test case: %s did not return a Success Acknowledgment", tc.name) providerChannel, ok := consumerKeeper.GetProviderChannel(ctx) require.True(t, ok) require.Equal(t, tc.packet.DestinationChannel, providerChannel, @@ -152,7 +186,7 @@ func TestOnRecvVSCPacket(t *testing.T) { expectedTime := ctx.BlockTime().Add(consumerKeeper.GetUnbondingPeriod(ctx)) require.True( t, - consumerKeeper.PacketMaturityTimeExists(ctx, tc.newChanges.ValsetUpdateId, expectedTime), + consumerKeeper.PacketMaturityTimeExists(ctx, newChanges.ValsetUpdateId, expectedTime), "no packet maturity time for case: %s", tc.name, ) } @@ -199,9 +233,8 @@ func TestOnRecvVSCPacketDuplicateUpdates(t *testing.T) { require.False(t, ok) // Execute OnRecvVSCPacket - ack := consumerKeeper.OnRecvVSCPacket(ctx, packet, vscData) - require.NotNil(t, ack) - require.True(t, ack.Success()) + err := consumerKeeper.OnRecvVSCPacket(ctx, packet, vscData) + require.Nil(t, err) // Confirm pending changes are queued by OnRecvVSCPacket gotPendingChanges, ok := consumerKeeper.GetPendingChanges(ctx) diff --git a/x/ccv/provider/ibc_module.go b/x/ccv/provider/ibc_module.go index 75da9588d6..8d342b0c6f 100644 --- a/x/ccv/provider/ibc_module.go +++ b/x/ccv/provider/ibc_module.go @@ -2,6 +2,7 @@ package provider import ( "fmt" + "strconv" channeltypes "github.com/cosmos/ibc-go/v7/modules/core/04-channel/types" porttypes "github.com/cosmos/ibc-go/v7/modules/core/05-port/types" @@ -175,36 +176,67 @@ func (am AppModule) OnRecvPacket( packet channeltypes.Packet, _ sdk.AccAddress, ) ibcexported.Acknowledgement { + logger := am.keeper.Logger(ctx) + ack := channeltypes.NewResultAcknowledgement([]byte{byte(1)}) + + var ackErr error consumerPacket, err := UnmarshalConsumerPacket(packet) if err != nil { - errAck := ccv.NewErrorAcknowledgementWithLog(ctx, err) - return &errAck + ackErr = errorsmod.Wrapf(sdkerrors.ErrInvalidType, "cannot unmarshal ConsumerPacket data") + logger.Error(fmt.Sprintf("%s sequence %d", ackErr.Error(), packet.Sequence)) + ack = channeltypes.NewErrorAcknowledgement(ackErr) + } + + eventAttributes := []sdk.Attribute{ + sdk.NewAttribute(sdk.AttributeKeyModule, providertypes.ModuleName), + } + + // only attempt the application logic if the packet data + // was successfully decoded + if ack.Success() { + var err error + switch consumerPacket.Type { + case ccv.VscMaturedPacket: + // handle VSCMaturedPacket + data := *consumerPacket.GetVscMaturedPacketData() + err = am.keeper.OnRecvVSCMaturedPacket(ctx, packet, data) + if err == nil { + logger.Info("successfully handled VSCMaturedPacket sequence: %d", packet.Sequence) + eventAttributes = append(eventAttributes, sdk.NewAttribute(ccv.AttributeValSetUpdateID, strconv.Itoa(int(data.ValsetUpdateId)))) + } + case ccv.SlashPacket: + // handle SlashPacket + var ackResult ccv.PacketAckResult + data := *consumerPacket.GetSlashPacketData() + ackResult, err = am.keeper.OnRecvSlashPacket(ctx, packet, data) + if err == nil { + ack = channeltypes.NewResultAcknowledgement(ackResult) + logger.Info("successfully handled SlashPacket sequence: %d", packet.Sequence) + eventAttributes = append(eventAttributes, sdk.NewAttribute(ccv.AttributeValSetUpdateID, strconv.Itoa(int(data.ValsetUpdateId)))) + } + default: + err = fmt.Errorf("invalid consumer packet type: %q", consumerPacket.Type) + } + if err != nil { + ack = channeltypes.NewErrorAcknowledgement(err) + ackErr = err + logger.Error(fmt.Sprintf("%s sequence %d", ackErr.Error(), packet.Sequence)) + } } - // TODO: call ValidateBasic method on consumer packet data - // See: https://github.com/cosmos/interchain-security/issues/634 - - var ack ibcexported.Acknowledgement - switch consumerPacket.Type { - case ccv.VscMaturedPacket: - // handle VSCMaturedPacket - ack = am.keeper.OnRecvVSCMaturedPacket(ctx, packet, *consumerPacket.GetVscMaturedPacketData()) - case ccv.SlashPacket: - // handle SlashPacket - ack = am.keeper.OnRecvSlashPacket(ctx, packet, *consumerPacket.GetSlashPacketData()) - default: - errAck := ccv.NewErrorAcknowledgementWithLog(ctx, fmt.Errorf("invalid consumer packet type: %q", consumerPacket.Type)) - ack = &errAck + eventAttributes = append(eventAttributes, sdk.NewAttribute(ccv.AttributeKeyAckSuccess, fmt.Sprintf("%t", ack.Success()))) + if ackErr != nil { + eventAttributes = append(eventAttributes, sdk.NewAttribute(ccv.AttributeKeyAckError, ackErr.Error())) } ctx.EventManager().EmitEvent( sdk.NewEvent( ccv.EventTypePacket, - sdk.NewAttribute(sdk.AttributeKeyModule, providertypes.ModuleName), - sdk.NewAttribute(ccv.AttributeKeyAckSuccess, fmt.Sprintf("%t", ack != nil)), + eventAttributes..., ), ) + // NOTE: acknowledgement will be written synchronously during IBC handler execution. return ack } diff --git a/x/ccv/provider/keeper/relay.go b/x/ccv/provider/keeper/relay.go index 95782587aa..07c246ee12 100644 --- a/x/ccv/provider/keeper/relay.go +++ b/x/ccv/provider/keeper/relay.go @@ -6,7 +6,6 @@ import ( clienttypes "github.com/cosmos/ibc-go/v7/modules/core/02-client/types" channeltypes "github.com/cosmos/ibc-go/v7/modules/core/04-channel/types" - "github.com/cosmos/ibc-go/v7/modules/core/exported" errorsmod "cosmossdk.io/errors" @@ -22,7 +21,7 @@ func (k Keeper) OnRecvVSCMaturedPacket( ctx sdk.Context, packet channeltypes.Packet, data ccv.VSCMaturedPacketData, -) exported.Acknowledgement { +) error { // check that the channel is established, panic if not chainID, found := k.GetChannelToChain(ctx, packet.DestinationChannel) if !found { @@ -34,9 +33,14 @@ func (k Keeper) OnRecvVSCMaturedPacket( panic(fmt.Errorf("VSCMaturedPacket received on unknown channel %s", packet.DestinationChannel)) } + // validate packet data upon receiving + if err := data.Validate(); err != nil { + return errorsmod.Wrapf(err, "error validating VSCMaturedPacket data") + } + if err := k.QueueThrottledVSCMaturedPacketData(ctx, chainID, packet.Sequence, data); err != nil { - return ccv.NewErrorAcknowledgementWithLog(ctx, fmt.Errorf( - "failed to queue VSCMatured packet data: %s", err.Error())) + return errorsmod.Wrapf(err, + "failed to queue VSCMatured packet data") } k.Logger(ctx).Info("VSCMaturedPacket received and enqueued", @@ -44,8 +48,7 @@ func (k Keeper) OnRecvVSCMaturedPacket( "vscID", data.ValsetUpdateId, ) - ack := channeltypes.NewResultAcknowledgement(ccv.V1Result) - return ack + return nil } // HandleLeadingVSCMaturedPackets handles all VSCMatured packet data that has been queued this block, @@ -312,7 +315,11 @@ func (k Keeper) EndBlockCIS(ctx sdk.Context) { // OnRecvSlashPacket delivers a received slash packet, validates it and // then queues the slash packet as pending if valid. -func (k Keeper) OnRecvSlashPacket(ctx sdk.Context, packet channeltypes.Packet, data ccv.SlashPacketData) exported.Acknowledgement { +func (k Keeper) OnRecvSlashPacket( + ctx sdk.Context, + packet channeltypes.Packet, + data ccv.SlashPacketData, +) (ccv.PacketAckResult, error) { // check that the channel is established, panic if not chainID, found := k.GetChannelToChain(ctx, packet.DestinationChannel) if !found { @@ -324,6 +331,11 @@ func (k Keeper) OnRecvSlashPacket(ctx sdk.Context, packet channeltypes.Packet, d panic(fmt.Errorf("SlashPacket received on unknown channel %s", packet.DestinationChannel)) } + // validate packet data upon receiving + if err := data.Validate(); err != nil { + return nil, errorsmod.Wrapf(err, "error validating SlashPacket data") + } + if err := k.ValidateSlashPacket(ctx, chainID, packet, data); err != nil { k.Logger(ctx).Error("invalid slash packet", "error", err.Error(), @@ -332,7 +344,7 @@ func (k Keeper) OnRecvSlashPacket(ctx sdk.Context, packet channeltypes.Packet, d "vscID", data.ValsetUpdateId, "infractionType", data.Infraction, ) - return ccv.NewErrorAcknowledgementWithLog(ctx, err) + return nil, err } // The slash packet validator address may be known only on the consumer chain, @@ -355,7 +367,7 @@ func (k Keeper) OnRecvSlashPacket(ctx sdk.Context, packet channeltypes.Packet, d // return successful ack, as an error would result // in the consumer closing the CCV channel - return channeltypes.NewResultAcknowledgement(ccv.V1Result) + return ccv.V1Result, nil } // Queue a slash entry to the global queue, which will be seen by the throttling logic @@ -368,7 +380,7 @@ func (k Keeper) OnRecvSlashPacket(ctx sdk.Context, packet channeltypes.Packet, d // Queue slash packet data in the same (consumer chain specific) queue as vsc matured packet data, // to enforce order of handling between the two packet data types. if err := k.QueueThrottledSlashPacketData(ctx, chainID, packet.Sequence, data); err != nil { - return ccv.NewErrorAcknowledgementWithLog(ctx, fmt.Errorf("failed to queue slash packet data: %s", err.Error())) + return nil, err } k.Logger(ctx).Info("slash packet received and enqueued", @@ -379,7 +391,7 @@ func (k Keeper) OnRecvSlashPacket(ctx sdk.Context, packet channeltypes.Packet, d "infractionType", data.Infraction, ) - return channeltypes.NewResultAcknowledgement(ccv.V1Result) + return ccv.V1Result, nil } // ValidateSlashPacket validates a recv slash packet before it is @@ -395,10 +407,6 @@ func (k Keeper) ValidateSlashPacket(ctx sdk.Context, chainID string, "the validator update id %d for chain %s", data.ValsetUpdateId, chainID) } - if data.Infraction != stakingtypes.Infraction_INFRACTION_DOUBLE_SIGN && data.Infraction != stakingtypes.Infraction_INFRACTION_DOWNTIME { - return fmt.Errorf("invalid infraction type: %s", data.Infraction) - } - return nil } diff --git a/x/ccv/provider/keeper/relay_test.go b/x/ccv/provider/keeper/relay_test.go index d3fdcaa21e..7fb5d7157d 100644 --- a/x/ccv/provider/keeper/relay_test.go +++ b/x/ccv/provider/keeper/relay_test.go @@ -7,7 +7,6 @@ import ( clienttypes "github.com/cosmos/ibc-go/v7/modules/core/02-client/types" channeltypes "github.com/cosmos/ibc-go/v7/modules/core/04-channel/types" - exported "github.com/cosmos/ibc-go/v7/modules/core/exported" ibctesting "github.com/cosmos/ibc-go/v7/testing" "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" @@ -107,8 +106,8 @@ func TestOnRecvVSCMaturedPacket(t *testing.T) { providerKeeper.SetChannelToChain(ctx, "channel-2", "chain-2") // Execute on recv for chain-1 - ack := executeOnRecvVSCMaturedPacket(t, &providerKeeper, ctx, "channel-1", 1) - require.Equal(t, channeltypes.NewResultAcknowledgement([]byte{byte(1)}), ack) + err := executeOnRecvVSCMaturedPacket(t, &providerKeeper, ctx, "channel-1", 1) + require.NoError(t, err) // Assert that the packet data was queued for chain-1 require.Equal(t, uint64(1), providerKeeper.GetThrottledPacketDataSize(ctx, "chain-1")) @@ -118,10 +117,10 @@ func TestOnRecvVSCMaturedPacket(t *testing.T) { // Now queue a slash packet data instance for chain-2, then confirm the on recv method // queues the vsc matured behind the slash packet data - err := providerKeeper.QueueThrottledSlashPacketData(ctx, "chain-2", 1, testkeeper.GetNewSlashPacketData()) + err = providerKeeper.QueueThrottledSlashPacketData(ctx, "chain-2", 1, testkeeper.GetNewSlashPacketData()) + require.NoError(t, err) + err = executeOnRecvVSCMaturedPacket(t, &providerKeeper, ctx, "channel-2", 2) require.NoError(t, err) - ack = executeOnRecvVSCMaturedPacket(t, &providerKeeper, ctx, "channel-2", 2) - require.Equal(t, channeltypes.NewResultAcknowledgement([]byte{byte(1)}), ack) require.Equal(t, uint64(2), providerKeeper.GetThrottledPacketDataSize(ctx, "chain-2")) // Chain-1 still has 1 packet data queued @@ -129,8 +128,8 @@ func TestOnRecvVSCMaturedPacket(t *testing.T) { // Receive 5 more vsc matured packets for chain-2, then confirm chain-2 queue size is 7, chain-1 still size 1 for i := 0; i < 5; i++ { - ack = executeOnRecvVSCMaturedPacket(t, &providerKeeper, ctx, "channel-2", uint64(i+3)) - require.Equal(t, channeltypes.NewResultAcknowledgement([]byte{byte(1)}), ack) + err := executeOnRecvVSCMaturedPacket(t, &providerKeeper, ctx, "channel-2", uint64(i+3)) + require.NoError(t, err) } require.Equal(t, uint64(7), providerKeeper.GetThrottledPacketDataSize(ctx, "chain-2")) require.Equal(t, uint64(1), providerKeeper.GetThrottledPacketDataSize(ctx, "chain-1")) @@ -138,6 +137,15 @@ func TestOnRecvVSCMaturedPacket(t *testing.T) { // Delete chain-2's data from its queue, then confirm the queue size is 0 providerKeeper.DeleteThrottledPacketData(ctx, "chain-2", []uint64{1, 2, 3, 4, 5, 6, 7}...) require.Equal(t, uint64(0), providerKeeper.GetThrottledPacketDataSize(ctx, "chain-2")) + + // Execute on recv for chain-1, confirm v1 result ack is returned + err = executeOnRecvVSCMaturedPacket(t, &providerKeeper, ctx, "channel-1", 1) + require.NoError(t, err) + + // Now queue a slash packet data instance for chain-2, confirm v1 result ack is returned + err = executeOnRecvVSCMaturedPacket(t, &providerKeeper, ctx, "channel-2", 2) + require.NoError(t, err) + } func TestHandleLeadingVSCMaturedPackets(t *testing.T) { @@ -182,6 +190,7 @@ func TestHandleLeadingVSCMaturedPackets(t *testing.T) { err = providerKeeper.QueueThrottledSlashPacketData(ctx, "chain-2", 3, testkeeper.GetNewSlashPacketData()) require.NoError(t, err) err = providerKeeper.QueueThrottledSlashPacketData(ctx, "chain-2", 4, testkeeper.GetNewSlashPacketData()) + require.NoError(t, err) // And one more trailing vsc matured packet for chain-2 @@ -248,8 +257,9 @@ func TestOnRecvDoubleSignSlashPacket(t *testing.T) { providerKeeper.SetValsetUpdateBlockHeight(ctx, packetData.ValsetUpdateId, uint64(15)) // Receive the double-sign slash packet for chain-1 and confirm the expected acknowledgement - ack := executeOnRecvSlashPacket(t, &providerKeeper, ctx, "channel-1", 1, packetData) - require.Equal(t, channeltypes.NewResultAcknowledgement([]byte{byte(1)}), ack) + ackResult, err := executeOnRecvSlashPacket(t, &providerKeeper, ctx, "channel-1", 1, packetData) + require.Equal(t, ccv.V1Result, ackResult) + require.NoError(t, err) // Nothing should be queued require.Equal(t, uint64(0), providerKeeper.GetThrottledPacketDataSize(ctx, "chain-1")) @@ -283,8 +293,9 @@ func TestOnRecvDowntimeSlashPacket(t *testing.T) { // Receive the downtime slash packet for chain-1 at time.Now() ctx = ctx.WithBlockTime(time.Now()) - ack := executeOnRecvSlashPacket(t, &providerKeeper, ctx, "channel-1", 1, packetData) - require.Equal(t, channeltypes.NewResultAcknowledgement([]byte{byte(1)}), ack) + ackResult, err := executeOnRecvSlashPacket(t, &providerKeeper, ctx, "channel-1", 1, packetData) + require.Equal(t, ccv.V1Result, ackResult) + require.NoError(t, err) // Confirm an entry was added to the global queue, and pending packet data was added to the per-chain queue globalEntries := providerKeeper.GetAllGlobalSlashEntries(ctx) // parent queue @@ -301,8 +312,9 @@ func TestOnRecvDowntimeSlashPacket(t *testing.T) { // Receive a downtime slash packet for chain-2 at time.Now(Add(1 *time.Hour)) ctx = ctx.WithBlockTime(time.Now().Add(1 * time.Hour)) - ack = executeOnRecvSlashPacket(t, &providerKeeper, ctx, "channel-2", 2, packetData) - require.Equal(t, channeltypes.NewResultAcknowledgement([]byte{byte(1)}), ack) + ackResult, err = executeOnRecvSlashPacket(t, &providerKeeper, ctx, "channel-2", 2, packetData) + require.Equal(t, ccv.V1Result, ackResult) + require.NoError(t, err) // Confirm sizes of parent queue and both per-chain queues globalEntries = providerKeeper.GetAllGlobalSlashEntries(ctx) @@ -315,7 +327,7 @@ func TestOnRecvDowntimeSlashPacket(t *testing.T) { func executeOnRecvVSCMaturedPacket(t *testing.T, providerKeeper *keeper.Keeper, ctx sdk.Context, channelID string, ibcSeqNum uint64, -) exported.Acknowledgement { +) error { t.Helper() // Instantiate vsc matured packet data and bytes data := testkeeper.GetNewVSCMaturedPacketData() @@ -331,7 +343,7 @@ func executeOnRecvVSCMaturedPacket(t *testing.T, providerKeeper *keeper.Keeper, func executeOnRecvSlashPacket(t *testing.T, providerKeeper *keeper.Keeper, ctx sdk.Context, channelID string, ibcSeqNum uint64, packetData ccv.SlashPacketData, -) exported.Acknowledgement { +) (ccv.PacketAckResult, error) { t.Helper() // Instantiate slash packet data and bytes dataBz, err := packetData.Marshal() @@ -358,16 +370,6 @@ func TestValidateSlashPacket(t *testing.T) { ccv.SlashPacketData{ValsetUpdateId: 61}, true, }, - { - "non-set infraction type", - ccv.SlashPacketData{ValsetUpdateId: validVscID}, - true, - }, - { - "invalid infraction type", - ccv.SlashPacketData{ValsetUpdateId: validVscID, Infraction: stakingtypes.MaxMonikerLength}, - true, - }, { "valid double sign packet with non-zero vscID", ccv.SlashPacketData{ValsetUpdateId: validVscID, Infraction: stakingtypes.Infraction_INFRACTION_DOUBLE_SIGN}, diff --git a/x/ccv/provider/types/errors.go b/x/ccv/provider/types/errors.go index 66d7a9a3b8..6c19a7b396 100644 --- a/x/ccv/provider/types/errors.go +++ b/x/ccv/provider/types/errors.go @@ -11,7 +11,7 @@ var ( ErrUnknownConsumerChainId = errorsmod.Register(ModuleName, 3, "no consumer chain with this chain id") ErrUnknownConsumerChannelId = errorsmod.Register(ModuleName, 4, "no consumer chain with this channel id") ErrInvalidConsumerConsensusPubKey = errorsmod.Register(ModuleName, 5, "empty consumer consensus public key") - ErrBlankConsumerChainID = errorsmod.Register(ModuleName, 6, "consumer chain id must not be blank") + ErrInvalidConsumerChainID = errorsmod.Register(ModuleName, 6, "invalid consumer chain id") ErrConsumerKeyNotFound = errorsmod.Register(ModuleName, 7, "consumer key not found") ErrNoValidatorConsumerAddress = errorsmod.Register(ModuleName, 8, "error getting validator consumer address") ErrNoValidatorProviderAddress = errorsmod.Register(ModuleName, 9, "error getting validator provider address") diff --git a/x/ccv/provider/types/msg.go b/x/ccv/provider/types/msg.go index f7ee11325c..73fb7f5b74 100644 --- a/x/ccv/provider/types/msg.go +++ b/x/ccv/provider/types/msg.go @@ -72,14 +72,13 @@ func (msg MsgAssignConsumerKey) GetSignBytes() []byte { // ValidateBasic implements the sdk.Msg interface. func (msg MsgAssignConsumerKey) ValidateBasic() error { if strings.TrimSpace(msg.ChainId) == "" { - return ErrBlankConsumerChainID + return errorsmod.Wrapf(ErrInvalidConsumerChainID, "chainId cannot be blank") } // It is possible to assign keys for consumer chains that are not yet approved. // This can only be done by a signing validator, but it is still sensible // to limit the chainID size to prevent abuse. - if 128 < len(msg.ChainId) { - return ErrBlankConsumerChainID + return errorsmod.Wrapf(ErrInvalidConsumerChainID, "chainId cannot exceed 128 length") } _, err := sdk.ValAddressFromBech32(msg.ProviderAddr) if err != nil { diff --git a/x/ccv/types/wire.go b/x/ccv/types/wire.go index 5b4e57994f..c7cbe9e126 100644 --- a/x/ccv/types/wire.go +++ b/x/ccv/types/wire.go @@ -5,6 +5,7 @@ import ( errorsmod "cosmossdk.io/errors" + sdk "github.com/cosmos/cosmos-sdk/types" stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types" abci "github.com/cometbft/cometbft/abci/types" @@ -18,14 +19,24 @@ func NewValidatorSetChangePacketData(valUpdates []abci.ValidatorUpdate, valUpdat } } -// ValidateBasic is used for validating the CCV packet data. -func (vsc ValidatorSetChangePacketData) ValidateBasic() error { - if len(vsc.ValidatorUpdates) == 0 { - return errorsmod.Wrap(ErrInvalidPacketData, "validator updates cannot be empty") +// Validate is used for validating the CCV packet data. +func (vsc ValidatorSetChangePacketData) Validate() error { + // Note that vsc.ValidatorUpdates can be empty in the case of unbonding + // operations w/o changes in the voting power of the validators in the validator set + if vsc.ValidatorUpdates == nil { + return errorsmod.Wrap(ErrInvalidPacketData, "validator updates cannot be nil") } + // ValsetUpdateId is strictly positive if vsc.ValsetUpdateId == 0 { return errorsmod.Wrap(ErrInvalidPacketData, "valset update id cannot be equal to zero") } + // Validate the slash acks - must be consensus addresses + for _, ack := range vsc.SlashAcks { + _, err := sdk.ConsAddressFromBech32(ack) + if err != nil { + return err + } + } return nil } @@ -42,8 +53,9 @@ func NewVSCMaturedPacketData(valUpdateID uint64) *VSCMaturedPacketData { } } -// ValidateBasic is used for validating the VSCMatured packet data. -func (mat VSCMaturedPacketData) ValidateBasic() error { +// Validate is used for validating the VSCMatured packet data. +func (mat VSCMaturedPacketData) Validate() error { + // ValsetUpdateId is strictly positive if mat.ValsetUpdateId == 0 { return errorsmod.Wrap(ErrInvalidPacketData, "vscId cannot be equal to zero") } @@ -75,13 +87,19 @@ func NewSlashPacketDataV1(validator abci.Validator, valUpdateId uint64, infracti } } -func (vdt SlashPacketData) ValidateBasic() error { - if len(vdt.Validator.Address) == 0 || vdt.Validator.Power == 0 { - return errorsmod.Wrap(ErrInvalidPacketData, "validator fields cannot be empty") +func (vdt SlashPacketData) Validate() error { + // vdt.Validator.Address must be a consensus address + if err := sdk.VerifyAddressFormat(vdt.Validator.Address); err != nil { + return errorsmod.Wrap(ErrInvalidPacketData, fmt.Sprintf("invalid validator: %s", err.Error())) + } + // vdt.Validator.Power must be positive + if vdt.Validator.Power == 0 { + return errorsmod.Wrap(ErrInvalidPacketData, "validator power cannot be zero") } + // Note that ValsetUpdateId can be zero due to the vscID mapping - if vdt.Infraction == stakingtypes.Infraction_INFRACTION_UNSPECIFIED { - return errorsmod.Wrap(ErrInvalidPacketData, "invalid infraction type") + if vdt.Infraction != stakingtypes.Infraction_INFRACTION_DOUBLE_SIGN && vdt.Infraction != stakingtypes.Infraction_INFRACTION_DOWNTIME { + return errorsmod.Wrap(ErrInvalidPacketData, fmt.Sprintf("invalid infraction type: %s", vdt.Infraction.String())) } return nil @@ -91,22 +109,22 @@ func (vdt SlashPacketData) ToV1() *SlashPacketDataV1 { return NewSlashPacketDataV1(vdt.Validator, vdt.ValsetUpdateId, vdt.Infraction) } -func (cp ConsumerPacketData) ValidateBasic() (err error) { +func (cp ConsumerPacketData) Validate() (err error) { switch cp.Type { case VscMaturedPacket: // validate VSCMaturedPacket - vscPacket := cp.GetVscMaturedPacketData() - if vscPacket == nil { + vscMaturedPacket := cp.GetVscMaturedPacketData() + if vscMaturedPacket == nil { return fmt.Errorf("invalid consumer packet data: VscMaturePacketData data cannot be empty") } - err = vscPacket.ValidateBasic() + err = vscMaturedPacket.Validate() case SlashPacket: // validate SlashPacket slashPacket := cp.GetSlashPacketData() if slashPacket == nil { return fmt.Errorf("invalid consumer packet data: SlashPacketData data cannot be empty") } - err = slashPacket.ValidateBasic() + err = slashPacket.Validate() default: err = fmt.Errorf("invalid consumer packet type: %q", cp.Type) } diff --git a/x/ccv/types/wire_test.go b/x/ccv/types/wire_test.go index 50164fcf73..f97d1af1d2 100644 --- a/x/ccv/types/wire_test.go +++ b/x/ccv/types/wire_test.go @@ -22,23 +22,75 @@ func TestPacketDataValidateBasic(t *testing.T) { pk2, err := cryptocodec.ToTmProtoPublicKey(ed25519.GenPrivKey().PubKey()) require.NoError(t, err) + cId := crypto.NewCryptoIdentityFromIntSeed(4732894342) + validSlashAck := cId.SDKValConsAddress().String() + tooLongSlashAck := string(make([]byte, 1024)) + cases := []struct { name string expError bool packetData types.ValidatorSetChangePacketData }{ { - "nil packet data", + "invalid: nil packet data", true, types.NewValidatorSetChangePacketData(nil, 1, nil), }, { - "empty packet data", - true, + "valid: empty packet data", + false, types.NewValidatorSetChangePacketData([]abci.ValidatorUpdate{}, 2, nil), }, { - "valid packet data", + "invalid: slash ack not consensus address", + true, + types.NewValidatorSetChangePacketData( + []abci.ValidatorUpdate{ + { + PubKey: pk1, + Power: 30, + }, + }, + 3, + []string{ + "some_string", + }, + ), + }, + { + "valid: packet data with valid slash ack", + false, + types.NewValidatorSetChangePacketData( + []abci.ValidatorUpdate{ + { + PubKey: pk2, + Power: 20, + }, + }, + 4, + []string{ + validSlashAck, + }, + ), + }, + { + "invalid: slash ack is too long", + true, + types.NewValidatorSetChangePacketData( + []abci.ValidatorUpdate{ + { + PubKey: pk2, + Power: 20, + }, + }, + 5, + []string{ + tooLongSlashAck, + }, + ), + }, + { + "valid: packet data with nil slash ack", false, types.NewValidatorSetChangePacketData( []abci.ValidatorUpdate{ @@ -51,18 +103,18 @@ func TestPacketDataValidateBasic(t *testing.T) { Power: 20, }, }, - 3, + 6, nil, ), }, } for _, c := range cases { - err := c.packetData.ValidateBasic() + err := c.packetData.Validate() if c.expError { - require.Error(t, err, "%s invalid but passed ValidateBasic", c.name) + require.Error(t, err, "%s invalid but passed Validate", c.name) } else { - require.NoError(t, err, "%s valid but ValidateBasic returned error: %w", c.name, err) + require.NoError(t, err, "%s valid but Validate returned error: %w", c.name, err) } } }