diff --git a/smart-wallets/src/token/YieldDistributionToken.sol b/smart-wallets/src/token/YieldDistributionToken.sol index 34e709e..472662d 100644 --- a/smart-wallets/src/token/YieldDistributionToken.sol +++ b/smart-wallets/src/token/YieldDistributionToken.sol @@ -25,7 +25,8 @@ abstract contract YieldDistributionToken is ERC20, Ownable, IYieldDistributionTo * @param yieldAccrued Total amount of yield that has ever been accrued to the user * @param yieldWithdrawn Total amount of yield that has ever been withdrawn by the user * @param lastBalanceTimestamp Timestamp of the most recent balance update for the user - * @param lastAmountSeconds AmountSeconds of the user at the time of the most recent deposit + * @param lastDepositAmountSeconds AmountSeconds of the user at the time of the + * most recent deposit that was successfully processed by calling accrueYield */ struct UserState { uint256 amount; @@ -33,7 +34,7 @@ abstract contract YieldDistributionToken is ERC20, Ownable, IYieldDistributionTo uint256 yieldAccrued; uint256 yieldWithdrawn; uint256 lastBalanceTimestamp; - uint256 lastAmountSeconds; + uint256 lastDepositAmountSeconds; } /** @@ -307,7 +308,6 @@ abstract contract YieldDistributionToken is ERC20, Ownable, IYieldDistributionTo UserState memory userState = $.userStates[user]; uint256 depositTimestamp = depositHistory.lastTimestamp; uint256 lastBalanceTimestamp = userState.lastBalanceTimestamp; - uint256 lastAmountSeconds = userState.lastAmountSeconds; /** * There is a race condition in the current implementation that occurs when @@ -330,23 +330,34 @@ abstract contract YieldDistributionToken is ERC20, Ownable, IYieldDistributionTo // Iterate through depositHistory and accrue yield for the user at each deposit timestamp Deposit storage deposit = depositHistory.deposits[depositTimestamp]; uint256 yieldAccrued = 0; + uint256 amountSeconds = userState.amountSeconds; uint256 depositAmount = deposit.currencyTokenAmount; while (depositAmount > 0 && depositTimestamp > lastBalanceTimestamp) { uint256 previousDepositTimestamp = deposit.previousTimestamp; - uint256 previousTotalAmountSeconds = depositHistory.deposits[previousDepositTimestamp].totalAmountSeconds; + uint256 intervalTotalAmountSeconds = + deposit.totalAmountSeconds - depositHistory.deposits[previousDepositTimestamp].totalAmountSeconds; if (previousDepositTimestamp > lastBalanceTimestamp) { - yieldAccrued += _BASE * depositAmount * userState.amount * (depositTimestamp - previousDepositTimestamp) - / (deposit.totalAmountSeconds - previousTotalAmountSeconds); + /** + * There can be a sequence of deposits made while the user balance remains the same throughout. + * Subtract the amountSeconds in this interval to get the total amountSeconds at the previous deposit. + */ + uint256 intervalAmountSeconds = userState.amount * (depositTimestamp - previousDepositTimestamp); + amountSeconds -= intervalAmountSeconds; + yieldAccrued += _BASE * depositAmount * intervalAmountSeconds / intervalTotalAmountSeconds; } else { - yieldAccrued += _BASE * depositAmount * (userState.amountSeconds - lastAmountSeconds) - / (deposit.totalAmountSeconds - previousTotalAmountSeconds); + /** + * At the very end, there can be a sequence of balance updates made right after + * the most recent previously processed deposit and before any other deposits. + */ + yieldAccrued += _BASE * depositAmount * (amountSeconds - userState.lastDepositAmountSeconds) + / intervalTotalAmountSeconds; } depositTimestamp = previousDepositTimestamp; deposit = depositHistory.deposits[depositTimestamp]; depositAmount = deposit.currencyTokenAmount; } - userState.lastAmountSeconds = userState.amountSeconds; + userState.lastDepositAmountSeconds = userState.amountSeconds; userState.amountSeconds += userState.amount * (depositHistory.lastTimestamp - lastBalanceTimestamp); userState.lastBalanceTimestamp = depositHistory.lastTimestamp; userState.yieldAccrued += yieldAccrued / _BASE;