From 9b803b7972cc3d040ed0aa7d2612936b2616546a Mon Sep 17 00:00:00 2001 From: Danilo Pantani Date: Wed, 23 Feb 2022 15:10:13 -0300 Subject: [PATCH] feat(reward): enable empty reward coins for `MsgSetRewards` (#540) * enable empty reward coins for `MsgSetRewards` * remove reward if empty coins or last reward height * create method to init the rewards * improve the readabillity --- x/reward/keeper/msg_server_set_reward.go | 14 +- x/reward/keeper/msg_server_set_reward_test.go | 130 ++++++++++++------ x/reward/types/message_set_reward.go | 3 - x/reward/types/message_set_reward_test.go | 19 ++- 4 files changed, 104 insertions(+), 62 deletions(-) diff --git a/x/reward/keeper/msg_server_set_reward.go b/x/reward/keeper/msg_server_set_reward.go index 79c5a0011..6771bdd26 100644 --- a/x/reward/keeper/msg_server_set_reward.go +++ b/x/reward/keeper/msg_server_set_reward.go @@ -56,10 +56,16 @@ func (k msgServer) SetRewards(goCtx context.Context, msg *types.MsgSetRewards) ( return nil, err } } - rewardPool.Coins = msg.Coins - rewardPool.Provider = msg.Provider - rewardPool.LastRewardHeight = msg.LastRewardHeight - k.SetRewardPool(ctx, rewardPool) + if msg.Coins.Empty() || msg.LastRewardHeight == 0 { + rewardPool.Coins = sdk.NewCoins() + rewardPool.LastRewardHeight = 0 + k.RemoveRewardPool(ctx, msg.LaunchID) + } else { + rewardPool.Coins = msg.Coins + rewardPool.Provider = msg.Provider + rewardPool.LastRewardHeight = msg.LastRewardHeight + k.SetRewardPool(ctx, rewardPool) + } return &types.MsgSetRewardsResponse{ PreviousCoins: previousCoins, diff --git a/x/reward/keeper/msg_server_set_reward_test.go b/x/reward/keeper/msg_server_set_reward_test.go index cd547449a..c7ba30336 100644 --- a/x/reward/keeper/msg_server_set_reward_test.go +++ b/x/reward/keeper/msg_server_set_reward_test.go @@ -5,50 +5,75 @@ import ( sdk "github.com/cosmos/cosmos-sdk/types" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" + bankkeeper "github.com/cosmos/cosmos-sdk/x/bank/keeper" "github.com/stretchr/testify/require" tc "github.com/tendermint/spn/testutil/constructor" "github.com/tendermint/spn/testutil/sample" + launchkeeper "github.com/tendermint/spn/x/launch/keeper" launchtypes "github.com/tendermint/spn/x/launch/types" profiletypes "github.com/tendermint/spn/x/profile/types" "github.com/tendermint/spn/x/reward/keeper" "github.com/tendermint/spn/x/reward/types" ) -func TestMsgSetRewards(t *testing.T) { +func initRewardPool( + t *testing.T, + k *keeper.Keeper, + bk bankkeeper.Keeper, + lk *launchkeeper.Keeper, + sdkCtx sdk.Context, + psrv profiletypes.MsgServer, +) types.RewardPool { var ( - k, lk, _, bk, srv, psrv, _, sdkCtx = setupMsgServer(t) - - ctx = sdk.WrapSDKContext(sdkCtx) - moduleBalance = sample.Coins() - newBalance = sample.Coins() - provider = sample.AccAddress() - invalidCoord = sample.Address() - noBalanceCoord = sample.Address() + provider = sample.AccAddress() + coins = sample.Coins() + ctx = sdk.WrapSDKContext(sdkCtx) + coordMsg = sample.MsgCreateCoordinator(provider.String()) ) - coordMsg := sample.MsgCreateCoordinator(invalidCoord) res, err := psrv.CreateCoordinator(ctx, &coordMsg) require.NoError(t, err) + launchID := lk.AppendChain(sdkCtx, sample.Chain(1, res.CoordinatorID)) + rewardPool := types.RewardPool{ + Provider: provider.String(), + LaunchID: launchID, + Coins: coins, + LastRewardHeight: 100, + CurrentRewardHeight: 30, + } + k.SetRewardPool(sdkCtx, rewardPool) - coordMsg = sample.MsgCreateCoordinator(noBalanceCoord) - res, err = psrv.CreateCoordinator(ctx, &coordMsg) + err = bk.MintCoins(sdkCtx, types.ModuleName, coins.Add(coins...)) require.NoError(t, err) - noBalancelaunchID := lk.AppendChain(sdkCtx, sample.Chain(1, res.CoordinatorID)) - - coordMsg = sample.MsgCreateCoordinator(provider.String()) - res, err = psrv.CreateCoordinator(ctx, &coordMsg) + err = bk.SendCoinsFromModuleToAccount(sdkCtx, types.ModuleName, provider, coins) require.NoError(t, err) - launchID := lk.AppendChain(sdkCtx, sample.Chain(1, res.CoordinatorID)) - launchTriggeredChain := sample.Chain(1, res.CoordinatorID) - launchTriggeredChain.LaunchTriggered = true - launchTriggeredChainID := lk.AppendChain(sdkCtx, launchTriggeredChain) + return rewardPool +} - err = bk.MintCoins(sdkCtx, types.ModuleName, moduleBalance.Add(newBalance...)) - require.NoError(t, err) - err = bk.SendCoinsFromModuleToAccount(sdkCtx, types.ModuleName, provider, newBalance) +func TestMsgSetRewards(t *testing.T) { + var ( + k, lk, _, bk, srv, psrv, _, sdkCtx = setupMsgServer(t) + + ctx = sdk.WrapSDKContext(sdkCtx) + invalidCoord = sample.Address() + ) + invalidCoordMsg := sample.MsgCreateCoordinator(invalidCoord) + _, err := psrv.CreateCoordinator(ctx, &invalidCoordMsg) require.NoError(t, err) + var ( + rewardPool = initRewardPool(t, k, bk, lk, sdkCtx, psrv) + noBalanceRewadPool = initRewardPool(t, k, bk, lk, sdkCtx, psrv) + emptyCoinsRewadPool = initRewardPool(t, k, bk, lk, sdkCtx, psrv) + zeroRewarHeightRewadPool = initRewardPool(t, k, bk, lk, sdkCtx, psrv) + launchedRewadPool = initRewardPool(t, k, bk, lk, sdkCtx, psrv) + ) + launchTriggeredChain, found := lk.GetChain(sdkCtx, launchedRewadPool.LaunchID) + require.True(t, found) + launchTriggeredChain.LaunchTriggered = true + lk.SetChain(sdkCtx, launchTriggeredChain) + tests := []struct { name string msg types.MsgSetRewards @@ -57,9 +82,9 @@ func TestMsgSetRewards(t *testing.T) { { name: "invalid chain", msg: types.MsgSetRewards{ - Provider: provider.String(), + Provider: rewardPool.Provider, LaunchID: 9999, - Coins: newBalance, + Coins: rewardPool.Coins, LastRewardHeight: 1000, }, err: launchtypes.ErrChainNotFound, @@ -68,8 +93,8 @@ func TestMsgSetRewards(t *testing.T) { name: "coordinator address not found", msg: types.MsgSetRewards{ Provider: sample.Address(), - LaunchID: launchID, - Coins: newBalance, + LaunchID: rewardPool.LaunchID, + Coins: rewardPool.Coins, LastRewardHeight: 1000, }, err: profiletypes.ErrCoordAddressNotFound, @@ -78,8 +103,8 @@ func TestMsgSetRewards(t *testing.T) { name: "invalid coordinator id", msg: types.MsgSetRewards{ Provider: invalidCoord, - LaunchID: launchID, - Coins: newBalance, + LaunchID: rewardPool.LaunchID, + Coins: rewardPool.Coins, LastRewardHeight: 1000, }, err: types.ErrInvalidCoordinatorID, @@ -87,9 +112,9 @@ func TestMsgSetRewards(t *testing.T) { { name: "launch triggered chain", msg: types.MsgSetRewards{ - Provider: provider.String(), - LaunchID: launchTriggeredChainID, - Coins: newBalance, + Provider: launchedRewadPool.Provider, + LaunchID: launchedRewadPool.LaunchID, + Coins: launchedRewadPool.Coins, LastRewardHeight: 1000, }, err: launchtypes.ErrTriggeredLaunch, @@ -97,29 +122,37 @@ func TestMsgSetRewards(t *testing.T) { { name: "coordinator with insufficient funds", msg: types.MsgSetRewards{ - Provider: noBalanceCoord, - LaunchID: noBalancelaunchID, - Coins: newBalance, + Provider: noBalanceRewadPool.Provider, + LaunchID: noBalanceRewadPool.LaunchID, + Coins: sample.Coins(), LastRewardHeight: 1000, }, err: sdkerrors.ErrInsufficientFunds, }, { - name: "coordinator with insufficient funds", + name: "empty coins", msg: types.MsgSetRewards{ - Provider: noBalanceCoord, - LaunchID: noBalancelaunchID, - Coins: newBalance, + Provider: emptyCoinsRewadPool.Provider, + LaunchID: emptyCoinsRewadPool.LaunchID, + Coins: sdk.NewCoins(), LastRewardHeight: 1000, }, - err: sdkerrors.ErrInsufficientFunds, + }, + { + name: "zero reward height", + msg: types.MsgSetRewards{ + Provider: zeroRewarHeightRewadPool.Provider, + LaunchID: zeroRewarHeightRewadPool.LaunchID, + Coins: zeroRewarHeightRewadPool.Coins, + LastRewardHeight: 0, + }, }, { name: "valid message", msg: types.MsgSetRewards{ - Provider: provider.String(), - LaunchID: launchID, - Coins: newBalance, + Provider: rewardPool.Provider, + LaunchID: rewardPool.LaunchID, + Coins: rewardPool.Coins, LastRewardHeight: 1000, }, }, @@ -134,7 +167,16 @@ func TestMsgSetRewards(t *testing.T) { } require.NoError(t, err) + require.Equal(t, previusRewardPool.Coins, got.PreviousCoins) + require.Equal(t, previusRewardPool.LastRewardHeight, got.PreviousLastRewardHeight) + rewardPool, found := k.GetRewardPool(sdkCtx, tt.msg.LaunchID) + if tt.msg.Coins.Empty() || tt.msg.LastRewardHeight == 0 { + require.False(t, found) + require.Equal(t, uint64(0), got.NewLastRewardHeight) + require.Equal(t, sdk.NewCoins(), got.NewCoins) + return + } require.True(t, found) require.Equal(t, tt.msg.Coins, rewardPool.Coins) require.Equal(t, tt.msg.Provider, rewardPool.Provider) @@ -142,8 +184,6 @@ func TestMsgSetRewards(t *testing.T) { require.Equal(t, tt.msg.Coins, got.NewCoins) require.Equal(t, tt.msg.LastRewardHeight, got.NewLastRewardHeight) - require.Equal(t, previusRewardPool.Coins, got.PreviousCoins) - require.Equal(t, previusRewardPool.LastRewardHeight, got.PreviousLastRewardHeight) }) } } diff --git a/x/reward/types/message_set_reward.go b/x/reward/types/message_set_reward.go index e57e0c9fa..0578b34c8 100644 --- a/x/reward/types/message_set_reward.go +++ b/x/reward/types/message_set_reward.go @@ -43,9 +43,6 @@ func (msg *MsgSetRewards) ValidateBasic() error { if _, err := sdk.AccAddressFromBech32(msg.Provider); err != nil { return sdkerrors.Wrapf(sdkerrors.ErrInvalidAddress, "invalid provider address (%s)", err) } - if msg.Coins.Empty() { - return sdkerrors.Wrap(ErrInvalidRewardPoolCoins, "empty reward pool coins") - } if err := msg.Coins.Validate(); err != nil { return sdkerrors.Wrapf(ErrInvalidRewardPoolCoins, "invalid reward pool coins (%s)", err) } diff --git a/x/reward/types/message_set_reward_test.go b/x/reward/types/message_set_reward_test.go index 29e353c97..67e6ef8e7 100644 --- a/x/reward/types/message_set_reward_test.go +++ b/x/reward/types/message_set_reward_test.go @@ -27,16 +27,6 @@ func TestMsgSetRewards_ValidateBasic(t *testing.T) { }, err: sdkerrors.ErrInvalidAddress, }, - { - name: "empty coins", - msg: types.MsgSetRewards{ - LaunchID: 1, - Provider: sample.Address(), - Coins: sdk.NewCoins(), - LastRewardHeight: 50, - }, - err: types.ErrInvalidRewardPoolCoins, - }, { name: "invalid coins", msg: types.MsgSetRewards{ @@ -59,6 +49,15 @@ func TestMsgSetRewards_ValidateBasic(t *testing.T) { LastRewardHeight: 50, }, }, + { + name: "valid reward pool message with empty coins", + msg: types.MsgSetRewards{ + LaunchID: 1, + Provider: sample.Address(), + Coins: sdk.NewCoins(), + LastRewardHeight: 50, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {