diff --git a/src/staking/SnapshotStakingPool.sol b/src/staking/SnapshotStakingPool.sol index d6d07e6..7831055 100644 --- a/src/staking/SnapshotStakingPool.sol +++ b/src/staking/SnapshotStakingPool.sol @@ -15,6 +15,25 @@ import {ReentrancyGuard} from "@openzeppelin/contracts/security/ReentrancyGuard. /// on snapshots taken when rewards are accrued. contract SnapshotStakingPool is ISnapshotStakingPool, Ownable, ERC20Snapshot, ReentrancyGuard { + /* ERRORS */ + + /// @notice Error when accrue is called by non-distributor + error MustBeDistributor(); + /// @notice Error when trying to accrue zero rewards + error CannotAccrueZero(); + /// @notice Error when trying to accrue rewards with zero staked supply + error CannotAccrueWithZeroStakedSupply(); + /// @notice Error when trying to accrue rewards before snapshot delay + error SnapshotDelayNotPassed(); + /// @notice Error when trying to claim rewards from past snapshots + error CannotClaimFromPastSnapshots(); + /// @notice Error when snapshot id is invalid + error InvalidSnapshotId(); + /// @notice Error when snapshot id does not exist + error NonExistentSnapshotId(); + /// @notice Error when transfers are attempted + error TransfersNotAllowed(); + /* EVENTS */ /// @notice Emitted when the reward distributor is changed. @@ -70,7 +89,7 @@ contract SnapshotStakingPool is ISnapshotStakingPool, Ownable, ERC20Snapshot, Re /// @dev Reverts if the caller is not the distributor. modifier onlyDistributor() { - require(msg.sender == distributor, "Must be distributor"); + if (msg.sender != distributor) revert MustBeDistributor(); _; } @@ -83,16 +102,15 @@ contract SnapshotStakingPool is ISnapshotStakingPool, Ownable, ERC20Snapshot, Re /// @inheritdoc ISnapshotStakingPool function unstake(uint256 amount) public nonReentrant { - require(amount > 0, "Cannot unstake 0"); super._burn(msg.sender, amount); stakeToken.transfer(msg.sender, amount); } /// @inheritdoc ISnapshotStakingPool function accrue(uint256 amount) external nonReentrant onlyDistributor { - require(amount > 0, "Cannot accrue 0"); - require(totalSupply() > 0, "Cannot accrue with 0 staked supply"); - require(canAccrue(), "Snapshot delay not passed"); + if (amount == 0) revert CannotAccrueZero(); + if (totalSupply() == 0) revert CannotAccrueWithZeroStakedSupply(); + if (!canAccrue()) revert SnapshotDelayNotPassed(); rewardToken.transferFrom(msg.sender, address(this), amount); lastSnapshotTime = block.timestamp; rewardSnapshots.push(amount); @@ -103,19 +121,13 @@ contract SnapshotStakingPool is ISnapshotStakingPool, Ownable, ERC20Snapshot, Re function claim() public nonReentrant { uint256 currentId = _getCurrentSnapshotId(); uint256 lastId = nextClaimId[msg.sender]; - uint256 amount = rewardOfInRange(msg.sender, lastId, currentId); - require(amount > 0, "No rewards to claim"); - nextClaimId[msg.sender] = currentId + 1; - rewardToken.transfer(msg.sender, amount); + _claim(msg.sender, lastId, currentId); } /// @inheritdoc ISnapshotStakingPool function claimPartial(uint256 startSnapshotId, uint256 endSnapshotId) public nonReentrant { - require(startSnapshotId >= nextClaimId[msg.sender], "Cannot claim from past snapshots"); - uint256 amount = rewardOfInRange(msg.sender, startSnapshotId, endSnapshotId); - require(amount > 0, "No rewards to claim"); - nextClaimId[msg.sender] = endSnapshotId + 1; - rewardToken.transfer(msg.sender, amount); + if (startSnapshotId < nextClaimId[msg.sender]) revert CannotClaimFromPastSnapshots(); + _claim(msg.sender, startSnapshotId, endSnapshotId); } /* ADMIN FUNCTIONS */ @@ -136,12 +148,12 @@ contract SnapshotStakingPool is ISnapshotStakingPool, Ownable, ERC20Snapshot, Re /// @notice Prevents transfers of the staked token. function transfer(address /*recipient*/, uint256 /*amount*/) public pure override(ERC20, IERC20) returns (bool) { - revert("Transfers not allowed"); + revert TransfersNotAllowed(); } /// @notice Prevents transfers of the staked token. function transferFrom(address /*sender*/, address /*recipient*/, uint256 /*amount*/) public pure override(ERC20, IERC20) returns (bool) { - revert("Transfers not allowed"); + revert TransfersNotAllowed(); } /* VIEW FUNCTIONS */ @@ -160,8 +172,8 @@ contract SnapshotStakingPool is ISnapshotStakingPool, Ownable, ERC20Snapshot, Re /// @inheritdoc ISnapshotStakingPool function rewardOfInRange(address account, uint256 startSnapshotId, uint256 endSnapshotId) public view returns (uint256) { - require(startSnapshotId > 0, "ERC20Snapshot: id is 0"); - require(startSnapshotId <= endSnapshotId && endSnapshotId <= _getCurrentSnapshotId(), "ERC20Snapshot: nonexistent id"); + if (startSnapshotId == 0) revert InvalidSnapshotId(); + if (startSnapshotId > endSnapshotId || endSnapshotId > _getCurrentSnapshotId()) revert NonExistentSnapshotId(); uint256 rewards = 0; for (uint256 i = startSnapshotId; i <= endSnapshotId; i++) { @@ -172,15 +184,15 @@ contract SnapshotStakingPool is ISnapshotStakingPool, Ownable, ERC20Snapshot, Re /// @inheritdoc ISnapshotStakingPool function rewardOfAt(address account, uint256 snapshotId) public view virtual returns (uint256) { - require(snapshotId > 0, "ERC20Snapshot: id is 0"); - require(snapshotId <= _getCurrentSnapshotId(), "ERC20Snapshot: nonexistent id"); + if (snapshotId == 0) revert InvalidSnapshotId(); + if (snapshotId > _getCurrentSnapshotId()) revert NonExistentSnapshotId(); return _rewardOfAt(account, snapshotId); } /// @inheritdoc ISnapshotStakingPool function rewardAt(uint256 snapshotId) public view virtual returns (uint256) { - require(snapshotId > 0, "ERC20Snapshot: id is 0"); - require(snapshotId <= _getCurrentSnapshotId(), "ERC20Snapshot: nonexistent id"); + if (snapshotId == 0) revert InvalidSnapshotId(); + if (snapshotId > _getCurrentSnapshotId()) revert NonExistentSnapshotId(); return _rewardAt(snapshotId); } @@ -205,7 +217,6 @@ contract SnapshotStakingPool is ISnapshotStakingPool, Ownable, ERC20Snapshot, Re /* INTERNAL FUNCTIONS */ function _stake(address account, uint256 amount) internal { - require(amount > 0, "Cannot stake 0"); if (nextClaimId[account] == 0) { uint256 currentId = _getCurrentSnapshotId(); nextClaimId[account] = currentId > 0 ? currentId : 1; @@ -214,6 +225,12 @@ contract SnapshotStakingPool is ISnapshotStakingPool, Ownable, ERC20Snapshot, Re super._mint(msg.sender, amount); } + function _claim(address account, uint256 startSnapshotId, uint256 endSnapshotId) internal { + uint256 amount = rewardOfInRange(account, startSnapshotId, endSnapshotId); + nextClaimId[account] = endSnapshotId + 1; + rewardToken.transfer(account, amount); + } + function _rewardAt(uint256 snapshotId) internal view returns (uint256) { return rewardSnapshots[snapshotId - 1]; } diff --git a/test/staking/SnapshotStakingPool.t.sol b/test/staking/SnapshotStakingPool.t.sol index 32ea9a6..28f76a0 100644 --- a/test/staking/SnapshotStakingPool.t.sol +++ b/test/staking/SnapshotStakingPool.t.sol @@ -88,10 +88,6 @@ contract SnapshotStakingPoolTest is Test { assertEq(snapshotStakingPool.balanceOf(bob.addr), amount); assertEq(snapshotStakingPool.nextClaimId(bob.addr), 1); assertEq(snapshotStakingPool.getCurrentSnapshotId(), 0); - - vm.prank(bob.addr); - vm.expectRevert("Cannot stake 0"); - snapshotStakingPool.stake(0); } function testUnstake() public { @@ -110,10 +106,6 @@ contract SnapshotStakingPoolTest is Test { assertEq(stakeToken.balanceOf(bob.addr), amount); assertEq(snapshotStakingPool.balanceOf(bob.addr), 0); - - vm.prank(bob.addr); - vm.expectRevert("Cannot unstake 0"); - snapshotStakingPool.unstake(0); } function testAccrue() public { @@ -142,22 +134,22 @@ contract SnapshotStakingPoolTest is Test { assertEq(snapshotStakingPool.getCurrentSnapshotId(), 1); vm.prank(distributor); - vm.expectRevert("Snapshot delay not passed"); + vm.expectRevert(SnapshotStakingPool.SnapshotDelayNotPassed.selector); snapshotStakingPool.accrue(amount); vm.prank(distributor); - vm.expectRevert("Cannot accrue 0"); + vm.expectRevert(SnapshotStakingPool.CannotAccrueZero.selector); snapshotStakingPool.accrue(0); vm.prank(bob.addr); snapshotStakingPool.unstake(amount); vm.prank(distributor); - vm.expectRevert("Cannot accrue with 0 staked supply"); + vm.expectRevert(SnapshotStakingPool.CannotAccrueWithZeroStakedSupply.selector); snapshotStakingPool.accrue(amount); vm.prank(bob.addr); - vm.expectRevert("Must be distributor"); + vm.expectRevert(SnapshotStakingPool.MustBeDistributor.selector); snapshotStakingPool.accrue(amount); } @@ -171,10 +163,10 @@ contract SnapshotStakingPoolTest is Test { } function testRewardAt() public { - vm.expectRevert("ERC20Snapshot: id is 0"); + vm.expectRevert(SnapshotStakingPool.InvalidSnapshotId.selector); snapshotStakingPool.rewardAt(0); - vm.expectRevert("ERC20Snapshot: nonexistent id"); + vm.expectRevert(SnapshotStakingPool.NonExistentSnapshotId.selector); snapshotStakingPool.rewardAt(1); _stake(bob.addr, 1 ether); @@ -182,15 +174,15 @@ contract SnapshotStakingPoolTest is Test { assertEq(snapshotStakingPool.rewardAt(1), 1 ether); - vm.expectRevert("ERC20Snapshot: nonexistent id"); + vm.expectRevert(SnapshotStakingPool.NonExistentSnapshotId.selector); snapshotStakingPool.rewardAt(2); } function testRewardOfAt() public { - vm.expectRevert("ERC20Snapshot: id is 0"); + vm.expectRevert(SnapshotStakingPool.InvalidSnapshotId.selector); snapshotStakingPool.rewardOfAt(bob.addr, 0); - vm.expectRevert("ERC20Snapshot: nonexistent id"); + vm.expectRevert(SnapshotStakingPool.NonExistentSnapshotId.selector); snapshotStakingPool.rewardOfAt(bob.addr, 1); _stake(bob.addr, 1 ether); @@ -206,20 +198,20 @@ contract SnapshotStakingPoolTest is Test { assertEq(snapshotStakingPool.rewardOfAt(bob.addr, 2), 0); assertEq(snapshotStakingPool.rewardOfAt(alice.addr, 2), 1 ether); - vm.expectRevert("ERC20Snapshot: nonexistent id"); + vm.expectRevert(SnapshotStakingPool.NonExistentSnapshotId.selector); snapshotStakingPool.rewardOfAt(bob.addr, 3); } function testRewardOfInRange() public { - vm.expectRevert("ERC20Snapshot: id is 0"); + vm.expectRevert(SnapshotStakingPool.InvalidSnapshotId.selector); snapshotStakingPool.rewardOfInRange(bob.addr, 0, 0); vm.prank(bob.addr); - vm.expectRevert("ERC20Snapshot: id is 0"); + vm.expectRevert(SnapshotStakingPool.InvalidSnapshotId.selector); snapshotStakingPool.rewardOfInRange(bob.addr, 0, 1); vm.prank(bob.addr); - vm.expectRevert("ERC20Snapshot: nonexistent id"); + vm.expectRevert(SnapshotStakingPool.NonExistentSnapshotId.selector); snapshotStakingPool.rewardOfInRange(bob.addr, 1, 1); _stake(bob.addr, 1 ether); @@ -235,17 +227,17 @@ contract SnapshotStakingPoolTest is Test { assertEq(snapshotStakingPool.rewardOfInRange(bob.addr, 1, 2), 2 ether); assertEq(snapshotStakingPool.rewardOfInRange(alice.addr, 1, 3), 3 ether); - vm.expectRevert("ERC20Snapshot: nonexistent id"); + vm.expectRevert(SnapshotStakingPool.NonExistentSnapshotId.selector); snapshotStakingPool.rewardOfInRange(bob.addr, 1, 4); } function testGetPendingRewards() public { - vm.expectRevert("ERC20Snapshot: id is 0"); + vm.expectRevert(SnapshotStakingPool.InvalidSnapshotId.selector); snapshotStakingPool.getPendingRewards(bob.addr); _stake(bob.addr, 1 ether); _stake(alice.addr, 1 ether); - vm.expectRevert("ERC20Snapshot: nonexistent id"); + vm.expectRevert(SnapshotStakingPool.NonExistentSnapshotId.selector); snapshotStakingPool.getPendingRewards(bob.addr); _snapshot(2 ether); @@ -254,7 +246,7 @@ contract SnapshotStakingPoolTest is Test { vm.prank(bob.addr); snapshotStakingPool.claim(); - vm.expectRevert("ERC20Snapshot: nonexistent id"); + vm.expectRevert(SnapshotStakingPool.NonExistentSnapshotId.selector); assertEq(snapshotStakingPool.getPendingRewards(bob.addr), 0); _snapshot(1 ether); @@ -305,13 +297,13 @@ contract SnapshotStakingPoolTest is Test { function testClaim() public { vm.prank(bob.addr); - vm.expectRevert("ERC20Snapshot: id is 0"); + vm.expectRevert(SnapshotStakingPool.InvalidSnapshotId.selector); snapshotStakingPool.claim(); _stake(bob.addr, 1 ether); _stake(alice.addr, 1 ether); vm.prank(bob.addr); - vm.expectRevert("ERC20Snapshot: nonexistent id"); + vm.expectRevert(SnapshotStakingPool.NonExistentSnapshotId.selector); snapshotStakingPool.claim(); _snapshot(2 ether); @@ -333,10 +325,10 @@ contract SnapshotStakingPoolTest is Test { assertEq(snapshotStakingPool.nextClaimId(alice.addr), 3); vm.prank(bob.addr); - vm.expectRevert("ERC20Snapshot: nonexistent id"); + vm.expectRevert(SnapshotStakingPool.NonExistentSnapshotId.selector); snapshotStakingPool.claim(); vm.prank(alice.addr); - vm.expectRevert("ERC20Snapshot: nonexistent id"); + vm.expectRevert(SnapshotStakingPool.NonExistentSnapshotId.selector); snapshotStakingPool.claim(); _snapshot(1 ether); @@ -359,15 +351,15 @@ contract SnapshotStakingPoolTest is Test { function testClaimPartial() public { vm.prank(bob.addr); - vm.expectRevert("ERC20Snapshot: id is 0"); + vm.expectRevert(SnapshotStakingPool.InvalidSnapshotId.selector); snapshotStakingPool.claimPartial(0, 0); vm.prank(bob.addr); - vm.expectRevert("ERC20Snapshot: id is 0"); + vm.expectRevert(SnapshotStakingPool.InvalidSnapshotId.selector); snapshotStakingPool.claimPartial(0, 1); vm.prank(bob.addr); - vm.expectRevert("ERC20Snapshot: nonexistent id"); + vm.expectRevert(SnapshotStakingPool.NonExistentSnapshotId.selector); snapshotStakingPool.claimPartial(1, 1); _stake(bob.addr, 1 ether); @@ -381,7 +373,7 @@ contract SnapshotStakingPoolTest is Test { assertEq(snapshotStakingPool.nextClaimId(bob.addr), 2); vm.prank(bob.addr); - vm.expectRevert("Cannot claim from past snapshots"); + vm.expectRevert(SnapshotStakingPool.CannotClaimFromPastSnapshots.selector); snapshotStakingPool.claimPartial(1, 1); _snapshot(2 ether); @@ -400,11 +392,11 @@ contract SnapshotStakingPoolTest is Test { assertEq(snapshotStakingPool.nextClaimId(alice.addr), 4); vm.prank(alice.addr); - vm.expectRevert("Cannot claim from past snapshots"); + vm.expectRevert(SnapshotStakingPool.CannotClaimFromPastSnapshots.selector); snapshotStakingPool.claimPartial(1, 3); vm.prank(alice.addr); - vm.expectRevert("ERC20Snapshot: nonexistent id"); + vm.expectRevert(SnapshotStakingPool.NonExistentSnapshotId.selector); snapshotStakingPool.claimPartial(4, 4); _unstake(bob.addr, 1 ether); @@ -422,7 +414,7 @@ contract SnapshotStakingPoolTest is Test { _stake(bob.addr, 1 ether); vm.prank(bob.addr); - vm.expectRevert("Transfers not allowed"); + vm.expectRevert(SnapshotStakingPool.TransfersNotAllowed.selector); snapshotStakingPool.transfer(alice.addr, 1 ether); } @@ -433,7 +425,7 @@ contract SnapshotStakingPoolTest is Test { snapshotStakingPool.approve(alice.addr, 1 ether); vm.prank(alice.addr); - vm.expectRevert("Transfers not allowed"); + vm.expectRevert(SnapshotStakingPool.TransfersNotAllowed.selector); snapshotStakingPool.transferFrom(bob.addr, alice.addr, 1 ether); }