diff --git a/pkg/interfaces/contracts/pool-gyro/IGyroECLPPool.sol b/pkg/interfaces/contracts/pool-gyro/IGyroECLPPool.sol index 7d3b03510..94a9c02aa 100644 --- a/pkg/interfaces/contracts/pool-gyro/IGyroECLPPool.sol +++ b/pkg/interfaces/contracts/pool-gyro/IGyroECLPPool.sol @@ -49,6 +49,14 @@ interface IGyroECLPPool is IBasePool { * and increase the precision. Therefore, the numbers are stored with 38 decimals precision. Please refer to * https://docs.gyro.finance/gyroscope-protocol/technical-documents, document "E-CLP high-precision * calculations.pdf", for further explanations on how to obtain the parameters below. + * + * @param tauAlpha + * @param tauBeta + * @param u from (A chi)_y = lambda * u + v + * @param v from (A chi)_y = lambda * u + v + * @param w from (A chi)_x = w / lambda + z + * @param z from (A chi)_x = w / lambda + z + * @param dSq error in c^2 + s^2 = dSq, used to correct errors in c, s, tau, u,v,w,z calculations */ struct DerivedEclpParams { Vector2 tauAlpha; diff --git a/pkg/pool-gyro/contracts/GyroECLPPool.sol b/pkg/pool-gyro/contracts/GyroECLPPool.sol index c8b729073..6f7d4f8a7 100644 --- a/pkg/pool-gyro/contracts/GyroECLPPool.sol +++ b/pkg/pool-gyro/contracts/GyroECLPPool.sol @@ -2,7 +2,7 @@ // for information on licensing please see the README in the GitHub repository // . -pragma solidity ^0.8.24; +pragma solidity ^0.8.27; import { SafeCast } from "@openzeppelin/contracts/utils/math/SafeCast.sol"; import { IERC20 } from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; @@ -32,6 +32,8 @@ contract GyroECLPPool is IGyroECLPPool, BalancerPoolToken { using FixedPoint for uint256; using SafeCast for *; + bytes32 private constant _POOL_TYPE = "ECLP"; + /// @dev Parameters of the ECLP pool int256 internal immutable _paramsAlpha; int256 internal immutable _paramsBeta; @@ -53,8 +55,6 @@ contract GyroECLPPool is IGyroECLPPool, BalancerPoolToken { int256 internal immutable _z; int256 internal immutable _dSq; - bytes32 private constant _POOL_TYPE = "ECLP"; - constructor(GyroECLPPoolParams memory params, IVault vault) BalancerPoolToken(vault, params.name, params.symbol) { GyroECLPMath.validateParams(params.eclpParams); emit ECLPParamsValidated(true); @@ -97,7 +97,7 @@ contract GyroECLPPool is IGyroECLPPool, BalancerPoolToken { ); if (rounding == Rounding.ROUND_DOWN) { - return currentInvariant.toUint256(); + return (currentInvariant - invErr).toUint256(); } else { return (currentInvariant + invErr).toUint256(); } @@ -119,18 +119,17 @@ contract GyroECLPPool is IGyroECLPPool, BalancerPoolToken { derivedECLPParams ); - // invariant = overestimate in x-component, underestimate in y-component. + // The invariant vector contains the rounded up and rounded down invariant. Both are needed when computing + // the virtual offsets. Depending on tauAlpha and tauBeta values, we want to use the invariant rounded up + // or rounded down to make sure we're conservative in the output. invariant = Vector2( - (currentInvariant + 2 * invErr).toUint256().mulUp(invariantRatio).toInt256(), - currentInvariant.toUint256().mulUp(invariantRatio).toInt256() + (currentInvariant + invErr).toUint256().mulUp(invariantRatio).toInt256(), + (currentInvariant - invErr).toUint256().mulUp(invariantRatio).toInt256() ); - // Edge case check. Should never happen except for insane tokens. - // If this is hit, actually adding the tokens would lead to a revert or (if it - // went through) a deadlock downstream, so we catch it here. - if (invariant.y > GyroECLPMath._MAX_INVARIANT) { - revert GyroECLPMath.MaxInvariantExceeded(); - } + // Edge case check. Should never happen except for insane tokens. If this is hit, actually adding the + // tokens would lead to a revert or (if it went through) a deadlock downstream, so we catch it here. + require(invariant.x < GyroECLPMath._MAX_INVARIANT, GyroECLPMath.MaxInvariantExceeded()); } if (tokenInIndex == 0) { @@ -148,6 +147,7 @@ contract GyroECLPPool is IGyroECLPPool, BalancerPoolToken { /// @inheritdoc IBasePool function onSwap(PoolSwapParams memory request) external view onlyVault returns (uint256) { + // The Vault already checks that index in != index out. bool tokenInIsToken0 = request.indexIn == 0; (EclpParams memory eclpParams, DerivedEclpParams memory derivedECLPParams) = _reconstructECLPParams(); @@ -219,11 +219,11 @@ contract GyroECLPPool is IGyroECLPPool, BalancerPoolToken { /// @inheritdoc IUnbalancedLiquidityInvariantRatioBounds function getMinimumInvariantRatio() external pure returns (uint256) { - return 0; + return GyroECLPMath.MIN_INVARIANT_RATIO; } /// @inheritdoc IUnbalancedLiquidityInvariantRatioBounds function getMaximumInvariantRatio() external pure returns (uint256) { - return type(uint256).max; + return GyroECLPMath.MAX_INVARIANT_RATIO; } } diff --git a/pkg/pool-gyro/contracts/GyroECLPPoolFactory.sol b/pkg/pool-gyro/contracts/GyroECLPPoolFactory.sol index 039473f80..5041f44f5 100644 --- a/pkg/pool-gyro/contracts/GyroECLPPoolFactory.sol +++ b/pkg/pool-gyro/contracts/GyroECLPPoolFactory.sol @@ -1,6 +1,6 @@ // SPDX-License-Identifier: GPL-3.0-or-later -pragma solidity ^0.8.24; +pragma solidity ^0.8.27; import { IERC20 } from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; @@ -53,13 +53,8 @@ contract GyroECLPPoolFactory is BasePoolFactory { address poolHooksContract, bytes32 salt ) external returns (address pool) { - if (tokens.length != 2) { - revert SupportsOnlyTwoTokens(); - } - - if (roleAccounts.poolCreator != address(0)) { - revert StandardPoolWithCreator(); - } + require(tokens.length == 2, SupportsOnlyTwoTokens()); + require(roleAccounts.poolCreator == address(0), StandardPoolWithCreator()); pool = _create( abi.encode( diff --git a/pkg/pool-gyro/contracts/lib/GyroECLPMath.sol b/pkg/pool-gyro/contracts/lib/GyroECLPMath.sol index b5931d0fc..ae9778473 100644 --- a/pkg/pool-gyro/contracts/lib/GyroECLPMath.sol +++ b/pkg/pool-gyro/contracts/lib/GyroECLPMath.sol @@ -2,7 +2,7 @@ // for information on licensing please see the README in the GitHub repository // . -pragma solidity ^0.8.24; +pragma solidity ^0.8.27; import { SafeCast } from "@openzeppelin/contracts/utils/math/SafeCast.sol"; @@ -24,12 +24,17 @@ library GyroECLPMath { using SafeCast for uint256; using SafeCast for int256; - error RotationVectorWrong(); + error RotationVectorSWrong(); + error RotationVectorCWrong(); error RotationVectorNotNormalized(); error AssetBoundsExceeded(); - error DerivedTauNotNormalized(); + error DerivedTauAlphaNotNormalized(); + error DerivedTauBetaNotNormalized(); error StretchingFactorWrong(); - error DerivedUvwzWrong(); + error DerivedUWrong(); + error DerivedVWrong(); + error DerivedWWrong(); + error DerivedZWrong(); error InvariantDenominatorWrong(); error MaxAssetsExceeded(); error MaxInvariantExceeded(); @@ -50,6 +55,11 @@ library GyroECLPMath { int256 internal constant _MAX_BALANCES = 1e34; // 1e16 in normal precision int256 internal constant _MAX_INVARIANT = 3e37; // 3e19 in normal precision + // Invariant growth limit: non-proportional add cannot cause the invariant to increase by more than this ratio. + uint256 public constant MIN_INVARIANT_RATIO = 60e16; // 60% + // Invariant shrink limit: non-proportional remove cannot cause the invariant to decrease by less than this ratio. + uint256 public constant MAX_INVARIANT_RATIO = 500e16; // 500% + struct QParams { int256 a; int256 b; @@ -58,24 +68,17 @@ library GyroECLPMath { /// @dev Enforces limits and approximate normalization of the rotation vector. function validateParams(IGyroECLPPool.EclpParams memory params) internal pure { - if (0 > params.s || params.s > _ONE) { - revert RotationVectorWrong(); - } - - if (0 > params.c || params.c > _ONE) { - revert RotationVectorWrong(); - } + require(params.s > 0 && params.s < _ONE, RotationVectorSWrong()); + require(params.c > 0 && params.c < _ONE, RotationVectorCWrong()); IGyroECLPPool.Vector2 memory sc = IGyroECLPPool.Vector2(params.s, params.c); int256 scnorm2 = scalarProd(sc, sc); // squared norm - if (_ONE - _ROTATION_VECTOR_NORM_ACCURACY > scnorm2 || scnorm2 > _ONE + _ROTATION_VECTOR_NORM_ACCURACY) { - revert RotationVectorNotNormalized(); - } - - if (params.lambda < 0 || params.lambda > _MAX_STRETCH_FACTOR) { - revert StretchingFactorWrong(); - } + require( + scnorm2 > _ONE - _ROTATION_VECTOR_NORM_ACCURACY && scnorm2 < _ONE + _ROTATION_VECTOR_NORM_ACCURACY, + RotationVectorNotNormalized() + ); + require(params.lambda > 0 && params.lambda < _MAX_STRETCH_FACTOR, StretchingFactorWrong()); } /** @@ -89,34 +92,32 @@ library GyroECLPMath { int256 norm2; norm2 = scalarProdXp(derived.tauAlpha, derived.tauAlpha); - if (_ONE_XP - _DERIVED_TAU_NORM_ACCURACY_XP > norm2 || norm2 > _ONE_XP + _DERIVED_TAU_NORM_ACCURACY_XP) { - revert DerivedTauNotNormalized(); - } + require( + norm2 > _ONE_XP - _DERIVED_TAU_NORM_ACCURACY_XP && norm2 < _ONE_XP + _DERIVED_TAU_NORM_ACCURACY_XP, + DerivedTauAlphaNotNormalized() + ); norm2 = scalarProdXp(derived.tauBeta, derived.tauBeta); - if (_ONE_XP - _DERIVED_TAU_NORM_ACCURACY_XP > norm2 || norm2 > _ONE_XP + _DERIVED_TAU_NORM_ACCURACY_XP) { - revert DerivedTauNotNormalized(); - } - - if (derived.u > _ONE_XP) revert DerivedUvwzWrong(); - if (derived.v > _ONE_XP) revert DerivedUvwzWrong(); - if (derived.w > _ONE_XP) revert DerivedUvwzWrong(); - if (derived.z > _ONE_XP) revert DerivedUvwzWrong(); - - if ( - _ONE_XP - _DERIVED_DSQ_NORM_ACCURACY_XP > derived.dSq || - derived.dSq > _ONE_XP + _DERIVED_DSQ_NORM_ACCURACY_XP - ) { - revert DerivedDsqWrong(); - } + require( + norm2 > _ONE_XP - _DERIVED_TAU_NORM_ACCURACY_XP && norm2 < _ONE_XP + _DERIVED_TAU_NORM_ACCURACY_XP, + DerivedTauBetaNotNormalized() + ); + require(derived.u < _ONE_XP, DerivedUWrong()); + require(derived.v < _ONE_XP, DerivedVWrong()); + require(derived.w < _ONE_XP, DerivedWWrong()); + require(derived.z < _ONE_XP, DerivedZWrong()); + + require( + derived.dSq > _ONE_XP - _DERIVED_DSQ_NORM_ACCURACY_XP && + derived.dSq < _ONE_XP + _DERIVED_DSQ_NORM_ACCURACY_XP, + DerivedDsqWrong() + ); // NB No anti-overflow checks are required given the checks done above and in validateParams(). int256 mulDenominator = _ONE_XP.divXpU(calcAChiAChiInXp(params, derived) - _ONE_XP); - if (mulDenominator > _MAX_INV_INVARIANT_DENOMINATOR_XP) { - revert InvariantDenominatorWrong(); - } + require(mulDenominator < _MAX_INV_INVARIANT_DENOMINATOR_XP, InvariantDenominatorWrong()); } function scalarProd( @@ -252,9 +253,7 @@ library GyroECLPMath { ) internal pure returns (int256, int256) { (int256 x, int256 y) = (balances[0].toInt256(), balances[1].toInt256()); - if (x + y > _MAX_BALANCES) { - revert MaxAssetsExceeded(); - } + require(x + y < _MAX_BALANCES, MaxAssetsExceeded()); int256 atAChi = calcAtAChi(x, y, params, derived); (int256 sqrt, int256 err) = calcInvariantSqrt(x, y, params, derived); @@ -297,9 +296,7 @@ library GyroECLPMath { _ONE_XP + 1; - if (invariant + err > _MAX_INVARIANT) { - revert MaxInvariantExceeded(); - } + require(invariant + err < _MAX_INVARIANT, MaxInvariantExceeded()); return (invariant, err); } @@ -517,12 +514,10 @@ library GyroECLPMath { ) internal pure { if (assetIndex == 0) { int256 xPlus = maxBalances0(params, derived, invariant); - if (!(newBal <= _MAX_BALANCES && newBal <= xPlus)) revert AssetBoundsExceeded(); - return; - } - { + require(newBal <= _MAX_BALANCES && newBal <= xPlus, AssetBoundsExceeded()); + } else { int256 yPlus = maxBalances1(params, derived, invariant); - if (!(newBal <= _MAX_BALANCES && newBal <= yPlus)) revert AssetBoundsExceeded(); + require(newBal <= _MAX_BALANCES && newBal <= yPlus, AssetBoundsExceeded()); } } @@ -586,7 +581,7 @@ library GyroECLPMath { calcGiven = calcYGivenX; // this reverses compared to calcOutGivenIn } - if (!(amountOut <= balances[ixOut])) revert AssetBoundsExceeded(); + require(amountOut <= balances[ixOut], AssetBoundsExceeded()); int256 balOutNew = (balances[ixOut] - amountOut).toInt256(); int256 balInNew = calcGiven(balOutNew, params, derived, invariant); // The checks in the following two lines should really always succeed; we keep them as extra safety against diff --git a/pkg/pool-gyro/foundry.toml b/pkg/pool-gyro/foundry.toml index 8b4179b43..cf85ccaf9 100755 --- a/pkg/pool-gyro/foundry.toml +++ b/pkg/pool-gyro/foundry.toml @@ -23,7 +23,7 @@ remappings = [ ] optimizer = true optimizer_runs = 999 -solc_version = '0.8.26' +solc_version = '0.8.27' auto_detect_solc = false evm_version = 'cancun' ignored_error_codes = [2394, 5574, 3860] # Transient storage, code size diff --git a/pkg/pool-gyro/hardhat.config.ts b/pkg/pool-gyro/hardhat.config.ts index dfd54996c..3bbb6a085 100644 --- a/pkg/pool-gyro/hardhat.config.ts +++ b/pkg/pool-gyro/hardhat.config.ts @@ -1,4 +1,5 @@ import { HardhatUserConfig } from 'hardhat/config'; +import { name } from './package.json'; import { hardhatBaseConfig } from '@balancer-labs/v3-common'; @@ -14,6 +15,7 @@ import { warnings } from '@balancer-labs/v3-common/hardhat-base-config'; const config: HardhatUserConfig = { solidity: { compilers: hardhatBaseConfig.compilers, + overrides: { ...hardhatBaseConfig.overrides(name) }, }, warnings, }; diff --git a/pkg/pool-gyro/test/foundry/E2eBatchSwapECLP.t.sol b/pkg/pool-gyro/test/foundry/E2eBatchSwapECLP.t.sol index 00df5682a..ef96857df 100644 --- a/pkg/pool-gyro/test/foundry/E2eBatchSwapECLP.t.sol +++ b/pkg/pool-gyro/test/foundry/E2eBatchSwapECLP.t.sol @@ -35,9 +35,8 @@ contract E2eBatchSwapECLPTest is E2eBatchSwapTest, GyroEclpPoolDeployer { minSwapAmountTokenA = 10 * PRODUCTION_MIN_TRADE_AMOUNT; minSwapAmountTokenD = 10 * PRODUCTION_MIN_TRADE_AMOUNT; - // Divide init amount by 10 to make sure weighted math ratios are respected (Cannot trade more than 30% of pool - // balance). - maxSwapAmountTokenA = poolInitAmount / 10; - maxSwapAmountTokenD = poolInitAmount / 10; + // 25% of pool init amount, so MIN and MAX invariant ratios are not violated. + maxSwapAmountTokenA = poolInitAmount / 4; + maxSwapAmountTokenD = poolInitAmount / 4; } } diff --git a/pkg/pool-gyro/test/foundry/LiquidityApproximationECLP.t.sol b/pkg/pool-gyro/test/foundry/LiquidityApproximationECLP.t.sol index 6b6ec454f..c8eac604e 100644 --- a/pkg/pool-gyro/test/foundry/LiquidityApproximationECLP.t.sol +++ b/pkg/pool-gyro/test/foundry/LiquidityApproximationECLP.t.sol @@ -18,7 +18,7 @@ contract LiquidityApproximationECLPTest is LiquidityApproximationTest, GyroEclpP minSwapFeePercentage = IBasePool(swapPool).getMinimumSwapFeePercentage(); // The invariant of ECLP pools are smaller. - maxAmount = 1e6 * 1e18; + maxAmount = 1e5 * 1e18; } function _createPool( diff --git a/pkg/pool-hooks/contracts/StableSurgeHook.sol b/pkg/pool-hooks/contracts/StableSurgeHook.sol new file mode 100644 index 000000000..fbb742cf3 --- /dev/null +++ b/pkg/pool-hooks/contracts/StableSurgeHook.sol @@ -0,0 +1,309 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +pragma solidity ^0.8.24; + +import { SafeCast } from "@openzeppelin/contracts/utils/math/SafeCast.sol"; + +import { IBasePoolFactory } from "@balancer-labs/v3-interfaces/contracts/vault/IBasePoolFactory.sol"; +import { IHooks } from "@balancer-labs/v3-interfaces/contracts/vault/IHooks.sol"; +import { IVault } from "@balancer-labs/v3-interfaces/contracts/vault/IVault.sol"; +import { + LiquidityManagement, + TokenConfig, + PoolSwapParams, + HookFlags, + SwapKind, + Rounding +} from "@balancer-labs/v3-interfaces/contracts/vault/VaultTypes.sol"; + +import { VaultGuard } from "@balancer-labs/v3-vault/contracts/VaultGuard.sol"; +import { BaseHooks } from "@balancer-labs/v3-vault/contracts/BaseHooks.sol"; +import { SingletonAuthentication } from "@balancer-labs/v3-vault/contracts/SingletonAuthentication.sol"; + +import { FixedPoint } from "@balancer-labs/v3-solidity-utils/contracts/math/FixedPoint.sol"; +import { StableMath } from "@balancer-labs/v3-solidity-utils/contracts/math/StableMath.sol"; +import { ScalingHelpers } from "@balancer-labs/v3-solidity-utils/contracts/helpers/ScalingHelpers.sol"; + +import { StablePool } from "@balancer-labs/v3-pool-stable/contracts/StablePool.sol"; + +import { StableSurgeMedianMath } from "./utils/StableSurgeMedianMath.sol"; + +/** + * @notice Hook that charges a fee on trades that push a pool into an imbalanced state beyond a given threshold. + * @dev Uses the dynamic fee mechanism to apply a "surge" fee on trades that unbalance the pool beyond the threshold. + */ +contract StableSurgeHook is BaseHooks, VaultGuard, SingletonAuthentication { + using FixedPoint for uint256; + using SafeCast for *; + + // Only pools from the allowed factory are able to register and use this hook. + address private immutable _allowedPoolFactory; + + // Percentages are 18-decimal FP values, which fit in 64 bits (sized ensure a single slot). + struct SurgeFeeData { + uint64 thresholdPercentage; + uint64 maxSurgeFeePercentage; + } + + // The default threshold, above which surging will occur. + uint256 private immutable _defaultMaxSurgeFeePercentage; + + // The default threshold, above which surging will occur. + uint256 private immutable _defaultSurgeThresholdPercentage; + + // Store the current threshold and max fee for each pool. + mapping(address pool => SurgeFeeData data) private _surgeFeePoolData; + + /** + * @notice A new `StableSurgeHook` contract has been registered successfully. + * @dev If the registration fails the call will revert, so there will be no event. + * @param pool The pool on which the hook was registered + * @param factory The factory that registered the pool + */ + event StableSurgeHookRegistered(address indexed pool, address indexed factory); + + /** + * @notice The threshold percentage has been changed for a pool in a `StableSurgeHook` contract. + * @dev Note, the initial threshold percentage is set on deployment, and an event is emitted. + * @param pool The pool for which the threshold percentage has been changed + * @param newSurgeThresholdPercentage The new threshold percentage + */ + event ThresholdSurgePercentageChanged(address indexed pool, uint256 newSurgeThresholdPercentage); + + /** + * @notice The maximum surge fee percentage has been changed for a pool in a `StableSurgeHook` contract. + * @dev Note, the initial max surge fee percentage is set on deployment, and an event is emitted. + * @param pool The pool for which the max surge fee percentage has been changed + * @param newMaxSurgeFeePercentage The new max surge fee percentage + */ + event MaxSurgeFeePercentageChanged(address indexed pool, uint256 newMaxSurgeFeePercentage); + + /// @notice The max surge fee and threshold values must be valid percentages. + error InvalidPercentage(); + + modifier withValidPercentage(uint256 percentageValue) { + _ensureValidPercentage(percentageValue); + _; + } + + modifier withPermission(address pool) { + _ensureValidSender(pool); + _; + } + + // Store the current threshold for each pool. + mapping(address pool => uint256 threshold) private _surgeThresholdPercentage; + + constructor( + IVault vault, + uint256 defaultMaxSurgeFeePercentage, + uint256 defaultSurgeThresholdPercentage + ) SingletonAuthentication(vault) VaultGuard(vault) { + _ensureValidPercentage(defaultSurgeThresholdPercentage); + _ensureValidPercentage(defaultMaxSurgeFeePercentage); + + _defaultSurgeThresholdPercentage = defaultSurgeThresholdPercentage; + _defaultMaxSurgeFeePercentage = defaultMaxSurgeFeePercentage; + + // Assumes the hook is deployed by the same factory as the pool. + _allowedPoolFactory = msg.sender; + } + + /// @inheritdoc IHooks + function getHookFlags() public pure override returns (HookFlags memory hookFlags) { + hookFlags.shouldCallComputeDynamicSwapFee = true; + } + + /** + * @notice Getter for the allowed pool factory. + * @dev This will likely be a custom factory that deploys the standard Stable Pool with this hook contract. + */ + function getAllowedPoolFactory() external view returns (address) { + return _allowedPoolFactory; + } + + /** + * @notice Getter for the default maximum surge surge fee percentage. + * @return maxSurgeFeePercentage The default max surge fee percentage for this hook contract + */ + function getDefaultMaxSurgeFeePercentage() external view returns (uint256) { + return _defaultMaxSurgeFeePercentage; + } + + /** + * @notice Getter for the default surge threshold percentage. + * @return surgeThresholdPercentage The default surge threshold percentage for this hook contract + */ + function getDefaultSurgeThresholdPercentage() external view returns (uint256) { + return _defaultSurgeThresholdPercentage; + } + + /** + * @notice Getter for the maximum surge fee percentage for a pool. + * @param pool The pool for which the max surge fee percentage is requested + * @return maxSurgeFeePercentage The max surge fee percentage for the pool + */ + function getMaxSurgeFeePercentage(address pool) external view returns (uint256) { + return _surgeFeePoolData[pool].maxSurgeFeePercentage; + } + + /** + * @notice Getter for the surge threshold percentage for a pool. + * @param pool The pool for which the surge threshold percentage is requested + * @return surgeThresholdPercentage The surge threshold percentage for the pool + */ + function getSurgeThresholdPercentage(address pool) external view returns (uint256) { + return _surgeFeePoolData[pool].thresholdPercentage; + } + + /// @inheritdoc IHooks + function onRegister( + address factory, + address pool, + TokenConfig[] memory, + LiquidityManagement calldata + ) public override onlyVault returns (bool) { + bool isAllowedFactory = factory == _allowedPoolFactory && IBasePoolFactory(factory).isPoolFromFactory(pool); + + if (isAllowedFactory == false) { + return false; + } + + // Initially set the max pool surge percentage to the default (can be changed by the pool swapFeeManager + // in the future). + _setMaxSurgeFeePercentage(pool, _defaultMaxSurgeFeePercentage); + + // Initially set the pool threshold to the default (can be changed by the pool swapFeeManager in the future). + _setSurgeThresholdPercentage(pool, _defaultSurgeThresholdPercentage); + + emit StableSurgeHookRegistered(pool, factory); + + return true; + } + + /// @inheritdoc IHooks + function onComputeDynamicSwapFeePercentage( + PoolSwapParams calldata params, + address pool, + uint256 staticSwapFeePercentage + ) public view override onlyVault returns (bool, uint256) { + return (true, getSurgeFeePercentage(params, pool, staticSwapFeePercentage)); + } + + /** + * @notice Sets the max surge fee percentage. + * @dev This function must be permissioned. If the pool does not have a swap fee manager role set, the max surge + * fee can only be changed by governance. It is initially set to the default max surge fee for this hook contract. + */ + function setMaxSurgeFeePercentage( + address pool, + uint256 newMaxSurgeSurgeFeePercentage + ) external withValidPercentage(newMaxSurgeSurgeFeePercentage) withPermission(pool) { + _setMaxSurgeFeePercentage(pool, newMaxSurgeSurgeFeePercentage); + } + + /** + * @notice Sets the hook threshold percentage. + * @dev This function must be permissioned. If the pool does not have a swap fee manager role set, the surge + * threshold can only be changed by governance. It is initially set to the default threshold for this hook contract. + */ + function setSurgeThresholdPercentage( + address pool, + uint256 newSurgeThresholdPercentage + ) external withValidPercentage(newSurgeThresholdPercentage) withPermission(pool) { + _setSurgeThresholdPercentage(pool, newSurgeThresholdPercentage); + } + + /// @dev Ensure the sender is the swapFeeManager, or default to governance if there is no manager. + function _ensureValidSender(address pool) private view { + address swapFeeManager = _vault.getPoolRoleAccounts(pool).swapFeeManager; + + if (swapFeeManager == address(0)) { + if (_canPerform(getActionId(msg.sig), msg.sender, pool) == false) { + revert SenderNotAllowed(); + } + } else if (swapFeeManager != msg.sender) { + revert SenderNotAllowed(); + } + } + + /** + * @notice Calculate the surge fee percentage. If below threshold, return the standard static swap fee percentage. + * @dev It is public to allow it to be called off-chain. + * @param params Input parameters for the swap (balances needed) + * @param pool The pool we are computing the fee for + * @param staticFeePercentage The static fee percentage for the pool (default if there is no surge) + */ + function getSurgeFeePercentage( + PoolSwapParams calldata params, + address pool, + uint256 staticFeePercentage + ) public view returns (uint256) { + uint256 numTokens = params.balancesScaled18.length; + + uint256 amountCalculatedScaled18 = StablePool(pool).onSwap(params); + + uint256[] memory newBalances = new uint256[](numTokens); + ScalingHelpers.copyToArray(params.balancesScaled18, newBalances); + + if (params.kind == SwapKind.EXACT_IN) { + newBalances[params.indexIn] += params.amountGivenScaled18; + newBalances[params.indexOut] -= amountCalculatedScaled18; + } else { + newBalances[params.indexIn] += amountCalculatedScaled18; + newBalances[params.indexOut] -= params.amountGivenScaled18; + } + + uint256 newTotalImbalance = StableSurgeMedianMath.calculateImbalance(newBalances); + + // If we are balanced, or the balance has improved, do not surge: simply return the regular fee percentage. + if (newTotalImbalance == 0) { + return staticFeePercentage; + } + + uint256 oldTotalImbalance = StableSurgeMedianMath.calculateImbalance(params.balancesScaled18); + + SurgeFeeData storage surgeFeeData = _surgeFeePoolData[pool]; + if (newTotalImbalance <= oldTotalImbalance || newTotalImbalance <= surgeFeeData.thresholdPercentage) { + return staticFeePercentage; + } + + // surgeFee = staticFee + (maxFee - staticFee) * (pctImbalance - pctThreshold) / (1 - pctThreshold). + // + // As you can see from the formula, if it’s unbalanced exactly at the threshold, the last term is 0, + // and the fee is just: static + 0 = static fee. + // As the unbalanced proportion term approaches 1, the fee surge approaches: static + max - static ~= max fee. + // This formula linearly increases the fee from 0 at the threshold up to the maximum fee. + // At 35%, the fee would be 1% + (0.95 - 0.01) * ((0.35 - 0.3)/(0.95-0.3)) = 1% + 0.94 * 0.0769 ~ 8.2%. + // At 50% unbalanced, the fee would be 44%. At 99% unbalanced, the fee would be ~94%, very close to the maximum. + return + staticFeePercentage + + (surgeFeeData.maxSurgeFeePercentage - staticFeePercentage).mulDown( + (newTotalImbalance - surgeFeeData.thresholdPercentage).divDown( + uint256(surgeFeeData.thresholdPercentage).complement() + ) + ); + } + + /// @dev Assumes the percentage value and sender have been externally validated. + function _setMaxSurgeFeePercentage(address pool, uint256 newMaxSurgeFeePercentage) private { + // Still use SafeCast out of an abundance of caution. + _surgeFeePoolData[pool].maxSurgeFeePercentage = newMaxSurgeFeePercentage.toUint64(); + + emit MaxSurgeFeePercentageChanged(pool, newMaxSurgeFeePercentage); + } + + /// @dev Assumes the percentage value and sender have been externally validated. + function _setSurgeThresholdPercentage(address pool, uint256 newSurgeThresholdPercentage) private { + // Still use SafeCast out of an abundance of caution. + _surgeFeePoolData[pool].thresholdPercentage = newSurgeThresholdPercentage.toUint64(); + + emit ThresholdSurgePercentageChanged(pool, newSurgeThresholdPercentage); + } + + function _ensureValidPercentage(uint256 percentage) private pure { + if (percentage > FixedPoint.ONE) { + revert InvalidPercentage(); + } + } +} diff --git a/pkg/pool-hooks/contracts/StableSurgePoolFactory.sol b/pkg/pool-hooks/contracts/StableSurgePoolFactory.sol new file mode 100644 index 000000000..30ec6b597 --- /dev/null +++ b/pkg/pool-hooks/contracts/StableSurgePoolFactory.sol @@ -0,0 +1,115 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +pragma solidity ^0.8.24; + +import { IPoolVersion } from "@balancer-labs/v3-interfaces/contracts/solidity-utils/helpers/IPoolVersion.sol"; +import { IVaultErrors } from "@balancer-labs/v3-interfaces/contracts/vault/IVaultErrors.sol"; +import { IVault } from "@balancer-labs/v3-interfaces/contracts/vault/IVault.sol"; +import { + TokenConfig, + PoolRoleAccounts, + LiquidityManagement +} from "@balancer-labs/v3-interfaces/contracts/vault/VaultTypes.sol"; + +import { BasePoolFactory } from "@balancer-labs/v3-pool-utils/contracts/BasePoolFactory.sol"; +import { StableMath } from "@balancer-labs/v3-solidity-utils/contracts/math/StableMath.sol"; +import { Version } from "@balancer-labs/v3-solidity-utils/contracts/helpers/Version.sol"; +import { StablePool } from "@balancer-labs/v3-pool-stable/contracts/StablePool.sol"; + +import { StableSurgeHook } from "./StableSurgeHook.sol"; + +/// @notice Stable Pool factory that deploys a standard StablePool with a StableSurgeHook. +contract StableSurgePoolFactory is IPoolVersion, BasePoolFactory, Version { + address private immutable _stableSurgeHook; + + string private _poolVersion; + + constructor( + IVault vault, + uint32 pauseWindowDuration, + uint256 defaultMaxSurgeFeePercentage, + uint256 defaultSurgeThresholdPercentage, + string memory factoryVersion, + string memory poolVersion + ) BasePoolFactory(vault, pauseWindowDuration, type(StablePool).creationCode) Version(factoryVersion) { + _poolVersion = poolVersion; + _stableSurgeHook = address( + new StableSurgeHook(vault, defaultMaxSurgeFeePercentage, defaultSurgeThresholdPercentage) + ); + } + + /// @inheritdoc IPoolVersion + function getPoolVersion() external view returns (string memory) { + return _poolVersion; + } + + /** + * @notice Getter for the internally deployed stable surge hook contract. + * @dev This hook will be registered to every pool created by this factory. + * @return address stableSurgeHook Address of the deployed StableSurgeHook + */ + function getStableSurgeHook() external view returns (address) { + return _stableSurgeHook; + } + + /** + * @notice Deploys a new `StablePool`. + * @param name The name of the pool + * @param symbol The symbol of the pool + * @param tokens An array of descriptors for the tokens the pool will manage + * @param amplificationParameter Starting value of the amplificationParameter (see StablePool) + * @param roleAccounts Addresses the Vault will allow to change certain pool settings + * @param swapFeePercentage Initial swap fee percentage + * @param enableDonation If true, the pool will support the donation add liquidity mechanism + * @param disableUnbalancedLiquidity If true, only proportional add and remove liquidity are accepted + * @param salt The salt value that will be passed to create3 deployment + */ + function create( + string memory name, + string memory symbol, + TokenConfig[] memory tokens, + uint256 amplificationParameter, + PoolRoleAccounts memory roleAccounts, + uint256 swapFeePercentage, + bool enableDonation, + bool disableUnbalancedLiquidity, + bytes32 salt + ) external returns (address pool) { + if (roleAccounts.poolCreator != address(0)) { + revert StandardPoolWithCreator(); + } + + // As the Stable Pool deployment does not know about the tokens, and the registration doesn't know about the + // pool type, we enforce the token limit at the factory level. + if (tokens.length > StableMath.MAX_STABLE_TOKENS) { + revert IVaultErrors.MaxTokens(); + } + + LiquidityManagement memory liquidityManagement = getDefaultLiquidityManagement(); + liquidityManagement.enableDonation = enableDonation; + liquidityManagement.disableUnbalancedLiquidity = disableUnbalancedLiquidity; + + pool = _create( + abi.encode( + StablePool.NewPoolParams({ + name: name, + symbol: symbol, + amplificationParameter: amplificationParameter, + version: _poolVersion + }), + getVault() + ), + salt + ); + + _registerPoolWithVault( + pool, + tokens, + swapFeePercentage, + false, // not exempt from protocol fees + roleAccounts, + _stableSurgeHook, + liquidityManagement + ); + } +} diff --git a/pkg/pool-hooks/contracts/test/StableSurgeMedianMathMock.sol b/pkg/pool-hooks/contracts/test/StableSurgeMedianMathMock.sol new file mode 100644 index 000000000..d5198ef6d --- /dev/null +++ b/pkg/pool-hooks/contracts/test/StableSurgeMedianMathMock.sol @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +pragma solidity ^0.8.24; + +import { StableSurgeMedianMath } from "../utils/StableSurgeMedianMath.sol"; + +contract StableSurgeMedianMathMock { + function calculateImbalance(uint256[] memory balancesScaled18) public pure returns (uint) { + return StableSurgeMedianMath.calculateImbalance(balancesScaled18); + } + + function findMedian(uint256[] memory sortedBalancesScaled18) public pure returns (uint256) { + return StableSurgeMedianMath.findMedian(sortedBalancesScaled18); + } + + function absSub(uint256 a, uint256 b) public pure returns (uint256) { + return StableSurgeMedianMath.absSub(a, b); + } +} diff --git a/pkg/pool-hooks/contracts/utils/StableSurgeMedianMath.sol b/pkg/pool-hooks/contracts/utils/StableSurgeMedianMath.sol new file mode 100644 index 000000000..4b5281484 --- /dev/null +++ b/pkg/pool-hooks/contracts/utils/StableSurgeMedianMath.sol @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +pragma solidity ^0.8.24; + +import { Arrays } from "@balancer-labs/v3-solidity-utils/contracts/openzeppelin/Arrays.sol"; +import { FixedPoint } from "@balancer-labs/v3-solidity-utils/contracts/math/FixedPoint.sol"; + +library StableSurgeMedianMath { + using Arrays for uint256[]; + using FixedPoint for uint256; + + function calculateImbalance(uint256[] memory balances) internal pure returns (uint256) { + uint256 median = findMedian(balances); + + uint256 totalBalance = 0; + uint256 totalDiff = 0; + + for (uint i = 0; i < balances.length; i++) { + totalBalance += balances[i]; + totalDiff += absSub(balances[i], median); + } + + return totalDiff.divDown(totalBalance); + } + + function findMedian(uint256[] memory balances) internal pure returns (uint256) { + uint256[] memory sortedBalances = balances.sort(); + uint256 mid = sortedBalances.length / 2; + + if (sortedBalances.length % 2 == 0) { + return (sortedBalances[mid - 1] + sortedBalances[mid]) / 2; + } else { + return sortedBalances[mid]; + } + } + + function absSub(uint256 a, uint256 b) internal pure returns (uint256) { + unchecked { + return a > b ? a - b : b - a; + } + } +} diff --git a/pkg/pool-hooks/test/StableSurgeMedianMath.test.ts b/pkg/pool-hooks/test/StableSurgeMedianMath.test.ts new file mode 100644 index 000000000..91e7a4b80 --- /dev/null +++ b/pkg/pool-hooks/test/StableSurgeMedianMath.test.ts @@ -0,0 +1,63 @@ +import { expect } from 'chai'; +import { deploy } from '@balancer-labs/v3-helpers/src/contract'; +import { StableSurgeMedianMathMock } from '../typechain-types/contracts/test/StableSurgeMedianMathMock'; +import { findMedian } from '@balancer-labs/v3-helpers/src/math/surgeMedianMath'; + +describe('StableSurgeMedianMath', function () { + const MIN_TOKENS = 2; + const MAX_TOKENS = 8; + const TEST_ITERATIONS = 100; + const MAX_VALUE = 100000; + + let surgeMath: StableSurgeMedianMathMock; + + function getRandomInt(min: number, max: number) { + return Math.floor(Math.random() * (max - min + 1)) + min; + } + + before('deploy mock', async () => { + surgeMath = await deploy('v3-pool-hooks/test/StableSurgeMedianMathMock'); + }); + + it('absSub', async () => { + for (let i = 0; i < TEST_ITERATIONS; i++) { + const a = getRandomInt(0, MAX_VALUE); + const b = getRandomInt(0, MAX_VALUE); + const expectedResult = Math.abs(a - b); + expect(await surgeMath.absSub(a, b)).to.eq(expectedResult); + expect(await surgeMath.absSub(b, a)).to.eq(expectedResult); + } + }); + + it('findMedian', async () => { + const worthCaseOne = [800, 700, 600, 500, 400, 300, 200, 100]; + const worthCaseTwo = worthCaseOne.reverse(); + + expect(Number(await surgeMath.findMedian(worthCaseOne))).to.eq(450); + expect(Number(await surgeMath.findMedian(worthCaseTwo))).to.eq(450); + + for (let i = 0; i < TEST_ITERATIONS; i++) { + const randomCase = new Array(getRandomInt(MIN_TOKENS, MAX_TOKENS)).fill(0).map(() => getRandomInt(0, MAX_VALUE)); + expect(Number(await surgeMath.findMedian(randomCase))).to.eq(findMedian(randomCase)); + } + }); + + it('calculateImbalance', async () => { + for (let i = 0; i < TEST_ITERATIONS; i++) { + const randomBalances = new Array(getRandomInt(MIN_TOKENS, MAX_TOKENS)) + .fill(0) + .map(() => getRandomInt(0, MAX_VALUE)); + const median = findMedian(randomBalances); + + let totalDiff = 0; + let totalBalance = 0; + for (let i = 0; i < randomBalances.length; i++) { + totalBalance += randomBalances[i]; + totalDiff += Math.abs(randomBalances[i] - median); + } + + const expectedResult = (BigInt(totalDiff) * BigInt(1e18)) / BigInt(totalBalance); + expect(Number(await surgeMath.calculateImbalance(randomBalances))).to.eq(Number(expectedResult)); + } + }); +}); diff --git a/pkg/pool-hooks/test/foundry/StableSurgeHook.t.sol b/pkg/pool-hooks/test/foundry/StableSurgeHook.t.sol new file mode 100644 index 000000000..1dae972f6 --- /dev/null +++ b/pkg/pool-hooks/test/foundry/StableSurgeHook.t.sol @@ -0,0 +1,227 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +pragma solidity ^0.8.24; + +import { IVault } from "@balancer-labs/v3-interfaces/contracts/vault/IVault.sol"; +import { PoolRoleAccounts } from "@balancer-labs/v3-interfaces/contracts/vault/VaultTypes.sol"; + +import { StablePool } from "@balancer-labs/v3-pool-stable/contracts/StablePool.sol"; + +import { StablePoolFactory } from "@balancer-labs/v3-pool-stable/contracts/StablePoolFactory.sol"; +import { BaseVaultTest } from "@balancer-labs/v3-vault/test/foundry/utils/BaseVaultTest.sol"; + +import { StableMath } from "@balancer-labs/v3-solidity-utils/contracts/math/StableMath.sol"; +import { ArrayHelpers } from "@balancer-labs/v3-solidity-utils/contracts/test/ArrayHelpers.sol"; +import { CastingHelpers } from "@balancer-labs/v3-solidity-utils/contracts/helpers/CastingHelpers.sol"; +import { FixedPoint } from "@balancer-labs/v3-solidity-utils/contracts/math/FixedPoint.sol"; +import { ScalingHelpers } from "@balancer-labs/v3-solidity-utils/contracts/helpers/ScalingHelpers.sol"; +import { PoolSwapParams, SwapKind } from "@balancer-labs/v3-interfaces/contracts/vault/VaultTypes.sol"; + +import { StableSurgeHook } from "../../contracts/StableSurgeHook.sol"; +import { StableSurgeMedianMathMock } from "../../contracts/test/StableSurgeMedianMathMock.sol"; + +contract StableSurgeHookTest is BaseVaultTest { + using ArrayHelpers for *; + using CastingHelpers for *; + using FixedPoint for uint256; + + uint256 internal constant DEFAULT_AMP_FACTOR = 200; + uint256 constant DEFAULT_SURGE_THRESHOLD_PERCENTAGE = 30e16; // 30% + uint256 constant DEFAULT_MAX_SURGE_FEE_PERCENTAGE = 95e16; // 95% + uint256 constant DEFAULT_POOL_TOKEN_COUNT = 2; + + uint256 internal daiIdx; + uint256 internal usdcIdx; + + StablePoolFactory internal stablePoolFactory; + StableSurgeHook internal stableSurgeHook; + + StableSurgeMedianMathMock stableSurgeMedianMathMock = new StableSurgeMedianMathMock(); + + function setUp() public override { + super.setUp(); + + (daiIdx, usdcIdx) = getSortedIndexes(address(dai), address(usdc)); + } + + function createHook() internal override returns (address) { + stablePoolFactory = new StablePoolFactory(IVault(address(vault)), 365 days, "Factory v1", "Pool v1"); + + vm.prank(address(stablePoolFactory)); + stableSurgeHook = new StableSurgeHook( + vault, + DEFAULT_MAX_SURGE_FEE_PERCENTAGE, + DEFAULT_SURGE_THRESHOLD_PERCENTAGE + ); + vm.label(address(stableSurgeHook), "StableSurgeHook"); + return address(stableSurgeHook); + } + + function _createPool( + address[] memory tokens, + string memory label + ) internal override returns (address newPool, bytes memory poolArgs) { + PoolRoleAccounts memory roleAccounts; + + newPool = stablePoolFactory.create( + "Stable Pool", + "STABLEPOOL", + vault.buildTokenConfig(tokens.asIERC20()), + DEFAULT_AMP_FACTOR, + roleAccounts, + swapFeePercentage, + poolHooksContract, + false, + false, + ZERO_BYTES32 + ); + vm.label(address(newPool), label); + + return ( + address(newPool), + abi.encode( + StablePool.NewPoolParams({ + name: "Stable Pool", + symbol: "STABLEPOOL", + amplificationParameter: DEFAULT_AMP_FACTOR, + version: "Pool v1" + }), + vault + ) + ); + } + + function testSuccessfulRegistry() public view { + assertEq( + stableSurgeHook.getSurgeThresholdPercentage(pool), + DEFAULT_SURGE_THRESHOLD_PERCENTAGE, + "Surge threshold is wrong" + ); + } + + function testSwap__Fuzz(uint256 amountGivenScaled18, uint256 swapFeePercentageRaw, uint256 kindRaw) public { + amountGivenScaled18 = bound(amountGivenScaled18, 1e18, poolInitAmount / 2); + SwapKind kind = SwapKind(bound(kindRaw, 0, 1)); + + vault.manuallySetSwapFee(pool, bound(swapFeePercentageRaw, 0, 1e16)); + swapFeePercentage = vault.getStaticSwapFeePercentage(pool); + + BaseVaultTest.Balances memory balancesBefore = getBalances(alice); + + if (kind == SwapKind.EXACT_IN) { + vm.prank(alice); + router.swapSingleTokenExactIn(pool, usdc, dai, amountGivenScaled18, 0, MAX_UINT256, false, bytes("")); + } else { + vm.prank(alice); + router.swapSingleTokenExactOut( + pool, + usdc, + dai, + amountGivenScaled18, + MAX_UINT256, + MAX_UINT256, + false, + bytes("") + ); + } + + uint256 actualSwapFeePercentage = _calculateFee( + amountGivenScaled18, + kind, + [poolInitAmount, poolInitAmount].toMemoryArray() + ); + + BaseVaultTest.Balances memory balancesAfter = getBalances(alice); + + uint256 actualAmountOut = balancesAfter.aliceTokens[daiIdx] - balancesBefore.aliceTokens[daiIdx]; + uint256 actualAmountIn = balancesBefore.aliceTokens[usdcIdx] - balancesAfter.aliceTokens[usdcIdx]; + + uint256 expectedAmountOut; + uint256 expectedAmountIn; + if (kind == SwapKind.EXACT_IN) { + // extract swap fee + expectedAmountIn = amountGivenScaled18; + uint256 swapAmount = amountGivenScaled18.mulUp(actualSwapFeePercentage); + + uint256 amountCalculatedScaled18 = StablePool(pool).onSwap( + PoolSwapParams({ + kind: kind, + indexIn: usdcIdx, + indexOut: daiIdx, + amountGivenScaled18: expectedAmountIn - swapAmount, + balancesScaled18: [poolInitAmount, poolInitAmount].toMemoryArray(), + router: address(0), + userData: bytes("") + }) + ); + + expectedAmountOut = amountCalculatedScaled18; + } else { + expectedAmountOut = amountGivenScaled18; + uint256 amountCalculatedScaled18 = StablePool(pool).onSwap( + PoolSwapParams({ + kind: kind, + indexIn: usdcIdx, + indexOut: daiIdx, + amountGivenScaled18: expectedAmountOut, + balancesScaled18: [poolInitAmount, poolInitAmount].toMemoryArray(), + router: address(0), + userData: bytes("") + }) + ); + expectedAmountIn = + amountCalculatedScaled18 + + amountCalculatedScaled18.mulDivUp(actualSwapFeePercentage, actualSwapFeePercentage.complement()); + } + + assertEq(expectedAmountIn, actualAmountIn, "Amount in should be expectedAmountIn"); + assertEq(expectedAmountOut, actualAmountOut, "Amount out should be expectedAmountOut"); + } + + function _calculateFee( + uint256 amountGivenScaled18, + SwapKind kind, + uint256[] memory balances + ) internal view returns (uint256) { + uint256 amountCalculatedScaled18 = StablePool(pool).onSwap( + PoolSwapParams({ + kind: kind, + indexIn: usdcIdx, + indexOut: daiIdx, + amountGivenScaled18: amountGivenScaled18, + balancesScaled18: balances, + router: address(0), + userData: bytes("") + }) + ); + + uint256[] memory newBalances = new uint256[](balances.length); + ScalingHelpers.copyToArray(balances, newBalances); + + if (kind == SwapKind.EXACT_IN) { + newBalances[usdcIdx] += amountGivenScaled18; + newBalances[daiIdx] -= amountCalculatedScaled18; + } else { + newBalances[usdcIdx] += amountCalculatedScaled18; + newBalances[daiIdx] -= amountGivenScaled18; + } + + uint256 newTotalImbalance = stableSurgeMedianMathMock.calculateImbalance(newBalances); + uint256 oldTotalImbalance = stableSurgeMedianMathMock.calculateImbalance(balances); + + if ( + newTotalImbalance == 0 || + (newTotalImbalance <= oldTotalImbalance || newTotalImbalance <= DEFAULT_SURGE_THRESHOLD_PERCENTAGE) + ) { + return swapFeePercentage; + } + + return + swapFeePercentage + + (stableSurgeHook.getMaxSurgeFeePercentage(pool) - swapFeePercentage).mulDown( + (newTotalImbalance - DEFAULT_SURGE_THRESHOLD_PERCENTAGE).divDown( + DEFAULT_SURGE_THRESHOLD_PERCENTAGE.complement() + ) + ); + } +} diff --git a/pkg/pool-hooks/test/foundry/StableSurgeHookUnit.t.sol b/pkg/pool-hooks/test/foundry/StableSurgeHookUnit.t.sol new file mode 100644 index 000000000..8d2fabc80 --- /dev/null +++ b/pkg/pool-hooks/test/foundry/StableSurgeHookUnit.t.sol @@ -0,0 +1,403 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +pragma solidity ^0.8.24; + +import "forge-std/console.sol"; +import { BaseVaultTest } from "@balancer-labs/v3-vault/test/foundry/utils/BaseVaultTest.sol"; +import { + LiquidityManagement, + TokenConfig, + PoolSwapParams, + HookFlags, + SwapKind, + PoolRoleAccounts +} from "@balancer-labs/v3-interfaces/contracts/vault/VaultTypes.sol"; +import { IVaultExplorer } from "@balancer-labs/v3-interfaces/contracts/vault/IVaultExplorer.sol"; +import { IAuthorizer } from "@balancer-labs/v3-interfaces/contracts/vault/IAuthorizer.sol"; +import { IAuthentication } from "@balancer-labs/v3-interfaces/contracts/solidity-utils/helpers/IAuthentication.sol"; +import { FixedPoint } from "@balancer-labs/v3-solidity-utils/contracts/math/FixedPoint.sol"; +import { ScalingHelpers } from "@balancer-labs/v3-solidity-utils/contracts/helpers/ScalingHelpers.sol"; +import { StablePool } from "@balancer-labs/v3-pool-stable/contracts/StablePool.sol"; + +import { StableSurgeHook } from "../../contracts/StableSurgeHook.sol"; +import { StableSurgeMedianMathMock } from "../../contracts/test/StableSurgeMedianMathMock.sol"; + +contract StableSurgeHookUnitTest is BaseVaultTest { + using FixedPoint for uint256; + + uint256 constant MIN_TOKENS = 2; + uint256 constant MAX_TOKENS = 8; + + uint256 constant DEFAULT_SURGE_THRESHOLD_PERCENTAGE = 30e16; // 30% + uint256 constant DEFAULT_MAX_SURGE_FEE_PERCENTAGE = 95e16; // 95% + uint256 constant STATIC_FEE_PERCENTAGE = 1e16; + + StableSurgeMedianMathMock stableSurgeMedianMathMock = new StableSurgeMedianMathMock(); + StableSurgeHook stableSurgeHook; + LiquidityManagement defaultLiquidityManagement; + + function setUp() public override { + super.setUp(); + + vm.prank(address(factoryMock)); + stableSurgeHook = new StableSurgeHook( + vault, + DEFAULT_MAX_SURGE_FEE_PERCENTAGE, + DEFAULT_SURGE_THRESHOLD_PERCENTAGE + ); + } + + function testOnRegister() public { + assertEq(stableSurgeHook.getSurgeThresholdPercentage(pool), 0, "Surge threshold percentage should be 0"); + + vm.expectEmit(); + emit StableSurgeHook.StableSurgeHookRegistered(pool, address(factoryMock)); + _registerPool(); + + assertEq( + stableSurgeHook.getSurgeThresholdPercentage(pool), + DEFAULT_SURGE_THRESHOLD_PERCENTAGE, + "Surge threshold percentage should be DEFAULT_SURGE_THRESHOLD_PERCENTAGE" + ); + } + + function testOnRegisterWithIncorrectFactory() public { + assertEq(stableSurgeHook.getSurgeThresholdPercentage(pool), 0, "Surge threshold percentage should be 0"); + + vm.prank(address(vault)); + assertEq( + stableSurgeHook.onRegister(address(0), pool, new TokenConfig[](0), defaultLiquidityManagement), + false, + "onRegister should return false" + ); + + assertEq(stableSurgeHook.getSurgeThresholdPercentage(pool), 0, "Surge threshold percentage should be 0"); + } + + function _registerPool() private { + LiquidityManagement memory emptyLiquidityManagement; + + vm.prank(address(vault)); + stableSurgeHook.onRegister(address(factoryMock), pool, new TokenConfig[](0), emptyLiquidityManagement); + } + + function testGetHookFlags() public view { + HookFlags memory hookFlags = HookFlags({ + enableHookAdjustedAmounts: false, + shouldCallBeforeInitialize: false, + shouldCallAfterInitialize: false, + shouldCallComputeDynamicSwapFee: true, + shouldCallBeforeSwap: false, + shouldCallAfterSwap: false, + shouldCallBeforeAddLiquidity: false, + shouldCallAfterAddLiquidity: false, + shouldCallBeforeRemoveLiquidity: false, + shouldCallAfterRemoveLiquidity: false + }); + assertEq(abi.encode(stableSurgeHook.getHookFlags()), abi.encode(hookFlags), "Hook flags should be correct"); + } + + function testGetDefaultSurgeThresholdPercentage() public view { + assertEq( + stableSurgeHook.getDefaultSurgeThresholdPercentage(), + DEFAULT_SURGE_THRESHOLD_PERCENTAGE, + "Default surge threshold percentage should be correct" + ); + } + + function testChangeSurgeThresholdPercentage() public { + uint256 newSurgeThresholdPercentage = 0.5e18; + + vm.expectEmit(); + emit StableSurgeHook.ThresholdSurgePercentageChanged(pool, newSurgeThresholdPercentage); + + PoolRoleAccounts memory poolRoleAccounts = PoolRoleAccounts({ + pauseManager: address(this), + swapFeeManager: address(this), + poolCreator: address(this) + }); + vm.mockCall( + address(vault), + abi.encodeWithSelector(IVaultExplorer.getPoolRoleAccounts.selector, pool), + abi.encode(poolRoleAccounts) + ); + + vm.prank(address(this)); + stableSurgeHook.setSurgeThresholdPercentage(pool, newSurgeThresholdPercentage); + + assertEq( + stableSurgeHook.getSurgeThresholdPercentage(pool), + newSurgeThresholdPercentage, + "Surge threshold percentage should be newSurgeThresholdPercentage" + ); + } + + function testChangeSurgeThresholdPercentageRevertIfValueIsGreaterThanOne() public { + uint256 newSurgeThresholdPercentage = 1.1e18; + + PoolRoleAccounts memory poolRoleAccounts = PoolRoleAccounts({ + pauseManager: address(this), + swapFeeManager: address(this), + poolCreator: address(this) + }); + vm.mockCall( + address(vault), + abi.encodeWithSelector(IVaultExplorer.getPoolRoleAccounts.selector, pool), + abi.encode(poolRoleAccounts) + ); + + vm.expectRevert(StableSurgeHook.InvalidPercentage.selector); + stableSurgeHook.setSurgeThresholdPercentage(pool, newSurgeThresholdPercentage); + } + + function testChangeSurgeThresholdPercentageRevertIfSenderIsNotFeeManager() public { + PoolRoleAccounts memory poolRoleAccounts = PoolRoleAccounts({ + pauseManager: address(0x01), + swapFeeManager: address(0x01), + poolCreator: address(0x01) + }); + vm.mockCall( + address(vault), + abi.encodeWithSelector(IVaultExplorer.getPoolRoleAccounts.selector, pool), + abi.encode(poolRoleAccounts) + ); + + vm.expectRevert(IAuthentication.SenderNotAllowed.selector); + stableSurgeHook.setSurgeThresholdPercentage(pool, 1e18); + } + + function testChangeSurgeThresholdPercentageRevertIfFeeManagerIsZero() public { + PoolRoleAccounts memory poolRoleAccounts = PoolRoleAccounts({ + pauseManager: address(0x00), + swapFeeManager: address(0x00), + poolCreator: address(0x00) + }); + vm.mockCall( + address(vault), + abi.encodeWithSelector(IVaultExplorer.getPoolRoleAccounts.selector, pool), + abi.encode(poolRoleAccounts) + ); + + vm.expectRevert(IAuthentication.SenderNotAllowed.selector); + stableSurgeHook.setSurgeThresholdPercentage(pool, 1e18); + } + + function testGetSurgeFeePercentage__Fuzz( + uint256 length, + uint256 indexIn, + uint256 indexOut, + uint256 amountGivenScaled18, + uint256 kindRaw, + uint256[8] memory rawBalances + ) public { + _registerPool(); + + SwapKind kind; + uint256[] memory balances; + + (length, indexIn, indexOut, amountGivenScaled18, kind, balances) = _boundValues( + length, + indexIn, + indexOut, + amountGivenScaled18, + kindRaw, + rawBalances + ); + PoolSwapParams memory swapParams = _buildSwapParams(indexIn, indexOut, amountGivenScaled18, kind, balances); + uint256 surgeFeePercentage = stableSurgeHook.getSurgeFeePercentage(swapParams, pool, STATIC_FEE_PERCENTAGE); + uint256[] memory newBalances = _computeNewBalances(swapParams); + uint256 expectedFee = _calculateFee( + stableSurgeMedianMathMock.calculateImbalance(newBalances), + stableSurgeMedianMathMock.calculateImbalance(balances) + ); + assertEq(surgeFeePercentage, expectedFee, "Surge fee percentage should be expectedFee"); + } + + function testOnComputeDynamicSwapFeePercentage__Fuzz( + uint256 length, + uint256 indexIn, + uint256 indexOut, + uint256 amountGivenScaled18, + uint256 kindRaw, + uint256[8] memory rawBalances + ) public { + _registerPool(); + + SwapKind kind; + uint256[] memory balances; + + (length, indexIn, indexOut, amountGivenScaled18, kind, balances) = _boundValues( + length, + indexIn, + indexOut, + amountGivenScaled18, + kindRaw, + rawBalances + ); + + PoolSwapParams memory swapParams = _buildSwapParams(indexIn, indexOut, amountGivenScaled18, kind, balances); + vm.prank(address(vault)); + (bool success, uint256 surgeFeePercentage) = stableSurgeHook.onComputeDynamicSwapFeePercentage( + swapParams, + pool, + STATIC_FEE_PERCENTAGE + ); + assertTrue(success, "onComputeDynamicSwapFeePercentage should return true"); + + uint256[] memory newBalances = _computeNewBalances(swapParams); + uint256 expectedFee = _calculateFee( + stableSurgeMedianMathMock.calculateImbalance(newBalances), + stableSurgeMedianMathMock.calculateImbalance(balances) + ); + + assertEq(surgeFeePercentage, expectedFee, "Surge fee percentage should be expectedFee"); + } + + function testGetSurgeFeePercentageWhenNewTotalImbalanceIsZero() public { + _registerPool(); + + uint256 numTokens = 8; + uint256[] memory balances = new uint256[](numTokens); + for (uint256 i = 0; i < numTokens; ++i) { + balances[i] = 1e18; + } + uint256 indexIn = 0; + uint256 indexOut = MAX_TOKENS - 1; + balances[indexIn] = 0; + balances[indexOut] = 2e18; + + uint256 surgeFeePercentage = stableSurgeHook.getSurgeFeePercentage( + _buildSwapParams(indexIn, indexOut, 1e18, SwapKind.EXACT_IN, balances), + pool, + STATIC_FEE_PERCENTAGE + ); + + assertEq(surgeFeePercentage, STATIC_FEE_PERCENTAGE, "Surge fee percentage should be staticFeePercentage"); + } + + function testGetSurgeFeePercentageWhenNewTotalImbalanceLesOrEqOld() public { + _registerPool(); + + uint256 numTokens = 8; + uint256[] memory balances = new uint256[](numTokens); + for (uint256 i = 0; i < numTokens; ++i) { + balances[i] = 1e18; + } + balances[3] = 10000e18; + + uint256 surgeFeePercentage = stableSurgeHook.getSurgeFeePercentage( + _buildSwapParams(0, MAX_TOKENS - 1, 0, SwapKind.EXACT_IN, balances), + pool, + STATIC_FEE_PERCENTAGE + ); + assertEq(surgeFeePercentage, STATIC_FEE_PERCENTAGE, "Surge fee percentage should be staticFeePercentage"); + } + + function testGetSurgeFeePercentageWhenNewTotalImbalanceLessOrEqThreshold() public { + _registerPool(); + + uint256 numTokens = 8; + uint256[] memory balances = new uint256[](numTokens); + for (uint256 i = 0; i < numTokens; ++i) { + balances[i] = 1e18; + } + balances[4] = 2e18; + balances[5] = 2e18; + + uint256 surgeFeePercentage = stableSurgeHook.getSurgeFeePercentage( + _buildSwapParams(0, MAX_TOKENS - 1, 1, SwapKind.EXACT_IN, balances), + pool, + STATIC_FEE_PERCENTAGE + ); + assertEq(surgeFeePercentage, STATIC_FEE_PERCENTAGE, "Surge fee percentage should be staticFeePercentage"); + } + + function _boundValues( + uint256 lengthRaw, + uint256 indexInRaw, + uint256 indexOutRaw, + uint256 amountGivenScaled18Raw, + uint256 kindRaw, + uint256[8] memory balancesRaw + ) + internal + pure + returns ( + uint256 length, + uint256 indexIn, + uint256 indexOut, + uint256 amountGivenScaled18, + SwapKind kind, + uint256[] memory balances + ) + { + length = bound(lengthRaw, MIN_TOKENS, MAX_TOKENS); + balances = new uint256[](length); + for (uint256 i = 0; i < length; i++) { + balances[i] = bound(balancesRaw[i], 1, MAX_UINT128); + } + + indexIn = bound(indexInRaw, 0, length - 1); + indexOut = bound(indexOutRaw, 0, length - 1); + if (indexIn == indexOut) { + indexOut = (indexOut + 1) % length; + } + + kind = SwapKind(bound(kindRaw, 0, 1)); + + amountGivenScaled18 = bound(amountGivenScaled18Raw, 1, balances[indexOut]); + } + + function _buildSwapParams( + uint256 indexIn, + uint256 indexOut, + uint256 amountGivenScaled18, + SwapKind kind, + uint256[] memory balances + ) internal pure returns (PoolSwapParams memory) { + return + PoolSwapParams({ + kind: kind, + indexIn: indexIn, + indexOut: indexOut, + amountGivenScaled18: amountGivenScaled18, + balancesScaled18: balances, + router: address(0), + userData: bytes("") + }); + } + + function _computeNewBalances(PoolSwapParams memory params) internal view returns (uint256[] memory) { + uint256 amountCalculatedScaled18 = StablePool(pool).onSwap(params); + + uint256[] memory newBalances = new uint256[](params.balancesScaled18.length); + ScalingHelpers.copyToArray(params.balancesScaled18, newBalances); + + if (params.kind == SwapKind.EXACT_IN) { + newBalances[params.indexIn] += params.amountGivenScaled18; + newBalances[params.indexOut] -= amountCalculatedScaled18; + } else { + newBalances[params.indexIn] += amountCalculatedScaled18; + newBalances[params.indexOut] -= params.amountGivenScaled18; + } + + return newBalances; + } + + function _calculateFee(uint256 newTotalImbalance, uint256 oldTotalImbalance) internal view returns (uint256) { + if ( + newTotalImbalance == 0 || + (newTotalImbalance <= oldTotalImbalance || newTotalImbalance <= DEFAULT_SURGE_THRESHOLD_PERCENTAGE) + ) { + return STATIC_FEE_PERCENTAGE; + } + + return + STATIC_FEE_PERCENTAGE + + (stableSurgeHook.getMaxSurgeFeePercentage(pool) - STATIC_FEE_PERCENTAGE).mulDown( + (newTotalImbalance - DEFAULT_SURGE_THRESHOLD_PERCENTAGE).divDown( + DEFAULT_SURGE_THRESHOLD_PERCENTAGE.complement() + ) + ); + } +} diff --git a/pkg/pool-hooks/test/foundry/StableSurgeMedianMath.t.sol b/pkg/pool-hooks/test/foundry/StableSurgeMedianMath.t.sol new file mode 100644 index 000000000..3ec4070f9 --- /dev/null +++ b/pkg/pool-hooks/test/foundry/StableSurgeMedianMath.t.sol @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +pragma solidity ^0.8.24; + +import { BaseVaultTest } from "@balancer-labs/v3-vault/test/foundry/utils/BaseVaultTest.sol"; +import { Arrays } from "@balancer-labs/v3-solidity-utils/contracts/openzeppelin/Arrays.sol"; +import { FixedPoint } from "@balancer-labs/v3-solidity-utils/contracts/math/FixedPoint.sol"; +import { InputHelpers } from "@balancer-labs/v3-solidity-utils/contracts/helpers/InputHelpers.sol"; + +import { StableSurgeMedianMathMock } from "../../contracts/test/StableSurgeMedianMathMock.sol"; + +contract StableSurgeMedianMathTest is BaseVaultTest { + using Arrays for uint256[]; + using FixedPoint for uint256; + using InputHelpers for uint256[]; + + uint256 constant MIN_TOKENS = 2; + uint256 constant MAX_TOKENS = 8; + + StableSurgeMedianMathMock stableSurgeMedianMathMock = new StableSurgeMedianMathMock(); + + function testAbsSub__Fuzz(uint256 a, uint256 b) public view { + a = bound(a, 0, MAX_UINT256); + b = bound(b, 0, MAX_UINT256); + + uint256 result; + if (a > b) { + result = a - b; + } else { + result = b - a; + } + + assertEq(stableSurgeMedianMathMock.absSub(a, b), result, "absSub(a,b) has incorrect result"); + assertEq(stableSurgeMedianMathMock.absSub(b, a), result, "absSub(b, a) has incorrect result"); + } + + function testAbsSubWithMinAndMaxValues() public view { + assertEq(stableSurgeMedianMathMock.absSub(0, 0), 0, "abs(0 - 0) != 0"); + assertEq(stableSurgeMedianMathMock.absSub(0, 1), 1, "abs(0 - 1) != 1"); + assertEq(stableSurgeMedianMathMock.absSub(1, 0), 1, "abs(1 - 0) != 1"); + assertEq( + stableSurgeMedianMathMock.absSub(MAX_UINT256, 1), + MAX_UINT256 - 1, + "abs(MAX_UINT256 - 1) != MAX_UINT256 - 1" + ); + assertEq( + stableSurgeMedianMathMock.absSub(1, MAX_UINT256), + MAX_UINT256 - 1, + "abs(1 - MAX_UINT256) != MAX_UINT256 - 1" + ); + assertEq(stableSurgeMedianMathMock.absSub(MAX_UINT256, 0), MAX_UINT256, "abs(MAX_UINT256 - 0) != MAX_UINT256"); + assertEq(stableSurgeMedianMathMock.absSub(0, MAX_UINT256), MAX_UINT256, "abs(0 - MAX_UINT256) != MAX_UINT256"); + } + + function testFindMedian__Fuzz(uint256 length, uint256[8] memory rawBalances) public view { + length = bound(length, MIN_TOKENS, MAX_TOKENS); + uint256[] memory balances = new uint256[](length); + for (uint256 i = 0; i < length; i++) { + balances[i] = bound(rawBalances[i], 0, MAX_UINT128); + } + + uint256[] memory sortedBalances = balances.sort(); + sortedBalances.ensureSortedAmounts(); + + uint256 expectedMedian; + uint256 mid = length / 2; + if (length % 2 == 0) { + expectedMedian = (sortedBalances[mid - 1] + sortedBalances[mid]) / 2; + } else { + expectedMedian = sortedBalances[mid]; + } + + uint256 median = stableSurgeMedianMathMock.findMedian(balances); + assertEq(median, expectedMedian, "Median is not correct"); + } + + function testCalculateImbalance__Fuzz(uint256 length, uint256[8] memory rawBalances) public view { + length = bound(length, MIN_TOKENS, MAX_TOKENS); + uint256[] memory balances = new uint256[](length); + for (uint256 i = 0; i < length; i++) { + balances[i] = bound(rawBalances[i], 1, MAX_UINT128); + } + + uint256 median = stableSurgeMedianMathMock.findMedian(balances); + uint256 totalBalance = 0; + uint256 totalDiffs = 0; + + for (uint256 i = 0; i < balances.length; i++) { + totalBalance += balances[i]; + + totalDiffs += stableSurgeMedianMathMock.absSub(balances[i], median); + } + + uint256 expectedImbalance = totalDiffs.divDown(totalBalance); + + uint256 imbalance = stableSurgeMedianMathMock.calculateImbalance(balances); + assertEq(imbalance, expectedImbalance, "Imbalance is not correct"); + } +} diff --git a/pkg/pool-hooks/test/foundry/StableSurgePoolFactory.t.sol b/pkg/pool-hooks/test/foundry/StableSurgePoolFactory.t.sol new file mode 100644 index 000000000..011eff63a --- /dev/null +++ b/pkg/pool-hooks/test/foundry/StableSurgePoolFactory.t.sol @@ -0,0 +1,198 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +pragma solidity ^0.8.24; + +import "forge-std/Test.sol"; + +import { Strings } from "@openzeppelin/contracts/utils/Strings.sol"; + +import { IVersion } from "@balancer-labs/v3-interfaces/contracts/solidity-utils/helpers/IVersion.sol"; +import { IVaultErrors } from "@balancer-labs/v3-interfaces/contracts/vault/IVaultErrors.sol"; +import { IVault } from "@balancer-labs/v3-interfaces/contracts/vault/IVault.sol"; +import "@balancer-labs/v3-interfaces/contracts/vault/VaultTypes.sol"; + +import { CastingHelpers } from "@balancer-labs/v3-solidity-utils/contracts/helpers/CastingHelpers.sol"; +import { ArrayHelpers } from "@balancer-labs/v3-solidity-utils/contracts/test/ArrayHelpers.sol"; +import { BalancerPoolToken } from "@balancer-labs/v3-vault/contracts/BalancerPoolToken.sol"; +import { BaseVaultTest } from "@balancer-labs/v3-vault/test/foundry/utils/BaseVaultTest.sol"; +import { StableMath } from "@balancer-labs/v3-solidity-utils/contracts/math/StableMath.sol"; + +import { StableSurgePoolFactoryDeployer } from "./utils/StableSurgePoolFactoryDeployer.sol"; +import { StableSurgePoolFactory } from "../../contracts/StableSurgePoolFactory.sol"; + +contract StableSurgePoolFactoryTest is BaseVaultTest, StableSurgePoolFactoryDeployer { + using CastingHelpers for address[]; + using ArrayHelpers for *; + + string private constant FACTORY_VERSION = "Factory v1"; + string private constant POOL_VERSION = "Pool v1"; + + uint256 internal daiIdx; + uint256 internal usdcIdx; + + // Maximum swap fee of 10% + uint64 public constant MAX_SWAP_FEE_PERCENTAGE = 10e16; + + StableSurgePoolFactory internal stablePoolFactory; + + uint256 internal constant DEFAULT_AMP_FACTOR = 200; + + function setUp() public override { + super.setUp(); + + stablePoolFactory = deployStableSurgePoolFactory( + IVault(address(vault)), + 365 days, + FACTORY_VERSION, + POOL_VERSION + ); + vm.label(address(stablePoolFactory), "stable pool factory"); + + (daiIdx, usdcIdx) = getSortedIndexes(address(dai), address(usdc)); + } + + function testFactoryHasHook() public { + address surgeHook = stablePoolFactory.getStableSurgeHook(); + assertNotEq(surgeHook, address(0), "No surge hook deployed"); + + address stablePool = _deployAndInitializeStablePool(false); + HooksConfig memory config = vault.getHooksConfig(stablePool); + + assertEq(config.hooksContract, surgeHook, "Hook contract mismatch"); + } + + function testVersions() public { + address stablePool = _deployAndInitializeStablePool(false); + + assertEq(IVersion(stablePoolFactory).version(), FACTORY_VERSION, "Wrong factory version"); + assertEq(stablePoolFactory.getPoolVersion(), POOL_VERSION, "Wrong pool version in factory"); + assertEq(IVersion(stablePool).version(), POOL_VERSION, "Wrong pool version in pool"); + } + + function testFactoryPausedState() public view { + uint32 pauseWindowDuration = stablePoolFactory.getPauseWindowDuration(); + assertEq(pauseWindowDuration, 365 days); + } + + function testCreatePoolWithoutDonation() public { + address stablePool = _deployAndInitializeStablePool(false); + + // Try to donate but fails because pool does not support donations + vm.prank(bob); + vm.expectRevert(IVaultErrors.DoesNotSupportDonation.selector); + router.donate(stablePool, [poolInitAmount, poolInitAmount].toMemoryArray(), false, bytes("")); + } + + function testCreatePoolWithDonation() public { + uint256 amountToDonate = poolInitAmount; + + address stablePool = _deployAndInitializeStablePool(true); + + HookTestLocals memory vars = _createHookTestLocals(stablePool); + + // Donates to pool successfully + vm.prank(bob); + router.donate(stablePool, [amountToDonate, amountToDonate].toMemoryArray(), false, bytes("")); + + _fillAfterHookTestLocals(vars, stablePool); + + // Bob balances + assertEq(vars.bob.daiBefore - vars.bob.daiAfter, amountToDonate, "Bob DAI balance is wrong"); + assertEq(vars.bob.usdcBefore - vars.bob.usdcAfter, amountToDonate, "Bob USDC balance is wrong"); + assertEq(vars.bob.bptAfter, vars.bob.bptBefore, "Bob BPT balance is wrong"); + + // Pool balances + assertEq(vars.poolAfter[daiIdx] - vars.poolBefore[daiIdx], amountToDonate, "Pool DAI balance is wrong"); + assertEq(vars.poolAfter[usdcIdx] - vars.poolBefore[usdcIdx], amountToDonate, "Pool USDC balance is wrong"); + assertEq(vars.bptSupplyAfter, vars.bptSupplyBefore, "Pool BPT supply is wrong"); + + // Vault Balances + assertEq(vars.vault.daiAfter - vars.vault.daiBefore, amountToDonate, "Vault DAI balance is wrong"); + assertEq(vars.vault.usdcAfter - vars.vault.usdcBefore, amountToDonate, "Vault USDC balance is wrong"); + } + + function testCreatePoolWithTooManyTokens() public { + IERC20[] memory bigPoolTokens = new IERC20[](StableMath.MAX_STABLE_TOKENS + 1); + for (uint256 i = 0; i < bigPoolTokens.length; ++i) { + bigPoolTokens[i] = createERC20(string.concat("TKN", Strings.toString(i)), 18); + } + + TokenConfig[] memory tokenConfig = vault.buildTokenConfig(bigPoolTokens); + PoolRoleAccounts memory roleAccounts; + + vm.expectRevert(IVaultErrors.MaxTokens.selector); + stablePoolFactory.create( + "Big Pool", + "TOO_BIG", + tokenConfig, + DEFAULT_AMP_FACTOR, + roleAccounts, + MAX_SWAP_FEE_PERCENTAGE, + false, + false, + ZERO_BYTES32 + ); + } + + function _deployAndInitializeStablePool(bool supportsDonation) private returns (address) { + PoolRoleAccounts memory roleAccounts; + IERC20[] memory tokens = [address(dai), address(usdc)].toMemoryArray().asIERC20(); + + address stablePool = stablePoolFactory.create( + supportsDonation ? "Pool With Donation" : "Pool Without Donation", + supportsDonation ? "PwD" : "PwoD", + vault.buildTokenConfig(tokens), + DEFAULT_AMP_FACTOR, + roleAccounts, + MAX_SWAP_FEE_PERCENTAGE, + supportsDonation, + false, // Do not disable unbalanced add/remove liquidity + ZERO_BYTES32 + ); + + // Initialize pool + vm.prank(lp); + router.initialize(stablePool, tokens, [poolInitAmount, poolInitAmount].toMemoryArray(), 0, false, bytes("")); + + return stablePool; + } + + struct WalletState { + uint256 daiBefore; + uint256 daiAfter; + uint256 usdcBefore; + uint256 usdcAfter; + uint256 bptBefore; + uint256 bptAfter; + } + + struct HookTestLocals { + WalletState bob; + WalletState hook; + WalletState vault; + uint256[] poolBefore; + uint256[] poolAfter; + uint256 bptSupplyBefore; + uint256 bptSupplyAfter; + } + + function _createHookTestLocals(address pool) private view returns (HookTestLocals memory vars) { + vars.bob.daiBefore = dai.balanceOf(bob); + vars.bob.usdcBefore = usdc.balanceOf(bob); + vars.bob.bptBefore = IERC20(pool).balanceOf(bob); + vars.vault.daiBefore = dai.balanceOf(address(vault)); + vars.vault.usdcBefore = usdc.balanceOf(address(vault)); + vars.poolBefore = vault.getRawBalances(pool); + vars.bptSupplyBefore = BalancerPoolToken(pool).totalSupply(); + } + + function _fillAfterHookTestLocals(HookTestLocals memory vars, address pool) private view { + vars.bob.daiAfter = dai.balanceOf(bob); + vars.bob.usdcAfter = usdc.balanceOf(bob); + vars.bob.bptAfter = IERC20(pool).balanceOf(bob); + vars.vault.daiAfter = dai.balanceOf(address(vault)); + vars.vault.usdcAfter = usdc.balanceOf(address(vault)); + vars.poolAfter = vault.getRawBalances(pool); + vars.bptSupplyAfter = BalancerPoolToken(pool).totalSupply(); + } +} diff --git a/pkg/pool-hooks/test/foundry/utils/StableSurgePoolFactoryDeployer.sol b/pkg/pool-hooks/test/foundry/utils/StableSurgePoolFactoryDeployer.sol new file mode 100644 index 000000000..a19a9f08c --- /dev/null +++ b/pkg/pool-hooks/test/foundry/utils/StableSurgePoolFactoryDeployer.sol @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +pragma solidity ^0.8.24; + +import { Test } from "forge-std/Test.sol"; + +import { IVault } from "@balancer-labs/v3-interfaces/contracts/vault/IVault.sol"; + +import { BaseContractsDeployer } from "@balancer-labs/v3-solidity-utils/test/foundry/utils/BaseContractsDeployer.sol"; + +import { StableSurgePoolFactory } from "../../../contracts/StableSurgePoolFactory.sol"; + +/** + * @dev This contract contains functions for deploying mocks and contracts related to the "StablePool". These functions should have support for reusing artifacts from the hardhat compilation. + */ +contract StableSurgePoolFactoryDeployer is BaseContractsDeployer { + uint256 public constant DEFAULT_SURGE_THRESHOLD_PERCENTAGE = 30e16; // 30% + uint256 public constant DEFAULT_MAX_SURGE_FEE_PERCENTAGE = 95e16; // 95% + + string private artifactsRootDir = "artifacts/"; + + constructor() { + // if this external artifact path exists, it means we are running outside of this repo + if (vm.exists("artifacts/@balancer-labs/v3-pool-hooks/")) { + artifactsRootDir = "artifacts/@balancer-labs/v3-pool-hooks/"; + } + } + + function deployStableSurgePoolFactory( + IVault vault, + uint32 pauseWindowDuration, + string memory factoryVersion, + string memory poolVersion + ) internal returns (StableSurgePoolFactory) { + if (reusingArtifacts) { + return + StableSurgePoolFactory( + deployCode( + "artifacts/contracts/StableSurgePoolFactory.sol/StableSurgePoolFactory.json", + abi.encode( + vault, + pauseWindowDuration, + DEFAULT_MAX_SURGE_FEE_PERCENTAGE, + DEFAULT_SURGE_THRESHOLD_PERCENTAGE, + factoryVersion, + poolVersion + ) + ) + ); + } else { + return + new StableSurgePoolFactory( + vault, + pauseWindowDuration, + DEFAULT_MAX_SURGE_FEE_PERCENTAGE, + DEFAULT_SURGE_THRESHOLD_PERCENTAGE, + factoryVersion, + poolVersion + ); + } + } + + function _computeStablePoolPath(string memory name) private view returns (string memory) { + return string(abi.encodePacked(artifactsRootDir, "contracts/", name, ".sol/", name, ".json")); + } +} diff --git a/pkg/pool-stable/contracts/StablePool.sol b/pkg/pool-stable/contracts/StablePool.sol index 59d40a372..947375ef2 100644 --- a/pkg/pool-stable/contracts/StablePool.sol +++ b/pkg/pool-stable/contracts/StablePool.sol @@ -166,7 +166,7 @@ contract StablePool is IStablePool, BalancerPoolToken, BasePoolAuthentication, P } /// @inheritdoc IBasePool - function onSwap(PoolSwapParams memory request) public view virtual onlyVault returns (uint256) { + function onSwap(PoolSwapParams memory request) public view virtual returns (uint256) { uint256 invariant = computeInvariant(request.balancesScaled18, Rounding.ROUND_DOWN); (uint256 currentAmp, ) = _getAmplificationParameter(); diff --git a/pkg/pool-stable/contracts/StablePoolFactory.sol b/pkg/pool-stable/contracts/StablePoolFactory.sol index 24d5f9211..eaf685f51 100644 --- a/pkg/pool-stable/contracts/StablePoolFactory.sol +++ b/pkg/pool-stable/contracts/StablePoolFactory.sol @@ -18,7 +18,7 @@ import { Version } from "@balancer-labs/v3-solidity-utils/contracts/helpers/Vers import { StablePool } from "./StablePool.sol"; /** - * @notice General Stable Pool factory + * @notice General Stable Pool factory. * @dev This is the most general factory, which allows up to `StableMath.MAX_STABLE_TOKENS` (5) tokens. * Since this limit is less than Vault's maximum of 8 tokens, we need to enforce this at the factory level. */ diff --git a/pkg/pool-stable/test/foundry/E2eBatchSwap.t.sol b/pkg/pool-stable/test/foundry/E2eBatchSwap.t.sol index b40c18e2a..a70567325 100644 --- a/pkg/pool-stable/test/foundry/E2eBatchSwap.t.sol +++ b/pkg/pool-stable/test/foundry/E2eBatchSwap.t.sol @@ -36,10 +36,9 @@ contract E2eBatchSwapStableTest is E2eBatchSwapTest, StablePoolContractsDeployer minSwapAmountTokenA = 10 * PRODUCTION_MIN_TRADE_AMOUNT; minSwapAmountTokenD = 10 * PRODUCTION_MIN_TRADE_AMOUNT; - // Divide init amount by 10 to make sure weighted math ratios are respected (Cannot trade more than 30% of pool - // balance). - maxSwapAmountTokenA = poolInitAmount / 10; - maxSwapAmountTokenD = poolInitAmount / 10; + // 25% of pool init amount, so MIN and MAX invariant ratios are not violated. + maxSwapAmountTokenA = poolInitAmount / 4; + maxSwapAmountTokenD = poolInitAmount / 4; } /// @notice Overrides BaseVaultTest _createPool(). This pool is used by E2eBatchSwapTest tests. diff --git a/pkg/pool-stable/test/foundry/utils/StablePoolContractsDeployer.sol b/pkg/pool-stable/test/foundry/utils/StablePoolContractsDeployer.sol index 3c6576799..c7d6917cb 100644 --- a/pkg/pool-stable/test/foundry/utils/StablePoolContractsDeployer.sol +++ b/pkg/pool-stable/test/foundry/utils/StablePoolContractsDeployer.sol @@ -15,6 +15,9 @@ import { StablePool } from "../../../contracts/StablePool.sol"; * @dev This contract contains functions for deploying mocks and contracts related to the "StablePool". These functions should have support for reusing artifacts from the hardhat compilation. */ contract StablePoolContractsDeployer is BaseContractsDeployer { + uint256 public constant DEFAULT_SURGE_THRESHOLD_PERCENTAGE = 30e16; // 30% + uint256 public constant DEFAULT_MAX_SURGE_FEE_PERCENTAGE = 95e16; // 95% + string private artifactsRootDir = "artifacts/"; constructor() { diff --git a/pkg/solidity-utils/contracts/helpers/InputHelpers.sol b/pkg/solidity-utils/contracts/helpers/InputHelpers.sol index c8b8700f2..1c334199b 100644 --- a/pkg/solidity-utils/contracts/helpers/InputHelpers.sol +++ b/pkg/solidity-utils/contracts/helpers/InputHelpers.sol @@ -91,6 +91,10 @@ library InputHelpers { /// @dev Ensure an array of tokens is sorted. As above, does not validate length or uniqueness. function ensureSortedTokens(IERC20[] memory tokens) internal pure { + if (tokens.length < 2) { + return; + } + IERC20 previous = tokens[0]; for (uint256 i = 1; i < tokens.length; ++i) { @@ -103,4 +107,23 @@ library InputHelpers { previous = current; } } + + /// @dev Ensure an array of amounts is sorted. As above, does not validate length or uniqueness. + function ensureSortedAmounts(uint256[] memory amounts) internal pure { + if (amounts.length < 2) { + return; + } + + uint256 previous = amounts[0]; + + for (uint256 i = 1; i < amounts.length; ++i) { + uint256 current = amounts[i]; + + if (previous > current) { + revert TokensNotSorted(); + } + + previous = current; + } + } } diff --git a/pkg/solidity-utils/contracts/openzeppelin/Arrays.sol b/pkg/solidity-utils/contracts/openzeppelin/Arrays.sol new file mode 100644 index 000000000..897dc956e --- /dev/null +++ b/pkg/solidity-utils/contracts/openzeppelin/Arrays.sol @@ -0,0 +1,237 @@ +// SPDX-License-Identifier: MIT +// OpenZeppelin Contracts (last updated v5.1.0) (utils/Arrays.sol) +// This file was procedurally generated from scripts/generate/templates/Arrays.js. + +// NOTE: This file copied only the sort and necessary helper functions from Arrays.sol (OpenZeppelin Contracts v5.1.0) +pragma solidity ^0.8.20; + +/** + * @dev Standard math utilities missing in the Solidity language. + */ +library Math { + /** + * @dev Returns the average of two numbers. The result is rounded towards + * zero. + */ + function average(uint256 a, uint256 b) internal pure returns (uint256) { + // (a + b) / 2 can overflow. + return (a & b) + (a ^ b) / 2; + } +} + +/** + * @dev Provides a set of functions to compare values. + * + * _Available since v5.1._ + */ +library Comparators { + function lt(uint256 a, uint256 b) internal pure returns (bool) { + return a < b; + } + + function gt(uint256 a, uint256 b) internal pure returns (bool) { + return a > b; + } +} + +/** + * @dev Collection of functions related to array types. + */ +library Arrays { + /** + * @dev Sort an array of uint256 (in memory) following the provided comparator function. + * + * This function does the sorting "in place", meaning that it overrides the input. The object is returned for + * convenience, but that returned value can be discarded safely if the caller has a memory pointer to the array. + * + * NOTE: this function's cost is `O(n · log(n))` in average and `O(n²)` in the worst case, with n the length of the + * array. Using it in view functions that are executed through `eth_call` is safe, but one should be very careful + * when executing this as part of a transaction. If the array being sorted is too large, the sort operation may + * consume more gas than is available in a block, leading to potential DoS. + * + * IMPORTANT: Consider memory side-effects when using custom comparator functions that access memory in an + * unsafe way. + */ + function sort( + uint256[] memory array, + function(uint256, uint256) pure returns (bool) comp + ) internal pure returns (uint256[] memory) { + _quickSort(_begin(array), _end(array), comp); + return array; + } + + /** + * @dev Variant of {sort} that sorts an array of uint256 in increasing order. + */ + function sort(uint256[] memory array) internal pure returns (uint256[] memory) { + sort(array, Comparators.lt); + return array; + } + + /** + * @dev Sort an array of address (in memory) following the provided comparator function. + * + * This function does the sorting "in place", meaning that it overrides the input. The object is returned for + * convenience, but that returned value can be discarded safely if the caller has a memory pointer to the array. + * + * NOTE: this function's cost is `O(n · log(n))` in average and `O(n²)` in the worst case, with n the length of the + * array. Using it in view functions that are executed through `eth_call` is safe, but one should be very careful + * when executing this as part of a transaction. If the array being sorted is too large, the sort operation may + * consume more gas than is available in a block, leading to potential DoS. + * + * IMPORTANT: Consider memory side-effects when using custom comparator functions that access memory in an + * unsafe way. + */ + function sort( + address[] memory array, + function(address, address) pure returns (bool) comp + ) internal pure returns (address[] memory) { + sort(_castToUint256Array(array), _castToUint256Comp(comp)); + return array; + } + + /** + * @dev Variant of {sort} that sorts an array of address in increasing order. + */ + function sort(address[] memory array) internal pure returns (address[] memory) { + sort(_castToUint256Array(array), Comparators.lt); + return array; + } + + /** + * @dev Sort an array of bytes32 (in memory) following the provided comparator function. + * + * This function does the sorting "in place", meaning that it overrides the input. The object is returned for + * convenience, but that returned value can be discarded safely if the caller has a memory pointer to the array. + * + * NOTE: this function's cost is `O(n · log(n))` in average and `O(n²)` in the worst case, with n the length of the + * array. Using it in view functions that are executed through `eth_call` is safe, but one should be very careful + * when executing this as part of a transaction. If the array being sorted is too large, the sort operation may + * consume more gas than is available in a block, leading to potential DoS. + * + * IMPORTANT: Consider memory side-effects when using custom comparator functions that access memory in an + * unsafe way. + */ + function sort( + bytes32[] memory array, + function(bytes32, bytes32) pure returns (bool) comp + ) internal pure returns (bytes32[] memory) { + sort(_castToUint256Array(array), _castToUint256Comp(comp)); + return array; + } + + /** + * @dev Variant of {sort} that sorts an array of bytes32 in increasing order. + */ + function sort(bytes32[] memory array) internal pure returns (bytes32[] memory) { + sort(_castToUint256Array(array), Comparators.lt); + return array; + } + + /** + * @dev Performs a quick sort of a segment of memory. The segment sorted starts at `begin` (inclusive), and stops + * at end (exclusive). Sorting follows the `comp` comparator. + * + * Invariant: `begin <= end`. This is the case when initially called by {sort} and is preserved in sub-calls. + * + * IMPORTANT: Memory locations between `begin` and `end` are not validated/zeroed. This function should + * be used only if the limits are within a memory array. + */ + function _quickSort(uint256 begin, uint256 end, function(uint256, uint256) pure returns (bool) comp) private pure { + unchecked { + if (end - begin < 0x40) return; + + // Use first element as pivot + uint256 pivot = _mload(begin); + // Position where the pivot should be at the end of the loop + uint256 pos = begin; + + for (uint256 it = begin + 0x20; it < end; it += 0x20) { + if (comp(_mload(it), pivot)) { + // If the value stored at the iterator's position comes before the pivot, we increment the + // position of the pivot and move the value there. + pos += 0x20; + _swap(pos, it); + } + } + + _swap(begin, pos); // Swap pivot into place + _quickSort(begin, pos, comp); // Sort the left side of the pivot + _quickSort(pos + 0x20, end, comp); // Sort the right side of the pivot + } + } + + // solhint-disable no-inline-assembly + + /** + * @dev Pointer to the memory location of the first element of `array`. + */ + function _begin(uint256[] memory array) private pure returns (uint256 ptr) { + assembly ("memory-safe") { + ptr := add(array, 0x20) + } + } + + /** + * @dev Pointer to the memory location of the first memory word (32bytes) after `array`. This is the memory word + * that comes just after the last element of the array. + */ + function _end(uint256[] memory array) private pure returns (uint256 ptr) { + unchecked { + return _begin(array) + array.length * 0x20; + } + } + + /** + * @dev Load memory word (as a uint256) at location `ptr`. + */ + function _mload(uint256 ptr) private pure returns (uint256 value) { + assembly { + value := mload(ptr) + } + } + + /** + * @dev Swaps the elements memory location `ptr1` and `ptr2`. + */ + function _swap(uint256 ptr1, uint256 ptr2) private pure { + assembly { + let value1 := mload(ptr1) + let value2 := mload(ptr2) + mstore(ptr1, value2) + mstore(ptr2, value1) + } + } + + /// @dev Helper: low level cast address memory array to uint256 memory array + function _castToUint256Array(address[] memory input) private pure returns (uint256[] memory output) { + assembly { + output := input + } + } + + /// @dev Helper: low level cast bytes32 memory array to uint256 memory array + function _castToUint256Array(bytes32[] memory input) private pure returns (uint256[] memory output) { + assembly { + output := input + } + } + + /// @dev Helper: low level cast address comp function to uint256 comp function + function _castToUint256Comp( + function(address, address) pure returns (bool) input + ) private pure returns (function(uint256, uint256) pure returns (bool) output) { + assembly { + output := input + } + } + + /// @dev Helper: low level cast bytes32 comp function to uint256 comp function + function _castToUint256Comp( + function(bytes32, bytes32) pure returns (bool) input + ) private pure returns (function(uint256, uint256) pure returns (bool) output) { + assembly { + output := input + } + } +} diff --git a/pvt/common/hardhat-base-config.ts b/pvt/common/hardhat-base-config.ts index 1181d76a3..9ad9ece0c 100644 --- a/pvt/common/hardhat-base-config.ts +++ b/pvt/common/hardhat-base-config.ts @@ -62,6 +62,21 @@ const contractSettings: ContractSettings = { runs: 500, viaIR, }, + '@balancer-labs/v3-pool-gyro/contracts/GyroECLPPool.sol': { + version: '0.8.27', + runs: 9999, + viaIR, + }, + '@balancer-labs/v3-pool-gyro/contracts/lib/GyroECLPMath.sol': { + version: '0.8.27', + runs: 9999, + viaIR, + }, + '@balancer-labs/v3-pool-gyro/contracts/GyroECLPPoolFactory.sol': { + version: '0.8.27', + runs: 9999, + viaIR, + }, '@balancer-labs/v3-vault/contracts/VaultExtension.sol': { version: '0.8.26', runs: 500, diff --git a/pvt/helpers/src/math/surgeMedianMath.ts b/pvt/helpers/src/math/surgeMedianMath.ts new file mode 100644 index 000000000..2aa23e656 --- /dev/null +++ b/pvt/helpers/src/math/surgeMedianMath.ts @@ -0,0 +1,5 @@ +export function findMedian(arr: number[]): number { + const mid = Math.floor(arr.length / 2), + nums = [...arr].sort((a, b) => a - b); + return arr.length % 2 !== 0 ? nums[mid] : Math.floor((nums[mid - 1] + nums[mid]) / 2); +}