diff --git a/foundry.toml b/foundry.toml index 2828806..d0cff7b 100644 --- a/foundry.toml +++ b/foundry.toml @@ -12,7 +12,7 @@ runs = 1000 [invariant] runs = 20 depth = 1000 -shrink_run_limit = 1000 +shrink_run_limit = 100 fail_on_revert = true [profile.pr.invariant] diff --git a/src/PSM3.sol b/src/PSM3.sol index 6107799..b8d395c 100644 --- a/src/PSM3.sol +++ b/src/PSM3.sol @@ -54,6 +54,10 @@ contract PSM3 is IPSM3 { _asset0Precision = 10 ** IERC20(asset0_).decimals(); _asset1Precision = 10 ** IERC20(asset1_).decimals(); _asset2Precision = 10 ** IERC20(asset2_).decimals(); + + // Necessary to ensure rounding works as expected + require(_asset0Precision <= 1e18, "PSM3/asset0-precision-too-high"); + require(_asset1Precision <= 1e18, "PSM3/asset1-precision-too-high"); } /**********************************************************************************************/ @@ -155,7 +159,7 @@ contract PSM3 is IPSM3 { require(_isValidAsset(asset), "PSM3/invalid-asset"); // Convert amount to 1e18 precision denominated in value of asset0 then convert to shares. - return convertToShares(_getAssetValue(asset, assetsToDeposit)); + return convertToShares(_getAssetValue(asset, assetsToDeposit, false)); // Round down } function previewWithdraw(address asset, uint256 maxAssetsToWithdraw) @@ -169,7 +173,8 @@ contract PSM3 is IPSM3 { ? assetBalance : maxAssetsToWithdraw; - sharesToBurn = _convertToSharesRoundUp(_getAssetValue(asset, assetsWithdrawn)); + // Get shares to burn, rounding up for both calculations + sharesToBurn = _convertToSharesRoundUp(_getAssetValue(asset, assetsWithdrawn, true)); uint256 userShares = shares[msg.sender]; @@ -237,7 +242,7 @@ contract PSM3 is IPSM3 { function convertToShares(address asset, uint256 assets) public view override returns (uint256) { require(_isValidAsset(asset), "PSM3/invalid-asset"); - return convertToShares(_getAssetValue(asset, assets)); + return convertToShares(_getAssetValue(asset, assets, false)); // Round down } /**********************************************************************************************/ @@ -247,17 +252,17 @@ contract PSM3 is IPSM3 { function totalAssets() public view override returns (uint256) { return _getAsset0Value(asset0.balanceOf(address(this))) + _getAsset1Value(asset1.balanceOf(address(this))) - + _getAsset2Value(asset2.balanceOf(address(this))); + + _getAsset2Value(asset2.balanceOf(address(this)), false); // Round down } /**********************************************************************************************/ /*** Internal valuation functions (deposit/withdraw) ***/ /**********************************************************************************************/ - function _getAssetValue(address asset, uint256 amount) internal view returns (uint256) { + function _getAssetValue(address asset, uint256 amount, bool roundUp) internal view returns (uint256) { if (asset == address(asset0)) return _getAsset0Value(amount); else if (asset == address(asset1)) return _getAsset1Value(amount); - else if (asset == address(asset2)) return _getAsset2Value(amount); + else if (asset == address(asset2)) return _getAsset2Value(amount, roundUp); else revert("PSM3/invalid-asset"); } @@ -269,12 +274,17 @@ contract PSM3 is IPSM3 { return amount * 1e18 / _asset1Precision; } - function _getAsset2Value(uint256 amount) internal view returns (uint256) { + function _getAsset2Value(uint256 amount, bool roundUp) internal view returns (uint256) { // NOTE: Multiplying by 1e18 and dividing by 1e27 cancels to 1e9 in denominator - return amount + if (!roundUp) return amount * IRateProviderLike(rateProvider).getConversionRate() / 1e9 / _asset2Precision; + + return Math.ceilDiv( + Math.ceilDiv(amount * IRateProviderLike(rateProvider).getConversionRate(), 1e9), + _asset2Precision + ); } /**********************************************************************************************/ diff --git a/test/invariant/Invariants.t.sol b/test/invariant/Invariants.t.sol index 47a2be1..30a9041 100644 --- a/test/invariant/Invariants.t.sol +++ b/test/invariant/Invariants.t.sol @@ -310,9 +310,9 @@ abstract contract PSMInvariantTestBase is PSMTestBase { 6 ); - // All funds can always be withdrawn completely. + // All funds can always be withdrawn completely (rounding in withdrawal against users). assertEq(psm.totalShares(), 0); - assertEq(psm.totalAssets(), 0); + assertLe(psm.totalAssets(), 5); } function _warpAndAssertConsistentValueAccrual() public { diff --git a/test/invariant/handlers/LpHandler.sol b/test/invariant/handlers/LpHandler.sol index 4c47205..ac92662 100644 --- a/test/invariant/handlers/LpHandler.sol +++ b/test/invariant/handlers/LpHandler.sol @@ -49,7 +49,7 @@ contract LpHandler is HandlerBase { amount = _bound(amount, 1, 1e12 * 10 ** asset.decimals()); // 2. Cache starting state - uint256 startingConversion = psm.convertToShares(1e18); + uint256 startingConversion = psm.convertToAssetValue(1e18); uint256 startingValue = psm.totalAssets(); // 3. Perform action against protocol @@ -63,13 +63,22 @@ contract LpHandler is HandlerBase { lpDeposits[lp][address(asset)] += amount; // 5. Perform action-specific assertions + + // Larger tolerance for rounding errors because of asset valuation changing assertApproxEqAbs( - psm.convertToShares(1e18), + psm.convertToAssetValue(1e18), startingConversion, - 2, + 1e12, "LpHandler/deposit/conversion-rate-change" ); + // Exchange rate always increases, never decreases from rounding + assertGe( + psm.convertToAssetValue(1e18), + startingConversion, + "LpHandler/deposit/conversion-rate-decrease" + ); + assertGe( psm.totalAssets() + 1, startingValue, @@ -88,7 +97,7 @@ contract LpHandler is HandlerBase { amount = _bound(amount, 1, 1e12 * 10 ** asset.decimals()); // 2. Cache starting state - uint256 startingConversion = psm.convertToShares(1e18); + uint256 startingConversion = psm.convertToAssetValue(1e18); uint256 startingValue = psm.totalAssets(); // 3. Perform action against protocol @@ -103,12 +112,19 @@ contract LpHandler is HandlerBase { // Larger tolerance for rounding errors because of burning more shares on USDC withdraw assertApproxEqAbs( - psm.convertToShares(1e18), + psm.convertToAssetValue(1e18), startingConversion, 1e12, "LpHandler/withdraw/conversion-rate-change" ); + // Exchange rate always increases, never decreases from rounding + assertGe( + psm.convertToAssetValue(1e18), + startingConversion, + "LpHandler/withdraw/conversion-rate-decrease" + ); + assertLe( psm.totalAssets(), startingValue + 1, diff --git a/test/unit/Constructor.t.sol b/test/unit/Constructor.t.sol index e6d31cf..b99cdb7 100644 --- a/test/unit/Constructor.t.sol +++ b/test/unit/Constructor.t.sol @@ -3,6 +3,8 @@ pragma solidity ^0.8.13; import "forge-std/Test.sol"; +import { MockERC20 } from "erc20-helpers/MockERC20.sol"; + import { PSM3 } from "src/PSM3.sol"; import { PSMTestBase } from "test/PSMTestBase.sol"; @@ -52,6 +54,28 @@ contract PSMConstructorTests is PSMTestBase { new PSM3(address(dai), address(usdc), address(sDai), address(rateProvider)); } + function test_constructor_asset0DecimalsToHighBoundary() public { + MockERC20 asset0 = new MockERC20("Asset0", "A0", 19); + + vm.expectRevert("PSM3/asset0-precision-too-high"); + new PSM3(address(asset0), address(usdc), address(sDai), address(rateProvider)); + + asset0 = new MockERC20("Asset0", "A0", 18); + + new PSM3(address(asset0), address(usdc), address(sDai), address(rateProvider)); + } + + function test_constructor_asset1DecimalsToHighBoundary() public { + MockERC20 asset1 = new MockERC20("Asset1", "A1", 19); + + vm.expectRevert("PSM3/asset1-precision-too-high"); + new PSM3(address(dai), address(asset1), address(sDai), address(rateProvider)); + + asset1 = new MockERC20("Asset1", "A1", 18); + + new PSM3(address(dai), address(asset1), address(sDai), address(rateProvider)); + } + function test_constructor() public { // Deploy new PSM to get test coverage psm = new PSM3(address(dai), address(usdc), address(sDai), address(rateProvider)); diff --git a/test/unit/Getters.t.sol b/test/unit/Getters.t.sol index f64727e..06502c6 100644 --- a/test/unit/Getters.t.sol +++ b/test/unit/Getters.t.sol @@ -62,87 +62,153 @@ contract PSMHarnessTests is PSMTestBase { } function test_getAsset2Value() public { - assertEq(psmHarness.getAsset2Value(1), 1); - assertEq(psmHarness.getAsset2Value(2), 2); - assertEq(psmHarness.getAsset2Value(3), 3); - assertEq(psmHarness.getAsset2Value(4), 5); - - assertEq(psmHarness.getAsset2Value(1e18), 1.25e18); - assertEq(psmHarness.getAsset2Value(2e18), 2.5e18); - assertEq(psmHarness.getAsset2Value(3e18), 3.75e18); - assertEq(psmHarness.getAsset2Value(4e18), 5e18); + assertEq(psmHarness.getAsset2Value(1, false), 1); + assertEq(psmHarness.getAsset2Value(2, false), 2); + assertEq(psmHarness.getAsset2Value(3, false), 3); + assertEq(psmHarness.getAsset2Value(4, false), 5); + + // Rounding up + assertEq(psmHarness.getAsset2Value(1, true), 2); + assertEq(psmHarness.getAsset2Value(2, true), 3); + assertEq(psmHarness.getAsset2Value(3, true), 4); + assertEq(psmHarness.getAsset2Value(4, true), 5); + + assertEq(psmHarness.getAsset2Value(1e18, false), 1.25e18); + assertEq(psmHarness.getAsset2Value(2e18, false), 2.5e18); + assertEq(psmHarness.getAsset2Value(3e18, false), 3.75e18); + assertEq(psmHarness.getAsset2Value(4e18, false), 5e18); + + // No rounding but shows why rounding occurred at lower values + assertEq(psmHarness.getAsset2Value(1e18, true), 1.25e18); + assertEq(psmHarness.getAsset2Value(2e18, true), 2.5e18); + assertEq(psmHarness.getAsset2Value(3e18, true), 3.75e18); + assertEq(psmHarness.getAsset2Value(4e18, true), 5e18); mockRateProvider.__setConversionRate(1.6e27); - assertEq(psmHarness.getAsset2Value(1), 1); - assertEq(psmHarness.getAsset2Value(2), 3); - assertEq(psmHarness.getAsset2Value(3), 4); - assertEq(psmHarness.getAsset2Value(4), 6); + assertEq(psmHarness.getAsset2Value(1, false), 1); + assertEq(psmHarness.getAsset2Value(2, false), 3); + assertEq(psmHarness.getAsset2Value(3, false), 4); + assertEq(psmHarness.getAsset2Value(4, false), 6); - assertEq(psmHarness.getAsset2Value(1e18), 1.6e18); - assertEq(psmHarness.getAsset2Value(2e18), 3.2e18); - assertEq(psmHarness.getAsset2Value(3e18), 4.8e18); - assertEq(psmHarness.getAsset2Value(4e18), 6.4e18); + // Rounding up + assertEq(psmHarness.getAsset2Value(1, true), 2); + assertEq(psmHarness.getAsset2Value(2, true), 4); + assertEq(psmHarness.getAsset2Value(3, true), 5); + assertEq(psmHarness.getAsset2Value(4, true), 7); - mockRateProvider.__setConversionRate(0.8e27); + assertEq(psmHarness.getAsset2Value(1e18, false), 1.6e18); + assertEq(psmHarness.getAsset2Value(2e18, false), 3.2e18); + assertEq(psmHarness.getAsset2Value(3e18, false), 4.8e18); + assertEq(psmHarness.getAsset2Value(4e18, false), 6.4e18); + + // No rounding but shows why rounding occurred at lower values + assertEq(psmHarness.getAsset2Value(1e18, true), 1.6e18); + assertEq(psmHarness.getAsset2Value(2e18, true), 3.2e18); + assertEq(psmHarness.getAsset2Value(3e18, true), 4.8e18); + assertEq(psmHarness.getAsset2Value(4e18, true), 6.4e18); - assertEq(psmHarness.getAsset2Value(1), 0); - assertEq(psmHarness.getAsset2Value(2), 1); - assertEq(psmHarness.getAsset2Value(3), 2); - assertEq(psmHarness.getAsset2Value(4), 3); + mockRateProvider.__setConversionRate(0.8e27); - assertEq(psmHarness.getAsset2Value(1e18), 0.8e18); - assertEq(psmHarness.getAsset2Value(2e18), 1.6e18); - assertEq(psmHarness.getAsset2Value(3e18), 2.4e18); - assertEq(psmHarness.getAsset2Value(4e18), 3.2e18); + assertEq(psmHarness.getAsset2Value(1, false), 0); + assertEq(psmHarness.getAsset2Value(2, false), 1); + assertEq(psmHarness.getAsset2Value(3, false), 2); + assertEq(psmHarness.getAsset2Value(4, false), 3); + + // Rounding up + assertEq(psmHarness.getAsset2Value(1, true), 1); + assertEq(psmHarness.getAsset2Value(2, true), 2); + assertEq(psmHarness.getAsset2Value(3, true), 3); + assertEq(psmHarness.getAsset2Value(4, true), 4); + + assertEq(psmHarness.getAsset2Value(1e18, false), 0.8e18); + assertEq(psmHarness.getAsset2Value(2e18, false), 1.6e18); + assertEq(psmHarness.getAsset2Value(3e18, false), 2.4e18); + assertEq(psmHarness.getAsset2Value(4e18, false), 3.2e18); + + // No rounding but shows why rounding occurred at lower values + assertEq(psmHarness.getAsset2Value(1e18, true), 0.8e18); + assertEq(psmHarness.getAsset2Value(2e18, true), 1.6e18); + assertEq(psmHarness.getAsset2Value(3e18, true), 2.4e18); + assertEq(psmHarness.getAsset2Value(4e18, true), 3.2e18); } - function testFuzz_getAsset2Value(uint256 conversionRate, uint256 amount) public { + function testFuzz_getAsset2Value_roundDown(uint256 conversionRate, uint256 amount) public { conversionRate = _bound(conversionRate, 0, 1000e27); amount = _bound(amount, 0, SDAI_TOKEN_MAX); mockRateProvider.__setConversionRate(conversionRate); - assertEq(psmHarness.getAsset2Value(amount), amount * conversionRate / 1e27); + assertEq(psmHarness.getAsset2Value(amount, false), amount * conversionRate / 1e27); } function test_getAssetValue() public view { - assertEq(psmHarness.getAssetValue(address(dai), 1), psmHarness.getAsset0Value(1)); - assertEq(psmHarness.getAssetValue(address(dai), 2), psmHarness.getAsset0Value(2)); - assertEq(psmHarness.getAssetValue(address(dai), 3), psmHarness.getAsset0Value(3)); + assertEq(psmHarness.getAssetValue(address(dai), 1, false), psmHarness.getAsset0Value(1)); + assertEq(psmHarness.getAssetValue(address(dai), 2, false), psmHarness.getAsset0Value(2)); + assertEq(psmHarness.getAssetValue(address(dai), 3, false), psmHarness.getAsset0Value(3)); + + assertEq(psmHarness.getAssetValue(address(dai), 1, true), psmHarness.getAsset0Value(1)); + assertEq(psmHarness.getAssetValue(address(dai), 2, true), psmHarness.getAsset0Value(2)); + assertEq(psmHarness.getAssetValue(address(dai), 3, true), psmHarness.getAsset0Value(3)); + + assertEq(psmHarness.getAssetValue(address(dai), 1e18, false), psmHarness.getAsset0Value(1e18)); + assertEq(psmHarness.getAssetValue(address(dai), 2e18, false), psmHarness.getAsset0Value(2e18)); + assertEq(psmHarness.getAssetValue(address(dai), 3e18, false), psmHarness.getAsset0Value(3e18)); - assertEq(psmHarness.getAssetValue(address(dai), 1e18), psmHarness.getAsset0Value(1e18)); - assertEq(psmHarness.getAssetValue(address(dai), 2e18), psmHarness.getAsset0Value(2e18)); - assertEq(psmHarness.getAssetValue(address(dai), 3e18), psmHarness.getAsset0Value(3e18)); + assertEq(psmHarness.getAssetValue(address(dai), 1e18, true), psmHarness.getAsset0Value(1e18)); + assertEq(psmHarness.getAssetValue(address(dai), 2e18, true), psmHarness.getAsset0Value(2e18)); + assertEq(psmHarness.getAssetValue(address(dai), 3e18, true), psmHarness.getAsset0Value(3e18)); - assertEq(psmHarness.getAssetValue(address(usdc), 1), psmHarness.getAsset1Value(1)); - assertEq(psmHarness.getAssetValue(address(usdc), 2), psmHarness.getAsset1Value(2)); - assertEq(psmHarness.getAssetValue(address(usdc), 3), psmHarness.getAsset1Value(3)); + assertEq(psmHarness.getAssetValue(address(usdc), 1, false), psmHarness.getAsset1Value(1)); + assertEq(psmHarness.getAssetValue(address(usdc), 2, false), psmHarness.getAsset1Value(2)); + assertEq(psmHarness.getAssetValue(address(usdc), 3, false), psmHarness.getAsset1Value(3)); - assertEq(psmHarness.getAssetValue(address(usdc), 1e6), psmHarness.getAsset1Value(1e6)); - assertEq(psmHarness.getAssetValue(address(usdc), 2e6), psmHarness.getAsset1Value(2e6)); - assertEq(psmHarness.getAssetValue(address(usdc), 3e6), psmHarness.getAsset1Value(3e6)); + assertEq(psmHarness.getAssetValue(address(usdc), 1, true), psmHarness.getAsset1Value(1)); + assertEq(psmHarness.getAssetValue(address(usdc), 2, true), psmHarness.getAsset1Value(2)); + assertEq(psmHarness.getAssetValue(address(usdc), 3, true), psmHarness.getAsset1Value(3)); - assertEq(psmHarness.getAssetValue(address(sDai), 1), psmHarness.getAsset2Value(1)); - assertEq(psmHarness.getAssetValue(address(sDai), 2), psmHarness.getAsset2Value(2)); - assertEq(psmHarness.getAssetValue(address(sDai), 3), psmHarness.getAsset2Value(3)); + assertEq(psmHarness.getAssetValue(address(usdc), 1e6, false), psmHarness.getAsset1Value(1e6)); + assertEq(psmHarness.getAssetValue(address(usdc), 2e6, false), psmHarness.getAsset1Value(2e6)); + assertEq(psmHarness.getAssetValue(address(usdc), 3e6, false), psmHarness.getAsset1Value(3e6)); - assertEq(psmHarness.getAssetValue(address(sDai), 1e18), psmHarness.getAsset2Value(1e18)); - assertEq(psmHarness.getAssetValue(address(sDai), 2e18), psmHarness.getAsset2Value(2e18)); - assertEq(psmHarness.getAssetValue(address(sDai), 3e18), psmHarness.getAsset2Value(3e18)); + assertEq(psmHarness.getAssetValue(address(usdc), 1e6, true), psmHarness.getAsset1Value(1e6)); + assertEq(psmHarness.getAssetValue(address(usdc), 2e6, true), psmHarness.getAsset1Value(2e6)); + assertEq(psmHarness.getAssetValue(address(usdc), 3e6, true), psmHarness.getAsset1Value(3e6)); + + assertEq(psmHarness.getAssetValue(address(sDai), 1, false), psmHarness.getAsset2Value(1, false)); + assertEq(psmHarness.getAssetValue(address(sDai), 2, false), psmHarness.getAsset2Value(2, false)); + assertEq(psmHarness.getAssetValue(address(sDai), 3, false), psmHarness.getAsset2Value(3, false)); + + assertEq(psmHarness.getAssetValue(address(sDai), 1e18, false), psmHarness.getAsset2Value(1e18, false)); + assertEq(psmHarness.getAssetValue(address(sDai), 2e18, false), psmHarness.getAsset2Value(2e18, false)); + assertEq(psmHarness.getAssetValue(address(sDai), 3e18, false), psmHarness.getAsset2Value(3e18, false)); + + assertEq(psmHarness.getAssetValue(address(sDai), 1, true), psmHarness.getAsset2Value(1, true)); + assertEq(psmHarness.getAssetValue(address(sDai), 2, true), psmHarness.getAsset2Value(2, true)); + assertEq(psmHarness.getAssetValue(address(sDai), 3, true), psmHarness.getAsset2Value(3, true)); + + assertEq(psmHarness.getAssetValue(address(sDai), 1e18, true), psmHarness.getAsset2Value(1e18, true)); + assertEq(psmHarness.getAssetValue(address(sDai), 2e18, true), psmHarness.getAsset2Value(2e18, true)); + assertEq(psmHarness.getAssetValue(address(sDai), 3e18, true), psmHarness.getAsset2Value(3e18, true)); } function testFuzz_getAssetValue(uint256 amount) public view { amount = _bound(amount, 0, SDAI_TOKEN_MAX); - assertEq(psmHarness.getAssetValue(address(dai), amount), psmHarness.getAsset0Value(amount)); - assertEq(psmHarness.getAssetValue(address(usdc), amount), psmHarness.getAsset1Value(amount)); - assertEq(psmHarness.getAssetValue(address(sDai), amount), psmHarness.getAsset2Value(amount)); + // `asset0` and `asset1` return the same values whether `roundUp` is true or false + assertEq(psmHarness.getAssetValue(address(dai), amount, false), psmHarness.getAsset0Value(amount)); + assertEq(psmHarness.getAssetValue(address(dai), amount, false), psmHarness.getAsset0Value(amount)); + assertEq(psmHarness.getAssetValue(address(usdc), amount, true), psmHarness.getAsset1Value(amount)); + assertEq(psmHarness.getAssetValue(address(usdc), amount, true), psmHarness.getAsset1Value(amount)); + + // `asset2` returns different values depending on the value of `roundUp`, but always same as underlying function + assertEq(psmHarness.getAssetValue(address(sDai), amount, false), psmHarness.getAsset2Value(amount, false)); + assertEq(psmHarness.getAssetValue(address(sDai), amount, true), psmHarness.getAsset2Value(amount, true)); } function test_getAssetValue_zeroAddress() public { vm.expectRevert("PSM3/invalid-asset"); - psmHarness.getAssetValue(address(0), 1); + psmHarness.getAssetValue(address(0), 1, false); } } diff --git a/test/unit/PreviewWIthdraw.t.sol b/test/unit/PreviewWIthdraw.t.sol index 64777da..5bbfcc1 100644 --- a/test/unit/PreviewWIthdraw.t.sol +++ b/test/unit/PreviewWIthdraw.t.sol @@ -94,7 +94,7 @@ contract PSMPreviewWithdraw_SuccessTests is PSMTestBase { function test_previewWithdraw_sdai_amountLtUnderlyingBalanceAndLtPsmBalance() public view { ( uint256 shares, uint256 assets ) = psm.previewWithdraw(address(sDai), 1e18 - 1); - assertEq(shares, 1.25e18 - 2); + assertEq(shares, 1.25e18 - 1); assertEq(assets, 1e18 - 1); } @@ -146,9 +146,10 @@ contract PSMPreviewWithdraw_SuccessFuzzTests is PSMTestBase { uint256 totalSharesMinted = params.amount1 + params.amount2 * 1e12 + params.amount3 * 1.25e27 / 1e27; uint256 totalValue = totalSharesMinted; - assertEq(shares1, params.previewAmount1 * totalSharesMinted / totalValue); - assertEq(shares2, params.previewAmount2 * 1e12 * totalSharesMinted / totalValue); - assertEq(shares3, params.previewAmount3 * 1.25e27 / 1e27 * totalSharesMinted / totalValue); + // Assert shares are always rounded up, max of 1 wei difference except for sUSDS + assertLe(shares1 - (params.previewAmount1 * totalSharesMinted / totalValue), 1); + assertLe(shares2 - (params.previewAmount2 * 1e12 * totalSharesMinted / totalValue), 1); + assertLe(shares3 - (params.previewAmount3 * 1.25e27 / 1e27 * totalSharesMinted / totalValue), 3); assertEq(assets1, params.previewAmount1); assertEq(assets2, params.previewAmount2); @@ -166,14 +167,11 @@ contract PSMPreviewWithdraw_SuccessFuzzTests is PSMTestBase { uint256 sDaiConvertedAmount = params.previewAmount3 * params.conversionRate / 1e27; - assertApproxEqAbs(shares1, params.previewAmount1 * totalSharesMinted / totalValue, 1); - assertApproxEqAbs(shares2, params.previewAmount2 * 1e12 * totalSharesMinted / totalValue, 1); - assertApproxEqAbs(shares3, sDaiConvertedAmount * totalSharesMinted / totalValue, 1); - - // Assert shares are always rounded up - assertGe(shares1, params.previewAmount1 * totalSharesMinted / totalValue); - assertGe(shares2, params.previewAmount2 * 1e12 * totalSharesMinted / totalValue); - assertGe(shares3, sDaiConvertedAmount * totalSharesMinted / totalValue); + // Assert shares are always rounded up, max of 1 wei difference except for sUSDS + // totalSharesMinted / totalValue is an integer amount that scales as the rate scales by orders of magnitude + assertLe(shares1 - (params.previewAmount1 * totalSharesMinted / totalValue), 1); + assertLe(shares2 - (params.previewAmount2 * 1e12 * totalSharesMinted / totalValue), 1); + assertLe(shares3 - (sDaiConvertedAmount * totalSharesMinted / totalValue), 3 + totalSharesMinted / totalValue); assertApproxEqAbs(assets1, params.previewAmount1, 1); assertApproxEqAbs(assets2, params.previewAmount2, 1); diff --git a/test/unit/Withdraw.t.sol b/test/unit/Withdraw.t.sol index 4274cc7..d014bec 100644 --- a/test/unit/Withdraw.t.sol +++ b/test/unit/Withdraw.t.sol @@ -487,11 +487,11 @@ contract PSMWithdrawTests is PSMTestBase { vm.prank(user2); amount = psm.withdraw(address(sDai), user2, type(uint256).max); - assertEq(amount, 100e18 - user1SDai); // Remaining funds in PSM + assertEq(amount, 100e18 - user1SDai - 1); // Remaining funds in PSM (rounding) assertEq(sDai.balanceOf(user1), user1SDai); - assertEq(sDai.balanceOf(user2), 100e18 - user1SDai); - assertEq(sDai.balanceOf(address(psm)), 0); + assertEq(sDai.balanceOf(user2), 100e18 - user1SDai - 1); // Rounding + assertEq(sDai.balanceOf(address(psm)), 1); // Rounding assertEq(psm.totalShares(), 0); assertEq(psm.shares(user1), 0); @@ -501,9 +501,9 @@ contract PSMWithdrawTests is PSMTestBase { uint256 user2ResultingValue = sDai.balanceOf(user2) * 150/100; // Use 1.5 conversion rate assertEq(user1ResultingValue, 111.111111111111111110e18); - assertEq(user2ResultingValue, 138.888888888888888889e18); + assertEq(user2ResultingValue, 138.888888888888888888e18); - assertEq(user1ResultingValue + user2ResultingValue, 249.999999999999999999e18); + assertEq(user1ResultingValue + user2ResultingValue, 249.999999999999999998e18); // Value gains are the same for both users assertEq((user1ResultingValue - 100e18) * 1e18 / 100e18, 0.111111111111111111e18); @@ -518,9 +518,12 @@ contract PSMWithdrawTests is PSMTestBase { public { // Use higher lower bounds to get returns at the end to be more accurate + // Always increase exchange rate so accrual of value can be checked. + // Since rounding is against user if it stays the same the value can decrease and + // the check will underflow usdcAmount = _bound(usdcAmount, 1e6, USDC_TOKEN_MAX); sDaiAmount = _bound(sDaiAmount, 1e18, SDAI_TOKEN_MAX); - conversionRate = _bound(conversionRate, 1.25e27, 1000e27); + conversionRate = _bound(conversionRate, 1.26e27, 1000e27); _deposit(address(usdc), user1, usdcAmount); _deposit(address(sDai), user2, sDaiAmount); @@ -585,17 +588,19 @@ contract PSMWithdrawTests is PSMTestBase { assertApproxEqAbs(sDai.balanceOf(address(psm)), 0, 2); } - assertLe(psm.totalShares(), 1); - assertLe(psm.shares(user1), 1); - assertLe(psm.shares(user2), 1); + assertEq(psm.totalShares(), 0); + assertEq(psm.shares(user1), 0); + assertEq(psm.shares(user2), 0); uint256 user1ResultingValue = usdc.balanceOf(user1) * 1e12 + sDai.balanceOf(user1) * conversionRate / 1e27; uint256 user2ResultingValue = sDai.balanceOf(user2) * conversionRate / 1e27; // Use 1.5 conversion rate + assertLe(psm.totalAssets(), 1000); + // Equal to starting value - assertApproxEqAbs(user1ResultingValue + user2ResultingValue, totalValue, 2); + assertApproxEqAbs(user1ResultingValue + user2ResultingValue, totalValue - psm.totalAssets(), 2); // Value gains are the same for both users, accurate to 0.02% assertApproxEqRel( diff --git a/test/unit/harnesses/PSM3Harness.sol b/test/unit/harnesses/PSM3Harness.sol index 7e1a60b..4835781 100644 --- a/test/unit/harnesses/PSM3Harness.sol +++ b/test/unit/harnesses/PSM3Harness.sol @@ -8,8 +8,10 @@ contract PSM3Harness is PSM3 { constructor(address asset0_, address asset1_, address asset2_, address rateProvider_) PSM3(asset0_, asset1_, asset2_, rateProvider_) {} - function getAssetValue(address asset, uint256 amount) external view returns (uint256) { - return _getAssetValue(asset, amount); + function getAssetValue(address asset, uint256 amount, bool roundUp) + external view returns (uint256) + { + return _getAssetValue(asset, amount, roundUp); } function getAsset0Value(uint256 amount) external view returns (uint256) { @@ -20,8 +22,8 @@ contract PSM3Harness is PSM3 { return _getAsset1Value(amount); } - function getAsset2Value(uint256 amount) external view returns (uint256) { - return _getAsset2Value(amount); + function getAsset2Value(uint256 amount, bool roundUp) external view returns (uint256) { + return _getAsset2Value(amount, roundUp); } }