diff --git a/README.md b/README.md index 61d1a87..aed42d4 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ For detailed implementation, refer to the contract code and `IPSM` interface doc On the deployment of the PSM, the deployer **MUST make an initial deposit to get AT LEAST 1e18 shares in order to protect the first depositor from getting attacked with a share inflation attack or DOS attack**. Share inflation attack is outlined further [here](https://github.com/marsfoundation/spark-automations/assets/44272939/9472a6d2-0361-48b0-b534-96a0614330d3). Technical details related to this can be found in `test/InflationAttack.t.sol`. The DOS attack is performed by: -1. Attacker sends funds directly to the PSM. `getPsmTotalValue` now returns a non-zero value. +1. Attacker sends funds directly to the PSM. `totalAssets` now returns a non-zero value. 2. Victim calls deposit. `convertToShares` returns `amount * totalShares / totalValue`. In this case, `totalValue` is non-zero and `totalShares` is zero, so it performs `amount * 0 / totalValue` and returns zero. 3. The victim has `transferFrom` called moving their funds into the PSM, but they receive zero shares so they cannot recover any of their underlying assets. This renders the PSM unusable for all users since this issue will persist. `totalShares` can never be increased in this state. @@ -73,7 +73,7 @@ NOTE: These functions do not round in the same way as preview functions, so they #### Asset Value Functions -- **`getPsmTotalValue`**: Returns the total value of all assets held by the PSM denominated in the base asset with 18 decimal precision. (e.g., USD). +- **`totalAssets`**: Returns the total value of all assets held by the PSM denominated in the base asset with 18 decimal precision. (e.g., USD). ### Events diff --git a/src/PSM3.sol b/src/PSM3.sol index 27c3f19..51b6867 100644 --- a/src/PSM3.sol +++ b/src/PSM3.sol @@ -214,15 +214,15 @@ contract PSM3 is IPSM3 { uint256 totalShares_ = totalShares; if (totalShares_ != 0) { - return numShares * getPsmTotalValue() / totalShares_; + return numShares * totalAssets() / totalShares_; } return numShares; } function convertToShares(uint256 assetValue) public view override returns (uint256) { - uint256 totalValue = getPsmTotalValue(); - if (totalValue != 0) { - return assetValue * totalShares / totalValue; + uint256 totalAssets_ = totalAssets(); + if (totalAssets_ != 0) { + return assetValue * totalShares / totalAssets_; } return assetValue; } @@ -236,7 +236,7 @@ contract PSM3 is IPSM3 { /*** Asset value functions ***/ /**********************************************************************************************/ - function getPsmTotalValue() public view override returns (uint256) { + function totalAssets() public view override returns (uint256) { return _getAsset0Value(asset0.balanceOf(address(this))) + _getAsset1Value(asset1.balanceOf(address(this))) + _getAsset2Value(asset2.balanceOf(address(this))); @@ -332,7 +332,7 @@ contract PSM3 is IPSM3 { /**********************************************************************************************/ function _convertToSharesRoundUp(uint256 assetValue) internal view returns (uint256) { - uint256 totalValue = getPsmTotalValue(); + uint256 totalValue = totalAssets(); if (totalValue != 0) { return _divUp(assetValue * totalShares, totalValue); } diff --git a/src/interfaces/IPSM3.sol b/src/interfaces/IPSM3.sol index 4f02ca4..f54f54a 100644 --- a/src/interfaces/IPSM3.sol +++ b/src/interfaces/IPSM3.sol @@ -292,6 +292,6 @@ interface IPSM3 { * @dev View function that returns the total value of the balance of all assets in the PSM * converted to asset0/asset1 terms denominated in 18 decimal precision. */ - function getPsmTotalValue() external view returns (uint256); + function totalAssets() external view returns (uint256); } diff --git a/test/invariant/Invariants.t.sol b/test/invariant/Invariants.t.sol index 95df7a7..02d1585 100644 --- a/test/invariant/Invariants.t.sol +++ b/test/invariant/Invariants.t.sol @@ -53,7 +53,7 @@ abstract contract PSMInvariantTestBase is PSMTestBase { function _checkInvariant_B() public view { assertApproxEqAbs( - psm.getPsmTotalValue(), + psm.totalAssets(), psm.convertToAssetValue(psm.totalShares()), 4 ); @@ -65,7 +65,7 @@ abstract contract PSMInvariantTestBase is PSMTestBase { psm.convertToAssetValue(psm.shares(address(lpHandler.lps(1)))) + psm.convertToAssetValue(psm.shares(address(lpHandler.lps(2)))) + psm.convertToAssetValue(1e18), // Seed amount - psm.getPsmTotalValue(), + psm.totalAssets(), 4 ); } @@ -117,7 +117,7 @@ abstract contract PSMInvariantTestBase is PSMTestBase { uint256 lp1WithdrawsValue = _getLpTokenValue(lp1); uint256 lp2WithdrawsValue = _getLpTokenValue(lp2); - uint256 psmTotalValue = psm.getPsmTotalValue(); + uint256 psmTotalValue = psm.totalAssets(); uint256 startingSeedValue = psm.convertToAssetValue(1e18); @@ -142,8 +142,8 @@ abstract contract PSMInvariantTestBase is PSMTestBase { uint256 seedValue = psm.convertToAssetValue(1e18); // PSM is empty (besides seed amount). - assertEq(psm.totalShares(), 1e18); - assertEq(psm.getPsmTotalValue(), seedValue); + assertEq(psm.totalShares(), 1e18); + assertEq(psm.totalAssets(), seedValue); // Tokens held by LPs are equal to the sum of their previous balance // plus the amount of value originally represented in the PSM's shares. @@ -194,8 +194,8 @@ abstract contract PSMInvariantTestBase is PSMTestBase { ); // All funds can always be withdrawn completely. - assertEq(psm.totalShares(), 0); - assertEq(psm.getPsmTotalValue(), 0); + assertEq(psm.totalShares(), 0); + assertEq(psm.totalAssets(), 0); } } diff --git a/test/unit/Deposit.t.sol b/test/unit/Deposit.t.sol index 32f56c6..0ce2b64 100644 --- a/test/unit/Deposit.t.sol +++ b/test/unit/Deposit.t.sol @@ -300,7 +300,7 @@ contract PSMDepositTests is PSMTestBase { assertEq(psm.convertToAssetValue(psm.shares(receiver1)), 250e18); assertEq(psm.convertToAssetValue(psm.shares(receiver2)), 0); - assertEq(psm.getPsmTotalValue(), 250e18); + assertEq(psm.totalAssets(), 250e18); newShares = psm.deposit(address(sDai), receiver2, 100e18); @@ -325,7 +325,7 @@ contract PSMDepositTests is PSMTestBase { assertEq(psm.convertToAssetValue(psm.shares(receiver1)), 250e18); assertEq(psm.convertToAssetValue(psm.shares(receiver2)), 150e18); - assertEq(psm.getPsmTotalValue(), 400e18); + assertEq(psm.totalAssets(), 400e18); } function testFuzz_deposit_multiUser_changeConversionRate( @@ -398,14 +398,14 @@ contract PSMDepositTests is PSMTestBase { assertEq(psm.convertToAssetValue(psm.shares(receiver2)), 0); - assertApproxEqAbs(psm.getPsmTotalValue(), receiver1NewValue, 1); + assertApproxEqAbs(psm.totalAssets(), receiver1NewValue, 1); newShares = psm.deposit(address(sDai), receiver2, sDaiAmount2); // Using queried values here instead of derived to avoid larger errors getting introduced // Assertions above prove that these values are as expected. uint256 receiver2Shares - = (sDaiAmount2 * newRate / 1e27) * psm.totalShares() / psm.getPsmTotalValue(); + = (sDaiAmount2 * newRate / 1e27) * psm.totalShares() / psm.totalAssets(); assertApproxEqAbs(newShares, receiver2Shares, 2); @@ -426,7 +426,7 @@ contract PSMDepositTests is PSMTestBase { assertApproxEqAbs(psm.convertToAssetValue(psm.shares(receiver1)), receiver1NewValue, 1000); assertApproxEqAbs(psm.convertToAssetValue(psm.shares(receiver2)), receiver2NewValue, 1000); - assertApproxEqAbs(psm.getPsmTotalValue(), receiver1NewValue + receiver2NewValue, 1000); + assertApproxEqAbs(psm.totalAssets(), receiver1NewValue + receiver2NewValue, 1000); } } diff --git a/test/unit/Getters.t.sol b/test/unit/Getters.t.sol index 408d046..f64727e 100644 --- a/test/unit/Getters.t.sol +++ b/test/unit/Getters.t.sol @@ -149,69 +149,69 @@ contract PSMHarnessTests is PSMTestBase { contract GetPsmTotalValueTests is PSMTestBase { - function test_getPsmTotalValue_balanceChanges() public { + function test_totalAssets_balanceChanges() public { dai.mint(address(psm), 1e18); - assertEq(psm.getPsmTotalValue(), 1e18); + assertEq(psm.totalAssets(), 1e18); usdc.mint(address(psm), 1e6); - assertEq(psm.getPsmTotalValue(), 2e18); + assertEq(psm.totalAssets(), 2e18); sDai.mint(address(psm), 1e18); - assertEq(psm.getPsmTotalValue(), 3.25e18); + assertEq(psm.totalAssets(), 3.25e18); dai.burn(address(psm), 1e18); - assertEq(psm.getPsmTotalValue(), 2.25e18); + assertEq(psm.totalAssets(), 2.25e18); usdc.burn(address(psm), 1e6); - assertEq(psm.getPsmTotalValue(), 1.25e18); + assertEq(psm.totalAssets(), 1.25e18); sDai.burn(address(psm), 1e18); - assertEq(psm.getPsmTotalValue(), 0); + assertEq(psm.totalAssets(), 0); } - function test_getPsmTotalValue_conversionRateChanges() public { - assertEq(psm.getPsmTotalValue(), 0); + function test_totalAssets_conversionRateChanges() public { + assertEq(psm.totalAssets(), 0); dai.mint(address(psm), 1e18); usdc.mint(address(psm), 1e6); sDai.mint(address(psm), 1e18); - assertEq(psm.getPsmTotalValue(), 3.25e18); + assertEq(psm.totalAssets(), 3.25e18); mockRateProvider.__setConversionRate(1.5e27); - assertEq(psm.getPsmTotalValue(), 3.5e18); + assertEq(psm.totalAssets(), 3.5e18); mockRateProvider.__setConversionRate(0.8e27); - assertEq(psm.getPsmTotalValue(), 2.8e18); + assertEq(psm.totalAssets(), 2.8e18); } - function test_getPsmTotalValue_bothChange() public { - assertEq(psm.getPsmTotalValue(), 0); + function test_totalAssets_bothChange() public { + assertEq(psm.totalAssets(), 0); dai.mint(address(psm), 1e18); usdc.mint(address(psm), 1e6); sDai.mint(address(psm), 1e18); - assertEq(psm.getPsmTotalValue(), 3.25e18); + assertEq(psm.totalAssets(), 3.25e18); mockRateProvider.__setConversionRate(1.5e27); - assertEq(psm.getPsmTotalValue(), 3.5e18); + assertEq(psm.totalAssets(), 3.5e18); sDai.mint(address(psm), 1e18); - assertEq(psm.getPsmTotalValue(), 5e18); + assertEq(psm.totalAssets(), 5e18); } - function testFuzz_getPsmTotalValue( + function testFuzz_totalAssets( uint256 daiAmount, uint256 usdcAmount, uint256 sDaiAmount, @@ -231,7 +231,7 @@ contract GetPsmTotalValueTests is PSMTestBase { mockRateProvider.__setConversionRate(conversionRate); assertEq( - psm.getPsmTotalValue(), + psm.totalAssets(), daiAmount + (usdcAmount * 1e12) + (sDaiAmount * conversionRate / 1e27) ); } diff --git a/test/unit/Withdraw.t.sol b/test/unit/Withdraw.t.sol index 927deaa..cc054ba 100644 --- a/test/unit/Withdraw.t.sol +++ b/test/unit/Withdraw.t.sol @@ -339,7 +339,7 @@ contract PSMWithdrawTests is PSMTestBase { _checkPsmInvariant(); assertEq( - usdc.balanceOf(receiver1) * 1e12 + psm.getPsmTotalValue(), + usdc.balanceOf(receiver1) * 1e12 + psm.totalAssets(), vars.totalValue ); @@ -366,7 +366,7 @@ contract PSMWithdrawTests is PSMTestBase { _checkPsmInvariant(); assertEq( - (usdc.balanceOf(receiver1) + usdc.balanceOf(receiver2)) * 1e12 + psm.getPsmTotalValue(), + (usdc.balanceOf(receiver1) + usdc.balanceOf(receiver2)) * 1e12 + psm.totalAssets(), vars.totalValue ); @@ -406,7 +406,7 @@ contract PSMWithdrawTests is PSMTestBase { assertApproxEqAbs( (usdc.balanceOf(receiver1) + usdc.balanceOf(receiver2)) * 1e12 + (sDai.balanceOf(receiver2) * rateProvider.getConversionRate() / 1e27) - + psm.getPsmTotalValue(), + + psm.totalAssets(), vars.totalValue, 1 ); @@ -539,7 +539,7 @@ contract PSMWithdrawTests is PSMTestBase { uint256 totalShares = user1Shares + user2Shares; uint256 totalValue = usdcAmount * 1e12 + sDaiAmount * conversionRate / 1e27; - assertEq(psm.getPsmTotalValue(), totalValue); + assertEq(psm.totalAssets(), totalValue); assertEq(psm.totalShares(), totalShares); assertEq(psm.shares(user1), user1Shares);