Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(reward): enable empty reward coins for MsgSetRewards #540

Merged
merged 10 commits into from
Feb 23, 2022
14 changes: 10 additions & 4 deletions x/reward/keeper/msg_server_set_reward.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,16 @@ func (k msgServer) SetRewards(goCtx context.Context, msg *types.MsgSetRewards) (
return nil, err
}
}
rewardPool.Coins = msg.Coins
rewardPool.Provider = msg.Provider
rewardPool.LastRewardHeight = msg.LastRewardHeight
k.SetRewardPool(ctx, rewardPool)
if msg.Coins.Empty() || msg.LastRewardHeight == 0 {
rewardPool.Coins = sdk.NewCoins()
rewardPool.LastRewardHeight = 0
k.RemoveRewardPool(ctx, msg.LaunchID)
} else {
rewardPool.Coins = msg.Coins
rewardPool.Provider = msg.Provider
rewardPool.LastRewardHeight = msg.LastRewardHeight
k.SetRewardPool(ctx, rewardPool)
}

return &types.MsgSetRewardsResponse{
PreviousCoins: previousCoins,
Expand Down
56 changes: 51 additions & 5 deletions x/reward/keeper/msg_server_set_reward_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,17 @@ func TestMsgSetRewards(t *testing.T) {
k, lk, _, bk, srv, psrv, _, sdkCtx = setupMsgServer(t)

ctx = sdk.WrapSDKContext(sdkCtx)
moduleBalance = sample.Coins()
newBalance = sample.Coins()
provider = sample.AccAddress()
invalidCoord = sample.Address()
noBalanceCoord = sample.Address()

emptyCoinsBalance = sample.Coins()
lumtis marked this conversation as resolved.
Show resolved Hide resolved
zeroRewarHeightBalance = sample.Coins()
lumtis marked this conversation as resolved.
Show resolved Hide resolved
newBalance = sample.Coins()
moduleBalance = sample.Coins().
Add(emptyCoinsBalance...).
Add(zeroRewarHeightBalance...).
Add(newBalance...)
)
coordMsg := sample.MsgCreateCoordinator(invalidCoord)
res, err := psrv.CreateCoordinator(ctx, &coordMsg)
Expand All @@ -44,7 +50,22 @@ func TestMsgSetRewards(t *testing.T) {
launchTriggeredChain.LaunchTriggered = true
launchTriggeredChainID := lk.AppendChain(sdkCtx, launchTriggeredChain)

err = bk.MintCoins(sdkCtx, types.ModuleName, moduleBalance.Add(newBalance...))
emptyBalanceLaunchID := lk.AppendChain(sdkCtx, sample.Chain(4, res.CoordinatorID))
k.SetRewardPool(sdkCtx, types.RewardPool{
LaunchID: emptyBalanceLaunchID,
Coins: emptyCoinsBalance,
LastRewardHeight: 100,
CurrentRewardHeight: 30,
})
zeroRewardHeightLaunchID := lk.AppendChain(sdkCtx, sample.Chain(5, res.CoordinatorID))
k.SetRewardPool(sdkCtx, types.RewardPool{
LaunchID: zeroRewardHeightLaunchID,
Coins: zeroRewarHeightBalance,
LastRewardHeight: 100,
CurrentRewardHeight: 30,
})

err = bk.MintCoins(sdkCtx, types.ModuleName, moduleBalance)
lumtis marked this conversation as resolved.
Show resolved Hide resolved
require.NoError(t, err)
err = bk.SendCoinsFromModuleToAccount(sdkCtx, types.ModuleName, provider, newBalance)
require.NoError(t, err)
Expand Down Expand Up @@ -114,6 +135,24 @@ func TestMsgSetRewards(t *testing.T) {
},
err: sdkerrors.ErrInsufficientFunds,
},
{
name: "empty coins",
msg: types.MsgSetRewards{
Provider: provider.String(),
LaunchID: emptyBalanceLaunchID,
Coins: emptyCoinsBalance,
LastRewardHeight: 1000,
},
},
{
name: "zero reward height",
msg: types.MsgSetRewards{
Provider: provider.String(),
LaunchID: zeroRewardHeightLaunchID,
Coins: zeroRewarHeightBalance,
LastRewardHeight: 0,
},
},
{
name: "valid message",
msg: types.MsgSetRewards{
Expand All @@ -134,16 +173,23 @@ func TestMsgSetRewards(t *testing.T) {
}
require.NoError(t, err)

require.Equal(t, previusRewardPool.Coins, got.PreviousCoins)
require.Equal(t, previusRewardPool.LastRewardHeight, got.PreviousLastRewardHeight)

rewardPool, found := k.GetRewardPool(sdkCtx, tt.msg.LaunchID)
if tt.msg.Coins.Empty() || tt.msg.LastRewardHeight == 0 {
require.False(t, found)
require.Equal(t, uint64(0), got.NewLastRewardHeight)
require.Equal(t, sdk.NewCoins(), got.NewCoins)
return
}
require.True(t, found)
require.Equal(t, tt.msg.Coins, rewardPool.Coins)
require.Equal(t, tt.msg.Provider, rewardPool.Provider)
require.Equal(t, tt.msg.LastRewardHeight, rewardPool.LastRewardHeight)

require.Equal(t, tt.msg.Coins, got.NewCoins)
require.Equal(t, tt.msg.LastRewardHeight, got.NewLastRewardHeight)
require.Equal(t, previusRewardPool.Coins, got.PreviousCoins)
require.Equal(t, previusRewardPool.LastRewardHeight, got.PreviousLastRewardHeight)
})
}
}
Expand Down
3 changes: 0 additions & 3 deletions x/reward/types/message_set_reward.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,6 @@ func (msg *MsgSetRewards) ValidateBasic() error {
if _, err := sdk.AccAddressFromBech32(msg.Provider); err != nil {
return sdkerrors.Wrapf(sdkerrors.ErrInvalidAddress, "invalid provider address (%s)", err)
}
if msg.Coins.Empty() {
return sdkerrors.Wrap(ErrInvalidRewardPoolCoins, "empty reward pool coins")
}
if err := msg.Coins.Validate(); err != nil {
return sdkerrors.Wrapf(ErrInvalidRewardPoolCoins, "invalid reward pool coins (%s)", err)
}
Expand Down
19 changes: 9 additions & 10 deletions x/reward/types/message_set_reward_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,6 @@ func TestMsgSetRewards_ValidateBasic(t *testing.T) {
},
err: sdkerrors.ErrInvalidAddress,
},
{
name: "empty coins",
msg: types.MsgSetRewards{
LaunchID: 1,
Provider: sample.Address(),
Coins: sdk.NewCoins(),
LastRewardHeight: 50,
},
err: types.ErrInvalidRewardPoolCoins,
},
{
name: "invalid coins",
msg: types.MsgSetRewards{
Expand All @@ -59,6 +49,15 @@ func TestMsgSetRewards_ValidateBasic(t *testing.T) {
LastRewardHeight: 50,
},
},
{
name: "valid reward pool message with empty coins",
msg: types.MsgSetRewards{
LaunchID: 1,
Provider: sample.Address(),
Coins: sdk.NewCoins(),
LastRewardHeight: 50,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down