diff --git a/x/ccv/provider/keeper/msg_server.go b/x/ccv/provider/keeper/msg_server.go index 0fccc43575..d55a8a1201 100644 --- a/x/ccv/provider/keeper/msg_server.go +++ b/x/ccv/provider/keeper/msg_server.go @@ -474,7 +474,13 @@ func (k msgServer) UpdateConsumer(goCtx context.Context, msg *types.MsgUpdateCon return &resp, errorsmod.Wrapf(ccvtypes.ErrInvalidConsumerState, "cannot get consumer chain ID: %s", err.Error()) } + // We only validate and use `NewChainId` if it is not empty (because `NewChainId` is an optional argument) + // or `NewChainId` is different from the current chain id of the consumer chain. if strings.TrimSpace(msg.NewChainId) != "" && msg.NewChainId != chainId { + if err = types.ValidateChainId("NewChainId", msg.NewChainId); err != nil { + return &resp, errorsmod.Wrapf(types.ErrInvalidMsgUpdateConsumer, "invalid new chain id: %s", err.Error()) + } + if k.IsConsumerPrelaunched(ctx, consumerId) { chainId = msg.NewChainId k.SetConsumerChainId(ctx, consumerId, chainId) @@ -521,14 +527,13 @@ func (k msgServer) UpdateConsumer(goCtx context.Context, msg *types.MsgUpdateCon previousSpawnTime := previousInitializationParameters.SpawnTime if msg.InitializationParameters != nil { - phase := k.GetConsumerPhase(ctx, consumerId) - - if phase == types.CONSUMER_PHASE_LAUNCHED { + if !k.IsConsumerPrelaunched(ctx, consumerId) { return &resp, errorsmod.Wrap(types.ErrInvalidMsgUpdateConsumer, "cannot update the initialization parameters of an an already launched chain; "+ "do not provide any initialization parameters when updating a launched chain") } + phase := k.GetConsumerPhase(ctx, consumerId) if msg.InitializationParameters.SpawnTime.IsZero() { if phase == types.CONSUMER_PHASE_INITIALIZED { // chain was previously ready to launch at `previousSpawnTime` so we remove the diff --git a/x/ccv/provider/keeper/msg_server_test.go b/x/ccv/provider/keeper/msg_server_test.go index 6fe265eeea..9e60ab484f 100644 --- a/x/ccv/provider/keeper/msg_server_test.go +++ b/x/ccv/provider/keeper/msg_server_test.go @@ -122,6 +122,17 @@ func TestUpdateConsumer(t *testing.T) { require.NoError(t, err) require.Equal(t, expectedChainId, chainId) + // assert that we cannot change the chain to that of a reserved chain id + _, err = msgServer.UpdateConsumer(ctx, + &providertypes.MsgUpdateConsumer{ + Owner: "submitter", ConsumerId: consumerId, + Metadata: nil, + InitializationParameters: nil, + PowerShapingParameters: nil, + NewChainId: "stride-1", // reversed chain id + }) + require.ErrorContains(t, err, "cannot use a reserved chain id") + expectedConsumerMetadata := providertypes.ConsumerMetadata{ Name: "name2", Description: "description2", diff --git a/x/ccv/provider/types/msg.go b/x/ccv/provider/types/msg.go index 0128cc9f5f..e5f36d2bc6 100644 --- a/x/ccv/provider/types/msg.go +++ b/x/ccv/provider/types/msg.go @@ -295,20 +295,35 @@ func NewMsgCreateConsumer(submitter, chainId string, metadata ConsumerMetadata, }, nil } -// ValidateBasic implements the sdk.HasValidateBasic interface. -func (msg MsgCreateConsumer) ValidateBasic() error { - if err := ValidateStringField("ChainId", msg.ChainId, cmttypes.MaxChainIDLen); err != nil { - return errorsmod.Wrapf(ErrInvalidMsgCreateConsumer, "ChainId: %s", err.Error()) - } - +// IsReservedChainId returns true if the specific chain id is reserved and cannot be used by other consumer chains +func IsReservedChainId(chainId string) bool { // With permissionless ICS, we can have multiple consumer chains with the exact same chain id. // However, as we already have the Neutron and Stride Top N chains running, as a first step we would like to // prevent permissionless chains from re-using the chain ids of Neutron and Stride. Note that this is just a // preliminary measure that will be removed later on as part of: // TODO (#2242): find a better way of ignoring past misbehaviors - if msg.ChainId == "neutron-1" || msg.ChainId == "stride-1" { - return errorsmod.Wrapf(ErrInvalidMsgCreateConsumer, - "cannot reuse chain ids of existing Neutron and Stride Top N consumer chains") + return chainId == "neutron-1" || chainId == "stride-1" +} + +// ValidateChainId validates that the chain id is valid and is not reserved. +// Can be called for the `MsgUpdateConsumer.NewChainId` field as well, so this method takes the `field` as an argument +// to return more appropriate error messages in case the validation fails. +func ValidateChainId(field string, chainId string) error { + if err := ValidateStringField(field, chainId, cmttypes.MaxChainIDLen); err != nil { + return errorsmod.Wrapf(ErrInvalidMsgCreateConsumer, "%s: %s", field, err.Error()) + } + + if IsReservedChainId(chainId) { + return errorsmod.Wrapf(ErrInvalidMsgCreateConsumer, "cannot use a reserved chain id") + } + + return nil +} + +// ValidateBasic implements the sdk.HasValidateBasic interface. +func (msg MsgCreateConsumer) ValidateBasic() error { + if err := ValidateChainId("ChainId", msg.ChainId); err != nil { + return errorsmod.Wrapf(ErrInvalidMsgCreateConsumer, "ChainId: %s", err.Error()) } if err := ValidateConsumerMetadata(msg.Metadata); err != nil { @@ -389,9 +404,10 @@ func (msg MsgUpdateConsumer) ValidateBasic() error { } } - if msg.NewChainId != "" && len(msg.NewChainId) > cmttypes.MaxChainIDLen { - return errorsmod.Wrapf(ErrInvalidMsgUpdateConsumer, "NewChainId (%s) is too long; got: %d, max: %d", - msg.NewChainId, len(msg.NewChainId), cmttypes.MaxChainIDLen) + if strings.TrimSpace(msg.NewChainId) != "" { + if err := ValidateStringField("NewChainId", msg.NewChainId, cmttypes.MaxChainIDLen); err != nil { + return errorsmod.Wrapf(ErrInvalidMsgUpdateConsumer, "NewChainId: %s", err.Error()) + } } return nil diff --git a/x/ccv/provider/types/msg_test.go b/x/ccv/provider/types/msg_test.go index 87ee0bb0fd..2a0adc70a1 100644 --- a/x/ccv/provider/types/msg_test.go +++ b/x/ccv/provider/types/msg_test.go @@ -1,6 +1,7 @@ package types_test import ( + "strings" "testing" "time" @@ -554,7 +555,7 @@ func TestMsgUpdateConsumerValidateBasic(t *testing.T) { { "too long new chain id", types.PowerShapingParameters{}, - "this is an extremely long chain id that is so long that the validation would fail", + strings.Repeat("thisIsAnExtremelyLongChainId", 2), false, }, } @@ -725,3 +726,50 @@ func TestValidateInitialHeight(t *testing.T) { } } } + +func TestValidateChainId(t *testing.T) { + testCases := []struct { + name string + chainId string + expPass bool + }{ + { + name: "valid chain id", + chainId: "chain-1", + expPass: true, + }, + { + name: "valid chain id with no revision", + chainId: "chainId", + expPass: true, + }, + { + name: "invalid (too long) chain id", + chainId: strings.Repeat("thisIsAnExtremelyLongChainId", 2), + expPass: false, + }, + { + name: "reserved chain id", + chainId: "stride-1", + expPass: false, + }, + { + name: "reserved chain id", + chainId: "neutron-1", + expPass: false, + }, + { + name: "empty chain id", + chainId: " ", + expPass: false, + }, + } + for _, tc := range testCases { + err := types.ValidateChainId("ChainId", tc.chainId) + if tc.expPass { + require.NoError(t, err, "valid case: '%s' should not return error. got %w", tc.name, err) + } else { + require.Error(t, err, "invalid case: '%s' must return error but got none", tc.name) + } + } +}