diff --git a/contracts/lib/Errors.sol b/contracts/lib/Errors.sol index 3d2c2607..dcc8f480 100644 --- a/contracts/lib/Errors.sol +++ b/contracts/lib/Errors.sol @@ -136,6 +136,9 @@ library Errors { /// @notice The disputed IP is not allowed to be added to the group. error GroupingModule__CannotAddDisputedIpToGroup(address ipId); + /// @notice The group reward pool is not whitelisted. + error GroupingModule__GroupRewardPoolNotWhitelisted(address groupId, address groupRewardPool); + //////////////////////////////////////////////////////////////////////////// // IP Asset Registry // //////////////////////////////////////////////////////////////////////////// @@ -607,6 +610,9 @@ library Errors { /// @notice Call failed. error RoyaltyModule__CallFailed(); + /// @notice The group pool is not whitelisted. + error RoyaltyModule__GroupRewardPoolNotWhitelisted(address groupId, address rewardPool); + //////////////////////////////////////////////////////////////////////////// // Royalty Policy LAP // //////////////////////////////////////////////////////////////////////////// diff --git a/contracts/modules/grouping/GroupingModule.sol b/contracts/modules/grouping/GroupingModule.sol index c4e4b641..172e7e6b 100644 --- a/contracts/modules/grouping/GroupingModule.sol +++ b/contracts/modules/grouping/GroupingModule.sol @@ -226,6 +226,9 @@ contract GroupingModule is /// @param ipIds The IP IDs. function claimReward(address groupId, address token, address[] calldata ipIds) external nonReentrant whenNotPaused { IGroupRewardPool pool = IGroupRewardPool(GROUP_IP_ASSET_REGISTRY.getGroupRewardPool(groupId)); + if (!GROUP_IP_ASSET_REGISTRY.isWhitelistedGroupRewardPool(address(pool))) { + revert Errors.GroupingModule__GroupRewardPoolNotWhitelisted(groupId, address(pool)); + } // trigger group pool to distribute rewards to group members vault uint256[] memory rewards = pool.distributeRewards(groupId, token, ipIds); emit ClaimedReward(groupId, token, ipIds, rewards); @@ -239,6 +242,9 @@ contract GroupingModule is address token ) external nonReentrant whenNotPaused returns (uint256 royalties) { IGroupRewardPool pool = IGroupRewardPool(GROUP_IP_ASSET_REGISTRY.getGroupRewardPool(groupId)); + if (!GROUP_IP_ASSET_REGISTRY.isWhitelistedGroupRewardPool(address(pool))) { + revert Errors.GroupingModule__GroupRewardPoolNotWhitelisted(groupId, address(pool)); + } IIpRoyaltyVault vault = IIpRoyaltyVault(ROYALTY_MODULE.ipRoyaltyVaults(groupId)); if (address(vault) == address(0)) revert Errors.GroupingModule__GroupRoyaltyVaultNotCreated(groupId); diff --git a/contracts/modules/royalty/RoyaltyModule.sol b/contracts/modules/royalty/RoyaltyModule.sol index 7ae7b1c9..d47db19e 100644 --- a/contracts/modules/royalty/RoyaltyModule.sol +++ b/contracts/modules/royalty/RoyaltyModule.sol @@ -261,9 +261,13 @@ contract RoyaltyModule is IRoyaltyModule, VaultController, ReentrancyGuardUpgrad // deploy ipRoyaltyVault for the ipId given in case it does not exist yet if ($.ipRoyaltyVaults[ipId] == address(0)) { - address receiver = IP_ASSET_REGISTRY.isRegisteredGroup(ipId) - ? IP_ASSET_REGISTRY.getGroupRewardPool(ipId) - : ipId; + address receiver = ipId; + if (IP_ASSET_REGISTRY.isRegisteredGroup(ipId)) { + receiver = IP_ASSET_REGISTRY.getGroupRewardPool(ipId); + if (!IP_ASSET_REGISTRY.isWhitelistedGroupRewardPool(receiver)) { + revert Errors.RoyaltyModule__GroupRewardPoolNotWhitelisted(ipId, receiver); + } + } _deployIpRoyaltyVault(ipId, receiver); } @@ -521,6 +525,7 @@ contract RoyaltyModule is IRoyaltyModule, VaultController, ReentrancyGuardUpgrad /// @param licensesPercent The license percentage of the licenses being minted /// @param ipRoyaltyVault The address of the ipRoyaltyVault /// @param maxRts The maximum number of royalty tokens that can be distributed to the external royalty policies + // solhint-disable code-complexity function _distributeRoyaltyTokensToPolicies( address ipId, address[] calldata parentIpIds, @@ -575,9 +580,13 @@ contract RoyaltyModule is IRoyaltyModule, VaultController, ReentrancyGuardUpgrad // sends remaining royalty tokens to the ipId address or // in the case the ipId is a group then send to the group reward pool - address receiver = IP_ASSET_REGISTRY.isRegisteredGroup(ipId) - ? IP_ASSET_REGISTRY.getGroupRewardPool(ipId) - : ipId; + address receiver = ipId; + if (IP_ASSET_REGISTRY.isRegisteredGroup(ipId)) { + receiver = IP_ASSET_REGISTRY.getGroupRewardPool(ipId); + if (!IP_ASSET_REGISTRY.isWhitelistedGroupRewardPool(receiver)) { + revert Errors.RoyaltyModule__GroupRewardPoolNotWhitelisted(ipId, receiver); + } + } IERC20(ipRoyaltyVault).safeTransfer(receiver, MAX_PERCENT - totalRtsRequiredToLink); } diff --git a/test/foundry/modules/grouping/GroupingModule.t.sol b/test/foundry/modules/grouping/GroupingModule.t.sol index 29c3012a..d1152976 100644 --- a/test/foundry/modules/grouping/GroupingModule.t.sol +++ b/test/foundry/modules/grouping/GroupingModule.t.sol @@ -296,6 +296,112 @@ contract GroupingModuleTest is BaseTest, ERC721Holder { assertEq(erc20.balanceOf(royaltyModule.ipRoyaltyVaults(ipId1)), 50); } + function test_GroupingModule_claimReward_revert_notWhitelistedPool() public { + vm.warp(100); + vm.prank(alice); + address groupId = groupingModule.registerGroup(address(rewardPool)); + + uint256 termsId = pilTemplate.registerLicenseTerms( + PILFlavors.commercialRemix({ + mintingFee: 0, + commercialRevShare: 10_000_000, + currencyToken: address(erc20), + royaltyPolicy: address(royaltyPolicyLAP) + }) + ); + + Licensing.LicensingConfig memory licensingConfig = Licensing.LicensingConfig({ + isSet: true, + mintingFee: 0, + licensingHook: address(0), + hookData: "", + commercialRevShare: 10 * 10 ** 6, + disabled: false, + expectMinimumGroupRewardShare: 10 * 10 ** 6, + expectGroupRewardPool: address(evenSplitGroupPool) + }); + + vm.startPrank(ipOwner1); + licensingModule.attachLicenseTerms(ipId1, address(pilTemplate), termsId); + licensingModule.setLicensingConfig(ipId1, address(pilTemplate), termsId, licensingConfig); + vm.stopPrank(); + licensingModule.mintLicenseTokens(ipId1, address(pilTemplate), termsId, 1, address(this), "", 0); + vm.startPrank(ipOwner2); + licensingModule.attachLicenseTerms(ipId2, address(pilTemplate), termsId); + licensingModule.setLicensingConfig(ipId2, address(pilTemplate), termsId, licensingConfig); + licensingModule.mintLicenseTokens(ipId2, address(pilTemplate), termsId, 1, address(this), "", 0); + vm.stopPrank(); + + vm.startPrank(alice); + licensingModule.attachLicenseTerms(groupId, address(pilTemplate), termsId); + address[] memory ipIds = new address[](2); + ipIds[0] = ipId1; + ipIds[1] = ipId2; + groupingModule.addIp(groupId, ipIds); + assertEq(ipAssetRegistry.totalMembers(groupId), 2); + assertEq(rewardPool.getTotalIps(groupId), 2); + vm.stopPrank(); + + address[] memory parentIpIds = new address[](1); + parentIpIds[0] = groupId; + uint256[] memory licenseTermsIds = new uint256[](1); + licenseTermsIds[0] = termsId; + vm.prank(ipOwner3); + licensingModule.registerDerivative(ipId3, parentIpIds, licenseTermsIds, address(pilTemplate), "", 0, 100e6); + + erc20.mint(ipOwner3, 1000); + vm.startPrank(ipOwner3); + erc20.approve(address(royaltyModule), 1000); + royaltyModule.payRoyaltyOnBehalf(ipId3, ipOwner3, address(erc20), 1000); + vm.stopPrank(); + royaltyPolicyLAP.transferToVault(ipId3, groupId, address(erc20)); + vm.warp(vm.getBlockTimestamp() + 7 days); + + vm.prank(address(groupingModule)); + ipAssetRegistry.whitelistGroupRewardPool(address(rewardPool), false); + vm.expectRevert( + abi.encodeWithSelector( + Errors.GroupingModule__GroupRewardPoolNotWhitelisted.selector, + groupId, + address(rewardPool) + ) + ); + groupingModule.collectRoyalties(groupId, address(erc20)); + + vm.prank(address(groupingModule)); + ipAssetRegistry.whitelistGroupRewardPool(address(rewardPool), true); + + vm.expectEmit(); + emit IGroupingModule.CollectedRoyaltiesToGroupPool(groupId, address(erc20), address(rewardPool), 100); + groupingModule.collectRoyalties(groupId, address(erc20)); + + address[] memory claimIpIds = new address[](1); + claimIpIds[0] = ipId1; + + uint256[] memory claimAmounts = new uint256[](1); + claimAmounts[0] = 50; + + vm.prank(address(groupingModule)); + ipAssetRegistry.whitelistGroupRewardPool(address(rewardPool), false); + vm.expectRevert( + abi.encodeWithSelector( + Errors.GroupingModule__GroupRewardPoolNotWhitelisted.selector, + groupId, + address(rewardPool) + ) + ); + groupingModule.claimReward(groupId, address(erc20), claimIpIds); + + vm.prank(address(groupingModule)); + ipAssetRegistry.whitelistGroupRewardPool(address(rewardPool), true); + + vm.expectEmit(); + emit IGroupingModule.ClaimedReward(groupId, address(erc20), claimIpIds, claimAmounts); + groupingModule.claimReward(groupId, address(erc20), claimIpIds); + assertEq(erc20.balanceOf(address(rewardPool)), 50); + assertEq(erc20.balanceOf(royaltyModule.ipRoyaltyVaults(ipId1)), 50); + } + function test_GroupingModule_addIp_revert_addGroupToGroup() public { uint256 termsId = pilTemplate.registerLicenseTerms( PILFlavors.commercialRemix({