Skip to content

Commit

Permalink
add UT + doc
Browse files Browse the repository at this point in the history
  • Loading branch information
sainoe committed Feb 20, 2024
1 parent 3d38245 commit ed4233d
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 32 deletions.
10 changes: 0 additions & 10 deletions x/ccv/provider/ibc_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,6 @@ func TestGetProviderDenom(t *testing.T) {
},
}

channeltypes.NewPacket(
[]byte{},
0,
"srcPort",
"srcChannel",
"dstPort",
"dstChannel",
clienttypes.NewHeight(1, 1),
0,
)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
res := provider.GetProviderDenom(
Expand Down
26 changes: 26 additions & 0 deletions x/ccv/provider/keeper/distribution.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package keeper

import (
errorsmod "cosmossdk.io/errors"
"cosmossdk.io/math"
abci "github.com/cometbft/cometbft/abci/types"
sdk "github.com/cosmos/cosmos-sdk/types"
channeltypes "github.com/cosmos/ibc-go/v7/modules/core/04-channel/types"

distrtypes "github.com/cosmos/cosmos-sdk/x/distribution/types"
"github.com/cosmos/interchain-security/v4/x/ccv/provider/types"
Expand Down Expand Up @@ -242,3 +244,27 @@ func (k Keeper) ComputeConsumerTotalVotingPower(ctx sdk.Context, chainID string,

return totalPower
}

// IdentifyConsumerChainIDFromIBCPacket checks if the packet destination matches a registered consumer chain.
// If so, it returns the consumer chain ID, otherwise an error.
func (k Keeper) IdentifyConsumerChainIDFromIBCPacket(ctx sdk.Context, packet channeltypes.Packet) (string, error) {
channel, ok := k.channelKeeper.GetChannel(ctx, packet.DestinationPort, packet.DestinationChannel)
if !ok {
return "", errorsmod.Wrapf(channeltypes.ErrChannelNotFound, "channel not found for channel ID: %s", packet.DestinationChannel)
}
if len(channel.ConnectionHops) != 1 {
return "", errorsmod.Wrap(channeltypes.ErrTooManyConnectionHops, "must have direct connection to consumer chain")
}
connectionID := channel.ConnectionHops[0]
_, tmClient, err := k.getUnderlyingClient(ctx, connectionID)
if err != nil {
return "", err
}

chainID := tmClient.ChainId
if _, ok := k.GetChainToChannel(ctx, chainID); !ok {
return "", errorsmod.Wrapf(types.ErrUnknownConsumerChannelId, "no CCV channel found for chain with ID: %s", chainID)
}

return chainID, nil
}
178 changes: 178 additions & 0 deletions x/ccv/provider/keeper/distribution_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,13 @@ import (
abci "github.com/cometbft/cometbft/abci/types"
tmtypes "github.com/cometbft/cometbft/types"
sdk "github.com/cosmos/cosmos-sdk/types"
clienttypes "github.com/cosmos/ibc-go/v7/modules/core/02-client/types"
conntypes "github.com/cosmos/ibc-go/v7/modules/core/03-connection/types"
channeltypes "github.com/cosmos/ibc-go/v7/modules/core/04-channel/types"
ibctmtypes "github.com/cosmos/ibc-go/v7/modules/light-clients/07-tendermint"
testkeeper "github.com/cosmos/interchain-security/v4/testutil/keeper"
"github.com/cosmos/interchain-security/v4/x/ccv/provider/types"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -58,3 +63,176 @@ func TestComputeConsumerTotalVotingPower(t *testing.T) {

require.Equal(t, expTotalPower, res)
}

func TestIdentifyConsumerChainIDFromIBCPacket(t *testing.T) {

var (
chainID = "consumer"
ccvChannel = "channel-0"
)

testCases := []struct {
name string
packet channeltypes.Packet
expectedCalls func(sdk.Context, testkeeper.MockedKeepers, channeltypes.Packet) []*gomock.Call
expCCVChannel bool
expErr bool
}{
{
"channel not found",
channeltypes.NewPacket(
[]byte{},
0,
"srcPort",
"srcChannel",
"dstPort",
"dstChannel",
clienttypes.NewHeight(1, 1),
0,
),
func(ctx sdk.Context, mocks testkeeper.MockedKeepers, packet channeltypes.Packet) []*gomock.Call {
return []*gomock.Call{
mocks.MockChannelKeeper.EXPECT().GetChannel(
ctx,
packet.DestinationPort,
packet.DestinationChannel,
).Return(channeltypes.Channel{}, false).Times(1),
}
},
false,
true,
},
{
"connection hops can't be greater than 1",
channeltypes.NewPacket(
[]byte{},
0,
"srcPort",
"srcChannel",
"dstPort",
"dstChannel",
clienttypes.NewHeight(1, 1),
0,
),
func(ctx sdk.Context, mocks testkeeper.MockedKeepers, packet channeltypes.Packet) []*gomock.Call {
return []*gomock.Call{
mocks.MockChannelKeeper.EXPECT().GetChannel(
ctx,
packet.DestinationPort,
packet.DestinationChannel,
).Return(channeltypes.Channel{ConnectionHops: []string{"conn1", "conn2"}}, true).Times(1),
}
},
false,
true,
},
{
"underlying client isn't found",
channeltypes.NewPacket(
[]byte{},
0,
"srcPort",
"srcChannel",
"dstPort",
"dstChannel",
clienttypes.NewHeight(1, 1),
0,
),
func(ctx sdk.Context, mocks testkeeper.MockedKeepers, packet channeltypes.Packet) []*gomock.Call {
return []*gomock.Call{
mocks.MockChannelKeeper.EXPECT().GetChannel(
ctx,
packet.DestinationPort,
packet.DestinationChannel,
).Return(channeltypes.Channel{ConnectionHops: []string{"connectionID"}}, true).Times(1),
mocks.MockConnectionKeeper.EXPECT().GetConnection(ctx, "connectionID").Return(
conntypes.ConnectionEnd{ClientId: "clientID"}, true,
).Times(1),
mocks.MockClientKeeper.EXPECT().GetClientState(ctx, "clientID").Return(
&ibctmtypes.ClientState{ChainId: ""}, false,
).Times(1),
}
},
false,
true,
},
{
"no CCV channel registered",
channeltypes.NewPacket(
[]byte{},
0,
"srcPort",
"srcChannel",
"dstPort",
"dstChannel",
clienttypes.NewHeight(1, 1),
0,
),
func(ctx sdk.Context, mocks testkeeper.MockedKeepers, packet channeltypes.Packet) []*gomock.Call {
return []*gomock.Call{
mocks.MockChannelKeeper.EXPECT().GetChannel(
ctx,
packet.DestinationPort,
packet.DestinationChannel,
).Return(channeltypes.Channel{ConnectionHops: []string{"connectionID"}}, true).Times(1),
mocks.MockConnectionKeeper.EXPECT().GetConnection(ctx, "connectionID").Return(
conntypes.ConnectionEnd{ClientId: "clientID"}, true,
).Times(1),
mocks.MockClientKeeper.EXPECT().GetClientState(ctx, "clientID").Return(
&ibctmtypes.ClientState{ChainId: chainID}, true,
).Times(1),
}
},
false,
true,
},
{
"consumer chain identified",
channeltypes.NewPacket(
[]byte{},
0,
"srcPort",
"srcChannel",
"dstPort",
"dstChannel",
clienttypes.NewHeight(1, 1),
0,
),
func(ctx sdk.Context, mocks testkeeper.MockedKeepers, packet channeltypes.Packet) []*gomock.Call {
return []*gomock.Call{
mocks.MockChannelKeeper.EXPECT().GetChannel(
ctx,
packet.DestinationPort,
packet.DestinationChannel,
),
}
},
false,
true,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {

keeper, ctx, ctrl, mocks := testkeeper.GetProviderKeeperAndCtx(t, testkeeper.NewInMemKeeperParams(t))
defer ctrl.Finish()

tc.expectedCalls(ctx, mocks, tc.packet)
_, err := keeper.IdentifyConsumerChainIDFromIBCPacket(
ctx,
tc.packet,
)

if tc.expCCVChannel {
keeper.SetChainToChannel(ctx, chainID, ccvChannel)
}

if !tc.expErr {
require.NoError(t, err)
} else {
require.Error(t, err)
}
})
}
}
22 changes: 0 additions & 22 deletions x/ccv/provider/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -1324,25 +1324,3 @@ func (k Keeper) GetToBeOptedOut(

return addresses
}

func (k Keeper) IdentifyConsumerChainIDFromIBCPacket(ctx sdk.Context, packet channeltypes.Packet) (string, error) {
channel, ok := k.channelKeeper.GetChannel(ctx, packet.DestinationPort, packet.DestinationChannel)
if !ok {
return "", errorsmod.Wrapf(channeltypes.ErrChannelNotFound, "channel not found for channel ID: %s", packet.DestinationChannel)
}
if len(channel.ConnectionHops) != 1 {
return "", errorsmod.Wrap(channeltypes.ErrTooManyConnectionHops, "must have direct connection to consumer chain")
}
connectionID := channel.ConnectionHops[0]
_, tmClient, err := k.getUnderlyingClient(ctx, connectionID)
if err != nil {
return "", err
}

chainID := tmClient.ChainId
if _, ok := k.GetChainToChannel(ctx, chainID); !ok {
return "", errorsmod.Wrapf(types.ErrUnknownConsumerChannelId, "no CCV channel found for chain with ID: %s", chainID)
}

return chainID, nil
}

0 comments on commit ed4233d

Please sign in to comment.