diff --git a/contracts/hooks/TokenGatedHook.sol b/contracts/hooks/TokenGatedHook.sol new file mode 100644 index 00000000..4035c2d0 --- /dev/null +++ b/contracts/hooks/TokenGatedHook.sol @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: BUSL-1.1 +pragma solidity ^0.8.19; + +import { HookResult } from "contracts/interfaces/hooks/base/IHook.sol"; +import { SyncBaseHook } from "contracts/hooks/base/SyncBaseHook.sol"; +import { Errors } from "contracts/lib/Errors.sol"; +import { TokenGated } from "contracts/lib/hooks/TokenGated.sol"; +import { ERC165Checker } from "@openzeppelin/contracts/utils/introspection/ERC165Checker.sol"; +import { IERC721 } from "@openzeppelin/contracts/token/ERC721/IERC721.sol"; + +/// @title TokenGatedHook +/// @notice This contract is a hook that ensures the user is the owner of a specific NFT token. +/// @dev It extends SyncBaseHook and provides the implementation for validating the hook configuration and executing the hook. +contract TokenGatedHook is SyncBaseHook { + using ERC165Checker for address; + + /// @notice Constructs the TokenGatedHook contract. + /// @param accessControl_ The address of the access control contract. + constructor(address accessControl_) SyncBaseHook(accessControl_) {} + + /// @notice Validates the configuration for the hook. + /// @dev This function checks if the tokenAddress is a valid ERC721 contract. + /// @param hookConfig_ The configuration data for the hook. + function _validateConfig(bytes memory hookConfig_) internal view override { + TokenGated.Config memory config = abi.decode(hookConfig_, (TokenGated.Config)); + address tokenAddress = config.tokenAddress; + if (tokenAddress == address(0)) { + revert Errors.ZeroAddress(); + } + // Check if the configured token address is a valid ERC 721 contract + if ( + !tokenAddress.supportsInterface( + type(IERC721).interfaceId + ) + ) { + revert Errors.UnsupportedInterface("IERC721"); + } + } + + /// @notice Executes token gated check in a synchronous manner. + /// @dev This function checks if the "tokenOwner" owns a token of the specified ERC721 token contract. + /// @param hookConfig_ The configuration of the hook. + /// @param hookParams_ The parameters for the hook. + /// @return hookData always return empty string as no return data from this hook. + function _executeSyncCall( + bytes memory hookConfig_, + bytes memory hookParams_ + ) internal virtual override returns (bytes memory) { + TokenGated.Config memory config = abi.decode(hookConfig_, (TokenGated.Config)); + TokenGated.Params memory params = abi.decode(hookParams_, (TokenGated.Params)); + + if (params.tokenOwner == address(0)) { + revert Errors.ZeroAddress(); + } + // check if tokenOwner own any required token + if (IERC721(config.tokenAddress).balanceOf(params.tokenOwner) == 0) { + revert Errors.TokenGatedHook_NotTokenOwner(config.tokenAddress, params.tokenOwner); + } + + return ""; + } +} diff --git a/contracts/lib/Errors.sol b/contracts/lib/Errors.sol index 5b969a6a..0733e4bc 100644 --- a/contracts/lib/Errors.sol +++ b/contracts/lib/Errors.sol @@ -337,4 +337,7 @@ library Errors { /// @notice Invalid async request ID. error Hook_InvalidAsyncRequestId(bytes32 invalidRequestId); + + /// @notice The address is not the owner of the token. + error TokenGatedHook_NotTokenOwner(address tokenAddress, address ownerAddress); } diff --git a/contracts/lib/hooks/TokenGated.sol b/contracts/lib/hooks/TokenGated.sol new file mode 100644 index 00000000..67b02d41 --- /dev/null +++ b/contracts/lib/hooks/TokenGated.sol @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: BUSL-1.1 +pragma solidity ^0.8.19; + +/// @title TokenGated +/// @notice This library defines the Config and Params structs used in the TokenGatedHook. +/// @dev The Config struct contains the tokenAddress field, and the Params struct contains the tokenOwner field. +library TokenGated { + /// @notice Defines the required configuration information for the TokenGatedHook. + /// @dev The Config struct contains a single field: tokenAddress. + struct Config { + /// @notice The address of the ERC721 token contract. + /// @dev This address is used to check if the tokenOwner owns a token of the specified ERC721 token contract. + address tokenAddress; + } + + /// @notice Defines the required parameter information for executing the TokenGatedHook. + /// @dev The Params struct contains a single field: tokenOwner. + struct Params { + /// @notice The address of the token owner. + /// @dev This address is checked against the tokenAddress in the Config struct to ensure the owner has a token. + address tokenOwner; + } +} diff --git a/test/foundry/hooks/TestTokenGatedHook.t.sol b/test/foundry/hooks/TestTokenGatedHook.t.sol new file mode 100644 index 00000000..41b30826 --- /dev/null +++ b/test/foundry/hooks/TestTokenGatedHook.t.sol @@ -0,0 +1,229 @@ +// SPDX-License-Identifier: BUSL-1.1 +pragma solidity ^0.8.19; + +import "forge-std/Test.sol"; + +import { BaseTest } from "test/foundry/utils/BaseTest.sol"; +import { TokenGatedHook } from "contracts/hooks/TokenGatedHook.sol"; +import { HookResult } from "contracts/interfaces/hooks/base/IHook.sol"; +import { MockSyncHook } from "test/foundry/mocks/MockSyncHook.sol"; +import { Errors } from "contracts/lib/Errors.sol"; +import { AccessControl } from "contracts/lib/AccessControl.sol"; +import { Hook } from "contracts/lib/hooks/Hook.sol"; +import { MockERC721 } from "test/foundry/mocks/MockERC721.sol"; +import { MockERC721Receiver } from "test/foundry/mocks/MockERC721Receiver.sol"; +import { TokenGated } from "contracts/lib/hooks/TokenGated.sol"; + +contract TestTokenGatedHook is BaseTest { + TokenGatedHook hook; + MockERC721 tokenContract; + MockERC721Receiver tokenOwner; + + event SyncHookExecuted( + address indexed hookAddress, + HookResult indexed result, + bytes contextData, + bytes returnData + ); + + function setUp() public override { + super.setUp(); + + vm.prank(admin); + accessControl.grantRole(AccessControl.HOOK_CALLER_ROLE, address(this)); + + hook = new TokenGatedHook(address(accessControl)); + tokenContract = new MockERC721(); + tokenOwner = new MockERC721Receiver(MockERC721Receiver.onERC721Received.selector, false); + // Simulate user has ownership of the NFT + tokenContract.mint(address(tokenOwner), 1); + } + + function test_tokenGatedHook_hasOwnership() public { + // create configuration of hook + TokenGated.Config memory hookConfig = TokenGated.Config({ + tokenAddress: address(tokenContract) + }); + bytes memory encodedConfig = abi.encode(hookConfig); + // Hook validating the configuration + hook.validateConfig(encodedConfig); + + // create parameters of executing the hook + TokenGated.Params memory hookParams = TokenGated.Params({ + tokenOwner: address(tokenOwner) + }); + bytes memory encodedParams = abi.encode(hookParams); + + // Create Hook execution context which has hook's config and current parameters + bytes memory context = _getExecutionContext(encodedConfig, encodedParams); + + bytes memory expectedHookData = ""; + + HookResult result; + bytes memory hookData; + + // Execute the sync hook + (result, hookData) = hook.executeSync(context); + + // Check the result + assertEq(uint(result), uint(HookResult.Completed)); + + // Check the hook data + assertEq0(hookData, expectedHookData); + } + + function test_tokenGatedHook_hasOwnershipVerifyEvent() public { + // create configuration of hook + TokenGated.Config memory hookConfig = TokenGated.Config({ + tokenAddress: address(tokenContract) + }); + bytes memory encodedConfig = abi.encode(hookConfig); + // Hook validating the configuration + hook.validateConfig(encodedConfig); + + // create parameters of executing the hook + TokenGated.Params memory hookParams = TokenGated.Params({ + tokenOwner: address(tokenOwner) + }); + bytes memory encodedParams = abi.encode(hookParams); + + // Create Hook execution context which has hook's config and current parameters + bytes memory context = _getExecutionContext(encodedConfig, encodedParams); + + bytes memory expectedHookData = ""; + + vm.expectEmit(address(hook)); + emit SyncHookExecuted( + address(hook), + HookResult.Completed, + context, + expectedHookData + ); + // Execute the sync hook + hook.executeSync(context); + } + + function test_tokenGatedHook_revert_hasNoOwnership() public { + MockERC721Receiver nonTokenOwner = new MockERC721Receiver(MockERC721Receiver.onERC721Received.selector, false); + // create configuration of hook + TokenGated.Config memory hookConfig = TokenGated.Config({ + tokenAddress: address(tokenContract) + }); + bytes memory encodedConfig = abi.encode(hookConfig); + // Hook validating the configuration + hook.validateConfig(encodedConfig); + + // create parameters of executing the hook + TokenGated.Params memory hookParams = TokenGated.Params({ + tokenOwner: address(nonTokenOwner) + }); + bytes memory encodedParams = abi.encode(hookParams); + + // Create Hook execution context which has hook's config and current parameters + bytes memory context = _getExecutionContext(encodedConfig, encodedParams); + + // Try to execute the hook without token ownership + vm.expectRevert( + abi.encodeWithSelector( + Errors.TokenGatedHook_NotTokenOwner.selector, + address(tokenContract), + address(nonTokenOwner) + ) + ); + hook.executeSync(context); + } + + function test_tokenGatedHook_revert_ZeroTokenAddress() public { + // create configuration of hook + TokenGated.Config memory hookConfig = TokenGated.Config({ + // Invalid token address + tokenAddress: address(0) + }); + bytes memory encodedConfig = abi.encode(hookConfig); + + // create parameters of executing the hook + TokenGated.Params memory hookParams = TokenGated.Params({ + tokenOwner: address(tokenOwner) + }); + bytes memory encodedParams = abi.encode(hookParams); + + // Create Hook execution context which has hook's config and current parameters + bytes memory context = _getExecutionContext(encodedConfig, encodedParams); + + // Try to execute the hook with invalid token contract address + vm.expectRevert(Errors.ZeroAddress.selector); + hook.executeSync(context); + } + + function test_tokenGatedHook_revert_NonERC721Address() public { + // create configuration of hook + TokenGated.Config memory hookConfig = TokenGated.Config({ + // Invalid token address + tokenAddress: address(0x77777) + }); + bytes memory encodedConfig = abi.encode(hookConfig); + + // create parameters of executing the hook + TokenGated.Params memory hookParams = TokenGated.Params({ + tokenOwner: address(tokenOwner) + }); + bytes memory encodedParams = abi.encode(hookParams); + + // Create Hook execution context which has hook's config and current parameters + bytes memory context = _getExecutionContext(encodedConfig, encodedParams); + + // Try to execute the hook with invalid token contract address + vm.expectRevert( + abi.encodeWithSelector( + Errors.UnsupportedInterface.selector, + "IERC721" + ) + ); + + hook.executeSync(context); + } + + function test_syncBaseHook_revert_InvalidOwnerAddress() public { + // create configuration of hook + TokenGated.Config memory hookConfig = TokenGated.Config({ + // Invalid token address + tokenAddress: address(tokenContract) + }); + bytes memory encodedConfig = abi.encode(hookConfig); + + // create parameters of executing the hook + TokenGated.Params memory hookParams = TokenGated.Params({ + tokenOwner: address(0) + }); + bytes memory encodedParams = abi.encode(hookParams); + + // Create Hook execution context which has hook's config and current parameters + bytes memory context = _getExecutionContext(encodedConfig, encodedParams); + + // Try to execute the hook with invalid contract address + vm.expectRevert(Errors.ZeroAddress.selector); + + hook.executeSync(context); + } + + function test_tokenGatedHook_revert_InvalidConfig() public { + // create configuration of hook + TokenGated.Config memory hookConfig = TokenGated.Config({ + // Invalid token address + tokenAddress: address(0) + }); + bytes memory encodedConfig = abi.encode(hookConfig); + + vm.expectRevert(Errors.ZeroAddress.selector); + hook.validateConfig(encodedConfig); + } + + function _getExecutionContext(bytes memory hookConfig_, bytes memory hookParams_) internal pure returns (bytes memory) { + Hook.ExecutionContext memory context = Hook.ExecutionContext({ + config: hookConfig_, + params: hookParams_ + }); + return abi.encode(context); + } + +}