Skip to content

Commit

Permalink
add custom errors, cleanup checks
Browse files Browse the repository at this point in the history
  • Loading branch information
pblivin0x committed Jul 3, 2024
1 parent e68a09d commit 0697f39
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 60 deletions.
63 changes: 40 additions & 23 deletions src/staking/SnapshotStakingPool.sol
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,25 @@ import {ReentrancyGuard} from "@openzeppelin/contracts/security/ReentrancyGuard.
/// on snapshots taken when rewards are accrued.
contract SnapshotStakingPool is ISnapshotStakingPool, Ownable, ERC20Snapshot, ReentrancyGuard {

/* ERRORS */

/// @notice Error when accrue is called by non-distributor
error MustBeDistributor();
/// @notice Error when trying to accrue zero rewards
error CannotAccrueZero();
/// @notice Error when trying to accrue rewards with zero staked supply
error CannotAccrueWithZeroStakedSupply();
/// @notice Error when trying to accrue rewards before snapshot delay
error SnapshotDelayNotPassed();
/// @notice Error when trying to claim rewards from past snapshots
error CannotClaimFromPastSnapshots();
/// @notice Error when snapshot id is invalid
error InvalidSnapshotId();
/// @notice Error when snapshot id does not exist
error NonExistentSnapshotId();
/// @notice Error when transfers are attempted
error TransfersNotAllowed();

/* EVENTS */

/// @notice Emitted when the reward distributor is changed.
Expand Down Expand Up @@ -70,7 +89,7 @@ contract SnapshotStakingPool is ISnapshotStakingPool, Ownable, ERC20Snapshot, Re

/// @dev Reverts if the caller is not the distributor.
modifier onlyDistributor() {
require(msg.sender == distributor, "Must be distributor");
if (msg.sender != distributor) revert MustBeDistributor();
_;
}

Expand All @@ -83,16 +102,15 @@ contract SnapshotStakingPool is ISnapshotStakingPool, Ownable, ERC20Snapshot, Re

/// @inheritdoc ISnapshotStakingPool
function unstake(uint256 amount) public nonReentrant {
require(amount > 0, "Cannot unstake 0");
super._burn(msg.sender, amount);
stakeToken.transfer(msg.sender, amount);
}

/// @inheritdoc ISnapshotStakingPool
function accrue(uint256 amount) external nonReentrant onlyDistributor {
require(amount > 0, "Cannot accrue 0");
require(totalSupply() > 0, "Cannot accrue with 0 staked supply");
require(canAccrue(), "Snapshot delay not passed");
if (amount == 0) revert CannotAccrueZero();
if (totalSupply() == 0) revert CannotAccrueWithZeroStakedSupply();
if (!canAccrue()) revert SnapshotDelayNotPassed();
rewardToken.transferFrom(msg.sender, address(this), amount);
lastSnapshotTime = block.timestamp;
rewardSnapshots.push(amount);
Expand All @@ -103,19 +121,13 @@ contract SnapshotStakingPool is ISnapshotStakingPool, Ownable, ERC20Snapshot, Re
function claim() public nonReentrant {
uint256 currentId = _getCurrentSnapshotId();
uint256 lastId = nextClaimId[msg.sender];
uint256 amount = rewardOfInRange(msg.sender, lastId, currentId);
require(amount > 0, "No rewards to claim");
nextClaimId[msg.sender] = currentId + 1;
rewardToken.transfer(msg.sender, amount);
_claim(msg.sender, lastId, currentId);
}

/// @inheritdoc ISnapshotStakingPool
function claimPartial(uint256 startSnapshotId, uint256 endSnapshotId) public nonReentrant {
require(startSnapshotId >= nextClaimId[msg.sender], "Cannot claim from past snapshots");
uint256 amount = rewardOfInRange(msg.sender, startSnapshotId, endSnapshotId);
require(amount > 0, "No rewards to claim");
nextClaimId[msg.sender] = endSnapshotId + 1;
rewardToken.transfer(msg.sender, amount);
if (startSnapshotId < nextClaimId[msg.sender]) revert CannotClaimFromPastSnapshots();
_claim(msg.sender, startSnapshotId, endSnapshotId);
}

/* ADMIN FUNCTIONS */
Expand All @@ -136,12 +148,12 @@ contract SnapshotStakingPool is ISnapshotStakingPool, Ownable, ERC20Snapshot, Re

/// @notice Prevents transfers of the staked token.
function transfer(address /*recipient*/, uint256 /*amount*/) public pure override(ERC20, IERC20) returns (bool) {
revert("Transfers not allowed");
revert TransfersNotAllowed();
}

/// @notice Prevents transfers of the staked token.
function transferFrom(address /*sender*/, address /*recipient*/, uint256 /*amount*/) public pure override(ERC20, IERC20) returns (bool) {
revert("Transfers not allowed");
revert TransfersNotAllowed();
}

/* VIEW FUNCTIONS */
Expand All @@ -160,8 +172,8 @@ contract SnapshotStakingPool is ISnapshotStakingPool, Ownable, ERC20Snapshot, Re

/// @inheritdoc ISnapshotStakingPool
function rewardOfInRange(address account, uint256 startSnapshotId, uint256 endSnapshotId) public view returns (uint256) {
require(startSnapshotId > 0, "ERC20Snapshot: id is 0");
require(startSnapshotId <= endSnapshotId && endSnapshotId <= _getCurrentSnapshotId(), "ERC20Snapshot: nonexistent id");
if (startSnapshotId == 0) revert InvalidSnapshotId();
if (startSnapshotId > endSnapshotId || endSnapshotId > _getCurrentSnapshotId()) revert NonExistentSnapshotId();

uint256 rewards = 0;
for (uint256 i = startSnapshotId; i <= endSnapshotId; i++) {
Expand All @@ -172,15 +184,15 @@ contract SnapshotStakingPool is ISnapshotStakingPool, Ownable, ERC20Snapshot, Re

/// @inheritdoc ISnapshotStakingPool
function rewardOfAt(address account, uint256 snapshotId) public view virtual returns (uint256) {
require(snapshotId > 0, "ERC20Snapshot: id is 0");
require(snapshotId <= _getCurrentSnapshotId(), "ERC20Snapshot: nonexistent id");
if (snapshotId == 0) revert InvalidSnapshotId();
if (snapshotId > _getCurrentSnapshotId()) revert NonExistentSnapshotId();
return _rewardOfAt(account, snapshotId);
}

/// @inheritdoc ISnapshotStakingPool
function rewardAt(uint256 snapshotId) public view virtual returns (uint256) {
require(snapshotId > 0, "ERC20Snapshot: id is 0");
require(snapshotId <= _getCurrentSnapshotId(), "ERC20Snapshot: nonexistent id");
if (snapshotId == 0) revert InvalidSnapshotId();
if (snapshotId > _getCurrentSnapshotId()) revert NonExistentSnapshotId();
return _rewardAt(snapshotId);
}

Expand All @@ -205,7 +217,6 @@ contract SnapshotStakingPool is ISnapshotStakingPool, Ownable, ERC20Snapshot, Re
/* INTERNAL FUNCTIONS */

function _stake(address account, uint256 amount) internal {
require(amount > 0, "Cannot stake 0");
if (nextClaimId[account] == 0) {
uint256 currentId = _getCurrentSnapshotId();
nextClaimId[account] = currentId > 0 ? currentId : 1;
Expand All @@ -214,6 +225,12 @@ contract SnapshotStakingPool is ISnapshotStakingPool, Ownable, ERC20Snapshot, Re
super._mint(msg.sender, amount);
}

function _claim(address account, uint256 startSnapshotId, uint256 endSnapshotId) internal {
uint256 amount = rewardOfInRange(account, startSnapshotId, endSnapshotId);
nextClaimId[account] = endSnapshotId + 1;
rewardToken.transfer(account, amount);
}

function _rewardAt(uint256 snapshotId) internal view returns (uint256) {
return rewardSnapshots[snapshotId - 1];
}
Expand Down
66 changes: 29 additions & 37 deletions test/staking/SnapshotStakingPool.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,6 @@ contract SnapshotStakingPoolTest is Test {
assertEq(snapshotStakingPool.balanceOf(bob.addr), amount);
assertEq(snapshotStakingPool.nextClaimId(bob.addr), 1);
assertEq(snapshotStakingPool.getCurrentSnapshotId(), 0);

vm.prank(bob.addr);
vm.expectRevert("Cannot stake 0");
snapshotStakingPool.stake(0);
}

function testUnstake() public {
Expand All @@ -110,10 +106,6 @@ contract SnapshotStakingPoolTest is Test {

assertEq(stakeToken.balanceOf(bob.addr), amount);
assertEq(snapshotStakingPool.balanceOf(bob.addr), 0);

vm.prank(bob.addr);
vm.expectRevert("Cannot unstake 0");
snapshotStakingPool.unstake(0);
}

function testAccrue() public {
Expand Down Expand Up @@ -142,22 +134,22 @@ contract SnapshotStakingPoolTest is Test {
assertEq(snapshotStakingPool.getCurrentSnapshotId(), 1);

vm.prank(distributor);
vm.expectRevert("Snapshot delay not passed");
vm.expectRevert(SnapshotStakingPool.SnapshotDelayNotPassed.selector);
snapshotStakingPool.accrue(amount);

vm.prank(distributor);
vm.expectRevert("Cannot accrue 0");
vm.expectRevert(SnapshotStakingPool.CannotAccrueZero.selector);
snapshotStakingPool.accrue(0);

vm.prank(bob.addr);
snapshotStakingPool.unstake(amount);

vm.prank(distributor);
vm.expectRevert("Cannot accrue with 0 staked supply");
vm.expectRevert(SnapshotStakingPool.CannotAccrueWithZeroStakedSupply.selector);
snapshotStakingPool.accrue(amount);

vm.prank(bob.addr);
vm.expectRevert("Must be distributor");
vm.expectRevert(SnapshotStakingPool.MustBeDistributor.selector);
snapshotStakingPool.accrue(amount);
}

Expand All @@ -171,26 +163,26 @@ contract SnapshotStakingPoolTest is Test {
}

function testRewardAt() public {
vm.expectRevert("ERC20Snapshot: id is 0");
vm.expectRevert(SnapshotStakingPool.InvalidSnapshotId.selector);
snapshotStakingPool.rewardAt(0);

vm.expectRevert("ERC20Snapshot: nonexistent id");
vm.expectRevert(SnapshotStakingPool.NonExistentSnapshotId.selector);
snapshotStakingPool.rewardAt(1);

_stake(bob.addr, 1 ether);
_snapshot(1 ether);

assertEq(snapshotStakingPool.rewardAt(1), 1 ether);

vm.expectRevert("ERC20Snapshot: nonexistent id");
vm.expectRevert(SnapshotStakingPool.NonExistentSnapshotId.selector);
snapshotStakingPool.rewardAt(2);
}

function testRewardOfAt() public {
vm.expectRevert("ERC20Snapshot: id is 0");
vm.expectRevert(SnapshotStakingPool.InvalidSnapshotId.selector);
snapshotStakingPool.rewardOfAt(bob.addr, 0);

vm.expectRevert("ERC20Snapshot: nonexistent id");
vm.expectRevert(SnapshotStakingPool.NonExistentSnapshotId.selector);
snapshotStakingPool.rewardOfAt(bob.addr, 1);

_stake(bob.addr, 1 ether);
Expand All @@ -206,20 +198,20 @@ contract SnapshotStakingPoolTest is Test {
assertEq(snapshotStakingPool.rewardOfAt(bob.addr, 2), 0);
assertEq(snapshotStakingPool.rewardOfAt(alice.addr, 2), 1 ether);

vm.expectRevert("ERC20Snapshot: nonexistent id");
vm.expectRevert(SnapshotStakingPool.NonExistentSnapshotId.selector);
snapshotStakingPool.rewardOfAt(bob.addr, 3);
}

function testRewardOfInRange() public {
vm.expectRevert("ERC20Snapshot: id is 0");
vm.expectRevert(SnapshotStakingPool.InvalidSnapshotId.selector);
snapshotStakingPool.rewardOfInRange(bob.addr, 0, 0);

vm.prank(bob.addr);
vm.expectRevert("ERC20Snapshot: id is 0");
vm.expectRevert(SnapshotStakingPool.InvalidSnapshotId.selector);
snapshotStakingPool.rewardOfInRange(bob.addr, 0, 1);

vm.prank(bob.addr);
vm.expectRevert("ERC20Snapshot: nonexistent id");
vm.expectRevert(SnapshotStakingPool.NonExistentSnapshotId.selector);
snapshotStakingPool.rewardOfInRange(bob.addr, 1, 1);

_stake(bob.addr, 1 ether);
Expand All @@ -235,17 +227,17 @@ contract SnapshotStakingPoolTest is Test {
assertEq(snapshotStakingPool.rewardOfInRange(bob.addr, 1, 2), 2 ether);
assertEq(snapshotStakingPool.rewardOfInRange(alice.addr, 1, 3), 3 ether);

vm.expectRevert("ERC20Snapshot: nonexistent id");
vm.expectRevert(SnapshotStakingPool.NonExistentSnapshotId.selector);
snapshotStakingPool.rewardOfInRange(bob.addr, 1, 4);
}

function testGetPendingRewards() public {
vm.expectRevert("ERC20Snapshot: id is 0");
vm.expectRevert(SnapshotStakingPool.InvalidSnapshotId.selector);
snapshotStakingPool.getPendingRewards(bob.addr);

_stake(bob.addr, 1 ether);
_stake(alice.addr, 1 ether);
vm.expectRevert("ERC20Snapshot: nonexistent id");
vm.expectRevert(SnapshotStakingPool.NonExistentSnapshotId.selector);
snapshotStakingPool.getPendingRewards(bob.addr);

_snapshot(2 ether);
Expand All @@ -254,7 +246,7 @@ contract SnapshotStakingPoolTest is Test {

vm.prank(bob.addr);
snapshotStakingPool.claim();
vm.expectRevert("ERC20Snapshot: nonexistent id");
vm.expectRevert(SnapshotStakingPool.NonExistentSnapshotId.selector);
assertEq(snapshotStakingPool.getPendingRewards(bob.addr), 0);

_snapshot(1 ether);
Expand Down Expand Up @@ -305,13 +297,13 @@ contract SnapshotStakingPoolTest is Test {

function testClaim() public {
vm.prank(bob.addr);
vm.expectRevert("ERC20Snapshot: id is 0");
vm.expectRevert(SnapshotStakingPool.InvalidSnapshotId.selector);
snapshotStakingPool.claim();

_stake(bob.addr, 1 ether);
_stake(alice.addr, 1 ether);
vm.prank(bob.addr);
vm.expectRevert("ERC20Snapshot: nonexistent id");
vm.expectRevert(SnapshotStakingPool.NonExistentSnapshotId.selector);
snapshotStakingPool.claim();

_snapshot(2 ether);
Expand All @@ -333,10 +325,10 @@ contract SnapshotStakingPoolTest is Test {
assertEq(snapshotStakingPool.nextClaimId(alice.addr), 3);

vm.prank(bob.addr);
vm.expectRevert("ERC20Snapshot: nonexistent id");
vm.expectRevert(SnapshotStakingPool.NonExistentSnapshotId.selector);
snapshotStakingPool.claim();
vm.prank(alice.addr);
vm.expectRevert("ERC20Snapshot: nonexistent id");
vm.expectRevert(SnapshotStakingPool.NonExistentSnapshotId.selector);
snapshotStakingPool.claim();

_snapshot(1 ether);
Expand All @@ -359,15 +351,15 @@ contract SnapshotStakingPoolTest is Test {

function testClaimPartial() public {
vm.prank(bob.addr);
vm.expectRevert("ERC20Snapshot: id is 0");
vm.expectRevert(SnapshotStakingPool.InvalidSnapshotId.selector);
snapshotStakingPool.claimPartial(0, 0);

vm.prank(bob.addr);
vm.expectRevert("ERC20Snapshot: id is 0");
vm.expectRevert(SnapshotStakingPool.InvalidSnapshotId.selector);
snapshotStakingPool.claimPartial(0, 1);

vm.prank(bob.addr);
vm.expectRevert("ERC20Snapshot: nonexistent id");
vm.expectRevert(SnapshotStakingPool.NonExistentSnapshotId.selector);
snapshotStakingPool.claimPartial(1, 1);

_stake(bob.addr, 1 ether);
Expand All @@ -381,7 +373,7 @@ contract SnapshotStakingPoolTest is Test {
assertEq(snapshotStakingPool.nextClaimId(bob.addr), 2);

vm.prank(bob.addr);
vm.expectRevert("Cannot claim from past snapshots");
vm.expectRevert(SnapshotStakingPool.CannotClaimFromPastSnapshots.selector);
snapshotStakingPool.claimPartial(1, 1);

_snapshot(2 ether);
Expand All @@ -400,11 +392,11 @@ contract SnapshotStakingPoolTest is Test {
assertEq(snapshotStakingPool.nextClaimId(alice.addr), 4);

vm.prank(alice.addr);
vm.expectRevert("Cannot claim from past snapshots");
vm.expectRevert(SnapshotStakingPool.CannotClaimFromPastSnapshots.selector);
snapshotStakingPool.claimPartial(1, 3);

vm.prank(alice.addr);
vm.expectRevert("ERC20Snapshot: nonexistent id");
vm.expectRevert(SnapshotStakingPool.NonExistentSnapshotId.selector);
snapshotStakingPool.claimPartial(4, 4);

_unstake(bob.addr, 1 ether);
Expand All @@ -422,7 +414,7 @@ contract SnapshotStakingPoolTest is Test {
_stake(bob.addr, 1 ether);

vm.prank(bob.addr);
vm.expectRevert("Transfers not allowed");
vm.expectRevert(SnapshotStakingPool.TransfersNotAllowed.selector);
snapshotStakingPool.transfer(alice.addr, 1 ether);
}

Expand All @@ -433,7 +425,7 @@ contract SnapshotStakingPoolTest is Test {
snapshotStakingPool.approve(alice.addr, 1 ether);

vm.prank(alice.addr);
vm.expectRevert("Transfers not allowed");
vm.expectRevert(SnapshotStakingPool.TransfersNotAllowed.selector);
snapshotStakingPool.transferFrom(bob.addr, alice.addr, 1 ether);
}

Expand Down

0 comments on commit 0697f39

Please sign in to comment.