Skip to content

Commit

Permalink
feat: Add asset value roundup for withdrawals, update invariants (#36)
Browse files Browse the repository at this point in the history
* fix: update tests to work

* feat: add more exchange rate assertions

* feat: add test coverage for new rounding functions

* fix: add comments

* fix: remove stack ver to fix coverage

* fix: add requires to constructor
  • Loading branch information
lucas-manuel authored Aug 28, 2024
1 parent 94e48b2 commit 6324e26
Show file tree
Hide file tree
Showing 9 changed files with 213 additions and 92 deletions.
2 changes: 1 addition & 1 deletion foundry.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ runs = 1000
[invariant]
runs = 20
depth = 1000
shrink_run_limit = 1000
shrink_run_limit = 100
fail_on_revert = true

[profile.pr.invariant]
Expand Down
26 changes: 18 additions & 8 deletions src/PSM3.sol
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ contract PSM3 is IPSM3 {
_asset0Precision = 10 ** IERC20(asset0_).decimals();
_asset1Precision = 10 ** IERC20(asset1_).decimals();
_asset2Precision = 10 ** IERC20(asset2_).decimals();

// Necessary to ensure rounding works as expected
require(_asset0Precision <= 1e18, "PSM3/asset0-precision-too-high");
require(_asset1Precision <= 1e18, "PSM3/asset1-precision-too-high");
}

/**********************************************************************************************/
Expand Down Expand Up @@ -155,7 +159,7 @@ contract PSM3 is IPSM3 {
require(_isValidAsset(asset), "PSM3/invalid-asset");

// Convert amount to 1e18 precision denominated in value of asset0 then convert to shares.
return convertToShares(_getAssetValue(asset, assetsToDeposit));
return convertToShares(_getAssetValue(asset, assetsToDeposit, false)); // Round down
}

function previewWithdraw(address asset, uint256 maxAssetsToWithdraw)
Expand All @@ -169,7 +173,8 @@ contract PSM3 is IPSM3 {
? assetBalance
: maxAssetsToWithdraw;

sharesToBurn = _convertToSharesRoundUp(_getAssetValue(asset, assetsWithdrawn));
// Get shares to burn, rounding up for both calculations
sharesToBurn = _convertToSharesRoundUp(_getAssetValue(asset, assetsWithdrawn, true));

uint256 userShares = shares[msg.sender];

Expand Down Expand Up @@ -237,7 +242,7 @@ contract PSM3 is IPSM3 {

function convertToShares(address asset, uint256 assets) public view override returns (uint256) {
require(_isValidAsset(asset), "PSM3/invalid-asset");
return convertToShares(_getAssetValue(asset, assets));
return convertToShares(_getAssetValue(asset, assets, false)); // Round down
}

/**********************************************************************************************/
Expand All @@ -247,17 +252,17 @@ contract PSM3 is IPSM3 {
function totalAssets() public view override returns (uint256) {
return _getAsset0Value(asset0.balanceOf(address(this)))
+ _getAsset1Value(asset1.balanceOf(address(this)))
+ _getAsset2Value(asset2.balanceOf(address(this)));
+ _getAsset2Value(asset2.balanceOf(address(this)), false); // Round down
}

/**********************************************************************************************/
/*** Internal valuation functions (deposit/withdraw) ***/
/**********************************************************************************************/

function _getAssetValue(address asset, uint256 amount) internal view returns (uint256) {
function _getAssetValue(address asset, uint256 amount, bool roundUp) internal view returns (uint256) {
if (asset == address(asset0)) return _getAsset0Value(amount);
else if (asset == address(asset1)) return _getAsset1Value(amount);
else if (asset == address(asset2)) return _getAsset2Value(amount);
else if (asset == address(asset2)) return _getAsset2Value(amount, roundUp);
else revert("PSM3/invalid-asset");
}

Expand All @@ -269,12 +274,17 @@ contract PSM3 is IPSM3 {
return amount * 1e18 / _asset1Precision;
}

function _getAsset2Value(uint256 amount) internal view returns (uint256) {
function _getAsset2Value(uint256 amount, bool roundUp) internal view returns (uint256) {
// NOTE: Multiplying by 1e18 and dividing by 1e27 cancels to 1e9 in denominator
return amount
if (!roundUp) return amount
* IRateProviderLike(rateProvider).getConversionRate()
/ 1e9
/ _asset2Precision;

return Math.ceilDiv(
Math.ceilDiv(amount * IRateProviderLike(rateProvider).getConversionRate(), 1e9),
_asset2Precision
);
}

/**********************************************************************************************/
Expand Down
4 changes: 2 additions & 2 deletions test/invariant/Invariants.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -310,9 +310,9 @@ abstract contract PSMInvariantTestBase is PSMTestBase {
6
);

// All funds can always be withdrawn completely.
// All funds can always be withdrawn completely (rounding in withdrawal against users).
assertEq(psm.totalShares(), 0);
assertEq(psm.totalAssets(), 0);
assertLe(psm.totalAssets(), 5);
}

function _warpAndAssertConsistentValueAccrual() public {
Expand Down
26 changes: 21 additions & 5 deletions test/invariant/handlers/LpHandler.sol
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ contract LpHandler is HandlerBase {
amount = _bound(amount, 1, 1e12 * 10 ** asset.decimals());

// 2. Cache starting state
uint256 startingConversion = psm.convertToShares(1e18);
uint256 startingConversion = psm.convertToAssetValue(1e18);
uint256 startingValue = psm.totalAssets();

// 3. Perform action against protocol
Expand All @@ -63,13 +63,22 @@ contract LpHandler is HandlerBase {
lpDeposits[lp][address(asset)] += amount;

// 5. Perform action-specific assertions

// Larger tolerance for rounding errors because of asset valuation changing
assertApproxEqAbs(
psm.convertToShares(1e18),
psm.convertToAssetValue(1e18),
startingConversion,
2,
1e12,
"LpHandler/deposit/conversion-rate-change"
);

// Exchange rate always increases, never decreases from rounding
assertGe(
psm.convertToAssetValue(1e18),
startingConversion,
"LpHandler/deposit/conversion-rate-decrease"
);

assertGe(
psm.totalAssets() + 1,
startingValue,
Expand All @@ -88,7 +97,7 @@ contract LpHandler is HandlerBase {
amount = _bound(amount, 1, 1e12 * 10 ** asset.decimals());

// 2. Cache starting state
uint256 startingConversion = psm.convertToShares(1e18);
uint256 startingConversion = psm.convertToAssetValue(1e18);
uint256 startingValue = psm.totalAssets();

// 3. Perform action against protocol
Expand All @@ -103,12 +112,19 @@ contract LpHandler is HandlerBase {

// Larger tolerance for rounding errors because of burning more shares on USDC withdraw
assertApproxEqAbs(
psm.convertToShares(1e18),
psm.convertToAssetValue(1e18),
startingConversion,
1e12,
"LpHandler/withdraw/conversion-rate-change"
);

// Exchange rate always increases, never decreases from rounding
assertGe(
psm.convertToAssetValue(1e18),
startingConversion,
"LpHandler/withdraw/conversion-rate-decrease"
);

assertLe(
psm.totalAssets(),
startingValue + 1,
Expand Down
24 changes: 24 additions & 0 deletions test/unit/Constructor.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ pragma solidity ^0.8.13;

import "forge-std/Test.sol";

import { MockERC20 } from "erc20-helpers/MockERC20.sol";

import { PSM3 } from "src/PSM3.sol";

import { PSMTestBase } from "test/PSMTestBase.sol";
Expand Down Expand Up @@ -52,6 +54,28 @@ contract PSMConstructorTests is PSMTestBase {
new PSM3(address(dai), address(usdc), address(sDai), address(rateProvider));
}

function test_constructor_asset0DecimalsToHighBoundary() public {
MockERC20 asset0 = new MockERC20("Asset0", "A0", 19);

vm.expectRevert("PSM3/asset0-precision-too-high");
new PSM3(address(asset0), address(usdc), address(sDai), address(rateProvider));

asset0 = new MockERC20("Asset0", "A0", 18);

new PSM3(address(asset0), address(usdc), address(sDai), address(rateProvider));
}

function test_constructor_asset1DecimalsToHighBoundary() public {
MockERC20 asset1 = new MockERC20("Asset1", "A1", 19);

vm.expectRevert("PSM3/asset1-precision-too-high");
new PSM3(address(dai), address(asset1), address(sDai), address(rateProvider));

asset1 = new MockERC20("Asset1", "A1", 18);

new PSM3(address(dai), address(asset1), address(sDai), address(rateProvider));
}

function test_constructor() public {
// Deploy new PSM to get test coverage
psm = new PSM3(address(dai), address(usdc), address(sDai), address(rateProvider));
Expand Down
Loading

0 comments on commit 6324e26

Please sign in to comment.