Skip to content

Commit

Permalink
fix: Update to totalAssets (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
lucas-manuel authored Jul 5, 2024
1 parent a1dafe1 commit 74cec31
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 44 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ For detailed implementation, refer to the contract code and `IPSM` interface doc
On the deployment of the PSM, the deployer **MUST make an initial deposit to get AT LEAST 1e18 shares in order to protect the first depositor from getting attacked with a share inflation attack or DOS attack**. Share inflation attack is outlined further [here](https://github.com/marsfoundation/spark-automations/assets/44272939/9472a6d2-0361-48b0-b534-96a0614330d3). Technical details related to this can be found in `test/InflationAttack.t.sol`.

The DOS attack is performed by:
1. Attacker sends funds directly to the PSM. `getPsmTotalValue` now returns a non-zero value.
1. Attacker sends funds directly to the PSM. `totalAssets` now returns a non-zero value.
2. Victim calls deposit. `convertToShares` returns `amount * totalShares / totalValue`. In this case, `totalValue` is non-zero and `totalShares` is zero, so it performs `amount * 0 / totalValue` and returns zero.
3. The victim has `transferFrom` called moving their funds into the PSM, but they receive zero shares so they cannot recover any of their underlying assets. This renders the PSM unusable for all users since this issue will persist. `totalShares` can never be increased in this state.

Expand Down Expand Up @@ -73,7 +73,7 @@ NOTE: These functions do not round in the same way as preview functions, so they

#### Asset Value Functions

- **`getPsmTotalValue`**: Returns the total value of all assets held by the PSM denominated in the base asset with 18 decimal precision. (e.g., USD).
- **`totalAssets`**: Returns the total value of all assets held by the PSM denominated in the base asset with 18 decimal precision. (e.g., USD).

### Events

Expand Down
12 changes: 6 additions & 6 deletions src/PSM3.sol
Original file line number Diff line number Diff line change
Expand Up @@ -214,15 +214,15 @@ contract PSM3 is IPSM3 {
uint256 totalShares_ = totalShares;

if (totalShares_ != 0) {
return numShares * getPsmTotalValue() / totalShares_;
return numShares * totalAssets() / totalShares_;
}
return numShares;
}

function convertToShares(uint256 assetValue) public view override returns (uint256) {
uint256 totalValue = getPsmTotalValue();
if (totalValue != 0) {
return assetValue * totalShares / totalValue;
uint256 totalAssets_ = totalAssets();
if (totalAssets_ != 0) {
return assetValue * totalShares / totalAssets_;
}
return assetValue;
}
Expand All @@ -236,7 +236,7 @@ contract PSM3 is IPSM3 {
/*** Asset value functions ***/
/**********************************************************************************************/

function getPsmTotalValue() public view override returns (uint256) {
function totalAssets() public view override returns (uint256) {
return _getAsset0Value(asset0.balanceOf(address(this)))
+ _getAsset1Value(asset1.balanceOf(address(this)))
+ _getAsset2Value(asset2.balanceOf(address(this)));
Expand Down Expand Up @@ -332,7 +332,7 @@ contract PSM3 is IPSM3 {
/**********************************************************************************************/

function _convertToSharesRoundUp(uint256 assetValue) internal view returns (uint256) {
uint256 totalValue = getPsmTotalValue();
uint256 totalValue = totalAssets();
if (totalValue != 0) {
return _divUp(assetValue * totalShares, totalValue);
}
Expand Down
2 changes: 1 addition & 1 deletion src/interfaces/IPSM3.sol
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,6 @@ interface IPSM3 {
* @dev View function that returns the total value of the balance of all assets in the PSM
* converted to asset0/asset1 terms denominated in 18 decimal precision.
*/
function getPsmTotalValue() external view returns (uint256);
function totalAssets() external view returns (uint256);

}
14 changes: 7 additions & 7 deletions test/invariant/Invariants.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ abstract contract PSMInvariantTestBase is PSMTestBase {

function _checkInvariant_B() public view {
assertApproxEqAbs(
psm.getPsmTotalValue(),
psm.totalAssets(),
psm.convertToAssetValue(psm.totalShares()),
4
);
Expand All @@ -65,7 +65,7 @@ abstract contract PSMInvariantTestBase is PSMTestBase {
psm.convertToAssetValue(psm.shares(address(lpHandler.lps(1)))) +
psm.convertToAssetValue(psm.shares(address(lpHandler.lps(2)))) +
psm.convertToAssetValue(1e18), // Seed amount
psm.getPsmTotalValue(),
psm.totalAssets(),
4
);
}
Expand Down Expand Up @@ -117,7 +117,7 @@ abstract contract PSMInvariantTestBase is PSMTestBase {
uint256 lp1WithdrawsValue = _getLpTokenValue(lp1);
uint256 lp2WithdrawsValue = _getLpTokenValue(lp2);

uint256 psmTotalValue = psm.getPsmTotalValue();
uint256 psmTotalValue = psm.totalAssets();

uint256 startingSeedValue = psm.convertToAssetValue(1e18);

Expand All @@ -142,8 +142,8 @@ abstract contract PSMInvariantTestBase is PSMTestBase {
uint256 seedValue = psm.convertToAssetValue(1e18);

// PSM is empty (besides seed amount).
assertEq(psm.totalShares(), 1e18);
assertEq(psm.getPsmTotalValue(), seedValue);
assertEq(psm.totalShares(), 1e18);
assertEq(psm.totalAssets(), seedValue);

// Tokens held by LPs are equal to the sum of their previous balance
// plus the amount of value originally represented in the PSM's shares.
Expand Down Expand Up @@ -194,8 +194,8 @@ abstract contract PSMInvariantTestBase is PSMTestBase {
);

// All funds can always be withdrawn completely.
assertEq(psm.totalShares(), 0);
assertEq(psm.getPsmTotalValue(), 0);
assertEq(psm.totalShares(), 0);
assertEq(psm.totalAssets(), 0);
}

}
Expand Down
10 changes: 5 additions & 5 deletions test/unit/Deposit.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ contract PSMDepositTests is PSMTestBase {
assertEq(psm.convertToAssetValue(psm.shares(receiver1)), 250e18);
assertEq(psm.convertToAssetValue(psm.shares(receiver2)), 0);

assertEq(psm.getPsmTotalValue(), 250e18);
assertEq(psm.totalAssets(), 250e18);

newShares = psm.deposit(address(sDai), receiver2, 100e18);

Expand All @@ -325,7 +325,7 @@ contract PSMDepositTests is PSMTestBase {
assertEq(psm.convertToAssetValue(psm.shares(receiver1)), 250e18);
assertEq(psm.convertToAssetValue(psm.shares(receiver2)), 150e18);

assertEq(psm.getPsmTotalValue(), 400e18);
assertEq(psm.totalAssets(), 400e18);
}

function testFuzz_deposit_multiUser_changeConversionRate(
Expand Down Expand Up @@ -398,14 +398,14 @@ contract PSMDepositTests is PSMTestBase {

assertEq(psm.convertToAssetValue(psm.shares(receiver2)), 0);

assertApproxEqAbs(psm.getPsmTotalValue(), receiver1NewValue, 1);
assertApproxEqAbs(psm.totalAssets(), receiver1NewValue, 1);

newShares = psm.deposit(address(sDai), receiver2, sDaiAmount2);

// Using queried values here instead of derived to avoid larger errors getting introduced
// Assertions above prove that these values are as expected.
uint256 receiver2Shares
= (sDaiAmount2 * newRate / 1e27) * psm.totalShares() / psm.getPsmTotalValue();
= (sDaiAmount2 * newRate / 1e27) * psm.totalShares() / psm.totalAssets();

assertApproxEqAbs(newShares, receiver2Shares, 2);

Expand All @@ -426,7 +426,7 @@ contract PSMDepositTests is PSMTestBase {
assertApproxEqAbs(psm.convertToAssetValue(psm.shares(receiver1)), receiver1NewValue, 1000);
assertApproxEqAbs(psm.convertToAssetValue(psm.shares(receiver2)), receiver2NewValue, 1000);

assertApproxEqAbs(psm.getPsmTotalValue(), receiver1NewValue + receiver2NewValue, 1000);
assertApproxEqAbs(psm.totalAssets(), receiver1NewValue + receiver2NewValue, 1000);
}

}
38 changes: 19 additions & 19 deletions test/unit/Getters.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -149,69 +149,69 @@ contract PSMHarnessTests is PSMTestBase {

contract GetPsmTotalValueTests is PSMTestBase {

function test_getPsmTotalValue_balanceChanges() public {
function test_totalAssets_balanceChanges() public {
dai.mint(address(psm), 1e18);

assertEq(psm.getPsmTotalValue(), 1e18);
assertEq(psm.totalAssets(), 1e18);

usdc.mint(address(psm), 1e6);

assertEq(psm.getPsmTotalValue(), 2e18);
assertEq(psm.totalAssets(), 2e18);

sDai.mint(address(psm), 1e18);

assertEq(psm.getPsmTotalValue(), 3.25e18);
assertEq(psm.totalAssets(), 3.25e18);

dai.burn(address(psm), 1e18);

assertEq(psm.getPsmTotalValue(), 2.25e18);
assertEq(psm.totalAssets(), 2.25e18);

usdc.burn(address(psm), 1e6);

assertEq(psm.getPsmTotalValue(), 1.25e18);
assertEq(psm.totalAssets(), 1.25e18);

sDai.burn(address(psm), 1e18);

assertEq(psm.getPsmTotalValue(), 0);
assertEq(psm.totalAssets(), 0);
}

function test_getPsmTotalValue_conversionRateChanges() public {
assertEq(psm.getPsmTotalValue(), 0);
function test_totalAssets_conversionRateChanges() public {
assertEq(psm.totalAssets(), 0);

dai.mint(address(psm), 1e18);
usdc.mint(address(psm), 1e6);
sDai.mint(address(psm), 1e18);

assertEq(psm.getPsmTotalValue(), 3.25e18);
assertEq(psm.totalAssets(), 3.25e18);

mockRateProvider.__setConversionRate(1.5e27);

assertEq(psm.getPsmTotalValue(), 3.5e18);
assertEq(psm.totalAssets(), 3.5e18);

mockRateProvider.__setConversionRate(0.8e27);

assertEq(psm.getPsmTotalValue(), 2.8e18);
assertEq(psm.totalAssets(), 2.8e18);
}

function test_getPsmTotalValue_bothChange() public {
assertEq(psm.getPsmTotalValue(), 0);
function test_totalAssets_bothChange() public {
assertEq(psm.totalAssets(), 0);

dai.mint(address(psm), 1e18);
usdc.mint(address(psm), 1e6);
sDai.mint(address(psm), 1e18);

assertEq(psm.getPsmTotalValue(), 3.25e18);
assertEq(psm.totalAssets(), 3.25e18);

mockRateProvider.__setConversionRate(1.5e27);

assertEq(psm.getPsmTotalValue(), 3.5e18);
assertEq(psm.totalAssets(), 3.5e18);

sDai.mint(address(psm), 1e18);

assertEq(psm.getPsmTotalValue(), 5e18);
assertEq(psm.totalAssets(), 5e18);
}

function testFuzz_getPsmTotalValue(
function testFuzz_totalAssets(
uint256 daiAmount,
uint256 usdcAmount,
uint256 sDaiAmount,
Expand All @@ -231,7 +231,7 @@ contract GetPsmTotalValueTests is PSMTestBase {
mockRateProvider.__setConversionRate(conversionRate);

assertEq(
psm.getPsmTotalValue(),
psm.totalAssets(),
daiAmount + (usdcAmount * 1e12) + (sDaiAmount * conversionRate / 1e27)
);
}
Expand Down
8 changes: 4 additions & 4 deletions test/unit/Withdraw.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ contract PSMWithdrawTests is PSMTestBase {
_checkPsmInvariant();

assertEq(
usdc.balanceOf(receiver1) * 1e12 + psm.getPsmTotalValue(),
usdc.balanceOf(receiver1) * 1e12 + psm.totalAssets(),
vars.totalValue
);

Expand All @@ -366,7 +366,7 @@ contract PSMWithdrawTests is PSMTestBase {
_checkPsmInvariant();

assertEq(
(usdc.balanceOf(receiver1) + usdc.balanceOf(receiver2)) * 1e12 + psm.getPsmTotalValue(),
(usdc.balanceOf(receiver1) + usdc.balanceOf(receiver2)) * 1e12 + psm.totalAssets(),
vars.totalValue
);

Expand Down Expand Up @@ -406,7 +406,7 @@ contract PSMWithdrawTests is PSMTestBase {
assertApproxEqAbs(
(usdc.balanceOf(receiver1) + usdc.balanceOf(receiver2)) * 1e12
+ (sDai.balanceOf(receiver2) * rateProvider.getConversionRate() / 1e27)
+ psm.getPsmTotalValue(),
+ psm.totalAssets(),
vars.totalValue,
1
);
Expand Down Expand Up @@ -539,7 +539,7 @@ contract PSMWithdrawTests is PSMTestBase {
uint256 totalShares = user1Shares + user2Shares;
uint256 totalValue = usdcAmount * 1e12 + sDaiAmount * conversionRate / 1e27;

assertEq(psm.getPsmTotalValue(), totalValue);
assertEq(psm.totalAssets(), totalValue);

assertEq(psm.totalShares(), totalShares);
assertEq(psm.shares(user1), user1Shares);
Expand Down

0 comments on commit 74cec31

Please sign in to comment.