Skip to content

Commit

Permalink
Introduce IPGraph Access Control Support (#208)
Browse files Browse the repository at this point in the history
* Add IPGraphACL contract
* Implement IPGraphACL to hold a control flag
  • Loading branch information
kingster-will authored Aug 26, 2024
1 parent bd2b989 commit 47d8fd1
Show file tree
Hide file tree
Showing 9 changed files with 200 additions and 7 deletions.
77 changes: 77 additions & 0 deletions contracts/access/IPGraphACL.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// SPDX-License-Identifier: BUSL-1.1
pragma solidity 0.8.23;

import { AccessManaged } from "@openzeppelin/contracts/access/manager/AccessManaged.sol";
import { Errors } from "../lib/Errors.sol";

/// @title IPGraphACL
/// @notice This contract is used to manage access to the IPGraph contract.
/// It allows the access manager to whitelist addresses that can allow or disallow access to the IPGraph contract.
/// It allows whitelisted addresses to allow or disallow access to the IPGraph contract.
/// IPGraph precompiled check if the IPGraphACL contract allows access to the IPGraph.
contract IPGraphACL is AccessManaged {
// keccak256(abi.encode(uint256(keccak256("story-protocol.IPGraphACL")) - 1)) & ~bytes32(uint256(0xff));
bytes32 private constant IP_GRAPH_ACL_SLOT = 0xaf99b37fdaacca72ee7240cb1435cc9e498aee6ef4edc19c8cc0cd787f4e6800;

/// @notice Whitelisted addresses that can allow or disallow access to the IPGraph contract.
mapping(address => bool) public whitelist;

modifier onlyWhitelisted() {
if (!whitelist[msg.sender]) {
revert Errors.IPGraphACL__NotWhitelisted(msg.sender);
}
_;
}

constructor(address accessManager) AccessManaged(accessManager) {}

/// @notice Allow access to the IPGraph contract.
function allow() external onlyWhitelisted {
bytes32 slot = IP_GRAPH_ACL_SLOT;
bool value = true;

assembly {
sstore(slot, value)
}
}

/// @notice Disallow access to the IPGraph contract.
function disallow() external onlyWhitelisted {
bytes32 slot = IP_GRAPH_ACL_SLOT;
bool value = false;

assembly {
sstore(slot, value)
}
}

/// @notice Check if access to the IPGraph contract is allowed.
function isAllowed() external view returns (bool) {
bytes32 slot = IP_GRAPH_ACL_SLOT;
bool value;

assembly {
value := sload(slot)
}

return value;
}

/// @notice Whitelist an address that can allow or disallow access to the IPGraph contract.
/// @param addr The address to whitelist.
function whitelistAddress(address addr) external restricted {
whitelist[addr] = true;
}

/// @notice Revoke whitelisted address.
/// @param addr The address to revoke.
function revokeWhitelistedAddress(address addr) external restricted {
whitelist[addr] = false;
}

/// @notice Check if an address is whitelisted.
/// @param addr The address to check.
function isWhitelisted(address addr) external view returns (bool) {
return whitelist[addr];
}
}
13 changes: 13 additions & 0 deletions contracts/lib/Errors.sol
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@ library Errors {
/// @notice Failed to add parent IPs to IP graph.
error LicenseRegistry__AddParentIpToIPGraphFailed(address childIpId, address[] parentIpIds);

/// @notice Zero address provided for IP Graph ACL.
error LicenseRegistry__ZeroIPGraphACL();

////////////////////////////////////////////////////////////////////////////
// License Token //
////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -416,6 +419,9 @@ library Errors {
/// @notice Zero address provided for Licensing Module.
error RoyaltyPolicyLAP__ZeroLicensingModule();

/// @notice Zero address provided for IP Graph ACL.
error RoyaltyPolicyLAP__ZeroIPGraphACL();

/// @notice Caller is not the Royalty Module.
error RoyaltyPolicyLAP__NotRoyaltyModule();

Expand Down Expand Up @@ -583,4 +589,11 @@ library Errors {

/// @notice Removing a contract that is not in the pausable list.
error ProtocolPauseAdmin__PausableNotFound();

////////////////////////////////////////////////////////////////////////////
// IPGraphACL //
////////////////////////////////////////////////////////////////////////////

/// @notice The address is not whitelisted.
error IPGraphACL__NotWhitelisted(address addr);
}
13 changes: 12 additions & 1 deletion contracts/modules/royalty/policies/RoyaltyPolicyLAP.sol
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import { IRoyaltyPolicyLAP } from "../../../interfaces/modules/royalty/policies/
import { ArrayUtils } from "../../../lib/ArrayUtils.sol";
import { Errors } from "../../../lib/Errors.sol";
import { ProtocolPausableUpgradeable } from "../../../pause/ProtocolPausableUpgradeable.sol";
import { IPGraphACL } from "../../../access/IPGraphACL.sol";

/// @title Liquid Absolute Percentage Royalty Policy
/// @notice Defines the logic for splitting royalties for a given ipId using a liquid absolute percentage mechanism
Expand Down Expand Up @@ -60,6 +61,9 @@ contract RoyaltyPolicyLAP is
/// @custom:oz-upgrades-unsafe-allow state-variable-immutable
address public immutable LICENSING_MODULE;

/// @custom:oz-upgrades-unsafe-allow state-variable-immutable
IPGraphACL public immutable IP_GRAPH_ACL;

/// @dev Restricts the calls to the royalty module
modifier onlyRoyaltyModule() {
if (msg.sender != ROYALTY_MODULE) revert Errors.RoyaltyPolicyLAP__NotRoyaltyModule();
Expand All @@ -69,13 +73,17 @@ contract RoyaltyPolicyLAP is
/// @notice Constructor
/// @param royaltyModule The RoyaltyModule address
/// @param licensingModule The LicensingModule address
/// @param ipGraphAcl The IPGraphACL address
/// @custom:oz-upgrades-unsafe-allow constructor
constructor(address royaltyModule, address licensingModule) {
constructor(address royaltyModule, address licensingModule, address ipGraphAcl) {
if (royaltyModule == address(0)) revert Errors.RoyaltyPolicyLAP__ZeroRoyaltyModule();
if (licensingModule == address(0)) revert Errors.RoyaltyPolicyLAP__ZeroLicensingModule();
if (ipGraphAcl == address(0)) revert Errors.RoyaltyPolicyLAP__ZeroIPGraphACL();

ROYALTY_MODULE = royaltyModule;
LICENSING_MODULE = licensingModule;
IP_GRAPH_ACL = IPGraphACL(ipGraphAcl);

_disableInitializers();
}

Expand Down Expand Up @@ -220,6 +228,8 @@ contract RoyaltyPolicyLAP is
uint32[] memory royaltiesGroupByParent = new uint32[](parentIpIds.length);
address[] memory uniqueParents = new address[](parentIpIds.length);
uint256 uniqueParentCount;

IP_GRAPH_ACL.allow();
for (uint256 i = 0; i < parentIpIds.length; i++) {
(uint256 index, bool exists) = ArrayUtils.indexOf(uniqueParents, parentIpIds[i]);
if (!exists) {
Expand All @@ -230,6 +240,7 @@ contract RoyaltyPolicyLAP is
uniqueParents[index] = parentIpIds[i];
_setRoyalty(ipId, parentIpIds[i], royaltiesGroupByParent[index]);
}
IP_GRAPH_ACL.disallow();

// calculate new royalty stack
uint32 royaltyStack = _getRoyaltyStack(ipId);
Expand Down
10 changes: 9 additions & 1 deletion contracts/registries/LicenseRegistry.sol
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import { ExpiringOps } from "../lib/ExpiringOps.sol";
import { ILicenseTemplate } from "../interfaces/modules/licensing/ILicenseTemplate.sol";
import { IPAccountStorageOps } from "../lib/IPAccountStorageOps.sol";
import { IIPAccount } from "../interfaces/IIPAccount.sol";
import { IPGraphACL } from "../access/IPGraphACL.sol";

/// @title LicenseRegistry aka LNFT
/// @notice Registry of License NFTs, which represent licenses granted by IP ID licensors to create derivative IPs.
Expand All @@ -33,6 +34,8 @@ contract LicenseRegistry is ILicenseRegistry, AccessManagedUpgradeable, UUPSUpgr
ILicensingModule public immutable LICENSING_MODULE;
/// @custom:oz-upgrades-unsafe-allow state-variable-immutable
IDisputeModule public immutable DISPUTE_MODULE;
/// @custom:oz-upgrades-unsafe-allow state-variable-immutable
IPGraphACL public immutable IP_GRAPH_ACL;

/// @dev Storage of the LicenseRegistry
/// @param defaultLicenseTemplate The default license template address
Expand Down Expand Up @@ -76,11 +79,13 @@ contract LicenseRegistry is ILicenseRegistry, AccessManagedUpgradeable, UUPSUpgr
}

/// @custom:oz-upgrades-unsafe-allow constructor
constructor(address licensingModule, address disputeModule) {
constructor(address licensingModule, address disputeModule, address ipGraphAcl) {
if (licensingModule == address(0)) revert Errors.LicenseRegistry__ZeroLicensingModule();
if (disputeModule == address(0)) revert Errors.LicenseRegistry__ZeroDisputeModule();
if (ipGraphAcl == address(0)) revert Errors.LicenseRegistry__ZeroIPGraphACL();
LICENSING_MODULE = ILicensingModule(licensingModule);
DISPUTE_MODULE = IDisputeModule(disputeModule);
IP_GRAPH_ACL = IPGraphACL(ipGraphAcl);
_disableInitializers();
}

Expand Down Expand Up @@ -242,9 +247,12 @@ contract LicenseRegistry is ILicenseRegistry, AccessManagedUpgradeable, UUPSUpgr
}
}

IP_GRAPH_ACL.allow();
(bool success, ) = IP_GRAPH_CONTRACT.call(
abi.encodeWithSignature("addParentIp(address,address[])", childIpId, parentIpIds)
);
IP_GRAPH_ACL.disallow();

if (!success) {
revert Errors.LicenseRegistry__AddParentIpToIPGraphFailed(childIpId, parentIpIds);
}
Expand Down
28 changes: 26 additions & 2 deletions script/foundry/utils/DeployHelper.sol
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ import { CoreMetadataModule } from "contracts/modules/metadata/CoreMetadataModul
import { CoreMetadataViewModule } from "contracts/modules/metadata/CoreMetadataViewModule.sol";
import { PILicenseTemplate, PILTerms } from "contracts/modules/licensing/PILicenseTemplate.sol";
import { LicenseToken } from "contracts/LicenseToken.sol";
import { PILFlavors } from "contracts/lib/PILFlavors.sol";
import { IPGraphACL } from "contracts/access/IPGraphACL.sol";

// script
import { StringUtil } from "./StringUtil.sol";
Expand Down Expand Up @@ -92,6 +94,7 @@ contract DeployHelper is Script, BroadcastManager, JsonDeploymentHandler, Storag
// Access Control
AccessManager internal protocolAccessManager; // protocol roles
AccessController internal accessController; // per IPA roles
IPGraphACL internal ipGraphACL;

// Pause
ProtocolPauseAdmin internal protocolPauser;
Expand Down Expand Up @@ -132,6 +135,7 @@ contract DeployHelper is Script, BroadcastManager, JsonDeploymentHandler, Storag
if (block.chainid == 1) erc20 = ERC20(0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48);
else if (block.chainid == 11155111) erc20 = ERC20(0x1c7D4B196Cb0C7B01d743Fbc6116a902379C7238);
else if (block.chainid == 1513) erc20 = ERC20(0xDE51BB12D5cef80ff2334fe1019089363F80b46e);
else if (block.chainid == 1337) erc20 = ERC20(0xDE51BB12D5cef80ff2334fe1019089363F80b46e);
}

/// @dev To use, run the following command (e.g. for Sepolia):
Expand Down Expand Up @@ -281,7 +285,8 @@ contract DeployHelper is Script, BroadcastManager, JsonDeploymentHandler, Storag
impl = address(
new LicenseRegistry(
_getDeployedAddress(type(LicensingModule).name),
_getDeployedAddress(type(DisputeModule).name)
_getDeployedAddress(type(DisputeModule).name),
_getDeployedAddress(type(IPGraphACL).name)
)
);
licenseRegistry = LicenseRegistry(
Expand Down Expand Up @@ -446,7 +451,11 @@ contract DeployHelper is Script, BroadcastManager, JsonDeploymentHandler, Storag
_postdeploy("ArbitrationPolicySP", address(arbitrationPolicySP));

_predeploy("RoyaltyPolicyLAP");
impl = address(new RoyaltyPolicyLAP(address(royaltyModule), address(licensingModule)));
impl = address(new RoyaltyPolicyLAP(
address(royaltyModule),
address(licensingModule),
_getDeployedAddress(type(IPGraphACL).name)
));
royaltyPolicyLAP = RoyaltyPolicyLAP(
TestProxyHelper.deployUUPSProxy(
create3Deployer,
Expand Down Expand Up @@ -546,6 +555,17 @@ contract DeployHelper is Script, BroadcastManager, JsonDeploymentHandler, Storag
);
_postdeploy("CoreMetadataViewModule", address(coreMetadataViewModule));

_predeploy("IPGraphACL");
ipGraphACL = IPGraphACL(
create3Deployer.deploy(
_getSalt(type(IPGraphACL).name),
abi.encodePacked(
type(IPGraphACL).creationCode,
abi.encode(address(protocolAccessManager))
)
)
);
_postdeploy("IPGraphACL", address(ipGraphACL));
}

function _predeploy(string memory contractKey) private view {
Expand Down Expand Up @@ -596,6 +616,10 @@ contract DeployHelper is Script, BroadcastManager, JsonDeploymentHandler, Storag

// License Template
licenseRegistry.registerLicenseTemplate(address(pilTemplate));

// IPGraphACL
ipGraphACL.whitelistAddress(address(licenseRegistry));
ipGraphACL.whitelistAddress(address(royaltyPolicyLAP));
}

function _configureRoles() private {
Expand Down
52 changes: 52 additions & 0 deletions test/foundry/access/IPGraphACL.t.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// SPDX-License-Identifier: BUSL-1.1
pragma solidity 0.8.23;

import { Errors } from "../../../contracts/lib/Errors.sol";
import { BaseTest } from "../utils/BaseTest.t.sol";

contract IPGraphACLTest is BaseTest {
function setUp() public override {
super.setUp();
}

// test allow/disallow
// test add/remove whitelist
// onlyWhitelisted modifier

function test_IPGraphACL_initialized_not_allow() public {
assertFalse(ipGraphACL.isAllowed());
}

function test_IPGraphACL_allow() public {
vm.prank(address(licenseRegistry));
ipGraphACL.allow();
assertTrue(ipGraphACL.isAllowed());
}

function test_IPGraphACL_disallow() public {
vm.prank(address(licenseRegistry));
ipGraphACL.disallow();
assertFalse(ipGraphACL.isAllowed());
}

function test_IPGraphACL_addToWhitelist() public {
vm.prank(admin);
ipGraphACL.whitelistAddress(address(0x123));
vm.prank(address(0x123));
ipGraphACL.allow();
assertTrue(ipGraphACL.isAllowed());
}

function test_IPGraphACL_revert_removeFromWhitelist() public {
vm.prank(admin);
ipGraphACL.whitelistAddress(address(0x123));
vm.prank(address(0x123));
ipGraphACL.allow();
assertTrue(ipGraphACL.isAllowed());
vm.prank(admin);
ipGraphACL.revokeWhitelistedAddress(address(0x123));
vm.prank(address(0x123));
vm.expectRevert(abi.encodeWithSelector(Errors.IPGraphACL__NotWhitelisted.selector, address(0x123)));
ipGraphACL.disallow();
}
}
6 changes: 5 additions & 1 deletion test/foundry/mocks/module/LicenseRegistryHarness.sol
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@ pragma solidity 0.8.23;
import { LicenseRegistry } from "../../../../contracts/registries/LicenseRegistry.sol";

contract LicenseRegistryHarness is LicenseRegistry {
constructor(address _erc721Registry, address _erc1155Registry) LicenseRegistry(_erc721Registry, _erc1155Registry) {}
constructor(
address _erc721Registry,
address _erc1155Registry,
address _ipGraphAcl
) LicenseRegistry(_erc721Registry, _erc1155Registry, _ipGraphAcl) {}

function setExpirationTime(address ipId, uint256 expireTime) external {
_setExpirationTime(ipId, expireTime);
Expand Down
4 changes: 3 additions & 1 deletion test/foundry/modules/royalty/RoyaltyModule.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ contract TestRoyaltyModule is BaseTest {

USDC.mint(ipAccount2, 1000 * 10 ** 6); // 1000 USDC

address impl = address(new RoyaltyPolicyLAP(address(royaltyModule), address(licensingModule)));
address impl = address(
new RoyaltyPolicyLAP(address(royaltyModule), address(licensingModule), address(ipGraphACL))
);
royaltyPolicyLAP2 = RoyaltyPolicyLAP(
TestProxyHelper.deployUUPSProxy(
impl,
Expand Down
4 changes: 3 additions & 1 deletion test/foundry/utils/BaseTest.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ contract BaseTest is Test, DeployHelper, LicensingHelper {
dealMockAssets();

ipAccountRegistry = IPAccountRegistry(ipAssetRegistry);
lrHarnessImpl = address(new LicenseRegistryHarness(address(licensingModule), address(disputeModule)));
lrHarnessImpl = address(
new LicenseRegistryHarness(address(licensingModule), address(disputeModule), address(ipGraphACL))
);
}

function dealMockAssets() public {
Expand Down

0 comments on commit 47d8fd1

Please sign in to comment.