diff --git a/src/contracts/Grateful.sol b/src/contracts/Grateful.sol index 4f26852..6f9c629 100644 --- a/src/contracts/Grateful.sol +++ b/src/contracts/Grateful.sol @@ -494,7 +494,7 @@ contract Grateful is IGrateful, Ownable2Step, ReentrancyGuard { } sharesToWithdraw = vault.previewWithdraw(_assets); if (sharesToWithdraw > totalShares) { - revert Grateful_WithdrawExceedsShares(); + revert Grateful_WithdrawExceedsShares(totalShares, sharesToWithdraw); } assetsToWithdraw = _assets; } diff --git a/src/interfaces/IGrateful.sol b/src/interfaces/IGrateful.sol index d82ff56..fda64b2 100644 --- a/src/interfaces/IGrateful.sol +++ b/src/interfaces/IGrateful.sol @@ -138,7 +138,7 @@ interface IGrateful { error Grateful_PaymentIdAlreadyUsed(); /// @notice Thrown when the user tries to withdraw more shares than they have. - error Grateful_WithdrawExceedsShares(); + error Grateful_WithdrawExceedsShares(uint256 totalShares, uint256 sharesToWithdraw); /// @notice Thrown when attempting to remove a token or vault that does not exist. error Grateful_TokenOrVaultNotFound(); diff --git a/test/integration/Grateful.t.sol b/test/integration/Grateful.t.sol index 0b8c7e9..6059f24 100644 --- a/test/integration/Grateful.t.sol +++ b/test/integration/Grateful.t.sol @@ -6,7 +6,7 @@ import {SafeERC20} from "@openzeppelin/contracts/token/ERC20/utils/SafeERC20.sol import {OneTime} from "contracts/OneTime.sol"; import {IntegrationBase} from "test/integration/IntegrationBase.sol"; -contract IntegrationGreeter is IntegrationBase { +contract IntegrationGrateful is IntegrationBase { using SafeERC20 for IERC20; /*////////////////////////////////////////////////////////////// @@ -14,11 +14,16 @@ contract IntegrationGreeter is IntegrationBase { //////////////////////////////////////////////////////////////*/ // Tests for Standard Payments - function test_Payment() public { + function test_Payment( + uint256 amountMultiplier + ) public { + vm.assume(amountMultiplier > 10); + vm.assume(amountMultiplier < 1000); + for (uint256 i = 0; i < _tokens.length; i++) { address tokenAddr = _tokens[i]; string memory symbol = _tokenSymbols[tokenAddr]; - uint256 amount = _tokenAmounts[tokenAddr]; + uint256 amount = _tokenAmounts[tokenAddr] * amountMultiplier; _approveAndPay(_user, _merchant, tokenAddr, amount, _NOT_YIELDING_FUNDS); @@ -50,9 +55,11 @@ contract IntegrationGreeter is IntegrationBase { // Calculate profit before withdrawal uint256 profit = _grateful.calculateProfit(_merchant, tokenAddr); + uint256 assetsToWithdraw = grateful.calculateAssets(_merchant, tokenAddr); + // Merchant withdraws funds vm.prank(_merchant); - _grateful.withdraw(tokenAddr); + _grateful.withdraw(tokenAddr, assetsToWithdraw); // Calculate performance fee after withdrawal uint256 performanceFee = _grateful.calculatePerformanceFee(profit); @@ -80,6 +87,76 @@ contract IntegrationGreeter is IntegrationBase { } } + function test_PaymentYieldingFundsPartialWithdrawal( + uint256 amountMultiplier + ) public { + vm.assume(amountMultiplier > 10); + vm.assume(amountMultiplier < 1000); + + for (uint256 i = 0; i < _tokens.length; i++) { + address tokenAddr = _tokens[i]; + uint256 amount = _tokenAmounts[tokenAddr] * amountMultiplier; + + // Capture owner's initial balance before payment + uint256 ownerInitialBalance = IERC20(tokenAddr).balanceOf(_owner); + + _approveAndPay(_user, _merchant, tokenAddr, amount, _YIELDING_FUNDS); + + // Advance time to accrue yield + vm.warp(block.timestamp + 1 days); + + // Calculate total assets and profit before withdrawal + uint256 totalAssets = _grateful.calculateAssets(_merchant, tokenAddr); + uint256 profitBeforeWithdrawal = _grateful.calculateProfit(_merchant, tokenAddr); + + // Decide on a partial withdrawal amount (e.g., withdraw half of the assets) + uint256 withdrawalAmount = totalAssets / 2; + + // Merchant withdraws partial funds + vm.prank(_merchant); + _grateful.withdraw(tokenAddr, withdrawalAmount); + + // Expected remaining assets after partial withdrawal + uint256 expectedRemainingAssets = totalAssets - withdrawalAmount; + + // Calculate performance fee for the withdrawn amount + uint256 performanceFee = + _grateful.calculatePerformanceFee((profitBeforeWithdrawal * withdrawalAmount) / totalAssets); + + // Verify merchant's balance + assertApproxEqAbs( + IERC20(tokenAddr).balanceOf(_merchant), + withdrawalAmount - performanceFee, + (withdrawalAmount - performanceFee) / 1000, // Precission loss tolerance of 0.1% + string(abi.encodePacked(_tokenSymbols[tokenAddr], ": Merchant balance mismatch after partial withdrawal")) + ); + + // Verify that the merchant still has some assets in the vault + assertApproxEqAbs( + _grateful.calculateAssets(_merchant, tokenAddr), + expectedRemainingAssets, + expectedRemainingAssets / 1000, // Precission loss tolerance of 0.1% + string(abi.encodePacked(_tokenSymbols[tokenAddr], ": Remaining assets mismatch after partial withdrawal")) + ); + + // Verify owner's balance (owner should have received performance fee from the withdrawn profit) + uint256 ownerFinalBalance = IERC20(tokenAddr).balanceOf(_owner); + + uint256 ownerExpectedBalanceIncrease = (amount - _grateful.applyFee(_merchant, amount)) + performanceFee; + + assertApproxEqAbs( + ownerFinalBalance - ownerInitialBalance, + ownerExpectedBalanceIncrease, + ownerExpectedBalanceIncrease / 1000, // Precission loss tolerance of 0.1% + string( + abi.encodePacked( + _tokenSymbols[tokenAddr], ": Owner did not receive correct performance fee after partial withdrawal" + ) + ) + ); + } + } + // Tests for One-Time Payments function test_OneTimePayment() public { for (uint256 i = 0; i < _tokens.length; i++) {