diff --git a/src/base/Roles/CrossChain/Bridges/LayerZero/LayerZeroTeller.sol b/src/base/Roles/CrossChain/Bridges/LayerZero/LayerZeroTeller.sol index bdd12a2a..fa302af7 100644 --- a/src/base/Roles/CrossChain/Bridges/LayerZero/LayerZeroTeller.sol +++ b/src/base/Roles/CrossChain/Bridges/LayerZero/LayerZeroTeller.sol @@ -22,10 +22,12 @@ contract LayerZeroTeller is CrossChainTellerWithGenericBridge, OAppAuth { * @dev Sender is stored in OAppAuthCore `peers` mapping. * @param allowMessagesFrom Whether to allow messages from this chain. * @param allowMessagesTo Whether to allow messages to this chain. + * @param messageGasLimit The gas limit for messages to this chain. */ struct Chain { bool allowMessagesFrom; bool allowMessagesTo; + uint128 messageGasLimit; } // ========================================= STATE ========================================= @@ -41,18 +43,23 @@ contract LayerZeroTeller is CrossChainTellerWithGenericBridge, OAppAuth { error LayerZeroTeller__MessagesNotAllowedTo(uint256 chainSelector); error LayerZeroTeller__FeeExceedsMax(uint256 chainSelector, uint256 fee, uint256 maxFee); error LayerZeroTeller__BadFeeToken(); + error LayerZeroTeller__ZeroMessageGasLimit(); //============================== EVENTS =============================== - event ChainAdded(uint256 chainSelector, bool allowMessagesFrom, bool allowMessagesTo, address targetTeller); - event ChainRemoved(uint256 chainSelector); - event ChainAllowMessagesFrom(uint256 chainSelector, address targetTeller); - event ChainAllowMessagesTo(uint256 chainSelector, address targetTeller); - event ChainStopMessagesFrom(uint256 chainSelector); - event ChainStopMessagesTo(uint256 chainSelector); + event ChainAdded(uint256 chainId, bool allowMessagesFrom, bool allowMessagesTo, address targetTeller); + event ChainRemoved(uint256 chainId); + event ChainAllowMessagesFrom(uint256 chainId, address targetTeller); + event ChainAllowMessagesTo(uint256 chainId, address targetTeller); + event ChainStopMessagesFrom(uint256 chainId); + event ChainStopMessagesTo(uint256 chainId); + event ChainSetGasLimit(uint256 chainId, uint128 messageGasLimit); //============================== IMMUTABLES =============================== + /** + * @notice The LayerZero token. + */ address internal immutable lzToken; constructor( @@ -75,12 +82,19 @@ contract LayerZeroTeller is CrossChainTellerWithGenericBridge, OAppAuth { * @param allowMessagesFrom Whether to allow messages from this chain. * @param allowMessagesTo Whether to allow messages to this chain. * @param targetTeller The address of the target teller on the other chain. + * @param messageGasLimit The gas limit for messages to this chain. */ - function addChain(uint32 chainId, bool allowMessagesFrom, bool allowMessagesTo, address targetTeller) - external - requiresAuth - { - idToChains[chainId] = Chain(allowMessagesFrom, allowMessagesTo); + function addChain( + uint32 chainId, + bool allowMessagesFrom, + bool allowMessagesTo, + address targetTeller, + uint128 messageGasLimit + ) external requiresAuth { + if (allowMessagesTo && messageGasLimit == 0) { + revert LayerZeroTeller__ZeroMessageGasLimit(); + } + idToChains[chainId] = Chain(allowMessagesFrom, allowMessagesTo, messageGasLimit); _setPeer(chainId, targetTeller.toBytes32()); emit ChainAdded(chainId, allowMessagesFrom, allowMessagesTo, targetTeller); @@ -113,9 +127,16 @@ contract LayerZeroTeller is CrossChainTellerWithGenericBridge, OAppAuth { * @notice Allow messages to a chain. * @dev Callable by OWNER_ROLE. */ - function allowMessagesToChain(uint32 chainId, address targetTeller) external requiresAuth { + function allowMessagesToChain(uint32 chainId, address targetTeller, uint128 messageGasLimit) + external + requiresAuth + { + if (messageGasLimit == 0) { + revert LayerZeroTeller__ZeroMessageGasLimit(); + } Chain storage chain = idToChains[chainId]; chain.allowMessagesTo = true; + chain.messageGasLimit = messageGasLimit; _setPeer(chainId, targetTeller.toBytes32()); emit ChainAllowMessagesTo(chainId, targetTeller); @@ -143,6 +164,20 @@ contract LayerZeroTeller is CrossChainTellerWithGenericBridge, OAppAuth { emit ChainStopMessagesTo(chainId); } + /** + * @notice Set the gas limit for messages to a chain. + * @dev Callable by OWNER_ROLE. + */ + function setChainGasLimit(uint32 chainId, uint128 messageGasLimit) external requiresAuth { + if (messageGasLimit == 0) { + revert LayerZeroTeller__ZeroMessageGasLimit(); + } + Chain storage chain = idToChains[chainId]; + chain.messageGasLimit = messageGasLimit; + + emit ChainSetGasLimit(chainId, messageGasLimit); + } + // ========================================= OAppAuthReceiver ========================================= /** @@ -166,9 +201,9 @@ contract LayerZeroTeller is CrossChainTellerWithGenericBridge, OAppAuth { // ========================================= INTERNAL BRIDGE FUNCTIONS ========================================= /** - * @notice Sends messages using CCIP router. + * @notice Sends messages using Layer Zero end point. * @dev This function does NOT revert if the `feeToken` is invalid, - * rather the CCIP bridge will revert. + * rather the Layer Zero end point will revert. * @dev This function will revert if maxFee is exceeded. * @dev This function will revert if destination chain does not allow messages. * @param message The message to send. @@ -187,8 +222,7 @@ contract LayerZeroTeller is CrossChainTellerWithGenericBridge, OAppAuth { revert LayerZeroTeller__MessagesNotAllowedTo(destinationId); } bytes memory m = abi.encode(message); - // TODO need to add in the gaslimit as a saved value in the Chain struct. - bytes memory options = OptionsBuilder.newOptions().addExecutorLzReceiveOption(1_000_000, 0); + bytes memory options = OptionsBuilder.newOptions().addExecutorLzReceiveOption(chain.messageGasLimit, 0); MessagingFee memory fee = _quote(destinationId, m, options, address(feeToken) != NATIVE); if (address(feeToken) == NATIVE) { if (fee.nativeFee > maxFee) { @@ -228,7 +262,7 @@ contract LayerZeroTeller is CrossChainTellerWithGenericBridge, OAppAuth { revert LayerZeroTeller__MessagesNotAllowedTo(destinationId); } bytes memory m = abi.encode(message); - bytes memory options = OptionsBuilder.newOptions().addExecutorLzReceiveOption(1_000_000, 0); + bytes memory options = OptionsBuilder.newOptions().addExecutorLzReceiveOption(chain.messageGasLimit, 0); MessagingFee memory messageFee = _quote(destinationId, m, options, address(feeToken) != NATIVE); fee = address(feeToken) == NATIVE ? messageFee.nativeFee : messageFee.lzTokenFee; diff --git a/src/base/Roles/CrossChain/CrossChainTellerWithGenericBridge.sol b/src/base/Roles/CrossChain/CrossChainTellerWithGenericBridge.sol index ac576db4..af7f9a26 100644 --- a/src/base/Roles/CrossChain/CrossChainTellerWithGenericBridge.sol +++ b/src/base/Roles/CrossChain/CrossChainTellerWithGenericBridge.sol @@ -40,7 +40,7 @@ abstract contract CrossChainTellerWithGenericBridge is TellerWithMultiAssetSuppo ERC20 feeToken, uint256 maxFee ) external payable requiresAuth nonReentrant returns (uint256 sharesBridged) { - if (address(depositAsset) == NATIVE) { + if (address(depositAsset) == NATIVE && address(feeToken) == NATIVE) { revert CrossChainTellerWithGenericBridge__CannotDepositWithNativeAndPayBridgeFeeInNative(); } sharesBridged = deposit(depositAsset, depositAmount, minimumMint); diff --git a/test/LayerZeroTeller.t.sol b/test/LayerZeroTeller.t.sol index 5e1d538a..f37ec25d 100644 --- a/test/LayerZeroTeller.t.sol +++ b/test/LayerZeroTeller.t.sol @@ -131,8 +131,8 @@ contract LayerZeroTellerTest is Test, MerkleTreeHelper { deal(address(boringVault), address(this), 1_000e18, true); // Setup chains on bridge. - sourceTeller.addChain(DESTINATION_ID, true, true, address(destinationTeller)); - destinationTeller.addChain(SOURCE_ID, true, true, address(sourceTeller)); + sourceTeller.addChain(DESTINATION_ID, true, true, address(destinationTeller), 1_000_000); + destinationTeller.addChain(SOURCE_ID, true, true, address(sourceTeller), 1_000_000); } function testBridgingShares(uint96 sharesToBridge) external { @@ -161,47 +161,73 @@ contract LayerZeroTellerTest is Test, MerkleTreeHelper { assertEq(previewedFee, fee, "Previewed fee should match set fee."); } - function testAdminFunctions() external { + function testAdminFunctions(uint128 msgGas) external { uint32 newSelector = 3; address targetTeller = vm.addr(1); + msgGas = uint128(bound(msgGas, 1, 1_000_000)); - sourceTeller.addChain(newSelector, true, true, targetTeller); + sourceTeller.addChain(newSelector, true, true, targetTeller, msgGas); - (bool allowMessagesFrom, bool allowMessagesTo) = sourceTeller.idToChains(newSelector); + (bool allowMessagesFrom, bool allowMessagesTo, uint128 messageGasLimit) = sourceTeller.idToChains(newSelector); assertEq(allowMessagesFrom, true, "Should allow messages from new chain."); assertEq(allowMessagesTo, true, "Should allow messages to new chain."); + assertEq(messageGasLimit, msgGas, "Should have set message gas limit."); sourceTeller.stopMessagesFromChain(newSelector); - (allowMessagesFrom, allowMessagesTo) = sourceTeller.idToChains(newSelector); + (allowMessagesFrom, allowMessagesTo, messageGasLimit) = sourceTeller.idToChains(newSelector); assertEq(allowMessagesFrom, false, "Should not allow messages from destination chain."); assertEq(allowMessagesTo, true, "Should still allow messages to destination chain."); + assertEq(messageGasLimit, msgGas, "Should have not changed message gas limit."); sourceTeller.stopMessagesToChain(newSelector); - (allowMessagesFrom, allowMessagesTo) = sourceTeller.idToChains(newSelector); + (allowMessagesFrom, allowMessagesTo, messageGasLimit) = sourceTeller.idToChains(newSelector); assertEq(allowMessagesFrom, false, "Should not allow messages from destination chain."); assertEq(allowMessagesTo, false, "Should not allow messages to destination chain."); + assertEq(messageGasLimit, msgGas, "Should have not changed message gas limit."); address newTargetTeller = vm.addr(2); - sourceTeller.allowMessagesToChain(newSelector, newTargetTeller); - (allowMessagesFrom, allowMessagesTo) = sourceTeller.idToChains(newSelector); + msgGas += 2; + sourceTeller.allowMessagesToChain(newSelector, newTargetTeller, msgGas); + (allowMessagesFrom, allowMessagesTo, messageGasLimit) = sourceTeller.idToChains(newSelector); assertEq(allowMessagesFrom, false, "Should allow messages from new chain."); assertEq(allowMessagesTo, true, "Should not allow messages to new chain."); + assertEq(messageGasLimit, msgGas, "Should have changed message gas limit."); address anotherNewTargetTeller = vm.addr(3); sourceTeller.allowMessagesFromChain(newSelector, anotherNewTargetTeller); - (allowMessagesFrom, allowMessagesTo) = sourceTeller.idToChains(newSelector); + (allowMessagesFrom, allowMessagesTo, messageGasLimit) = sourceTeller.idToChains(newSelector); assertEq(allowMessagesFrom, true, "Should allow messages from new chain."); assertEq(allowMessagesTo, true, "Should allow messages to new chain."); + assertEq(messageGasLimit, msgGas, "Should have not changed message gas limit."); sourceTeller.removeChain(newSelector); - (allowMessagesFrom, allowMessagesTo) = sourceTeller.idToChains(newSelector); + (allowMessagesFrom, allowMessagesTo, messageGasLimit) = sourceTeller.idToChains(newSelector); assertEq(allowMessagesFrom, false, "Should not allow messages from new chain."); assertEq(allowMessagesTo, false, "Should not allow messages to new chain."); + assertEq(messageGasLimit, 0, "Should have zeroed message gas limit."); + + sourceTeller.setChainGasLimit(newSelector, msgGas + 1); + (allowMessagesFrom, allowMessagesTo, messageGasLimit) = sourceTeller.idToChains(newSelector); + assertEq(allowMessagesFrom, false, "Should not allow messages from new chain."); + assertEq(allowMessagesTo, false, "Should not allow messages to new chain."); + assertEq(messageGasLimit, msgGas + 1, "Should have changed message gas limit."); } function testReverts() external { + // Adding a chain with a zero message gas limit should revert. + vm.expectRevert(bytes(abi.encodeWithSelector(LayerZeroTeller.LayerZeroTeller__ZeroMessageGasLimit.selector))); + sourceTeller.addChain(DESTINATION_ID, true, true, address(destinationTeller), 0); + + // Allowing messages to a chain with a zero message gas limit should revert. + vm.expectRevert(bytes(abi.encodeWithSelector(LayerZeroTeller.LayerZeroTeller__ZeroMessageGasLimit.selector))); + sourceTeller.allowMessagesToChain(DESTINATION_ID, address(destinationTeller), 0); + + // Changing the gas limit to zero should revert. + vm.expectRevert(bytes(abi.encodeWithSelector(LayerZeroTeller.LayerZeroTeller__ZeroMessageGasLimit.selector))); + sourceTeller.setChainGasLimit(DESTINATION_ID, 0); + // If teller is paused bridging is not allowed. sourceTeller.pause(); vm.expectRevert( @@ -222,8 +248,8 @@ contract LayerZeroTellerTest is Test, MerkleTreeHelper { sourceTeller.bridge(1e18, address(this), abi.encode(DESTINATION_ID), NATIVE_ERC20, expectedFee); // setup chains. - sourceTeller.addChain(DESTINATION_ID, true, true, address(destinationTeller)); - destinationTeller.addChain(SOURCE_ID, true, true, address(sourceTeller)); + sourceTeller.addChain(DESTINATION_ID, true, true, address(destinationTeller), 1_000_000); + destinationTeller.addChain(SOURCE_ID, true, true, address(sourceTeller), 1_000_000); // If the max fee is exceeded the transaction should revert. uint256 newFee = 1.01e18; diff --git a/test/LayerZeroTellerNoMock.t.sol b/test/LayerZeroTellerNoMock.t.sol index 0e55bf04..736d510b 100644 --- a/test/LayerZeroTellerNoMock.t.sol +++ b/test/LayerZeroTellerNoMock.t.sol @@ -105,7 +105,7 @@ contract LayerZeroTellerNoMockTest is Test, MerkleTreeHelper { deal(address(boringVault), address(this), 1_000e18, true); // Setup chains on bridge. - sourceTeller.addChain(layerZeroArbitrumEndpointId, true, true, address(sourceTeller)); + sourceTeller.addChain(layerZeroArbitrumEndpointId, true, true, address(sourceTeller), 1_000_000); } function testBridgingShares(uint96 sharesToBridge) external { @@ -114,7 +114,6 @@ contract LayerZeroTellerNoMockTest is Test, MerkleTreeHelper { // Get fee. address to = vm.addr(1); uint256 fee = sourceTeller.previewFee(sharesToBridge, to, abi.encode(layerZeroArbitrumEndpointId), NATIVE_ERC20); - console.log("Fee: ", fee); uint256 expectedFee = 1e18; sourceTeller.bridge{value: fee}( sharesToBridge, to, abi.encode(layerZeroArbitrumEndpointId), NATIVE_ERC20, expectedFee