Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Base module #149

Merged
merged 12 commits into from
Nov 1, 2023
14 changes: 14 additions & 0 deletions contracts/interfaces/modules/base/IModule.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// SPDX-License-Identifier: BUSL-1.1
pragma solidity ^0.8.13;

import { IModule } from "./IModule.sol";

interface IModule {

event RequestPending(address indexed sender);
event RequestCompleted(address indexed sender);

function execute(address caller, bytes calldata selfParams, bytes[] calldata preHooksParams, bytes[] calldata postHooksParams) external;
function configure(address caller_, bytes calldata params_) external;

}
26 changes: 23 additions & 3 deletions contracts/lib/Errors.sol
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import { IPAsset } from "contracts/lib/IPAsset.sol";
/// @title Errors
/// @notice Library for all contract errors, including a set of global errors.
library Errors {

////////////////////////////////////////////////////////////////////////////
// Globals //
////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -35,6 +34,24 @@ library Errors {
/// @notice The amount specified may not be zero.
error ZeroAmount();

////////////////////////////////////////////////////////////////////////////
// BaseModule //
////////////////////////////////////////////////////////////////////////////

error BaseModule_HooksParamsLengthMismatch(uint8 hookType);
error BaseModule_ZeroIpaRegistry();
error BaseModule_ZeroModuleRegistry();

////////////////////////////////////////////////////////////////////////////
// HookRegistry //
////////////////////////////////////////////////////////////////////////////

/// @notice The hook is already registered.
error HookRegistry_RegisteringDuplicatedHook();
error HookRegistry_RegisteringZeroAddressHook();
error HookRegistry_CallerNotAdmin();
error HookRegistry_MaxHooksExceeded();

////////////////////////////////////////////////////////////////////////////
// BaseRelationshipProcessor //
////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -67,7 +84,7 @@ library Errors {
////////////////////////////////////////////////////////////////////////////
// CollectPaymentModule //
////////////////////////////////////////////////////////////////////////////
//

/// @notice The configured collect module payment amount is invalid.
error CollectPaymentModule_AmountInvalid();

Expand Down Expand Up @@ -279,7 +296,10 @@ library Errors {
////////////////////////////////////////////////////////////////////////////

/// @notice Mismatch between parity of accounts and their respective allocations.
error RoyaltyNFT_AccountsAndAllocationsMismatch(uint256 accountsLength, uint256 allocationsLength);
error RoyaltyNFT_AccountsAndAllocationsMismatch(
uint256 accountsLength,
uint256 allocationsLength
);

/// @notice Invalid summation for royalty NFT allocations.
error RoyaltyNFT_InvalidAllocationsSum(uint32 allocationsSum);
Expand Down
69 changes: 69 additions & 0 deletions contracts/modules/base/BaseModule.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// SPDX-License-Identifier: BUSL-1.1
pragma solidity ^0.8.13;

import { IModule } from "contracts/interfaces/modules/base/IModule.sol";
import { IERC721 } from "@openzeppelin/contracts/token/ERC721/IERC721.sol";
import { HookRegistry } from "./HookRegistry.sol";
import { Errors } from "contracts/lib/Errors.sol";

abstract contract BaseModule is IModule, HookRegistry {
Ramarti marked this conversation as resolved.
Show resolved Hide resolved

struct ModuleConstruction {
address ipaRegistry;
address moduleRegistry;
}

address public immutable IPA_REGISTRY;
address public immutable MODULE_REGISTRY;

constructor(ModuleConstruction memory params_) {
if (params_.ipaRegistry == address(0)) {
revert Errors.BaseModule_ZeroIpaRegistry();
}
IPA_REGISTRY = params_.ipaRegistry;
if (params_.moduleRegistry == address(0)) {
revert Errors.BaseModule_ZeroModuleRegistry();
}
MODULE_REGISTRY = params_.moduleRegistry;
}

// TODO access control on sender
function execute(
address caller_,
bytes calldata selfParams_,
bytes[] calldata preHookParams_,
bytes[] calldata postHookParams_
) external {
_verifyExecution(caller_, selfParams_);
if (!_executeHooks(preHookParams_, HookType.PreAction)) {
emit RequestPending(caller_);
return;
}
_performAction(caller_, selfParams_);
_executeHooks(postHookParams_, HookType.PostAction);
emit RequestCompleted(caller_);
}

// TODO access control on sender
function configure(address caller_, bytes calldata params_) external {
_configure(caller_, params_);
}

function _executeHooks(bytes[] calldata params_, HookRegistry.HookType hType_) virtual internal returns (bool) {
address[] memory hooks = _hooksForType(hType_);
uint256 hooksLength = hooks.length;
if (params_.length != hooksLength) {
revert Errors.BaseModule_HooksParamsLengthMismatch(uint8(hType_));
}
for (uint256 i = 0; i < hooksLength; i++) {
// TODO: hook execution and return false if a hook returns false
}
return true;
}

function _hookRegistryAdmin() virtual override internal view returns (address);
function _configure(address caller_, bytes calldata params_) virtual internal;
function _verifyExecution(address caller_, bytes calldata params_) virtual internal {}
function _performAction(address caller_, bytes calldata params_) virtual internal {}

}
122 changes: 122 additions & 0 deletions contracts/modules/base/HookRegistry.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// SPDX-License-Identifier: BUSL-1.1
pragma solidity ^0.8.13;

import { Errors } from "contracts/lib/Errors.sol";

abstract contract HookRegistry {
enum HookType {
PreAction,
PostAction
}

address[] private _preActionHooks;
address[] private _postActionHooks;

uint256 public constant INDEX_NOT_FOUND = type(uint256).max;
uint256 public constant MAX_HOOKS = 10;

event HooksRegistered(HookType indexed hType, address[] indexed hook);
event HooksCleared(HookType indexed hType);

modifier onlyHookRegistryAdmin() {
if (msg.sender != _hookRegistryAdmin())
revert Errors.HookRegistry_CallerNotAdmin();
_;
}

function registerHooks(
HookType hType_,
address[] calldata hooks_
) external onlyHookRegistryAdmin {
clearHooks(hType_);
_registerHooks(_hooksForType(hType_), hooks_);
emit HooksRegistered(hType_, hooks_);
}

function isRegistered(
HookType hType_,
address hook_
) external view returns (bool) {
return hookIndex(hType_, hook_) != INDEX_NOT_FOUND;
}

function hookAt(
HookType hType_,
uint256 index_
) external view returns (address) {
return _hooksForType(hType_)[index_];
}

function totalHooks(
HookType hType_
) external view returns (uint256) {
return _hooksForType(hType_).length;
}

function clearHooks(
HookType hType_
) public onlyHookRegistryAdmin {
if (hType_ == HookType.PreAction && _preActionHooks.length > 0) {
delete _preActionHooks;
} else if (_postActionHooks.length > 0) {
delete _postActionHooks;
}
emit HooksCleared(hType_);
}

function hookIndex(
HookType hType_,
address hook_
) public view returns (uint256) {
return _hookIndex(_hooksForType(hType_), hook_);
}

function _hookRegistryAdmin() internal view virtual returns (address);

function _hooksForType(
HookType hType_
) internal view returns (address[] storage) {
if (hType_ == HookType.PreAction) {
return _preActionHooks;
} else {
return _postActionHooks;
}
}

function _registerHooks(
address[] storage hooks_,
address[] memory newHooks_
) private {
uint256 newLength = newHooks_.length;
if (newLength > MAX_HOOKS) {
revert Errors.HookRegistry_MaxHooksExceeded();
}
unchecked {
for (uint256 i = 0; i < newLength; i++) {
if (newHooks_[i] == address(0)) {
revert Errors.HookRegistry_RegisteringZeroAddressHook();
}
if (i > 0 && newHooks_[i] == newHooks_[i - 1]) {
revert Errors.HookRegistry_RegisteringDuplicatedHook();
}
hooks_.push(newHooks_[i]);
}
}
}

function _hookIndex(
address[] storage hooks,
address hook_
) private view returns (uint256) {
uint256 length = hooks.length;
for (uint256 i = 0; i < length; ) {
if (hooks[i] == hook_) {
return i;
}
unchecked {
i++;
}
}
return INDEX_NOT_FOUND;
}
}
84 changes: 84 additions & 0 deletions test/foundry/_temp_modules/base/BaseModule.t.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// SPDX-License-Identifier: BUSDL-1.1
pragma solidity ^0.8.13;

import "forge-std/Test.sol";

import "contracts/modules/base/BaseModule.sol";
import "test/foundry/mocks/MockBaseModule.sol";
import "contracts/lib/Errors.sol";

contract BaseModuleTest is Test {
MockBaseModule module;
address admin = address(123);
address ipaRegistry = address(456);
address moduleRegistry = address(789);

event RequestPending(address indexed sender);
event RequestCompleted(address indexed sender);

function setUp() public {
vm.prank(admin);
module = new MockBaseModule(admin, BaseModule.ModuleConstruction(ipaRegistry, moduleRegistry));
}

function test_baseModule_revert_constructorIpaRegistryIsZero() public {
vm.prank(admin);
vm.expectRevert(Errors.BaseModule_ZeroIpaRegistry.selector);
new MockBaseModule(admin, BaseModule.ModuleConstruction(address(0), moduleRegistry));
}

function test_baseModule_revert_constructorModuleRegistryIsZero() public {
vm.prank(admin);
vm.expectRevert(Errors.BaseModule_ZeroModuleRegistry.selector);
new MockBaseModule(admin, BaseModule.ModuleConstruction(ipaRegistry, address(0)));
}

function test_baseModule_setup() public {
assertEq(module.IPA_REGISTRY(), ipaRegistry);
assertEq(module.MODULE_REGISTRY(), moduleRegistry);
}

function test_baseModule_passesConfigParams() public {
bytes memory params = abi.encode(uint256(123));
module.configure(address(123), params);
assertEq(module.callStackAt(0).caller, address(123));
assertEq(module.callStackAt(0).params, params);
}

function test_baseModule_correctExecutionOrderAndParams() public {
bytes memory params = abi.encode(uint256(123));
vm.expectEmit(true, true, true, true);
emit RequestCompleted(address(123));
module.execute(address(123), params, new bytes[](0), new bytes[](0));
assertEq(module.callStackAt(0).caller, address(123));
assertEq(module.callStackAt(0).params, params);
assertEq(module.callStackAt(1).caller, address(123));
assertEq(module.callStackAt(1).params, params);
}

function test_baseModule_revertPreHookWrongParamsLength() public {
bytes memory params = abi.encode(uint256(123));
vm.expectRevert(
abi.encodeWithSelector(
Errors.BaseModule_HooksParamsLengthMismatch.selector,
uint8(HookRegistry.HookType.PreAction)
)
);
module.execute(address(123), params, new bytes[](1), new bytes[](0));
}

function test_baseModule_revertPostHookWrongParamsLength() public {
bytes memory params = abi.encode(uint256(123));
vm.expectRevert(
abi.encodeWithSelector(
Errors.BaseModule_HooksParamsLengthMismatch.selector,
uint8(HookRegistry.HookType.PostAction)
)
);
module.execute(address(123), params, new bytes[](0), new bytes[](1));
}

// TODO: hook execution tests, waiting for base hook


}
Loading