diff --git a/contracts/interfaces/modules/royalty/policies/IIpRoyaltyVault.sol b/contracts/interfaces/modules/royalty/policies/IIpRoyaltyVault.sol index cef1ef6d..ecfed792 100644 --- a/contracts/interfaces/modules/royalty/policies/IIpRoyaltyVault.sol +++ b/contracts/interfaces/modules/royalty/policies/IIpRoyaltyVault.sol @@ -64,7 +64,43 @@ interface IIpRoyaltyVault { /// @param ancestorIpId The ip id of the ancestor to whom the royalty tokens belong to function collectRoyaltyTokens(address ancestorIpId) external; - /// @notice Returns the list of revenue tokens in the vault - /// @return The list of revenue tokens - function getVaultTokens() external view returns (address[] memory); + /// @notice The ip id to whom this royalty vault belongs to + /// @return The ip id address + function ipId() external view returns (address); + + /// @notice The amount of unclaimed royalty tokens + function unclaimedRoyaltyTokens() external view returns (uint32); + + /// @notice The last snapshotted timestamp + function lastSnapshotTimestamp() external view returns (uint256); + + /// @notice The amount of revenue token in the ancestors vault + /// @param token The address of the revenue token + function ancestorsVaultAmount(address token) external view returns (uint256); + + /// @notice Indicates whether the ancestor has collected the royalty tokens + /// @param ancestorIpId The ancestor ipId address + function isCollectedByAncestor(address ancestorIpId) external view returns (bool); + + /// @notice Amount of revenue token in the claim vault + /// @param token The address of the revenue token + function claimVaultAmount(address token) external view returns (uint256); + + /// @notice Amount of revenue token claimable at a given snapshot + /// @param snapshotId The snapshot id + /// @param token The address of the revenue token + function claimableAtSnapshot(uint256 snapshotId, address token) external view returns (uint256); + + /// @notice Amount of unclaimed revenue tokens at the snapshot + /// @param snapshotId The snapshot id + function unclaimedAtSnapshot(uint256 snapshotId) external view returns (uint32); + + /// @notice Indicates whether the claimer has claimed the revenue tokens at a given snapshot + /// @param snapshotId The snapshot id + /// @param claimer The address of the claimer + /// @param token The address of the revenue token + function isClaimedAtSnapshot(uint256 snapshotId, address claimer, address token) external view returns (bool); + + /// @notice The list of revenue tokens in the vault + function tokens() external view returns (address[] memory); } diff --git a/contracts/modules/royalty/policies/IpRoyaltyVault.sol b/contracts/modules/royalty/policies/IpRoyaltyVault.sol index b6eaf96d..d8f57977 100644 --- a/contracts/modules/royalty/policies/IpRoyaltyVault.sol +++ b/contracts/modules/royalty/policies/IpRoyaltyVault.sol @@ -22,6 +22,35 @@ contract IpRoyaltyVault is IIpRoyaltyVault, ERC20SnapshotUpgradeable, Reentrancy using EnumerableSet for EnumerableSet.AddressSet; using SafeERC20Upgradeable for IERC20Upgradeable; + /// @dev Storage structure for the IpRoyaltyVault + /// @param ipId The ip id to whom this royalty vault belongs to + /// @param unclaimedRoyaltyTokens The amount of unclaimed royalty tokens + /// @param lastSnapshotTimestamp The last snapshotted timestamp + /// @param ancestorsVaultAmount The amount of revenue token in the ancestors vault + /// @param isCollectedByAncestor Indicates whether the ancestor has collected the royalty tokens + /// @param claimVaultAmount Amount of revenue token in the claim vault + /// @param claimableAtSnapshot Amount of revenue token claimable at a given snapshot + /// @param unclaimedAtSnapshot Amount of unclaimed revenue tokens at the snapshot + /// @param isClaimedAtSnapshot Indicates whether the claimer has claimed the revenue tokens at a given snapshot + /// @param tokens The list of revenue tokens in the vault + /// @custom:storage-location erc7201:story-protocol.IpRoyaltyVault + struct IpRoyaltyVaultStorage { + address ipId; + uint32 unclaimedRoyaltyTokens; + uint256 lastSnapshotTimestamp; + mapping(address token => uint256 amount) ancestorsVaultAmount; + mapping(address ancestorIpId => bool) isCollectedByAncestor; + mapping(address token => uint256 amount) claimVaultAmount; + mapping(uint256 snapshotId => mapping(address token => uint256 amount)) claimableAtSnapshot; + mapping(uint256 snapshotId => uint32 tokenAmount) unclaimedAtSnapshot; + mapping(uint256 snapshotId => mapping(address claimer => mapping(address token => bool))) isClaimedAtSnapshot; + EnumerableSet.AddressSet tokens; + } + + // keccak256(abi.encode(uint256(keccak256("story-protocol.IpRoyaltyVault")) - 1)) & ~bytes32(uint256(0xff)); + bytes32 private constant IpRoyaltyVaultStorageLocation = + 0xe1c3e3b0c445d504edb1b9e6fa2ca4fab60584208a4bc973fe2db2b554d1df00; + /// @notice LAP royalty policy address /// @custom:oz-upgrades-unsafe-allow state-variable-immutable IRoyaltyPolicyLAP public immutable ROYALTY_POLICY_LAP; @@ -30,37 +59,6 @@ contract IpRoyaltyVault is IIpRoyaltyVault, ERC20SnapshotUpgradeable, Reentrancy /// @custom:oz-upgrades-unsafe-allow state-variable-immutable IDisputeModule public immutable DISPUTE_MODULE; - /// @notice Ip id to whom this royalty vault belongs to - address public ipId; - - /// @notice Amount of unclaimed royalty tokens - uint32 public unclaimedRoyaltyTokens; - - /// @notice Last snapshotted timestamp - uint256 public lastSnapshotTimestamp; - - /// @notice Amount of revenue token in the ancestors vault - mapping(address token => uint256 amount) public ancestorsVaultAmount; - - /// @notice Indicates if a given ancestor address has already claimed - mapping(address ancestorIpId => bool) public isClaimedByAncestor; - - /// @notice Amount of revenue token in the claim vault - mapping(address token => uint256 amount) public claimVaultAmount; - - /// @notice Amount of tokens claimable at a given snapshot - mapping(uint256 snapshotId => mapping(address token => uint256 amount)) public claimableAtSnapshot; - - /// @notice Amount of unclaimed tokens at the snapshot - mapping(uint256 snapshotId => uint32 tokenAmount) public unclaimedAtSnapshot; - - /// @notice Indicates whether the claimer has claimed the revenue tokens at a given snapshot - mapping(uint256 snapshotId => mapping(address claimer => mapping(address token => bool))) - public isClaimedAtSnapshot; - - /// @notice Royalty tokens of the vault - EnumerableSet.AddressSet private _tokens; - /// @notice Constructor /// @param royaltyPolicyLAP The address of the royalty policy LAP /// @param disputeModule The address of the dispute module @@ -75,7 +73,6 @@ contract IpRoyaltyVault is IIpRoyaltyVault, ERC20SnapshotUpgradeable, Reentrancy _disableInitializers(); } - // TODO: adjust/review for upgradeability /// @notice Initializer for this implementation contract /// @param name The name of the royalty token /// @param symbol The symbol of the royalty token @@ -91,9 +88,11 @@ contract IpRoyaltyVault is IIpRoyaltyVault, ERC20SnapshotUpgradeable, Reentrancy ) external initializer { if (ipIdAddress == address(0)) revert Errors.IpRoyaltyVault__ZeroIpId(); - ipId = ipIdAddress; - lastSnapshotTimestamp = block.timestamp; - unclaimedRoyaltyTokens = unclaimedTokens; + IpRoyaltyVaultStorage storage $ = _getIpRoyaltyVaultStorage(); + + $.ipId = ipIdAddress; + $.lastSnapshotTimestamp = block.timestamp; + $.unclaimedRoyaltyTokens = unclaimedTokens; _mint(address(this), unclaimedTokens); _mint(ipIdAddress, supply - unclaimedTokens); @@ -108,39 +107,42 @@ contract IpRoyaltyVault is IIpRoyaltyVault, ERC20SnapshotUpgradeable, Reentrancy /// @dev Only callable by the royalty policy LAP function addIpRoyaltyVaultTokens(address token) external { if (msg.sender != address(ROYALTY_POLICY_LAP)) revert Errors.IpRoyaltyVault__NotRoyaltyPolicyLAP(); - _tokens.add(token); + IpRoyaltyVaultStorage storage $ = _getIpRoyaltyVaultStorage(); + $.tokens.add(token); } /// @notice Snapshots the claimable revenue and royalty token amounts /// @return snapshotId The snapshot id function snapshot() external returns (uint256) { - if (block.timestamp - lastSnapshotTimestamp < ROYALTY_POLICY_LAP.getSnapshotInterval()) + IpRoyaltyVaultStorage storage $ = _getIpRoyaltyVaultStorage(); + + if (block.timestamp - $.lastSnapshotTimestamp < ROYALTY_POLICY_LAP.getSnapshotInterval()) revert Errors.IpRoyaltyVault__SnapshotIntervalTooShort(); uint256 snapshotId = _snapshot(); - lastSnapshotTimestamp = block.timestamp; + $.lastSnapshotTimestamp = block.timestamp; - uint32 unclaimedTokens = unclaimedRoyaltyTokens; - unclaimedAtSnapshot[snapshotId] = unclaimedTokens; + uint32 unclaimedTokens = $.unclaimedRoyaltyTokens; + $.unclaimedAtSnapshot[snapshotId] = unclaimedTokens; - address[] memory tokens = _tokens.values(); + address[] memory tokens = $.tokens.values(); for (uint256 i = 0; i < tokens.length; i++) { uint256 tokenBalance = IERC20Upgradeable(tokens[i]).balanceOf(address(this)); if (tokenBalance == 0) { - _tokens.remove(tokens[i]); + $.tokens.remove(tokens[i]); continue; } - uint256 newRevenue = tokenBalance - claimVaultAmount[tokens[i]] - ancestorsVaultAmount[tokens[i]]; + uint256 newRevenue = tokenBalance - $.claimVaultAmount[tokens[i]] - $.ancestorsVaultAmount[tokens[i]]; if (newRevenue == 0) continue; uint256 ancestorsTokens = (newRevenue * unclaimedTokens) / totalSupply(); - ancestorsVaultAmount[tokens[i]] += ancestorsTokens; + $.ancestorsVaultAmount[tokens[i]] += ancestorsTokens; uint256 claimableTokens = newRevenue - ancestorsTokens; - claimableAtSnapshot[snapshotId][tokens[i]] = claimableTokens; - claimVaultAmount[tokens[i]] += claimableTokens; + $.claimableAtSnapshot[snapshotId][tokens[i]] = claimableTokens; + $.claimVaultAmount[tokens[i]] += claimableTokens; } emit SnapshotCompleted(snapshotId, block.timestamp, unclaimedTokens); @@ -161,12 +163,14 @@ contract IpRoyaltyVault is IIpRoyaltyVault, ERC20SnapshotUpgradeable, Reentrancy /// @param snapshotId The snapshot id /// @param tokens The list of revenue tokens to claim function claimRevenueByTokenBatch(uint256 snapshotId, address[] calldata tokens) external nonReentrant { + IpRoyaltyVaultStorage storage $ = _getIpRoyaltyVaultStorage(); + for (uint256 i = 0; i < tokens.length; i++) { uint256 claimableToken = _claimableRevenue(msg.sender, snapshotId, tokens[i]); if (claimableToken == 0) continue; - isClaimedAtSnapshot[snapshotId][msg.sender][tokens[i]] = true; - claimVaultAmount[tokens[i]] -= claimableToken; + $.isClaimedAtSnapshot[snapshotId][msg.sender][tokens[i]] = true; + $.claimVaultAmount[tokens[i]] -= claimableToken; IERC20Upgradeable(tokens[i]).safeTransfer(msg.sender, claimableToken); emit RevenueTokenClaimed(msg.sender, tokens[i], claimableToken); @@ -177,13 +181,15 @@ contract IpRoyaltyVault is IIpRoyaltyVault, ERC20SnapshotUpgradeable, Reentrancy /// @param snapshotIds The list of snapshot ids /// @param token The revenue token to claim function claimRevenueBySnapshotBatch(uint256[] memory snapshotIds, address token) external { + IpRoyaltyVaultStorage storage $ = _getIpRoyaltyVaultStorage(); + uint256 claimableToken; for (uint256 i = 0; i < snapshotIds.length; i++) { claimableToken += _claimableRevenue(msg.sender, snapshotIds[i], token); - isClaimedAtSnapshot[snapshotIds[i]][msg.sender][token] = true; + $.isClaimedAtSnapshot[snapshotIds[i]][msg.sender][token] = true; } - claimVaultAmount[token] -= claimableToken; + $.claimVaultAmount[token] -= claimableToken; IERC20Upgradeable(token).safeTransfer(msg.sender, claimableToken); emit RevenueTokenClaimed(msg.sender, token, claimableToken); @@ -192,12 +198,14 @@ contract IpRoyaltyVault is IIpRoyaltyVault, ERC20SnapshotUpgradeable, Reentrancy /// @notice Allows ancestors to claim the royalty tokens and any accrued revenue tokens /// @param ancestorIpId The ip id of the ancestor to whom the royalty tokens belong to function collectRoyaltyTokens(address ancestorIpId) external nonReentrant { + IpRoyaltyVaultStorage storage $ = _getIpRoyaltyVaultStorage(); + (, , , address[] memory ancestors, uint32[] memory ancestorsRoyalties) = ROYALTY_POLICY_LAP.getRoyaltyData( - ipId + $.ipId ); - if (DISPUTE_MODULE.isIpTagged(ipId)) revert Errors.IpRoyaltyVault__IpTagged(); - if (isClaimedByAncestor[ancestorIpId]) revert Errors.IpRoyaltyVault__AlreadyClaimed(); + if (DISPUTE_MODULE.isIpTagged($.ipId)) revert Errors.IpRoyaltyVault__IpTagged(); + if ($.isCollectedByAncestor[ancestorIpId]) revert Errors.IpRoyaltyVault__AlreadyClaimed(); // check if the address being claimed to is an ancestor (uint32 index, bool isIn) = ArrayUtils.indexOf(ancestors, ancestorIpId); @@ -209,49 +217,115 @@ contract IpRoyaltyVault is IIpRoyaltyVault, ERC20SnapshotUpgradeable, Reentrancy // collect accrued revenue tokens (if any) _collectAccruedTokens(ancestorsRoyalties[index], ancestorIpId); - isClaimedByAncestor[ancestorIpId] = true; - unclaimedRoyaltyTokens -= ancestorsRoyalties[index]; + $.isCollectedByAncestor[ancestorIpId] = true; + $.unclaimedRoyaltyTokens -= ancestorsRoyalties[index]; emit RoyaltyTokensCollected(ancestorIpId, ancestorsRoyalties[index]); } - /// @notice Returns the list of revenue tokens in the vault - /// @return The list of revenue tokens - function getVaultTokens() external view returns (address[] memory) { - return _tokens.values(); - } - /// @notice A function to calculate the amount of revenue token claimable by a token holder at certain snapshot /// @param account The address of the token holder /// @param snapshotId The snapshot id /// @param token The revenue token to claim /// @return The amount of revenue token claimable function _claimableRevenue(address account, uint256 snapshotId, address token) internal view returns (uint256) { + IpRoyaltyVaultStorage storage $ = _getIpRoyaltyVaultStorage(); + // if the ip is tagged, then the unclaimed royalties are lost - if (DISPUTE_MODULE.isIpTagged(ipId)) return 0; + if (DISPUTE_MODULE.isIpTagged($.ipId)) return 0; uint256 balance = balanceOfAt(account, snapshotId); - uint256 totalSupply = totalSupplyAt(snapshotId) - unclaimedAtSnapshot[snapshotId]; - uint256 claimableToken = claimableAtSnapshot[snapshotId][token]; - return isClaimedAtSnapshot[snapshotId][account][token] ? 0 : (balance * claimableToken) / totalSupply; + uint256 totalSupply = totalSupplyAt(snapshotId) - $.unclaimedAtSnapshot[snapshotId]; + uint256 claimableToken = $.claimableAtSnapshot[snapshotId][token]; + return $.isClaimedAtSnapshot[snapshotId][account][token] ? 0 : (balance * claimableToken) / totalSupply; } /// @dev Collect the accrued tokens (if any) /// @param royaltyTokensToClaim The amount of royalty tokens being claimed by the ancestor /// @param ancestorIpId The ip id of the ancestor to whom the royalty tokens belong to function _collectAccruedTokens(uint256 royaltyTokensToClaim, address ancestorIpId) internal { - address[] memory tokens = _tokens.values(); + IpRoyaltyVaultStorage storage $ = _getIpRoyaltyVaultStorage(); + + address[] memory tokens = $.tokens.values(); for (uint256 i = 0; i < tokens.length; ++i) { // the only case in which unclaimedRoyaltyTokens can be 0 is when the vault is empty and everyone claimed // in which case the call will revert upstream with IpRoyaltyVault__AlreadyClaimed error - uint256 collectAmount = (ancestorsVaultAmount[tokens[i]] * royaltyTokensToClaim) / unclaimedRoyaltyTokens; + uint256 collectAmount = ($.ancestorsVaultAmount[tokens[i]] * royaltyTokensToClaim) / + $.unclaimedRoyaltyTokens; if (collectAmount == 0) continue; - ancestorsVaultAmount[tokens[i]] -= collectAmount; + $.ancestorsVaultAmount[tokens[i]] -= collectAmount; IERC20Upgradeable(tokens[i]).safeTransfer(ancestorIpId, collectAmount); emit RevenueTokenClaimed(ancestorIpId, tokens[i], collectAmount); } } + + /// @notice The ip id to whom this royalty vault belongs to + /// @return The ip id address + function ipId() external view returns (address) { + return _getIpRoyaltyVaultStorage().ipId; + } + + /// @notice The amount of unclaimed royalty tokens + function unclaimedRoyaltyTokens() external view returns (uint32) { + return _getIpRoyaltyVaultStorage().unclaimedRoyaltyTokens; + } + + /// @notice The last snapshotted timestamp + function lastSnapshotTimestamp() external view returns (uint256) { + return _getIpRoyaltyVaultStorage().lastSnapshotTimestamp; + } + + /// @notice The amount of revenue token in the ancestors vault + /// @param token The address of the revenue token + function ancestorsVaultAmount(address token) external view returns (uint256) { + return _getIpRoyaltyVaultStorage().ancestorsVaultAmount[token]; + } + + /// @notice Indicates whether the ancestor has collected the royalty tokens + /// @param ancestorIpId The ancestor ipId address + function isCollectedByAncestor(address ancestorIpId) external view returns (bool) { + return _getIpRoyaltyVaultStorage().isCollectedByAncestor[ancestorIpId]; + } + + /// @notice Amount of revenue token in the claim vault + /// @param token The address of the revenue token + function claimVaultAmount(address token) external view returns (uint256) { + return _getIpRoyaltyVaultStorage().claimVaultAmount[token]; + } + + /// @notice Amount of revenue token claimable at a given snapshot + /// @param snapshotId The snapshot id + /// @param token The address of the revenue token + function claimableAtSnapshot(uint256 snapshotId, address token) external view returns (uint256) { + return _getIpRoyaltyVaultStorage().claimableAtSnapshot[snapshotId][token]; + } + + /// @notice Amount of unclaimed revenue tokens at the snapshot + /// @param snapshotId The snapshot id + function unclaimedAtSnapshot(uint256 snapshotId) external view returns (uint32) { + return _getIpRoyaltyVaultStorage().unclaimedAtSnapshot[snapshotId]; + } + + /// @notice Indicates whether the claimer has claimed the revenue tokens at a given snapshot + /// @param snapshotId The snapshot id + /// @param claimer The address of the claimer + /// @param token The address of the revenue token + function isClaimedAtSnapshot(uint256 snapshotId, address claimer, address token) external view returns (bool) { + return _getIpRoyaltyVaultStorage().isClaimedAtSnapshot[snapshotId][claimer][token]; + } + + /// @notice The list of revenue tokens in the vault + function tokens() external view returns (address[] memory) { + return _getIpRoyaltyVaultStorage().tokens.values(); + } + + /// @dev Returns the storage struct of the IpRoyaltyVault + function _getIpRoyaltyVaultStorage() private pure returns (IpRoyaltyVaultStorage storage $) { + assembly { + $.slot := IpRoyaltyVaultStorageLocation + } + } } diff --git a/script/foundry/utils/upgrades/ERC7201Helper.s.sol b/script/foundry/utils/upgrades/ERC7201Helper.s.sol index 18923128..170f8137 100644 --- a/script/foundry/utils/upgrades/ERC7201Helper.s.sol +++ b/script/foundry/utils/upgrades/ERC7201Helper.s.sol @@ -12,7 +12,7 @@ import { console2 } from "forge-std/console2.sol"; contract ERC7201HelperScript is Script { string constant NAMESPACE = "story-protocol"; - string constant CONTRACT_NAME = "RoyaltyPolicyLAP"; + string constant CONTRACT_NAME = "IpRoyaltyVault"; function run() external { bytes memory erc7201Key = abi.encodePacked(NAMESPACE, ".", CONTRACT_NAME); diff --git a/test/foundry/modules/royalty/IpRoyaltyVault.t.sol b/test/foundry/modules/royalty/IpRoyaltyVault.t.sol index ed0e9952..4e3b91cd 100644 --- a/test/foundry/modules/royalty/IpRoyaltyVault.t.sol +++ b/test/foundry/modules/royalty/IpRoyaltyVault.t.sol @@ -141,7 +141,7 @@ contract TestIpRoyaltyVault is BaseTest { vm.startPrank(address(royaltyPolicyLAP)); ipRoyaltyVault.addIpRoyaltyVaultTokens(address(1)); - address[] memory tokens = ipRoyaltyVault.getVaultTokens(); + address[] memory tokens = ipRoyaltyVault.tokens(); assertEq(tokens.length, 1); assertEq(tokens[0], address(1)); @@ -350,7 +350,7 @@ contract TestIpRoyaltyVault is BaseTest { ipRoyaltyVault.snapshot(); // all USDC was claimed but LINK was not - assertEq(ipRoyaltyVault.getVaultTokens().length, 1); + assertEq(ipRoyaltyVault.tokens().length, 1); } function test_IpRoyaltyVault_CollectRoyaltyTokens_AlreadyClaimed() public { @@ -397,7 +397,7 @@ contract TestIpRoyaltyVault is BaseTest { assertEq(USDC.balanceOf(address(5)) - userUsdcBalanceBefore, accruedCollectableRevenue); assertEq(contractUsdcBalanceBefore - USDC.balanceOf(address(ipRoyaltyVault)), accruedCollectableRevenue); - assertEq(ipRoyaltyVault.isClaimedByAncestor(address(5)), true); + assertEq(ipRoyaltyVault.isCollectedByAncestor(address(5)), true); assertEq( contractRTBalBefore - IERC20(address(ipRoyaltyVault)).balanceOf(address(ipRoyaltyVault)), parentRoyalty