Skip to content

Commit

Permalink
feat(reward): enable empty reward coins for MsgSetRewards (#540)
Browse files Browse the repository at this point in the history
* enable empty reward coins for `MsgSetRewards`

* remove reward if empty coins or last reward height

* create method to init the rewards

* improve the readabillity
  • Loading branch information
Pantani authored Feb 23, 2022
1 parent dfeb8aa commit 9b803b7
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 62 deletions.
14 changes: 10 additions & 4 deletions x/reward/keeper/msg_server_set_reward.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,16 @@ func (k msgServer) SetRewards(goCtx context.Context, msg *types.MsgSetRewards) (
return nil, err
}
}
rewardPool.Coins = msg.Coins
rewardPool.Provider = msg.Provider
rewardPool.LastRewardHeight = msg.LastRewardHeight
k.SetRewardPool(ctx, rewardPool)
if msg.Coins.Empty() || msg.LastRewardHeight == 0 {
rewardPool.Coins = sdk.NewCoins()
rewardPool.LastRewardHeight = 0
k.RemoveRewardPool(ctx, msg.LaunchID)
} else {
rewardPool.Coins = msg.Coins
rewardPool.Provider = msg.Provider
rewardPool.LastRewardHeight = msg.LastRewardHeight
k.SetRewardPool(ctx, rewardPool)
}

return &types.MsgSetRewardsResponse{
PreviousCoins: previousCoins,
Expand Down
130 changes: 85 additions & 45 deletions x/reward/keeper/msg_server_set_reward_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,50 +5,75 @@ import (

sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
bankkeeper "github.com/cosmos/cosmos-sdk/x/bank/keeper"
"github.com/stretchr/testify/require"

tc "github.com/tendermint/spn/testutil/constructor"
"github.com/tendermint/spn/testutil/sample"
launchkeeper "github.com/tendermint/spn/x/launch/keeper"
launchtypes "github.com/tendermint/spn/x/launch/types"
profiletypes "github.com/tendermint/spn/x/profile/types"
"github.com/tendermint/spn/x/reward/keeper"
"github.com/tendermint/spn/x/reward/types"
)

func TestMsgSetRewards(t *testing.T) {
func initRewardPool(
t *testing.T,
k *keeper.Keeper,
bk bankkeeper.Keeper,
lk *launchkeeper.Keeper,
sdkCtx sdk.Context,
psrv profiletypes.MsgServer,
) types.RewardPool {
var (
k, lk, _, bk, srv, psrv, _, sdkCtx = setupMsgServer(t)

ctx = sdk.WrapSDKContext(sdkCtx)
moduleBalance = sample.Coins()
newBalance = sample.Coins()
provider = sample.AccAddress()
invalidCoord = sample.Address()
noBalanceCoord = sample.Address()
provider = sample.AccAddress()
coins = sample.Coins()
ctx = sdk.WrapSDKContext(sdkCtx)
coordMsg = sample.MsgCreateCoordinator(provider.String())
)
coordMsg := sample.MsgCreateCoordinator(invalidCoord)
res, err := psrv.CreateCoordinator(ctx, &coordMsg)
require.NoError(t, err)
launchID := lk.AppendChain(sdkCtx, sample.Chain(1, res.CoordinatorID))
rewardPool := types.RewardPool{
Provider: provider.String(),
LaunchID: launchID,
Coins: coins,
LastRewardHeight: 100,
CurrentRewardHeight: 30,
}
k.SetRewardPool(sdkCtx, rewardPool)

coordMsg = sample.MsgCreateCoordinator(noBalanceCoord)
res, err = psrv.CreateCoordinator(ctx, &coordMsg)
err = bk.MintCoins(sdkCtx, types.ModuleName, coins.Add(coins...))
require.NoError(t, err)
noBalancelaunchID := lk.AppendChain(sdkCtx, sample.Chain(1, res.CoordinatorID))

coordMsg = sample.MsgCreateCoordinator(provider.String())
res, err = psrv.CreateCoordinator(ctx, &coordMsg)
err = bk.SendCoinsFromModuleToAccount(sdkCtx, types.ModuleName, provider, coins)
require.NoError(t, err)
launchID := lk.AppendChain(sdkCtx, sample.Chain(1, res.CoordinatorID))

launchTriggeredChain := sample.Chain(1, res.CoordinatorID)
launchTriggeredChain.LaunchTriggered = true
launchTriggeredChainID := lk.AppendChain(sdkCtx, launchTriggeredChain)
return rewardPool
}

err = bk.MintCoins(sdkCtx, types.ModuleName, moduleBalance.Add(newBalance...))
require.NoError(t, err)
err = bk.SendCoinsFromModuleToAccount(sdkCtx, types.ModuleName, provider, newBalance)
func TestMsgSetRewards(t *testing.T) {
var (
k, lk, _, bk, srv, psrv, _, sdkCtx = setupMsgServer(t)

ctx = sdk.WrapSDKContext(sdkCtx)
invalidCoord = sample.Address()
)
invalidCoordMsg := sample.MsgCreateCoordinator(invalidCoord)
_, err := psrv.CreateCoordinator(ctx, &invalidCoordMsg)
require.NoError(t, err)

var (
rewardPool = initRewardPool(t, k, bk, lk, sdkCtx, psrv)
noBalanceRewadPool = initRewardPool(t, k, bk, lk, sdkCtx, psrv)
emptyCoinsRewadPool = initRewardPool(t, k, bk, lk, sdkCtx, psrv)
zeroRewarHeightRewadPool = initRewardPool(t, k, bk, lk, sdkCtx, psrv)
launchedRewadPool = initRewardPool(t, k, bk, lk, sdkCtx, psrv)
)
launchTriggeredChain, found := lk.GetChain(sdkCtx, launchedRewadPool.LaunchID)
require.True(t, found)
launchTriggeredChain.LaunchTriggered = true
lk.SetChain(sdkCtx, launchTriggeredChain)

tests := []struct {
name string
msg types.MsgSetRewards
Expand All @@ -57,9 +82,9 @@ func TestMsgSetRewards(t *testing.T) {
{
name: "invalid chain",
msg: types.MsgSetRewards{
Provider: provider.String(),
Provider: rewardPool.Provider,
LaunchID: 9999,
Coins: newBalance,
Coins: rewardPool.Coins,
LastRewardHeight: 1000,
},
err: launchtypes.ErrChainNotFound,
Expand All @@ -68,8 +93,8 @@ func TestMsgSetRewards(t *testing.T) {
name: "coordinator address not found",
msg: types.MsgSetRewards{
Provider: sample.Address(),
LaunchID: launchID,
Coins: newBalance,
LaunchID: rewardPool.LaunchID,
Coins: rewardPool.Coins,
LastRewardHeight: 1000,
},
err: profiletypes.ErrCoordAddressNotFound,
Expand All @@ -78,48 +103,56 @@ func TestMsgSetRewards(t *testing.T) {
name: "invalid coordinator id",
msg: types.MsgSetRewards{
Provider: invalidCoord,
LaunchID: launchID,
Coins: newBalance,
LaunchID: rewardPool.LaunchID,
Coins: rewardPool.Coins,
LastRewardHeight: 1000,
},
err: types.ErrInvalidCoordinatorID,
},
{
name: "launch triggered chain",
msg: types.MsgSetRewards{
Provider: provider.String(),
LaunchID: launchTriggeredChainID,
Coins: newBalance,
Provider: launchedRewadPool.Provider,
LaunchID: launchedRewadPool.LaunchID,
Coins: launchedRewadPool.Coins,
LastRewardHeight: 1000,
},
err: launchtypes.ErrTriggeredLaunch,
},
{
name: "coordinator with insufficient funds",
msg: types.MsgSetRewards{
Provider: noBalanceCoord,
LaunchID: noBalancelaunchID,
Coins: newBalance,
Provider: noBalanceRewadPool.Provider,
LaunchID: noBalanceRewadPool.LaunchID,
Coins: sample.Coins(),
LastRewardHeight: 1000,
},
err: sdkerrors.ErrInsufficientFunds,
},
{
name: "coordinator with insufficient funds",
name: "empty coins",
msg: types.MsgSetRewards{
Provider: noBalanceCoord,
LaunchID: noBalancelaunchID,
Coins: newBalance,
Provider: emptyCoinsRewadPool.Provider,
LaunchID: emptyCoinsRewadPool.LaunchID,
Coins: sdk.NewCoins(),
LastRewardHeight: 1000,
},
err: sdkerrors.ErrInsufficientFunds,
},
{
name: "zero reward height",
msg: types.MsgSetRewards{
Provider: zeroRewarHeightRewadPool.Provider,
LaunchID: zeroRewarHeightRewadPool.LaunchID,
Coins: zeroRewarHeightRewadPool.Coins,
LastRewardHeight: 0,
},
},
{
name: "valid message",
msg: types.MsgSetRewards{
Provider: provider.String(),
LaunchID: launchID,
Coins: newBalance,
Provider: rewardPool.Provider,
LaunchID: rewardPool.LaunchID,
Coins: rewardPool.Coins,
LastRewardHeight: 1000,
},
},
Expand All @@ -134,16 +167,23 @@ func TestMsgSetRewards(t *testing.T) {
}
require.NoError(t, err)

require.Equal(t, previusRewardPool.Coins, got.PreviousCoins)
require.Equal(t, previusRewardPool.LastRewardHeight, got.PreviousLastRewardHeight)

rewardPool, found := k.GetRewardPool(sdkCtx, tt.msg.LaunchID)
if tt.msg.Coins.Empty() || tt.msg.LastRewardHeight == 0 {
require.False(t, found)
require.Equal(t, uint64(0), got.NewLastRewardHeight)
require.Equal(t, sdk.NewCoins(), got.NewCoins)
return
}
require.True(t, found)
require.Equal(t, tt.msg.Coins, rewardPool.Coins)
require.Equal(t, tt.msg.Provider, rewardPool.Provider)
require.Equal(t, tt.msg.LastRewardHeight, rewardPool.LastRewardHeight)

require.Equal(t, tt.msg.Coins, got.NewCoins)
require.Equal(t, tt.msg.LastRewardHeight, got.NewLastRewardHeight)
require.Equal(t, previusRewardPool.Coins, got.PreviousCoins)
require.Equal(t, previusRewardPool.LastRewardHeight, got.PreviousLastRewardHeight)
})
}
}
Expand Down
3 changes: 0 additions & 3 deletions x/reward/types/message_set_reward.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,6 @@ func (msg *MsgSetRewards) ValidateBasic() error {
if _, err := sdk.AccAddressFromBech32(msg.Provider); err != nil {
return sdkerrors.Wrapf(sdkerrors.ErrInvalidAddress, "invalid provider address (%s)", err)
}
if msg.Coins.Empty() {
return sdkerrors.Wrap(ErrInvalidRewardPoolCoins, "empty reward pool coins")
}
if err := msg.Coins.Validate(); err != nil {
return sdkerrors.Wrapf(ErrInvalidRewardPoolCoins, "invalid reward pool coins (%s)", err)
}
Expand Down
19 changes: 9 additions & 10 deletions x/reward/types/message_set_reward_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,6 @@ func TestMsgSetRewards_ValidateBasic(t *testing.T) {
},
err: sdkerrors.ErrInvalidAddress,
},
{
name: "empty coins",
msg: types.MsgSetRewards{
LaunchID: 1,
Provider: sample.Address(),
Coins: sdk.NewCoins(),
LastRewardHeight: 50,
},
err: types.ErrInvalidRewardPoolCoins,
},
{
name: "invalid coins",
msg: types.MsgSetRewards{
Expand All @@ -59,6 +49,15 @@ func TestMsgSetRewards_ValidateBasic(t *testing.T) {
LastRewardHeight: 50,
},
},
{
name: "valid reward pool message with empty coins",
msg: types.MsgSetRewards{
LaunchID: 1,
Provider: sample.Address(),
Coins: sdk.NewCoins(),
LastRewardHeight: 50,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down

0 comments on commit 9b803b7

Please sign in to comment.