From 909425ea376ee23519936731b303c5b214815294 Mon Sep 17 00:00:00 2001 From: Oleksandr Date: Wed, 16 Oct 2024 00:32:46 +0300 Subject: [PATCH 1/9] base sc refactoring --- smart-contracts/.eslintrc.json | 22 + .../contracts/diamond/facets/Marketplace.sol | 121 ++- .../diamond/facets/ModelRegistry.sol | 82 +- .../diamond/facets/ProviderRegistry.sol | 154 ++-- .../diamond/facets/SessionRouter.sol | 627 ++++++------- .../diamond/presets/OwnableDiamondStorage.sol | 8 +- .../contracts/diamond/storages/BidStorage.sol | 78 +- .../diamond/storages/MarketplaceStorage.sol | 19 +- .../diamond/storages/ModelStorage.sol | 47 +- .../diamond/storages/ProviderStorage.sol | 38 +- .../diamond/storages/SessionStorage.sol | 148 ++-- .../diamond/storages/StatsStorage.sol | 38 +- .../interfaces/facets/IMarketplace.sol | 26 +- .../interfaces/facets/IModelRegistry.sol | 16 +- .../interfaces/facets/IProviderRegistry.sol | 32 +- .../interfaces/facets/ISessionRouter.sol | 116 +-- .../interfaces/storage/IBidStorage.sol | 18 +- .../storage/IMarketplaceStorage.sol | 2 + .../interfaces/storage/IModelStorage.sol | 16 +- .../interfaces/storage/IProviderStorage.sol | 16 +- .../interfaces/storage/ISessionStorage.sol | 46 +- .../interfaces/storage/IStatsStorage.sol | 22 +- .../contracts/mock/tokens/MorpheusToken.sol | 4 +- .../contracts/tokens/LumerinToken.sol | 4 +- smart-contracts/hardhat.config.ts | 2 +- smart-contracts/package.json | 2 +- .../test/diamond/LumerinDiamond.test.ts | 30 +- .../test/diamond/Marketplace.test.ts | 476 ---------- .../test/diamond/ModelRegistry.test.ts | 508 ----------- .../test/diamond/ProviderRegistry.test.ts | 444 ---------- .../SessionRouter/closeSession.test.ts | 512 ----------- .../diamond/SessionRouter/openSession.test.ts | 612 ------------- .../SessionRouter/readFunctions.test.ts | 341 ------- .../test/diamond/SessionRouter/stats.test.ts | 320 ------- .../diamond/SessionRouter/userOnHold.test.ts | 337 ------- .../SessionRouter/writeFunctions.test.ts | 348 -------- .../test/diamond/facets/Marketplace.test.ts | 269 ++++++ .../test/diamond/facets/ModelRegistry.test.ts | 224 +++++ .../diamond/facets/ProviderRegistry.test.ts | 202 +++++ .../test/diamond/facets/SessionRouter.test.ts | 830 ++++++++++++++++++ .../deployers/diamond/facets/marketplace.ts | 40 + .../diamond/facets/model-registry.ts | 27 + .../diamond/facets/provider-registry.ts | 27 + .../diamond/facets/session-router.ts | 51 ++ .../test/helpers/deployers/diamond/index.ts | 5 + .../deployers/diamond/lumerin-diamond.ts | 17 + .../test/helpers/deployers/index.ts | 2 + .../test/helpers/deployers/mock/index.ts | 1 + .../deployers/mock/tokens/morpheus-token.ts | 10 + smart-contracts/test/helpers/enums.ts | 5 - smart-contracts/test/helpers/pool-helper.ts | 10 +- smart-contracts/test/helpers/reverter.ts | 1 + smart-contracts/test/helpers/time.ts | 2 + smart-contracts/utils/provider-helper.ts | 23 +- smart-contracts/utils/time.ts | 10 +- smart-contracts/wagmi.config.ts | 14 +- 56 files changed, 2588 insertions(+), 4814 deletions(-) create mode 100644 smart-contracts/.eslintrc.json delete mode 100644 smart-contracts/test/diamond/Marketplace.test.ts delete mode 100644 smart-contracts/test/diamond/ModelRegistry.test.ts delete mode 100644 smart-contracts/test/diamond/ProviderRegistry.test.ts delete mode 100644 smart-contracts/test/diamond/SessionRouter/closeSession.test.ts delete mode 100644 smart-contracts/test/diamond/SessionRouter/openSession.test.ts delete mode 100644 smart-contracts/test/diamond/SessionRouter/readFunctions.test.ts delete mode 100644 smart-contracts/test/diamond/SessionRouter/stats.test.ts delete mode 100644 smart-contracts/test/diamond/SessionRouter/userOnHold.test.ts delete mode 100644 smart-contracts/test/diamond/SessionRouter/writeFunctions.test.ts create mode 100644 smart-contracts/test/diamond/facets/Marketplace.test.ts create mode 100644 smart-contracts/test/diamond/facets/ModelRegistry.test.ts create mode 100644 smart-contracts/test/diamond/facets/ProviderRegistry.test.ts create mode 100644 smart-contracts/test/diamond/facets/SessionRouter.test.ts create mode 100644 smart-contracts/test/helpers/deployers/diamond/facets/marketplace.ts create mode 100644 smart-contracts/test/helpers/deployers/diamond/facets/model-registry.ts create mode 100644 smart-contracts/test/helpers/deployers/diamond/facets/provider-registry.ts create mode 100644 smart-contracts/test/helpers/deployers/diamond/facets/session-router.ts create mode 100644 smart-contracts/test/helpers/deployers/diamond/index.ts create mode 100644 smart-contracts/test/helpers/deployers/diamond/lumerin-diamond.ts create mode 100644 smart-contracts/test/helpers/deployers/index.ts create mode 100644 smart-contracts/test/helpers/deployers/mock/index.ts create mode 100644 smart-contracts/test/helpers/deployers/mock/tokens/morpheus-token.ts delete mode 100644 smart-contracts/test/helpers/enums.ts create mode 100644 smart-contracts/test/helpers/time.ts diff --git a/smart-contracts/.eslintrc.json b/smart-contracts/.eslintrc.json new file mode 100644 index 00000000..402a5fba --- /dev/null +++ b/smart-contracts/.eslintrc.json @@ -0,0 +1,22 @@ +{ + "parser": "@typescript-eslint/parser", + "parserOptions": { + "ecmaVersion": 12, + "sourceType": "module", + "project": "./tsconfig.json" + }, + "plugins": ["@typescript-eslint"], + "extends": [ + "eslint:recommended", + "plugin:@typescript-eslint/recommended", + "plugin:prettier/recommended" + ], + "rules": { + "@typescript-eslint/no-unused-vars": "error", + "@typescript-eslint/no-floating-promises": "error" + }, + "env": { + "browser": true, + "es2021": true + } +} diff --git a/smart-contracts/contracts/diamond/facets/Marketplace.sol b/smart-contracts/contracts/diamond/facets/Marketplace.sol index 4e4a79d9..0a190a46 100644 --- a/smart-contracts/contracts/diamond/facets/Marketplace.sol +++ b/smart-contracts/contracts/diamond/facets/Marketplace.sol @@ -22,80 +22,48 @@ contract Marketplace is { using SafeERC20 for IERC20; - function __Marketplace_init( - address token_ - ) external initializer(MARKETPLACE_STORAGE_SLOT) initializer(BID_STORAGE_SLOT) { - _getBidStorage().token = IERC20(token_); + function __Marketplace_init(address token_) external initializer(MARKETPLACE_STORAGE_SLOT) { + setToken(IERC20(token_)); } - /// @notice sets a bid fee - function setBidFee(uint256 bidFee_) external onlyOwner { - _getMarketplaceStorage().bidFee = bidFee_; - emit FeeUpdated(bidFee_); - } - - /// @notice posts a new bid for a model - function postModelBid( - address provider_, - bytes32 modelId_, - uint256 pricePerSecond_ - ) external returns (bytes32 bidId) { - if (!_ownerOrProvider(provider_)) { - revert NotOwnerOrProvider(); - } - if (!isProviderActive(provider_)) { - revert ProviderNotFound(); - } - if (!isModelActive(modelId_)) { - revert ModelNotFound(); - } + function setMarketplaceBidFee(uint256 bidFee_) external onlyOwner { + setBidFee(bidFee_); - return _postModelBid(provider_, modelId_, pricePerSecond_); + emit MaretplaceFeeUpdated(bidFee_); } - /// @notice deletes a bid - function deleteModelBid(bytes32 bidId_) external { - if (!_isBidActive(bidId_)) { - revert ActiveBidNotFound(); - } - if (!_ownerOrProvider(getBid(bidId_).provider)) { - revert NotOwnerOrProvider(); - } - - _deleteBid(bidId_); - } + function postModelBid(bytes32 modelId_, uint256 pricePerSecond_) external returns (bytes32 bidId) { + address provider_ = _msgSender(); - /// @notice withdraws the fee balance - function withdraw(address recipient_, uint256 amount_) external onlyOwner { - if (amount_ > getFeeBalance()) { - revert NotEnoughBalance(); + if (!getIsProviderActive(provider_)) { + revert MarketplaceProviderNotFound(); + } + if (!getIsModelActive(modelId_)) { + revert MarketplaceModelNotFound(); } - decreaseFeeBalance(amount_); - getToken().safeTransfer(recipient_, amount_); - } - - function _incrementBidNonce(address provider_, bytes32 modelId_) private returns (uint256) { - return _incrementBidNonce(getProviderModelId(provider_, modelId_)); - } - - function _postModelBid(address provider_, bytes32 modelId_, uint256 pricePerSecond_) private returns (bytes32) { uint256 fee_ = getBidFee(); getToken().safeTransferFrom(_msgSender(), address(this), fee_); - increaseFeeBalance(fee_); - // TEST IT if it increments nonce correctly - uint256 nonce_ = _incrementBidNonce(provider_, modelId_); - if (nonce_ != 0) { - bytes32 oldBidId_ = getBidId(provider_, modelId_, nonce_ - 1); - if (_isBidActive(oldBidId_)) { + setFeeBalance(getFeeBalance() + fee_); + + bytes32 providerModelId_ = getProviderModelId(provider_, modelId_); + uint256 providerModelNonce_ = incrementBidNonce(providerModelId_); + bytes32 bidId_ = getBidId(provider_, modelId_, providerModelNonce_); + + if (providerModelNonce_ != 0) { + bytes32 oldBidId_ = getBidId(provider_, modelId_, providerModelNonce_ - 1); + if (isBidActive(oldBidId_)) { _deleteBid(oldBidId_); } } - bytes32 bidId_ = getBidId(provider_, modelId_, nonce_); - - setBid(bidId_, Bid(provider_, modelId_, pricePerSecond_, nonce_, uint128(block.timestamp), 0)); + Bid storage bid = bids(bidId_); + bid.provider = provider_; + bid.modelId = modelId_; + bid.pricePerSecond = pricePerSecond_; + bid.nonce = providerModelNonce_; + bid.createdAt = uint128(block.timestamp); addProviderBid(provider_, bidId_); addModelBid(modelId_, bidId_); @@ -103,20 +71,39 @@ contract Marketplace is addProviderActiveBids(provider_, bidId_); addModelActiveBids(modelId_, bidId_); - emit BidPosted(provider_, modelId_, nonce_); + emit MarketplaceBidPosted(provider_, modelId_, providerModelNonce_); return bidId_; } + function deleteModelBid(bytes32 bidId_) external { + _onlyAccount(bids(bidId_).provider); + + if (!isBidActive(bidId_)) { + revert MarketplaceActiveBidNotFound(); + } + + _deleteBid(bidId_); + } + + function withdraw(address recipient_, uint256 amount_) external onlyOwner { + uint256 feeBalance_ = getFeeBalance(); + amount_ = amount_ > feeBalance_ ? feeBalance_ : amount_; + + setFeeBalance(getFeeBalance() - amount_); + + getToken().safeTransfer(recipient_, amount_); + } + function _deleteBid(bytes32 bidId_) private { - Bid storage bid = getBid(bidId_); + Bid storage bid = bids(bidId_); bid.deletedAt = uint128(block.timestamp); removeProviderActiveBids(bid.provider, bidId_); removeModelActiveBids(bid.modelId, bidId_); - emit BidDeleted(bid.provider, bid.modelId, bid.nonce); + emit MarketplaceBidDeleted(bid.provider, bid.modelId, bid.nonce); } function getBidId(address provider_, bytes32 modelId_, uint256 nonce_) public pure returns (bytes32) { @@ -126,14 +113,4 @@ contract Marketplace is function getProviderModelId(address provider_, bytes32 modelId_) public pure returns (bytes32) { return keccak256(abi.encodePacked(provider_, modelId_)); } - - function _ownerOrProvider(address provider_) private view returns (bool) { - return _msgSender() == owner() || _msgSender() == provider_; - } - - function _isBidActive(bytes32 bidId_) private view returns (bool) { - Bid memory bid_ = getBid(bidId_); - - return bid_.createdAt != 0 && bid_.deletedAt == 0; - } } diff --git a/smart-contracts/contracts/diamond/facets/ModelRegistry.sol b/smart-contracts/contracts/diamond/facets/ModelRegistry.sol index ec2037c2..ae0c8be5 100644 --- a/smart-contracts/contracts/diamond/facets/ModelRegistry.sol +++ b/smart-contracts/contracts/diamond/facets/ModelRegistry.sol @@ -15,87 +15,69 @@ contract ModelRegistry is IModelRegistry, OwnableDiamondStorage, ModelStorage, B function __ModelRegistry_init() external initializer(MODEL_STORAGE_SLOT) {} - function setModelMinimumStake(uint256 modelMinimumStake_) external onlyOwner { - _setModelMinimumStake(modelMinimumStake_); - emit ModelMinimumStakeSet(modelMinimumStake_); + function modelSetMinStake(uint256 modelMinimumStake_) external onlyOwner { + setModelMinimumStake(modelMinimumStake_); + + emit ModelMinStakeUpdated(modelMinimumStake_); } - /// @notice Registers or updates existing model function modelRegister( - // TODO: it is not secure (frontrunning) to take the modelId as key bytes32 modelId_, bytes32 ipfsCID_, uint256 fee_, - uint256 addStake_, - address owner_, + uint256 amount_, string calldata name_, - string[] calldata tags_ + string[] memory tags_ ) external { - if (!_isOwnerOrModelOwner(owner_)) { - revert NotOwnerOrModelOwner(); - } + Model storage model = models(modelId_); - Model memory model_ = models(modelId_); - // TODO: there is no way to decrease the stake - uint256 newStake_ = model_.stake + addStake_; - if (newStake_ < modelMinimumStake()) { - revert StakeTooLow(); + uint256 newStake_ = model.stake + amount_; + uint256 minStake_ = getModelMinimumStake(); + if (newStake_ < minStake_) { + revert ModelStakeTooLow(newStake_, minStake_); } - if (addStake_ > 0) { - getToken().safeTransferFrom(_msgSender(), address(this), addStake_); + if (amount_ > 0) { + getToken().safeTransferFrom(_msgSender(), address(this), amount_); } - uint128 createdAt_ = model_.createdAt; - if (createdAt_ == 0) { - // model never existed - addModel(modelId_); - setModelActive(modelId_, true); - createdAt_ = uint128(block.timestamp); + if (model.createdAt == 0) { + addModelId(modelId_); + + model.createdAt = uint128(block.timestamp); + model.owner = _msgSender(); } else { - if (!_isOwnerOrModelOwner(model_.owner)) { - revert NotOwnerOrModelOwner(); - } - if (model_.isDeleted) { - setModelActive(modelId_, true); - } + _onlyAccount(model.owner); } - setModel(modelId_, Model(ipfsCID_, fee_, newStake_, owner_, name_, tags_, createdAt_, false)); + model.stake = newStake_; + model.ipfsCID = ipfsCID_; + model.fee = fee_; // TODO: validate fee and get usage places + model.name = name_; + model.tags = tags_; + model.isDeleted = false; - emit ModelRegisteredUpdated(owner_, modelId_); + emit ModelRegisteredUpdated(_msgSender(), modelId_); } function modelDeregister(bytes32 modelId_) external { Model storage model = models(modelId_); - if (!isModelExists(modelId_)) { - revert ModelNotFound(); - } - if (!_isOwnerOrModelOwner(model.owner)) { - revert NotOwnerOrModelOwner(); - } + _onlyAccount(model.owner); if (!isModelActiveBidsEmpty(modelId_)) { revert ModelHasActiveBids(); } + if (model.isDeleted) { + revert ModelHasAlreadyDeregistered(); + } - uint256 stake_ = model.stake; + uint256 withdrawAmount_ = model.stake; model.stake = 0; model.isDeleted = true; - setModelActive(modelId_, false); - - getToken().safeTransfer(model.owner, stake_); + getToken().safeTransfer(model.owner, withdrawAmount_); emit ModelDeregistered(model.owner, modelId_); } - - function isModelExists(bytes32 modelId_) public view returns (bool) { - return models(modelId_).createdAt != 0; - } - - function _isOwnerOrModelOwner(address modelOwner_) internal view returns (bool) { - return _msgSender() == owner() || _msgSender() == modelOwner_; - } } diff --git a/smart-contracts/contracts/diamond/facets/ProviderRegistry.sol b/smart-contracts/contracts/diamond/facets/ProviderRegistry.sol index 8ff515a7..996a0afd 100644 --- a/smart-contracts/contracts/diamond/facets/ProviderRegistry.sol +++ b/smart-contracts/contracts/diamond/facets/ProviderRegistry.sol @@ -15,120 +15,102 @@ contract ProviderRegistry is IProviderRegistry, OwnableDiamondStorage, ProviderS function __ProviderRegistry_init() external initializer(PROVIDER_STORAGE_SLOT) {} - /// @notice Sets the minimum stake required for a provider function providerSetMinStake(uint256 providerMinimumStake_) external onlyOwner { setProviderMinimumStake(providerMinimumStake_); + emit ProviderMinStakeUpdated(providerMinimumStake_); } - /// @notice Registers a provider - /// @param providerAddress_ provider address - /// @param amount_ amount of stake to add - /// @param endpoint_ provider endpoint (host.com:1234) - function providerRegister(address providerAddress_, uint256 amount_, string calldata endpoint_) external { - if (!_ownerOrProvider(providerAddress_)) { - // TODO: such that we cannon create a provider with the owner as another address - // Do we need this check? - revert NotOwnerOrProvider(); + function providerRegister(uint256 amount_, string calldata endpoint_) external { + if (amount_ > 0) { + getToken().safeTransferFrom(_msgSender(), address(this), amount_); } - Provider memory provider_ = providers(providerAddress_); - uint256 newStake_ = provider_.stake + amount_; - if (newStake_ < providerMinimumStake()) { - revert StakeTooLow(); - } + Provider storage provider = providers(_msgSender()); - if (amount_ > 0) { - getToken().safeTransferFrom(_msgSender(), address(this), amount_); + uint256 newStake_ = provider.stake + amount_; + uint256 minStake_ = getProviderMinimumStake(); + if (newStake_ < minStake_) { + revert ProviderStakeTooLow(newStake_, minStake_); } - // if we add stake to an existing provider the limiter period is not reset - uint128 createdAt_ = provider_.createdAt; - uint128 periodEnd_ = provider_.limitPeriodEnd; - if (createdAt_ == 0) { - setProviderActive(providerAddress_, true); - createdAt_ = uint128(block.timestamp); - periodEnd_ = createdAt_ + PROVIDER_REWARD_LIMITER_PERIOD; - } else if (provider_.isDeleted) { - setProviderActive(providerAddress_, true); + if (provider.createdAt == 0) { + provider.endpoint = endpoint_; + provider.createdAt = uint128(block.timestamp); + provider.limitPeriodEnd = uint128(block.timestamp) + PROVIDER_REWARD_LIMITER_PERIOD; + } else if (provider.isDeleted) { + provider.isDeleted = false; } - setProvider( - providerAddress_, - Provider(endpoint_, newStake_, createdAt_, periodEnd_, provider_.limitPeriodEarned, false) - ); + provider.endpoint = endpoint_; + provider.stake = newStake_; - emit ProviderRegisteredUpdated(providerAddress_); + emit ProviderRegistered(_msgSender()); } - /// @notice Deregisters a provider - function providerDeregister(address provider_) external { - if (!_ownerOrProvider(provider_)) { - revert NotOwnerOrProvider(); - } - if (!isProviderExists(provider_)) { + function providerDeregister() external { + Provider storage provider = providers(_msgSender()); + + if (provider.createdAt == 0) { revert ProviderNotFound(); } - if (!isProviderActiveBidsEmpty(provider_)) { + if (!isProviderActiveBidsEmpty(_msgSender())) { revert ProviderHasActiveBids(); } - - setProviderActive(provider_, false); - - Provider storage provider = providers(provider_); - uint256 withdrawable_ = _getWithdrawableStake(provider); - - provider.stake -= withdrawable_; - provider.isDeleted = true; - - if (withdrawable_ > 0) { - getToken().safeTransfer(_msgSender(), withdrawable_); + if (provider.isDeleted) { + revert ProviderHasAlreadyDeregistered(); } - emit ProviderDeregistered(provider_); - } + uint256 withdrawAmount_ = _getWithdrawAmount(provider); - /// @notice Withdraws stake from a provider after it has been deregistered - /// Allows to withdraw the stake after provider reward period has ended - function providerWithdrawStake(address provider_) external { - Provider storage provider = providers(provider_); - if (!provider.isDeleted) { - revert ErrProviderNotDeleted(); - } - if (provider.stake == 0) { - revert ErrNoStake(); - } + provider.stake -= withdrawAmount_; + provider.isDeleted = true; - uint256 withdrawable_ = _getWithdrawableStake(provider); - if (withdrawable_ == 0) { - revert ErrNoWithdrawableStake(); + if (withdrawAmount_ > 0) { + getToken().safeTransfer(_msgSender(), withdrawAmount_); } - provider.stake -= withdrawable_; - - getToken().safeTransfer(provider_, withdrawable_); - - emit ProviderWithdrawnStake(provider_, withdrawable_); - } - - function isProviderExists(address provider_) public view returns (bool) { - return providers(provider_).createdAt != 0; + emit ProviderDeregistered(_msgSender()); } - /// @notice Returns the withdrawable stake for a provider - /// @dev If the provider already earned this period then withdrawable stake - /// is limited by the amount earning that remains in the current period. - /// It is done to prevent the provider from withdrawing and then staking - /// again from a different address, which bypasses the limitation. - function _getWithdrawableStake(Provider memory provider_) private view returns (uint256) { - if (uint128(block.timestamp) > provider_.limitPeriodEnd) { - return provider_.stake; + // /** + // * + // * @notice Withdraws stake from a provider after it has been deregistered + // * Allows to withdraw the stake after provider reward period has ended + // */ + // function providerWithdrawStake() external { + // Provider storage provider = providers(_msgSender()); + + // if (!provider.isDeleted) { + // revert ProviderNotDeregistered(); + // } + // if (provider.stake == 0) { + // revert ProviderNoStake(); + // } + + // uint256 withdrawAmount_ = _getWithdrawAmount(provider); + // if (withdrawAmount_ == 0) { + // revert ProviderNothingToWithdraw(); + // } + + // provider.stake -= withdrawAmount_; + // getToken().safeTransfer(_msgSender(), withdrawAmount_); + + // emit ProviderWithdrawn(_msgSender(), withdrawAmount_); + // } + + /** + * @notice Returns the withdrawable stake for a provider + * @dev If the provider already earned this period then withdrawable stake + * is limited by the amount earning that remains in the current period. + * It is done to prevent the provider from withdrawing and then staking + * again from a different address, which bypasses the limitation. + */ + function _getWithdrawAmount(Provider storage provider) private view returns (uint256) { + if (block.timestamp > provider.limitPeriodEnd) { + return provider.stake; } - return provider_.stake - provider_.limitPeriodEarned; - } - - function _ownerOrProvider(address provider_) internal view returns (bool) { - return _msgSender() == owner() || _msgSender() == provider_; + return provider.stake - provider.limitPeriodEarned; } } diff --git a/smart-contracts/contracts/diamond/facets/SessionRouter.sol b/smart-contracts/contracts/diamond/facets/SessionRouter.sol index 9f9b1ed4..331c6beb 100644 --- a/smart-contracts/contracts/diamond/facets/SessionRouter.sol +++ b/smart-contracts/contracts/diamond/facets/SessionRouter.sol @@ -5,6 +5,8 @@ import {Math} from "@openzeppelin/contracts/utils/math/Math.sol"; import {ECDSA} from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol"; import {SafeERC20, IERC20} from "@openzeppelin/contracts/token/ERC20/utils/SafeERC20.sol"; +import {LinearDistributionIntervalDecrease} from "morpheus-smart-contracts/contracts/libs/LinearDistributionIntervalDecrease.sol"; + import {OwnableDiamondStorage} from "../presets/OwnableDiamondStorage.sol"; import {BidStorage, EnumerableSet} from "../storages/BidStorage.sol"; @@ -16,8 +18,6 @@ import {LibSD} from "../../libs/LibSD.sol"; import {ISessionRouter} from "../../interfaces/facets/ISessionRouter.sol"; -import {LinearDistributionIntervalDecrease} from "morpheus-smart-contracts/contracts/libs/LinearDistributionIntervalDecrease.sol"; - contract SessionRouter is ISessionRouter, OwnableDiamondStorage, @@ -26,376 +26,358 @@ contract SessionRouter is BidStorage, StatsStorage { - using Math for uint256; + using Math for *; using LibSD for LibSD.SD; using SafeERC20 for IERC20; - using EnumerableSet for EnumerableSet.Bytes32Set; - - uint32 public constant MIN_SESSION_DURATION = 5 minutes; - uint32 public constant MAX_SESSION_DURATION = 1 days; - uint32 public constant SIGNATURE_TTL = 10 minutes; - uint256 public constant COMPUTE_POOL_INDEX = 3; function __SessionRouter_init( address fundingAccount_, Pool[] calldata pools_ ) external initializer(SESSION_STORAGE_SLOT) { - SNStorage storage s = _getSessionStorage(); - - s.fundingAccount = fundingAccount_; + setFundingAccount(fundingAccount_); + setPools(pools_); + } - for (uint256 i = 0; i < pools_.length; i++) { - s.pools.push(pools_[i]); - } + //////////////////////////// + /// CONTRACT CONFIGS /// + //////////////////////////// + /** + * @notice Sets distibution pool configuration + * @dev parameters should be the same as in Ethereum L1 Distribution contract + * @dev at address 0x47176B2Af9885dC6C4575d4eFd63895f7Aaa4790 + * @dev call 'Distribution.pools(3)' where '3' is a poolId + */ + function setPoolConfig(uint256 index_, Pool calldata pool_) external onlyOwner { + if (index_ >= pools().length) { + revert SessionPoolIndexOutOfBounds(); + } + + setPool(index_, pool_); } + //////////////////////// + /// OPEN SESSION /// + //////////////////////// function openSession( uint256 amount_, - bytes calldata providerApproval_, + bytes calldata approvalEncoded_, bytes calldata signature_ ) external returns (bytes32) { - // should a user pass the bidId to compare with a providerApproval? - bytes32 bidId_ = _extractProviderApproval(providerApproval_); - - Bid memory bid_ = getBid(bidId_); - if (bid_.deletedAt != 0 || bid_.createdAt == 0) { - // wtf? - revert BidNotFound(); + bytes32 bidId_ = _extractProviderApproval(approvalEncoded_); + if (!isBidActive(bidId_)) { + revert SessionBidNotFound(); } - if (!_isValidReceipt(bid_.provider, providerApproval_, signature_)) { - revert ProviderSignatureMismatch(); + + Bid storage bid = bids(bidId_); + if (!_isValidProviderReceipt(bid.provider, approvalEncoded_, signature_)) { + revert SessionProviderSignatureMismatch(); } - if (isApprovalUsed(providerApproval_)) { - revert DuplicateApproval(); + if (getIsProviderApprovalUsed(approvalEncoded_)) { + revert SessionDuplicateApproval(); } - setApprovalUsed(providerApproval_); - uint256 endsAt_ = whenSessionEnds(amount_, bid_.pricePerSecond, block.timestamp); + uint128 endsAt_ = getSessionEnd(amount_, bid.pricePerSecond, uint128(block.timestamp)); + bytes32 sessionId_ = getSessionId(_msgSender(), bid.provider, bidId_, incrementSessionNonce()); + if (endsAt_ - block.timestamp < MIN_SESSION_DURATION) { revert SessionTooShort(); } - // do we need to specify the amount in id? - bytes32 sessionId_ = getSessionId(_msgSender(), bid_.provider, amount_, incrementSessionNonce()); - setSession( - sessionId_, - Session({ - id: sessionId_, - user: _msgSender(), - provider: bid_.provider, - modelId: bid_.modelId, - bidId: bidId_, - stake: amount_, - pricePerSecond: bid_.pricePerSecond, - closeoutReceipt: "", - closeoutType: 0, - providerWithdrawnAmount: 0, - openedAt: block.timestamp, - endsAt: endsAt_, - closedAt: 0 - }) - ); - - addUserSessionId(_msgSender(), sessionId_); - addProviderSessionId(bid_.provider, sessionId_); - addModelSessionId(bid_.modelId, sessionId_); + Session storage session = sessions(sessionId_); - setUserSessionActive(_msgSender(), sessionId_, true); - setProviderSessionActive(bid_.provider, sessionId_, true); + session.user = _msgSender(); + session.stake = amount_; + session.bidId = bidId_; + session.openedAt = uint128(block.timestamp); + session.endsAt = endsAt_; + session.isActive = true; - // try to use locked stake first, but limit iterations to 20 - // if user has more than 20 onHold entries, they will have to use withdrawUserStake separately - amount_ -= _removeUserStake(amount_, 10); + addUserSessionId(_msgSender(), sessionId_); + addProviderSessionId(bid.provider, sessionId_); + addModelSessionId(bid.modelId, sessionId_); + setIsProviderApprovalUsed(approvalEncoded_, true); getToken().safeTransferFrom(_msgSender(), address(this), amount_); - emit SessionOpened(_msgSender(), sessionId_, bid_.provider); + emit SessionOpened(_msgSender(), sessionId_, bid.provider); return sessionId_; } - function closeSession(bytes calldata receiptEncoded_, bytes calldata signature_) external { - (bytes32 sessionId_, uint32 tpsScaled1000_, uint32 ttftMs_) = _extractReceipt(receiptEncoded_); + function getSessionId( + address user_, + address provider_, + bytes32 bidId_, + uint256 sessionNonce_ + ) public pure returns (bytes32) { + return keccak256(abi.encodePacked(user_, provider_, bidId_, sessionNonce_)); + } + + function getSessionEnd(uint256 amount_, uint256 pricePerSecond_, uint128 openedAt_) public view returns (uint128) { + uint128 duration_ = uint128(stakeToStipend(amount_, openedAt_) / pricePerSecond_); + + if (duration_ > MAX_SESSION_DURATION) { + duration_ = MAX_SESSION_DURATION; + } + + return openedAt_ + duration_; + } + + /** + * @dev Returns stipend of user based on their stake + * (User session stake amount / MOR Supply without Compute) * (MOR Compute Supply / 100) + */ + function stakeToStipend(uint256 amount_, uint128 timestamp_) public view returns (uint256) { + uint256 totalMorSupply_ = totalMORSupply(timestamp_); + if (totalMorSupply_ == 0) { + return 0; + } + + return (amount_ * getComputeBalance(timestamp_)) / (totalMORSupply(timestamp_) * 100); + } - Session storage session = _getSession(sessionId_); - if (session.openedAt == 0) { - revert SessionNotFound(); + function _extractProviderApproval(bytes calldata providerApproval_) private view returns (bytes32) { + (bytes32 bidId_, uint256 chainId_, address user_, uint128 timestamp_) = abi.decode( + providerApproval_, + (bytes32, uint256, address, uint128) + ); + + if (user_ != _msgSender()) { + revert SessionApprovedForAnotherUser(); } - if (!_ownerOrUser(session.user)) { - revert NotOwnerOrUser(); + if (chainId_ != block.chainid) { + revert SesssionApprovedForAnotherChainId(); } + if (block.timestamp > timestamp_ + SIGNATURE_TTL) { + revert SesssionApproveExpired(); + } + + return bidId_; + } + + /////////////////////////////////// + /// CLOSE SESSION, WITHDRAW /// + /////////////////////////////////// + function closeSession(bytes calldata receiptEncoded_, bytes calldata signature_) external { + (bytes32 sessionId_, uint32 tpsScaled1000_, uint32 ttftMs_) = _extractProviderReceipt(receiptEncoded_); + + Session storage session = sessions(sessionId_); + Bid storage bid = bids(session.bidId); + + _onlyAccount(session.user); if (session.closedAt != 0) { revert SessionAlreadyClosed(); } - // update indexes - setUserSessionActive(session.user, sessionId_, false); - setProviderSessionActive(session.provider, sessionId_, false); - - // update session record - session.closeoutReceipt = receiptEncoded_; //TODO: remove that field in favor of tps and ttftMs - session.closedAt = block.timestamp; + session.isActive = false; + session.closeoutReceipt = receiptEncoded_; // TODO: Remove that field in favor of tps and ttftMs + session.closedAt = uint128(block.timestamp); - // calculate provider withdraw - uint256 providerWithdraw_; - uint256 startOfToday_ = startOfTheDay(block.timestamp); + //// PROVIDER REWARDS + uint128 startOfToday_ = startOfTheDay(uint128(block.timestamp)); + // The session should be closed the day after the end of the session to prevent provider rewards locking bool isClosingLate_ = startOfToday_ > startOfTheDay(session.endsAt); - bool noDispute_ = _isValidReceipt(session.provider, receiptEncoded_, signature_); + bool noDispute_ = _isValidProviderReceipt(bid.provider, receiptEncoded_, signature_); + uint128 duration_; if (noDispute_ || isClosingLate_) { - // session was closed without dispute or next day after it expected to end - uint256 duration_ = session.endsAt.min(block.timestamp) - session.openedAt; - uint256 cost_ = duration_ * session.pricePerSecond; - providerWithdraw_ = cost_ - session.providerWithdrawnAmount; + // Session was closed without dispute or next day after it expected to end + duration_ = uint128(session.endsAt.min(session.closedAt)) - session.openedAt; } else { - // session was closed on the same day or earlier with dispute + // Session was closed on the same day or earlier with dispute // withdraw all funds except for today's session cost - uint256 durationTillToday_ = startOfToday_ - session.openedAt.min(startOfToday_); - uint256 costTillToday_ = durationTillToday_ * session.pricePerSecond; - providerWithdraw_ = costTillToday_ - session.providerWithdrawnAmount; + duration_ = startOfToday_ - uint128(session.openedAt.min(uint256(startOfToday_))); + } + uint256 providerAmountToWithdraw_ = (duration_ * bid.pricePerSecond) - session.providerWithdrawnAmount; + _claimForProvider(session, providerAmountToWithdraw_); + //// END + + //// USER REWARDS + // We have to lock today's stake so the user won't get the reward twice + uint256 userStakeToLock_ = 0; + if (!isClosingLate_) { + // Session was closed on the same day, lock today's stake + uint256 userDuration_ = session.endsAt.min(session.closedAt) - session.openedAt.max(startOfToday_); + uint256 userInitialLock_ = userDuration_ * bid.pricePerSecond; + userStakeToLock_ = session.stake.min(stipendToStake(userInitialLock_, startOfToday_)); + + addUserStakeOnHold(session.user, OnHold(userStakeToLock_, uint128(startOfToday_ + 1 days))); } + uint256 userAmountToWithdraw_ = session.stake - userStakeToLock_; + getToken().safeTransfer(session.user, userAmountToWithdraw_); + //// END - // updating provider stats - ProviderModelStats storage prStats = _getProviderModelStats(session.modelId, session.provider); - ModelStats storage modelStats = _getModelStats(session.modelId); + //// STATS + ProviderModelStats storage prStats = providerModelStats(bid.modelId, bid.provider); + ModelStats storage modelStats = modelStats(bid.modelId); prStats.totalCount++; if (noDispute_) { if (prStats.successCount > 0) { - // stats for this provider-model pair already contribute to average model stats + // Stats for this provider-model pair already contribute to average model stats modelStats.tpsScaled1000.remove(int32(prStats.tpsScaled1000.mean), int32(modelStats.count - 1)); modelStats.ttftMs.remove(int32(prStats.ttftMs.mean), int32(modelStats.count - 1)); } else { - // stats for this provider-model pair do not contribute + // Stats for this provider-model pair do not contribute modelStats.count++; } - // update provider-model stats + // Update provider model stats prStats.successCount++; prStats.totalDuration += uint32(session.closedAt - session.openedAt); prStats.tpsScaled1000.add(int32(tpsScaled1000_), int32(prStats.successCount)); prStats.ttftMs.add(int32(ttftMs_), int32(prStats.successCount)); - // update model stats + // Update model stats modelStats.totalDuration.add(int32(prStats.totalDuration), int32(modelStats.count)); modelStats.tpsScaled1000.add(int32(prStats.tpsScaled1000.mean), int32(modelStats.count)); modelStats.ttftMs.add(int32(prStats.ttftMs.mean), int32(modelStats.count)); } else { session.closeoutType = 1; } + //// END - // we have to lock today's stake so the user won't get the reward twice - uint256 userStakeToLock_ = 0; - if (!isClosingLate_) { - // session was closed on the same day - // lock today's stake - uint256 todaysDuration_ = session.endsAt.min(block.timestamp) - session.openedAt.max(startOfToday_); - uint256 todaysCost_ = todaysDuration_ * session.pricePerSecond; - userStakeToLock_ = session.stake.min(stipendToStake(todaysCost_, startOfToday_)); - addOnHold(session.user, OnHold(userStakeToLock_, uint128(startOfToday_ + 1 days))); - } - uint256 userWithdraw_ = session.stake - userStakeToLock_; - - emit SessionClosed(session.user, sessionId_, session.provider); - - // withdraw provider - _rewardProvider(session, providerWithdraw_, false); - - getToken().safeTransfer(session.user, userWithdraw_); + emit SessionClosed(session.user, sessionId_, bid.provider); } - /// @notice allows provider to claim their funds - function claimProviderBalance(bytes32 sessionId_, uint256 amountToWithdraw_) external { - Session storage session = _getSession(sessionId_); - if (session.openedAt == 0) { - revert SessionNotFound(); - } - if (!_ownerOrProvider(session.provider)) { - revert NotOwnerOrProvider(); - } + function _extractProviderReceipt(bytes calldata receiptEncoded_) private view returns (bytes32, uint32, uint32) { + (bytes32 sessionId_, uint256 chainId_, uint128 timestamp_, uint32 tpsScaled1000_, uint32 ttftMs_) = abi.decode( + receiptEncoded_, + (bytes32, uint256, uint128, uint32, uint32) + ); - uint256 withdrawableAmount = _getProviderClaimableBalance(session); - if (amountToWithdraw_ > withdrawableAmount) { - revert NotEnoughWithdrawableBalance(); + if (chainId_ != block.chainid) { + revert SesssionReceiptForAnotherChainId(); + } + if (block.timestamp > timestamp_ + SIGNATURE_TTL) { + revert SesssionReceiptExpired(); } - _rewardProvider(session, amountToWithdraw_, true); + return (sessionId_, tpsScaled1000_, ttftMs_); } - /// @notice deletes session from the history - function deleteHistory(bytes32 sessionId_) external { - // Why do we need this function? - Session storage session = _getSession(sessionId_); - if (!_ownerOrUser(session.user)) { - revert NotOwnerOrUser(); - } - if (session.closedAt == 0) { - revert SessionNotClosed(); - } + /** + * @dev Allows providers to receive their funds after the end or closure of the session + */ + function claimForProvider(bytes32 sessionId_) external { + Session storage session = sessions(sessionId_); + Bid storage bid = bids(session.bidId); - session.user = address(0); - } + _onlyAccount(bid.provider); - /// @notice withdraws user stake - /// @param amountToWithdraw_ amount of funds to withdraw, maxUint256 means all available - /// @param iterations_ number of entries to process - function withdrawUserStake(uint256 amountToWithdraw_, uint8 iterations_) external { - // withdraw all available funds if amountToWithdraw is 0 - if (amountToWithdraw_ == 0) { - revert AmountToWithdrawIsZero(); + uint256 sessionEnd_ = session.closedAt == 0 ? session.endsAt : session.closedAt; + if (sessionEnd_ > block.timestamp) { + revert SessionNotEndedOrNotExist(); } - uint256 removed_ = _removeUserStake(amountToWithdraw_, iterations_); - if (removed_ < amountToWithdraw_) { - revert NotEnoughWithdrawableBalance(); - } + uint256 amount_ = (sessionEnd_ - session.openedAt) * bid.pricePerSecond - session.providerWithdrawnAmount; - getToken().safeTransfer(_msgSender(), amountToWithdraw_); + _claimForProvider(session, amount_); } - /// @dev removes user stake amount from onHold entries - function _removeUserStake(uint256 amountToRemove_, uint8 iterations_) private returns (uint256) { - uint256 balance_ = 0; - - OnHold[] storage onHoldEntries = getOnHold(_msgSender()); - iterations_ = iterations_ > onHoldEntries.length ? uint8(onHoldEntries.length) : iterations_; - - // the only loop that is not avoidable - for (uint256 i = 0; i < onHoldEntries.length && iterations_ > 0; i++) { - if (block.timestamp < onHoldEntries[i].releaseAt) { - continue; - } - - balance_ += onHoldEntries[i].amount; - - if (balance_ >= amountToRemove_) { - onHoldEntries[i].amount = balance_ - amountToRemove_; - return amountToRemove_; - } - - // Remove entry by swapping with last element and popping - uint256 lastIndex_ = onHoldEntries.length - 1; - if (i < lastIndex_) { - onHoldEntries[i] = onHoldEntries[lastIndex_]; - i--; // TODO: is it correct? - } - onHoldEntries.pop(); + /** + * @dev Sends provider reward considering stake as the limit for the reward + * @param session Storage session object + * @param amount_ Amount of reward to send + */ + function _claimForProvider(Session storage session, uint256 amount_) private { + Bid storage bid = bids(session.bidId); + Provider storage provider = providers(bid.provider); - iterations_--; + if (block.timestamp > provider.limitPeriodEnd) { + provider.limitPeriodEnd = uint128(block.timestamp) + PROVIDER_REWARD_LIMITER_PERIOD; + provider.limitPeriodEarned = 0; } - return balance_; - } + uint256 providerClaimLimit_ = provider.stake - provider.limitPeriodEarned; - ///////////////////////// - // STATS FUNCTIONS // - ///////////////////////// - - /// @notice sets distibution pool configuration - /// @dev parameters should be the same as in Ethereum L1 Distribution contract - /// @dev at address 0x47176B2Af9885dC6C4575d4eFd63895f7Aaa4790 - /// @dev call 'Distribution.pools(3)' where '3' is a poolId - function setPoolConfig(uint256 index, Pool calldata pool) external onlyOwner { - if (index >= getPools().length) { - revert PoolIndexOutOfBounds(); - } - _getSessionStorage().pools[index] = pool; - } + amount_ = amount_.min(providerClaimLimit_); - function _maybeResetProviderRewardLimiter(Provider storage provider) private { - if (block.timestamp > provider.limitPeriodEnd) { - provider.limitPeriodEnd += PROVIDER_REWARD_LIMITER_PERIOD; - provider.limitPeriodEarned = 0; + if (amount_ == 0) { + return; } - } - /// @notice sends provider reward considering stake as the limit for the reward - /// @param session session storage object - /// @param reward_ amount of reward to send - /// @param revertOnReachingLimit_ if true function will revert if reward is more than stake, otherwise just limit the reward - function _rewardProvider(Session storage session, uint256 reward_, bool revertOnReachingLimit_) private { - Provider storage provider = providers(session.provider); - _maybeResetProviderRewardLimiter(provider); - uint256 limit_ = provider.stake - provider.limitPeriodEarned; - - if (reward_ > limit_) { - if (revertOnReachingLimit_) { - revert WithdrawableBalanceLimitByStakeReached(); - } - reward_ = limit_; - } + session.providerWithdrawnAmount += amount_; + provider.limitPeriodEarned += amount_; + increaseProvidersTotalClaimed(amount_); - getToken().safeTransferFrom(getFundingAccount(), session.provider, reward_); + getToken().safeTransferFrom(getFundingAccount(), bid.provider, amount_); + } - session.providerWithdrawnAmount += reward_; - increaseTotalClaimed(reward_); - provider.limitPeriodEarned += reward_; + /** + * @notice Returns stake of user based on their stipend + */ + function stipendToStake(uint256 stipend_, uint128 timestamp_) public view returns (uint256) { + return (stipend_ * totalMORSupply(timestamp_) * 100) / getComputeBalance(timestamp_); } - /// @notice returns amount of withdrawable user stake and one on hold - function withdrawableUserStake( + function getUserStakesOnHold( address user_, uint8 iterations_ - ) external view returns (uint256 avail_, uint256 hold_) { - OnHold[] memory onHold = getOnHold(user_); + ) external view returns (uint256 available_, uint256 hold_) { + OnHold[] memory onHold = userStakesOnHold(user_); iterations_ = iterations_ > onHold.length ? uint8(onHold.length) : iterations_; for (uint256 i = 0; i < onHold.length; i++) { uint256 amount = onHold[i].amount; + if (block.timestamp < onHold[i].releaseAt) { hold_ += amount; } else { - avail_ += amount; + available_ += amount; } } } - function getSessionId( - address user_, - address provider_, - uint256 stake_, - uint256 sessionNonce_ - ) public pure returns (bytes32) { - return keccak256(abi.encodePacked(user_, provider_, stake_, sessionNonce_)); - } + function withdrawUserStakes(uint8 iterations_) external { + uint256 amount_ = 0; - /// @notice returns stipend of user based on their stake - function stakeToStipend(uint256 sessionStake_, uint256 timestamp_) public view returns (uint256) { - // inlined getTodaysBudget call to get a better precision - return (sessionStake_ * getComputeBalance(timestamp_)) / (totalMORSupply(timestamp_) * 100); - } + OnHold[] storage onHoldEntries = userStakesOnHold(_msgSender()); + uint8 i = iterations_ >= onHoldEntries.length ? uint8(onHoldEntries.length) : iterations_; + i--; - /// @notice returns stake of user based on their stipend - function stipendToStake(uint256 stipend_, uint256 timestamp_) public view returns (uint256) { - // inlined getTodaysBudget call to get a better precision - // return (stipend * totalMORSupply(timestamp)) / getTodaysBudget(timestamp); - return (stipend_ * totalMORSupply(timestamp_) * 100) / getComputeBalance(timestamp_); - } + while (i >= 0) { + if (block.timestamp < onHoldEntries[i].releaseAt) { + if (i == 0) break; + i--; + + continue; + } + + amount_ += onHoldEntries[i].amount; + onHoldEntries.pop(); + + if (i == 0) break; + i--; + } - /// @dev make it pure - function whenSessionEnds( - uint256 sessionStake_, - uint256 pricePerSecond_, - uint256 openedAt_ - ) public view returns (uint256) { - // if session stake is more than daily price then session will last for its max duration - uint256 duration = stakeToStipend(sessionStake_, openedAt_) / pricePerSecond_; - if (duration >= MAX_SESSION_DURATION) { - return openedAt_ + MAX_SESSION_DURATION; + if (amount_ == 0) { + revert SessionUserAmountToWithdrawIsZero(); } - return openedAt_ + duration; + getToken().safeTransfer(_msgSender(), amount_); + + emit UserWithdrawn(_msgSender(), amount_); } - /// @notice returns today's budget in MOR - function getTodaysBudget(uint256 timestamp_) public view returns (uint256) { - return getComputeBalance(timestamp_) / 100; // 1% of Compute Balance + //////////////////////// + /// GLOBAL PUBLIC /// + //////////////////////// + + /** + * @dev Returns today's budget in MOR. 1% + */ + function getTodaysBudget(uint128 timestamp_) external view returns (uint256) { + return getComputeBalance(timestamp_) / 100; } - /// @notice returns today's compute balance in MOR - function getComputeBalance(uint256 timestamp_) public view returns (uint256) { - Pool memory pool = getPool(COMPUTE_POOL_INDEX); + /** + * @dev Returns today's compute balance in MOR without claimed amount + */ + function getComputeBalance(uint128 timestamp_) public view returns (uint256) { + Pool storage pool = pool(COMPUTE_POOL_INDEX); + uint256 periodReward = LinearDistributionIntervalDecrease.getPeriodReward( pool.initialReward, pool.rewardDecrease, @@ -405,140 +387,49 @@ contract SessionRouter is uint128(startOfTheDay(timestamp_)) ); - return periodReward - totalClaimed(); + return periodReward - getProvidersTotalClaimed(); } - // returns total amount of MOR tokens that were distributed across all pools - function totalMORSupply(uint256 timestamp_) public view returns (uint256) { + /** + * @dev Total amount of MOR tokens that were distributed across all pools + * without compute pool rewards and with compute claimed rewards + */ + function totalMORSupply(uint128 timestamp_) public view returns (uint256) { uint256 startOfTheDay_ = startOfTheDay(timestamp_); uint256 totalSupply_ = 0; - Pool[] memory pools = getPools(); + Pool[] storage pools = pools(); for (uint256 i = 0; i < pools.length; i++) { - if (i == COMPUTE_POOL_INDEX) continue; // skip compute pool (it's calculated separately) - - Pool memory pool = pools[i]; + if (i == COMPUTE_POOL_INDEX) continue; totalSupply_ += LinearDistributionIntervalDecrease.getPeriodReward( - pool.initialReward, - pool.rewardDecrease, - pool.payoutStart, - pool.decreaseInterval, - pool.payoutStart, + pools[i].initialReward, + pools[i].rewardDecrease, + pools[i].payoutStart, + pools[i].decreaseInterval, + pools[i].payoutStart, uint128(startOfTheDay_) ); } - return totalSupply_ + totalClaimed(); + return totalSupply_ + getProvidersTotalClaimed(); } - function getActiveBidsRatingByModel( - bytes32 modelId_, - uint256 offset_, - uint8 limit_ - ) external view returns (bytes32[] memory, Bid[] memory, ProviderModelStats[] memory) { - bytes32[] memory modelBidsSet_ = modelActiveBids(modelId_, offset_, limit_); - uint256 length_ = modelBidsSet_.length; - - Bid[] memory bids_ = new Bid[](length_); - bytes32[] memory bidIds_ = new bytes32[](length_); - ProviderModelStats[] memory stats_ = new ProviderModelStats[](length_); - - for (uint i = 0; i < length_; i++) { - bytes32 id_ = modelBidsSet_[i]; - bidIds_[i] = id_; - Bid memory bid_ = getBid(id_); - bids_[i] = bid_; - stats_[i] = _getProviderModelStats(modelId_, bid_.provider); - } - - return (bidIds_, bids_, stats_); - } - - function startOfTheDay(uint256 timestamp_) public pure returns (uint256) { + function startOfTheDay(uint128 timestamp_) public pure returns (uint128) { return timestamp_ - (timestamp_ % 1 days); } - function _extractProviderApproval(bytes calldata providerApproval_) private view returns (bytes32) { - (bytes32 bidId_, uint256 chainId_, address user_, uint128 timestamp_) = abi.decode( - providerApproval_, - (bytes32, uint256, address, uint128) - ); - - if (user_ != _msgSender()) { - revert ApprovedForAnotherUser(); - } - if (chainId_ != block.chainid) { - revert WrongChainId(); - } - if (timestamp_ < block.timestamp - SIGNATURE_TTL) { - revert SignatureExpired(); - } - - return bidId_; - } - - function _extractReceipt(bytes calldata receiptEncoded_) private view returns (bytes32, uint32, uint32) { - (bytes32 sessionId_, uint256 chainId_, uint128 timestamp_, uint32 tpsScaled1000_, uint32 ttftMs_) = abi.decode( - receiptEncoded_, - (bytes32, uint256, uint128, uint32, uint32) - ); - - if (chainId_ != block.chainid) { - revert WrongChainId(); - } - if (timestamp_ < block.timestamp - SIGNATURE_TTL) { - revert SignatureExpired(); - } - - return (sessionId_, tpsScaled1000_, ttftMs_); - } - - function _getProviderClaimableBalance(Session memory session_) private view returns (uint256) { - // if session was closed with no dispute - provider already got all funds - // - // if session was closed with dispute - - // if session was ended but not closed - - // if session was not ended - provider can claim all funds except for today's session cost - - uint256 claimIntervalEnd_ = session_.closedAt.min(session_.endsAt.min(startOfTheDay(block.timestamp))); - uint256 claimableDuration_ = claimIntervalEnd_.max(session_.openedAt) - session_.openedAt; - uint256 totalCost_ = claimableDuration_ * session_.pricePerSecond; - uint256 withdrawableAmount_ = totalCost_ - session_.providerWithdrawnAmount; - - return withdrawableAmount_; - } - - /// @notice returns total claimanble balance for the provider for particular session - function getProviderClaimableBalance(bytes32 sessionId_) public view returns (uint256) { - Session memory session_ = _getSession(sessionId_); - if (session_.openedAt == 0) { - revert SessionNotFound(); - } - - return _getProviderClaimableBalance(session_); - } + ////////////////////////// + /// GLOBAL PRIVATE /// + ////////////////////////// - /// @notice checks if receipt is valid - function _isValidReceipt( - address signer_, + function _isValidProviderReceipt( + address provider_, bytes calldata receipt_, bytes calldata signature_ ) private pure returns (bool) { - if (signature_.length == 0) { - return false; - } - bytes32 receiptHash_ = ECDSA.toEthSignedMessageHash(keccak256(receipt_)); - return ECDSA.recover(receiptHash_, signature_) == signer_; - } - - function _ownerOrProvider(address provider_) private view returns (bool) { - return _msgSender() == owner() || _msgSender() == provider_; - } - - function _ownerOrUser(address user_) private view returns (bool) { - return _msgSender() == owner() || _msgSender() == user_; + return ECDSA.recover(receiptHash_, signature_) == provider_; } } diff --git a/smart-contracts/contracts/diamond/presets/OwnableDiamondStorage.sol b/smart-contracts/contracts/diamond/presets/OwnableDiamondStorage.sol index e03351a5..d671a880 100644 --- a/smart-contracts/contracts/diamond/presets/OwnableDiamondStorage.sol +++ b/smart-contracts/contracts/diamond/presets/OwnableDiamondStorage.sol @@ -7,7 +7,7 @@ import {Context} from "@openzeppelin/contracts/utils/Context.sol"; abstract contract OwnableDiamondStorage is DiamondOwnableStorage, Context { /** - * @dev The caller account is not authorized to perform an operation. + * @dev The caller account is not authorized to perform an operation as owner. */ error OwnableUnauthorizedAccount(address account_); @@ -16,4 +16,10 @@ abstract contract OwnableDiamondStorage is DiamondOwnableStorage, Context { revert OwnableUnauthorizedAccount(_msgSender()); } } + + function _onlyAccount(address sender_) internal view { + if (_msgSender() != sender_) { + revert OwnableUnauthorizedAccount(_msgSender()); + } + } } diff --git a/smart-contracts/contracts/diamond/storages/BidStorage.sol b/smart-contracts/contracts/diamond/storages/BidStorage.sol index 8c418ace..667b1712 100644 --- a/smart-contracts/contracts/diamond/storages/BidStorage.sol +++ b/smart-contracts/contracts/diamond/storages/BidStorage.sol @@ -24,11 +24,12 @@ contract BidStorage is IBidStorage { bytes32 public constant BID_STORAGE_SLOT = keccak256("diamond.standard.bid.storage"); - function bids(bytes32 bidId) external view returns (Bid memory) { - return _getBidStorage().bids[bidId]; + /** PUBLIC, GETTERS */ + function getBid(bytes32 bidId_) external view returns (Bid memory) { + return _getBidStorage().bids[bidId_]; } - function providerActiveBids( + function getProviderActiveBids( address provider_, uint256 offset_, uint256 limit_ @@ -36,15 +37,23 @@ contract BidStorage is IBidStorage { return _getBidStorage().providerActiveBids[provider_].part(offset_, limit_); } - function modelActiveBids(bytes32 modelId_, uint256 offset_, uint256 limit_) public view returns (bytes32[] memory) { + function getModelActiveBids( + bytes32 modelId_, + uint256 offset_, + uint256 limit_ + ) public view returns (bytes32[] memory) { return _getBidStorage().modelActiveBids[modelId_].part(offset_, limit_); } - function providerBids(address provider_, uint256 offset_, uint256 limit_) external view returns (bytes32[] memory) { + function getProviderBids( + address provider_, + uint256 offset_, + uint256 limit_ + ) external view returns (bytes32[] memory) { return _getBidStorage().providerBids[provider_].part(offset_, limit_); } - function modelBids(bytes32 modelId_, uint256 offset_, uint256 limit_) external view returns (bytes32[] memory) { + function getModelBids(bytes32 modelId_, uint256 offset_, uint256 limit_) external view returns (bytes32[] memory) { return _getBidStorage().modelBids[modelId_].part(offset_, limit_); } @@ -52,59 +61,64 @@ contract BidStorage is IBidStorage { return _getBidStorage().token; } - function getBid(bytes32 bidId) internal view returns (Bid storage) { - return _getBidStorage().bids[bidId]; + function isBidActive(bytes32 bidId_) public view returns (bool) { + Bid storage bid = _getBidStorage().bids[bidId_]; + + return bid.createdAt != 0 && bid.deletedAt == 0; } - function addProviderActiveBids(address provider, bytes32 bidId) internal { - _getBidStorage().providerActiveBids[provider].add(bidId); + /** INTERNAL, GETTERS */ + function bids(bytes32 bidId_) internal view returns (Bid storage) { + return _getBidStorage().bids[bidId_]; } - function addModelActiveBids(bytes32 modelId, bytes32 bidId) internal { - _getBidStorage().modelActiveBids[modelId].add(bidId); + function isModelActiveBidsEmpty(bytes32 modelId) internal view returns (bool) { + return _getBidStorage().modelActiveBids[modelId].length() == 0; } - function removeProviderActiveBids(address provider, bytes32 bidId) internal { - _getBidStorage().providerActiveBids[provider].remove(bidId); + function isProviderActiveBidsEmpty(address provider) internal view returns (bool) { + return _getBidStorage().providerActiveBids[provider].length() == 0; } - function getModelActiveBids(bytes32 modelId) internal view returns (EnumerableSet.Bytes32Set storage) { - return _getBidStorage().modelActiveBids[modelId]; + /** INTERNAL, SETTERS */ + function addProviderActiveBids(address provider_, bytes32 bidId_) internal { + _getBidStorage().providerActiveBids[provider_].add(bidId_); } - function removeModelActiveBids(bytes32 modelId, bytes32 bidId) internal { - _getBidStorage().modelActiveBids[modelId].remove(bidId); + function removeProviderActiveBids(address provider_, bytes32 bidId_) internal { + _getBidStorage().providerActiveBids[provider_].remove(bidId_); } - function isModelActiveBidsEmpty(bytes32 modelId) internal view returns (bool) { - return _getBidStorage().modelActiveBids[modelId].length() == 0; + function addModelActiveBids(bytes32 modelId_, bytes32 bidId_) internal { + _getBidStorage().modelActiveBids[modelId_].add(bidId_); } - function isProviderActiveBidsEmpty(address provider) internal view returns (bool) { - return _getBidStorage().providerActiveBids[provider].length() == 0; + function removeModelActiveBids(bytes32 modelId_, bytes32 bidId_) internal { + _getBidStorage().modelActiveBids[modelId_].remove(bidId_); } - function addProviderBid(address provider, bytes32 bidId) internal { - _getBidStorage().providerBids[provider].push(bidId); + function addProviderBid(address provider_, bytes32 bidId_) internal { + _getBidStorage().providerBids[provider_].push(bidId_); } - function addModelBid(bytes32 modelId, bytes32 bidId) internal { - _getBidStorage().modelBids[modelId].push(bidId); + function addModelBid(bytes32 modelId_, bytes32 bidId_) internal { + _getBidStorage().modelBids[modelId_].push(bidId_); } - function setBid(bytes32 bidId, Bid memory bid) internal { - _getBidStorage().bids[bidId] = bid; + function setToken(IERC20 token_) internal { + _getBidStorage().token = token_; } - function _incrementBidNonce(bytes32 providerModelId) internal returns (uint256) { - return _getBidStorage().providerModelNonce[providerModelId]++; + function incrementBidNonce(bytes32 providerModelId_) internal returns (uint256) { + return _getBidStorage().providerModelNonce[providerModelId_]++; } - function _getBidStorage() internal pure returns (BDStorage storage _ds) { + /** PRIVATE */ + function _getBidStorage() private pure returns (BDStorage storage ds) { bytes32 slot_ = BID_STORAGE_SLOT; assembly { - _ds.slot := slot_ + ds.slot := slot_ } } } diff --git a/smart-contracts/contracts/diamond/storages/MarketplaceStorage.sol b/smart-contracts/contracts/diamond/storages/MarketplaceStorage.sol index a7230d1d..b5445cf6 100644 --- a/smart-contracts/contracts/diamond/storages/MarketplaceStorage.sol +++ b/smart-contracts/contracts/diamond/storages/MarketplaceStorage.sol @@ -5,33 +5,36 @@ import {IMarketplaceStorage} from "../../interfaces/storage/IMarketplaceStorage. contract MarketplaceStorage is IMarketplaceStorage { struct MPStorage { - uint256 feeBalance; // total fees balance of the contract + uint256 feeBalance; // Total fees balance of the contract uint256 bidFee; } bytes32 public constant MARKETPLACE_STORAGE_SLOT = keccak256("diamond.standard.marketplace.storage"); + /** PUBLIC, GETTERS */ function getBidFee() public view returns (uint256) { return _getMarketplaceStorage().bidFee; } - function getFeeBalance() internal view returns (uint256) { + function getFeeBalance() public view returns (uint256) { return _getMarketplaceStorage().feeBalance; } - function increaseFeeBalance(uint256 amount) internal { - _getMarketplaceStorage().feeBalance += amount; + /** INTERNAL, SETTERS */ + function setBidFee(uint256 bidFee_) internal { + _getMarketplaceStorage().bidFee = bidFee_; } - function decreaseFeeBalance(uint256 amount) internal { - _getMarketplaceStorage().feeBalance -= amount; + function setFeeBalance(uint256 feeBalance_) internal { + _getMarketplaceStorage().feeBalance = feeBalance_; } - function _getMarketplaceStorage() internal pure returns (MPStorage storage _ds) { + /** PRIVATE */ + function _getMarketplaceStorage() private pure returns (MPStorage storage ds) { bytes32 slot_ = MARKETPLACE_STORAGE_SLOT; assembly { - _ds.slot := slot_ + ds.slot := slot_ } } } diff --git a/smart-contracts/contracts/diamond/storages/ModelStorage.sol b/smart-contracts/contracts/diamond/storages/ModelStorage.sol index 22e8fc8f..9600e900 100644 --- a/smart-contracts/contracts/diamond/storages/ModelStorage.sol +++ b/smart-contracts/contracts/diamond/storages/ModelStorage.sol @@ -1,59 +1,58 @@ // SPDX-License-Identifier: MIT pragma solidity ^0.8.24; +import {Paginator} from "@solarity/solidity-lib/libs/arrays/Paginator.sol"; + import {IModelStorage} from "../../interfaces/storage/IModelStorage.sol"; contract ModelStorage is IModelStorage { + using Paginator for *; + struct MDLStorage { uint256 modelMinimumStake; - bytes32[] modelIds; // all model ids + bytes32[] modelIds; mapping(bytes32 modelId => Model) models; - mapping(bytes32 modelId => bool) isModelActive; } bytes32 public constant MODEL_STORAGE_SLOT = keccak256("diamond.standard.model.storage"); - function getModel(bytes32 modelId) external view returns (Model memory) { - return _getModelStorage().models[modelId]; + /** PUBLIC, GETTERS */ + function getModel(bytes32 modelId_) external view returns (Model memory) { + return _getModelStorage().models[modelId_]; } - function models(uint256 index) external view returns (bytes32) { - return _getModelStorage().modelIds[index]; + function getModelIds(uint256 offset_, uint256 limit_) external view returns (bytes32[] memory) { + return _getModelStorage().modelIds.part(offset_, limit_); } - function modelMinimumStake() public view returns (uint256) { + function getModelMinimumStake() public view returns (uint256) { return _getModelStorage().modelMinimumStake; } - function setModelActive(bytes32 modelId, bool isActive) internal { - _getModelStorage().isModelActive[modelId] = isActive; + function getIsModelActive(bytes32 modelId_) public view returns (bool) { + return !_getModelStorage().models[modelId_].isDeleted; } - function addModel(bytes32 modelId) internal { - _getModelStorage().modelIds.push(modelId); + /** INTERNAL, GETTERS */ + function models(bytes32 modelId_) internal view returns (Model storage) { + return _getModelStorage().models[modelId_]; } - function setModel(bytes32 modelId, Model memory model) internal { - _getModelStorage().models[modelId] = model; + /** INTERNAL, SETTERS */ + function addModelId(bytes32 modelId_) internal { + _getModelStorage().modelIds.push(modelId_); } - function _setModelMinimumStake(uint256 modelMinimumStake_) internal { + function setModelMinimumStake(uint256 modelMinimumStake_) internal { _getModelStorage().modelMinimumStake = modelMinimumStake_; } - function models(bytes32 id) internal view returns (Model storage) { - return _getModelStorage().models[id]; - } - - function isModelActive(bytes32 modelId) internal view returns (bool) { - return _getModelStorage().isModelActive[modelId]; - } - - function _getModelStorage() internal pure returns (MDLStorage storage _ds) { + /** PRIVATE */ + function _getModelStorage() private pure returns (MDLStorage storage ds) { bytes32 slot_ = MODEL_STORAGE_SLOT; assembly { - _ds.slot := slot_ + ds.slot := slot_ } } } diff --git a/smart-contracts/contracts/diamond/storages/ProviderStorage.sol b/smart-contracts/contracts/diamond/storages/ProviderStorage.sol index f5db0e3f..4318bd29 100644 --- a/smart-contracts/contracts/diamond/storages/ProviderStorage.sol +++ b/smart-contracts/contracts/diamond/storages/ProviderStorage.sol @@ -7,46 +7,42 @@ contract ProviderStorage is IProviderStorage { struct PRVDRStorage { uint256 providerMinimumStake; mapping(address => Provider) providers; - mapping(address => bool) isProviderActive; } - uint128 constant PROVIDER_REWARD_LIMITER_PERIOD = 365 days; // reward for this period will be limited by the stake + // Reward for this period will be limited by the stake + uint128 constant PROVIDER_REWARD_LIMITER_PERIOD = 365 days; bytes32 public constant PROVIDER_STORAGE_SLOT = keccak256("diamond.standard.provider.storage"); - function getProvider(address provider) external view returns (Provider memory) { - return _getProviderStorage().providers[provider]; + /** PUBLIC, GETTERS */ + function getProvider(address provider_) external view returns (Provider memory) { + return providers(provider_); } - function providerMinimumStake() public view returns (uint256) { + function getProviderMinimumStake() public view returns (uint256) { return _getProviderStorage().providerMinimumStake; } - function setProviderActive(address provider, bool isActive) internal { - _getProviderStorage().isProviderActive[provider] = isActive; + function getIsProviderActive(address provider_) public view returns (bool) { + return !providers(provider_).isDeleted; } - function setProvider(address provider, Provider memory provider_) internal { - _getProviderStorage().providers[provider] = provider_; + /** INTERNAL, GETTERS */ + function providers(address provider_) internal view returns (Provider storage) { + return _getProviderStorage().providers[provider_]; } - function setProviderMinimumStake(uint256 _providerMinimumStake) internal { - _getProviderStorage().providerMinimumStake = _providerMinimumStake; + /** INTERNAL, SETTERS */ + function setProviderMinimumStake(uint256 providerMinimumStake_) internal { + _getProviderStorage().providerMinimumStake = providerMinimumStake_; } - function providers(address addr) internal view returns (Provider storage) { - return _getProviderStorage().providers[addr]; - } - - function isProviderActive(address provider) internal view returns (bool) { - return _getProviderStorage().isProviderActive[provider]; - } - - function _getProviderStorage() internal pure returns (PRVDRStorage storage _ds) { + /** PRIVATE */ + function _getProviderStorage() private pure returns (PRVDRStorage storage ds) { bytes32 slot_ = PROVIDER_STORAGE_SLOT; assembly { - _ds.slot := slot_ + ds.slot := slot_ } } } diff --git a/smart-contracts/contracts/diamond/storages/SessionStorage.sol b/smart-contracts/contracts/diamond/storages/SessionStorage.sol index 1c469c1c..20ab07e5 100644 --- a/smart-contracts/contracts/diamond/storages/SessionStorage.sol +++ b/smart-contracts/contracts/diamond/storages/SessionStorage.sol @@ -1,122 +1,156 @@ // SPDX-License-Identifier: MIT pragma solidity ^0.8.24; -import {ISessionStorage} from "../../interfaces/storage/ISessionStorage.sol"; - +import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; import {Paginator} from "@solarity/solidity-lib/libs/arrays/Paginator.sol"; +import {ISessionStorage} from "../../interfaces/storage/ISessionStorage.sol"; + contract SessionStorage is ISessionStorage { using Paginator for *; + using EnumerableSet for EnumerableSet.Bytes32Set; struct SNStorage { - // all sessions - uint256 sessionNonce; // used to generate unique session id + // Account which stores the MOR tokens with infinite allowance for this contract + address fundingAccount; + // Distribution pools configuration that mirrors L1 contract + Pool[] pools; + // Total amount of MOR claimed by providers + uint256 providersTotalClaimed; + // Used to generate unique session ID + uint256 sessionNonce; mapping(bytes32 sessionId => Session) sessions; - mapping(address user => bytes32[]) userSessionIds; - mapping(address provider => bytes32[]) providerSessionIds; - mapping(bytes32 modelId => bytes32[]) modelSessionIds; - // active sessions - uint256 totalClaimed; // total amount of MOR claimed by providers - mapping(address user => mapping(bytes32 => bool)) isUserSessionActive; // user address => active session indexes - mapping(address provider => mapping(bytes32 => bool)) isProviderSessionActive; // provider address => active session indexes - mapping(address user => OnHold[]) userOnHold; // user address => balance - mapping(bytes providerApproval => bool) isApprovalUsed; - // other - address fundingAccount; // account which stores the MOR tokens with infinite allowance for this contract - Pool[] pools; // distribution pools configuration that mirrors L1 contract + // Session registry for providers, users and models + mapping(address user => EnumerableSet.Bytes32Set) userSessions; + mapping(address provider => EnumerableSet.Bytes32Set) providerSessions; + mapping(bytes32 modelId => EnumerableSet.Bytes32Set) modelSessions; + mapping(address user => OnHold[]) userStakesOnHold; + mapping(bytes providerApproval => bool) isProviderApprovalUsed; } bytes32 public constant SESSION_STORAGE_SLOT = keccak256("diamond.standard.session.storage"); + uint32 public constant MIN_SESSION_DURATION = 5 minutes; + uint32 public constant MAX_SESSION_DURATION = 1 days; + uint32 public constant SIGNATURE_TTL = 10 minutes; + uint256 public constant COMPUTE_POOL_INDEX = 3; - function sessions(bytes32 sessionId) external view returns (Session memory) { - return _getSessionStorage().sessions[sessionId]; + /** PUBLIC, GETTERS */ + function getSession(bytes32 sessionId_) external view returns (Session memory) { + return _getSessionStorage().sessions[sessionId_]; } - function getSessionsByUser(address user, uint256 offset_, uint256 limit_) external view returns (bytes32[] memory) { - return _getSessionStorage().userSessionIds[user].part(offset_, limit_); + function getUserSessions(address user_, uint256 offset_, uint256 limit_) external view returns (bytes32[] memory) { + return _getSessionStorage().userSessions[user_].part(offset_, limit_); } - function pools() external view returns (Pool[] memory) { - return _getSessionStorage().pools; + function getProviderSessions( + address provider_, + uint256 offset_, + uint256 limit_ + ) external view returns (bytes32[] memory) { + return _getSessionStorage().providerSessions[provider_].part(offset_, limit_); } - function getPools() internal view returns (Pool[] storage) { + function getModelSessions( + bytes32 modelId_, + uint256 offset_, + uint256 limit_ + ) external view returns (bytes32[] memory) { + return _getSessionStorage().modelSessions[modelId_].part(offset_, limit_); + } + + function getPools() external view returns (Pool[] memory) { return _getSessionStorage().pools; } - function getPool(uint256 poolIndex) internal view returns (Pool storage) { - return _getSessionStorage().pools[poolIndex]; + function getPool(uint256 index_) external view returns (Pool memory) { + return _getSessionStorage().pools[index_]; } function getFundingAccount() public view returns (address) { return _getSessionStorage().fundingAccount; } - function setSession(bytes32 sessionId, Session memory session) internal { - _getSessionStorage().sessions[sessionId] = session; + function getTotalSessions(address providerAddr_) public view returns (uint256) { + return _getSessionStorage().providerSessions[providerAddr_].length(); } - function setUserSessionActive(address user, bytes32 sessionId, bool active) internal { - _getSessionStorage().isUserSessionActive[user][sessionId] = active; + function getProvidersTotalClaimed() public view returns (uint256) { + return _getSessionStorage().providersTotalClaimed; } - function setProviderSessionActive(address provider, bytes32 sessionId, bool active) internal { - _getSessionStorage().isProviderSessionActive[provider][sessionId] = active; + function getIsProviderApprovalUsed(bytes memory approval_) public view returns (bool) { + return _getSessionStorage().isProviderApprovalUsed[approval_]; } - function addUserSessionId(address user, bytes32 sessionId) internal { - _getSessionStorage().userSessionIds[user].push(sessionId); + /** INTERNAL, GETTERS */ + function pools() internal view returns (Pool[] storage) { + return _getSessionStorage().pools; } - function addProviderSessionId(address provider, bytes32 sessionId) internal { - _getSessionStorage().providerSessionIds[provider].push(sessionId); + function pool(uint256 poolIndex_) internal view returns (Pool storage) { + return _getSessionStorage().pools[poolIndex_]; } - function totalSessions(address providerAddr) internal view returns (uint256) { - return _getSessionStorage().providerSessionIds[providerAddr].length; + function userStakesOnHold(address user_) internal view returns (OnHold[] storage) { + return _getSessionStorage().userStakesOnHold[user_]; } - function addModelSessionId(bytes32 modelId, bytes32 sessionId) internal { - _getSessionStorage().modelSessionIds[modelId].push(sessionId); + function sessions(bytes32 sessionId_) internal view returns (Session storage) { + return _getSessionStorage().sessions[sessionId_]; + } + + /** INTERNAL, SETTERS */ + function setFundingAccount(address fundingAccount_) internal { + _getSessionStorage().fundingAccount = fundingAccount_; } - function addOnHold(address user, OnHold memory onHold) internal { - _getSessionStorage().userOnHold[user].push(onHold); + function setPools(Pool[] calldata pools_) internal { + SNStorage storage s = _getSessionStorage(); + + for (uint256 i = 0; i < pools_.length; i++) { + s.pools.push(pools_[i]); + } } - function increaseTotalClaimed(uint256 amount) internal { - _getSessionStorage().totalClaimed += amount; + function setPool(uint256 index_, Pool calldata pool_) internal { + _getSessionStorage().pools[index_] = pool_; } - function totalClaimed() internal view returns (uint256) { - return _getSessionStorage().totalClaimed; + function addUserSessionId(address user_, bytes32 sessionId_) internal { + _getSessionStorage().userSessions[user_].add(sessionId_); } - function getOnHold(address user) internal view returns (OnHold[] storage) { - return _getSessionStorage().userOnHold[user]; + function addProviderSessionId(address provider_, bytes32 sessionId_) internal { + _getSessionStorage().providerSessions[provider_].add(sessionId_); } - function _getSession(bytes32 sessionId) internal view returns (Session storage) { - return _getSessionStorage().sessions[sessionId]; + function addModelSessionId(bytes32 modelId, bytes32 sessionId) internal { + _getSessionStorage().modelSessions[modelId].add(sessionId); } - function incrementSessionNonce() internal returns (uint256) { - return _getSessionStorage().sessionNonce++; + function addUserStakeOnHold(address user, OnHold memory onHold) internal { + _getSessionStorage().userStakesOnHold[user].push(onHold); + } + + function increaseProvidersTotalClaimed(uint256 amount) internal { + _getSessionStorage().providersTotalClaimed += amount; } - function isApprovalUsed(bytes memory approval) internal view returns (bool) { - return _getSessionStorage().isApprovalUsed[approval]; + function incrementSessionNonce() internal returns (uint256) { + return _getSessionStorage().sessionNonce++; } - function setApprovalUsed(bytes memory approval) internal { - _getSessionStorage().isApprovalUsed[approval] = true; + function setIsProviderApprovalUsed(bytes memory approval_, bool isUsed_) internal { + _getSessionStorage().isProviderApprovalUsed[approval_] = isUsed_; } - function _getSessionStorage() internal pure returns (SNStorage storage _ds) { + /** PRIVATE */ + function _getSessionStorage() private pure returns (SNStorage storage ds) { bytes32 slot_ = SESSION_STORAGE_SLOT; assembly { - _ds.slot := slot_ + ds.slot := slot_ } } } diff --git a/smart-contracts/contracts/diamond/storages/StatsStorage.sol b/smart-contracts/contracts/diamond/storages/StatsStorage.sol index c86751da..bfa3c6f8 100644 --- a/smart-contracts/contracts/diamond/storages/StatsStorage.sol +++ b/smart-contracts/contracts/diamond/storages/StatsStorage.sol @@ -3,39 +3,41 @@ pragma solidity ^0.8.24; import {IStatsStorage} from "../../interfaces/storage/IStatsStorage.sol"; -import {LibSD} from "../../libs/LibSD.sol"; - contract StatsStorage is IStatsStorage { - struct ModelStats { - LibSD.SD tpsScaled1000; - LibSD.SD ttftMs; - LibSD.SD totalDuration; - uint32 count; - } - struct STTSStorage { - mapping(bytes32 => mapping(address => ProviderModelStats)) stats; // modelId => provider => stats + mapping(bytes32 => mapping(address => ProviderModelStats)) providerModelStats; // modelId => provider => stats mapping(bytes32 => ModelStats) modelStats; } bytes32 public constant STATS_STORAGE_SLOT = keccak256("diamond.stats.storage"); - function _getModelStats(bytes32 modelId) internal view returns (ModelStats storage) { + /** PUBLIC, GETTERS */ + function getProviderModelStats( + bytes32 modelId_, + address provider_ + ) external view returns (ProviderModelStats memory) { + return _getStatsStorage().providerModelStats[modelId_][provider_]; + } + + function getModelStats(bytes32 modelId_) external view returns (ModelStats memory) { + return _getStatsStorage().modelStats[modelId_]; + } + + /** INTERNAL, GETTERS */ + function modelStats(bytes32 modelId) internal view returns (ModelStats storage) { return _getStatsStorage().modelStats[modelId]; } - function _getProviderModelStats( - bytes32 modelId, - address provider - ) internal view returns (ProviderModelStats storage) { - return _getStatsStorage().stats[modelId][provider]; + function providerModelStats(bytes32 modelId, address provider) internal view returns (ProviderModelStats storage) { + return _getStatsStorage().providerModelStats[modelId][provider]; } - function _getStatsStorage() internal pure returns (STTSStorage storage _ds) { + /** PRIVATE */ + function _getStatsStorage() private pure returns (STTSStorage storage ds) { bytes32 slot_ = STATS_STORAGE_SLOT; assembly { - _ds.slot := slot_ + ds.slot := slot_ } } } diff --git a/smart-contracts/contracts/interfaces/facets/IMarketplace.sol b/smart-contracts/contracts/interfaces/facets/IMarketplace.sol index bd68ca44..ab4ab154 100644 --- a/smart-contracts/contracts/interfaces/facets/IMarketplace.sol +++ b/smart-contracts/contracts/interfaces/facets/IMarketplace.sol @@ -1,30 +1,22 @@ // SPDX-License-Identifier: MIT pragma solidity ^0.8.24; -import {IBidStorage} from "../storage/IBidStorage.sol"; import {IMarketplaceStorage} from "../storage/IMarketplaceStorage.sol"; -interface IMarketplace is IBidStorage, IMarketplaceStorage { - event BidPosted(address indexed provider, bytes32 indexed modelId, uint256 nonce); - event BidDeleted(address indexed provider, bytes32 indexed modelId, uint256 nonce); - event FeeUpdated(uint256 bidFee); +interface IMarketplace is IMarketplaceStorage { + event MarketplaceBidPosted(address indexed provider, bytes32 indexed modelId, uint256 nonce); + event MarketplaceBidDeleted(address indexed provider, bytes32 indexed modelId, uint256 nonce); + event MaretplaceFeeUpdated(uint256 bidFee); - error ProviderNotFound(); - error ModelNotFound(); - error ActiveBidNotFound(); - error BidTaken(); - error NotEnoughBalance(); - error NotOwnerOrProvider(); + error MarketplaceProviderNotFound(); + error MarketplaceModelNotFound(); + error MarketplaceActiveBidNotFound(); function __Marketplace_init(address token_) external; - function setBidFee(uint256 bidFee_) external; + function setMarketplaceBidFee(uint256 bidFee_) external; - function postModelBid( - address provider_, - bytes32 modelId_, - uint256 pricePerSecond_ - ) external returns (bytes32 bidId); + function postModelBid(bytes32 modelId_, uint256 pricePerSecond_) external returns (bytes32); function deleteModelBid(bytes32 bidId_) external; diff --git a/smart-contracts/contracts/interfaces/facets/IModelRegistry.sol b/smart-contracts/contracts/interfaces/facets/IModelRegistry.sol index 61cad56a..de19f793 100644 --- a/smart-contracts/contracts/interfaces/facets/IModelRegistry.sol +++ b/smart-contracts/contracts/interfaces/facets/IModelRegistry.sol @@ -6,25 +6,23 @@ import {IModelStorage} from "../storage/IModelStorage.sol"; interface IModelRegistry is IModelStorage { event ModelRegisteredUpdated(address indexed owner, bytes32 indexed modelId); event ModelDeregistered(address indexed owner, bytes32 indexed modelId); - event ModelMinimumStakeSet(uint256 newStake); - + event ModelMinStakeUpdated(uint256 newStake); + error ModelStakeTooLow(uint256 amount, uint256 minAmount); + error ModelHasAlreadyDeregistered(); error ModelNotFound(); - error StakeTooLow(); error ModelHasActiveBids(); - error NotOwnerOrModelOwner(); function __ModelRegistry_init() external; - function setModelMinimumStake(uint256 modelMinimumStake_) external; + function modelSetMinStake(uint256 modelMinimumStake_) external; function modelRegister( bytes32 modelId_, bytes32 ipfsCID_, uint256 fee_, - uint256 stake_, - address owner_, - string memory name_, - string[] memory tags_ + uint256 amount_, + string calldata name_, + string[] calldata tags_ ) external; function modelDeregister(bytes32 modelId_) external; diff --git a/smart-contracts/contracts/interfaces/facets/IProviderRegistry.sol b/smart-contracts/contracts/interfaces/facets/IProviderRegistry.sol index c950237c..35c3e392 100644 --- a/smart-contracts/contracts/interfaces/facets/IProviderRegistry.sol +++ b/smart-contracts/contracts/interfaces/facets/IProviderRegistry.sol @@ -4,25 +4,35 @@ pragma solidity ^0.8.24; import {IProviderStorage} from "../storage/IProviderStorage.sol"; interface IProviderRegistry is IProviderStorage { - event ProviderRegisteredUpdated(address indexed provider); + event ProviderRegistered(address indexed provider); event ProviderDeregistered(address indexed provider); event ProviderMinStakeUpdated(uint256 newStake); - event ProviderWithdrawnStake(address indexed provider, uint256 amount); - error StakeTooLow(); - error ErrProviderNotDeleted(); - error ErrNoStake(); - error ErrNoWithdrawableStake(); + event ProviderWithdrawn(address indexed provider, uint256 amount); + error ProviderStakeTooLow(uint256 amount, uint256 minAmount); + error ProviderNotDeregistered(); + error ProviderNoStake(); + error ProviderNothingToWithdraw(); error ProviderHasActiveBids(); - error NotOwnerOrProvider(); error ProviderNotFound(); + error ProviderHasAlreadyDeregistered(); function __ProviderRegistry_init() external; + /** + * @notice Sets the minimum stake required for a provider + * @param providerMinimumStake_ The minimal stake + */ function providerSetMinStake(uint256 providerMinimumStake_) external; - function providerRegister(address providerAddress_, uint256 amount_, string memory endpoint_) external; + /** + * @notice Register a provider + * @param amount_ The amount of stake to add + * @param endpoint_ The provider endpoint (host.com:1234) + */ + function providerRegister(uint256 amount_, string calldata endpoint_) external; - function providerDeregister(address provider_) external; - - function providerWithdrawStake(address provider_) external; + /** + * @notice Deregister a provider + */ + function providerDeregister() external; } diff --git a/smart-contracts/contracts/interfaces/facets/ISessionRouter.sol b/smart-contracts/contracts/interfaces/facets/ISessionRouter.sol index 513dd717..9a342996 100644 --- a/smart-contracts/contracts/interfaces/facets/ISessionRouter.sol +++ b/smart-contracts/contracts/interfaces/facets/ISessionRouter.sol @@ -8,83 +8,87 @@ import {IStatsStorage} from "../storage/IStatsStorage.sol"; interface ISessionRouter is ISessionStorage { event SessionOpened(address indexed user, bytes32 indexed sessionId, address indexed providerId); event SessionClosed(address indexed user, bytes32 indexed sessionId, address indexed providerId); - - error NotEnoughWithdrawableBalance(); // means that there is not enough funds at all or some funds are still locked - error WithdrawableBalanceLimitByStakeReached(); // means that user can't withdraw more funds because of the limit which equals to the stake - error ProviderSignatureMismatch(); - error SignatureExpired(); - error WrongChainId(); - error DuplicateApproval(); - error ApprovedForAnotherUser(); // means that approval generated for another user address, protection from front-running - + event UserWithdrawn(address indexed user, uint256 amount_); + error SessionProviderSignatureMismatch(); + error SesssionApproveExpired(); + error SesssionApprovedForAnotherChainId(); + error SessionDuplicateApproval(); + error SessionApprovedForAnotherUser(); // Means that approval generated for another user address, protection from front-running + error SesssionReceiptForAnotherChainId(); + error SesssionReceiptExpired(); error SessionTooShort(); - error SessionNotFound(); error SessionAlreadyClosed(); - error SessionNotClosed(); - - error BidNotFound(); - error CannotDecodeAbi(); - - error AmountToWithdrawIsZero(); - error NotOwnerOrProvider(); - error NotOwnerOrUser(); - error PoolIndexOutOfBounds(); - - function __SessionRouter_init(address fundingAccount_, Pool[] memory pools_) external; + error SessionNotEndedOrNotExist(); + error SessionProviderNothingToClaimInThisPeriod(); + error SessionBidNotFound(); + error SessionPoolIndexOutOfBounds(); + error SessionUserAmountToWithdrawIsZero(); + + function __SessionRouter_init(address fundingAccount_, Pool[] calldata pools_) external; + + /** + * @notice Sets distibution pool configuration + * @dev parameters should be the same as in Ethereum L1 Distribution contract + * @dev at address 0x47176B2Af9885dC6C4575d4eFd63895f7Aaa4790 + * @dev call 'Distribution.pools(3)' where '3' is a poolId + */ + function setPoolConfig(uint256 index_, Pool calldata pool_) external; function openSession( uint256 amount_, - bytes memory providerApproval_, - bytes memory signature_ + bytes calldata approvalEncoded_, + bytes calldata signature_ ) external returns (bytes32); - function closeSession(bytes memory receiptEncoded_, bytes memory signature_) external; - - function claimProviderBalance(bytes32 sessionId_, uint256 amountToWithdraw_) external; - - function deleteHistory(bytes32 sessionId_) external; - - function withdrawUserStake(uint256 amountToWithdraw_, uint8 iterations_) external; - - function withdrawableUserStake( - address user_, - uint8 iterations_ - ) external view returns (uint256 avail_, uint256 hold_); - function getSessionId( address user_, address provider_, - uint256 stake_, + bytes32 bidId_, uint256 sessionNonce_ ) external pure returns (bytes32); - function stakeToStipend(uint256 sessionStake_, uint256 timestamp_) external view returns (uint256); + function getSessionEnd(uint256 amount_, uint256 pricePerSecond_, uint128 openedAt_) external view returns (uint128); - function stipendToStake(uint256 stipend_, uint256 timestamp_) external view returns (uint256); + /** + * @dev Returns stipend of user based on their stake + * (User session stake amount / MOR Supply without Compute) * (MOR Compute Supply / 100) + */ + function stakeToStipend(uint256 amount_, uint128 timestamp_) external view returns (uint256); - function whenSessionEnds( - uint256 sessionStake_, - uint256 pricePerSecond_, - uint256 openedAt_ - ) external view returns (uint256); + function closeSession(bytes calldata receiptEncoded_, bytes calldata signature_) external; - function getTodaysBudget(uint256 timestamp_) external view returns (uint256); + /** + * @dev Allows providers to receive their funds after the end or closure of the session + */ + function claimForProvider(bytes32 sessionId_) external; - function getComputeBalance(uint256 timestamp_) external view returns (uint256); + /** + * @notice Returns stake of user based on their stipend + */ + function stipendToStake(uint256 stipend_, uint128 timestamp_) external view returns (uint256); - function totalMORSupply(uint256 timestamp_) external view returns (uint256); + function getUserStakesOnHold( + address user_, + uint8 iterations_ + ) external view returns (uint256 available_, uint256 hold_); - function startOfTheDay(uint256 timestamp_) external pure returns (uint256); + function withdrawUserStakes(uint8 iterations_) external; - function getProviderClaimableBalance(bytes32 sessionId_) external view returns (uint256); + /** + * @dev Returns today's budget in MOR. 1% + */ + function getTodaysBudget(uint128 timestamp_) external view returns (uint256); - function setPoolConfig(uint256 index, Pool calldata pool) external; + /** + * @dev Returns today's compute balance in MOR without claimed amount + */ + function getComputeBalance(uint128 timestamp_) external view returns (uint256); - function SIGNATURE_TTL() external view returns (uint32); + /** + * @dev Total amount of MOR tokens that were distributed across all pools + * without compute pool rewards and with compute claimed rewards + */ + function totalMORSupply(uint128 timestamp_) external view returns (uint256); - function getActiveBidsRatingByModel( - bytes32 modelId_, - uint256 offset_, - uint8 limit_ - ) external view returns (bytes32[] memory, IBidStorage.Bid[] memory, IStatsStorage.ProviderModelStats[] memory); + function startOfTheDay(uint128 timestamp_) external pure returns (uint128); } diff --git a/smart-contracts/contracts/interfaces/storage/IBidStorage.sol b/smart-contracts/contracts/interfaces/storage/IBidStorage.sol index 5282bd46..d70d056d 100644 --- a/smart-contracts/contracts/interfaces/storage/IBidStorage.sol +++ b/smart-contracts/contracts/interfaces/storage/IBidStorage.sol @@ -7,29 +7,35 @@ interface IBidStorage { struct Bid { address provider; bytes32 modelId; - uint256 pricePerSecond; // hourly price + uint256 pricePerSecond; // Hourly price uint256 nonce; uint128 createdAt; uint128 deletedAt; } - function bids(bytes32 bidId) external view returns (Bid memory); + function getBid(bytes32 bidId_) external view returns (Bid memory); - function providerActiveBids( + function getProviderActiveBids( address provider_, uint256 offset_, uint256 limit_ ) external view returns (bytes32[] memory); - function modelActiveBids( + function getModelActiveBids( bytes32 modelId_, uint256 offset_, uint256 limit_ ) external view returns (bytes32[] memory); - function providerBids(address provider_, uint256 offset_, uint256 limit_) external view returns (bytes32[] memory); + function getProviderBids( + address provider_, + uint256 offset_, + uint256 limit_ + ) external view returns (bytes32[] memory); - function modelBids(bytes32 modelId_, uint256 offset_, uint256 limit_) external view returns (bytes32[] memory); + function getModelBids(bytes32 modelId_, uint256 offset_, uint256 limit_) external view returns (bytes32[] memory); function getToken() external view returns (IERC20); + + function isBidActive(bytes32 bidId_) external view returns (bool); } diff --git a/smart-contracts/contracts/interfaces/storage/IMarketplaceStorage.sol b/smart-contracts/contracts/interfaces/storage/IMarketplaceStorage.sol index e7d4987a..197000ee 100644 --- a/smart-contracts/contracts/interfaces/storage/IMarketplaceStorage.sol +++ b/smart-contracts/contracts/interfaces/storage/IMarketplaceStorage.sol @@ -3,4 +3,6 @@ pragma solidity ^0.8.24; interface IMarketplaceStorage { function getBidFee() external view returns (uint256); + + function getFeeBalance() external view returns (uint256); } diff --git a/smart-contracts/contracts/interfaces/storage/IModelStorage.sol b/smart-contracts/contracts/interfaces/storage/IModelStorage.sol index 190f9260..ec76fef7 100644 --- a/smart-contracts/contracts/interfaces/storage/IModelStorage.sol +++ b/smart-contracts/contracts/interfaces/storage/IModelStorage.sol @@ -3,19 +3,21 @@ pragma solidity ^0.8.24; interface IModelStorage { struct Model { - bytes32 ipfsCID; // https://docs.ipfs.tech/concepts/content-addressing/#what-is-a-cid - uint256 fee; + bytes32 ipfsCID; // https://docs.ipfs.tech/concepts/content-addressing/#what-is-a-cid. Up to the model maintainer to keep up to date + uint256 fee; // The fee is a royalty placeholder that isn't currently used uint256 stake; address owner; - string name; // limit name length - string[] tags; // TODO: limit tags amount + string name; // TODO: Limit name length. Up to the model maintainer to keep up to date + string[] tags; // TODO: Limit tags amount. Up to the model maintainer to keep up to date uint128 createdAt; bool isDeleted; } - function getModel(bytes32 modelId) external view returns (Model memory); + function getModel(bytes32 modelId_) external view returns (Model memory); - function models(uint256 index) external view returns (bytes32); + function getModelIds(uint256 offset_, uint256 limit_) external view returns (bytes32[] memory); - function modelMinimumStake() external view returns (uint256); + function getModelMinimumStake() external view returns (uint256); + + function getIsModelActive(bytes32 modelId_) external view returns (bool); } diff --git a/smart-contracts/contracts/interfaces/storage/IProviderStorage.sol b/smart-contracts/contracts/interfaces/storage/IProviderStorage.sol index db0c9b38..75935a00 100644 --- a/smart-contracts/contracts/interfaces/storage/IProviderStorage.sol +++ b/smart-contracts/contracts/interfaces/storage/IProviderStorage.sol @@ -3,15 +3,17 @@ pragma solidity ^0.8.24; interface IProviderStorage { struct Provider { - string endpoint; // example 'domain.com:1234' - uint256 stake; // stake amount, which also server as a reward limiter - uint128 createdAt; // timestamp of the registration - uint128 limitPeriodEnd; // timestamp of the limiter period end - uint256 limitPeriodEarned; // total earned during the last limiter period + string endpoint; // Example 'domain.com:1234' + uint256 stake; // Stake amount, which also server as a reward limiter + uint128 createdAt; // Timestamp of the registration + uint128 limitPeriodEnd; // Timestamp of the limiter period end + uint256 limitPeriodEarned; // Total earned during the last limiter period bool isDeleted; } - function getProvider(address provider) external view returns (Provider memory); + function getProvider(address provider_) external view returns (Provider memory); - function providerMinimumStake() external view returns (uint256); + function getProviderMinimumStake() external view returns (uint256); + + function getIsProviderActive(address provider_) external view returns (bool); } diff --git a/smart-contracts/contracts/interfaces/storage/ISessionStorage.sol b/smart-contracts/contracts/interfaces/storage/ISessionStorage.sol index 6843c89e..62810691 100644 --- a/smart-contracts/contracts/interfaces/storage/ISessionStorage.sol +++ b/smart-contracts/contracts/interfaces/storage/ISessionStorage.sol @@ -3,25 +3,25 @@ pragma solidity ^0.8.24; interface ISessionStorage { struct Session { - bytes32 id; address user; - address provider; - bytes32 modelId; bytes32 bidId; uint256 stake; - uint256 pricePerSecond; bytes closeoutReceipt; - uint256 closeoutType; // use enum ?? - // amount of funds that was already withdrawn by provider (we allow to withdraw for the previous day) + // TODO: Use enum? + uint256 closeoutType; + // Amount of funds that was already withdrawn by provider (we allow to withdraw for the previous day) uint256 providerWithdrawnAmount; - uint256 openedAt; - uint256 endsAt; // expected end time considering the stake provided - uint256 closedAt; + uint128 openedAt; + // Expected end time considering the stake provided + uint128 endsAt; + uint128 closedAt; + bool isActive; } struct OnHold { uint256 amount; - uint128 releaseAt; // in epoch seconds TODO: consider using hours to reduce storage cost + // In epoch seconds. TODO: consider using hours to reduce storage cost + uint128 releaseAt; } struct Pool { @@ -31,11 +31,31 @@ interface ISessionStorage { uint128 decreaseInterval; } - function sessions(bytes32 sessionId) external view returns (Session memory); + function getSession(bytes32 sessionId_) external view returns (Session memory); - function getSessionsByUser(address user, uint256 offset_, uint256 limit_) external view returns (bytes32[] memory); + function getUserSessions(address user, uint256 offset_, uint256 limit_) external view returns (bytes32[] memory); + + function getProviderSessions( + address provider_, + uint256 offset_, + uint256 limit_ + ) external view returns (bytes32[] memory); + + function getModelSessions( + bytes32 modelId_, + uint256 offset_, + uint256 limit_ + ) external view returns (bytes32[] memory); + + function getPools() external view returns (Pool[] memory); + + function getPool(uint256 index_) external view returns (Pool memory); function getFundingAccount() external view returns (address); - function pools() external view returns (Pool[] memory); + function getTotalSessions(address providerAddr_) external view returns (uint256); + + function getProvidersTotalClaimed() external view returns (uint256); + + function getIsProviderApprovalUsed(bytes memory approval_) external view returns (bool); } diff --git a/smart-contracts/contracts/interfaces/storage/IStatsStorage.sol b/smart-contracts/contracts/interfaces/storage/IStatsStorage.sol index 7e25b9b3..1eeb6bde 100644 --- a/smart-contracts/contracts/interfaces/storage/IStatsStorage.sol +++ b/smart-contracts/contracts/interfaces/storage/IStatsStorage.sol @@ -4,12 +4,26 @@ pragma solidity ^0.8.24; import {LibSD} from "../../libs/LibSD.sol"; interface IStatsStorage { + struct ModelStats { + LibSD.SD tpsScaled1000; + LibSD.SD ttftMs; + LibSD.SD totalDuration; + uint32 count; + } + struct ProviderModelStats { - LibSD.SD tpsScaled1000; // tokens per second running average - LibSD.SD ttftMs; // time to first token running average in milliseconds - uint32 totalDuration; // total duration of sessions - uint32 successCount; // number of observations + LibSD.SD tpsScaled1000; // Tokens per second running average + LibSD.SD ttftMs; // Time to first token running average in milliseconds + uint32 totalDuration; // Total duration of sessions + uint32 successCount; // Number of observations uint32 totalCount; // TODO: consider adding SD with weldford algorithm } + + function getProviderModelStats( + bytes32 modelId_, + address provider_ + ) external view returns (ProviderModelStats memory); + + function getModelStats(bytes32 modelId_) external view returns (ModelStats memory); } diff --git a/smart-contracts/contracts/mock/tokens/MorpheusToken.sol b/smart-contracts/contracts/mock/tokens/MorpheusToken.sol index 8b5241dc..fe5b5110 100644 --- a/smart-contracts/contracts/mock/tokens/MorpheusToken.sol +++ b/smart-contracts/contracts/mock/tokens/MorpheusToken.sol @@ -5,9 +5,9 @@ import {ERC20} from "@openzeppelin/contracts/token/ERC20/ERC20.sol"; contract MorpheusToken is ERC20 { // set the initial supply to 42 million like in whitepaper - uint256 constant initialSupply = 42_000_000 * (10 ** 18); + uint256 public constant INITIAL_SUPPLUY = 42_000_000 ether; constructor() ERC20("Morpheus dev", "MOR") { - _mint(_msgSender(), initialSupply); + _mint(_msgSender(), INITIAL_SUPPLUY); } } diff --git a/smart-contracts/contracts/tokens/LumerinToken.sol b/smart-contracts/contracts/tokens/LumerinToken.sol index ede14ac5..e08234f4 100644 --- a/smart-contracts/contracts/tokens/LumerinToken.sol +++ b/smart-contracts/contracts/tokens/LumerinToken.sol @@ -4,10 +4,10 @@ pragma solidity ^0.8.24; import {ERC20} from "@openzeppelin/contracts/token/ERC20/ERC20.sol"; contract LumerinToken is ERC20 { - uint256 constant initialSupply = 1_000_000_000 * (10 ** 8); + uint256 public constant INITIAL_SUPPLUY = 1_000_000_000 * (10 ** 8); constructor(string memory name_, string memory symbol_) ERC20(name_, symbol_) { - _mint(_msgSender(), initialSupply); + _mint(_msgSender(), INITIAL_SUPPLUY); } function decimals() public pure override returns (uint8) { diff --git a/smart-contracts/hardhat.config.ts b/smart-contracts/hardhat.config.ts index 41c1d579..a040ce23 100644 --- a/smart-contracts/hardhat.config.ts +++ b/smart-contracts/hardhat.config.ts @@ -25,7 +25,7 @@ function forceTypechain() { const config: HardhatUserConfig = { networks: { hardhat: { - initialDate: '2024-07-15T01:00:00.000Z', + initialDate: '1970-01-01T00:00:00Z', gas: 'auto', // required for tests where two transactions should be mined in the same block // loggingEnabled: true, // mining: { diff --git a/smart-contracts/package.json b/smart-contracts/package.json index c02a5dc1..788f380c 100644 --- a/smart-contracts/package.json +++ b/smart-contracts/package.json @@ -5,7 +5,7 @@ "license": "MIT", "scripts": { "compile": "hardhat compile", - "test": "hardhat test", + "test": "npm run generate-types && hardhat test", "test:gas": "REPORT_GAS=true hardhat test", "coverage": "hardhat coverage --solcoverjs ./.solcover.ts", "report": "open ./coverage/index.html", diff --git a/smart-contracts/test/diamond/LumerinDiamond.test.ts b/smart-contracts/test/diamond/LumerinDiamond.test.ts index 1e9c804c..a4b8e134 100644 --- a/smart-contracts/test/diamond/LumerinDiamond.test.ts +++ b/smart-contracts/test/diamond/LumerinDiamond.test.ts @@ -3,6 +3,7 @@ import { SignerWithAddress } from '@nomicfoundation/hardhat-ethers/signers'; import { expect } from 'chai'; import { ethers } from 'hardhat'; +import { deployLumerinDiamond } from '../helpers/deployers'; import { Reverter } from '../helpers/reverter'; describe('LumerinDiamond', () => { @@ -12,30 +13,27 @@ describe('LumerinDiamond', () => { let diamond: LumerinDiamond; - before('setup', async () => { + before(async () => { [OWNER] = await ethers.getSigners(); - const [LumerinDiamond] = await Promise.all([ethers.getContractFactory('LumerinDiamond')]); - - [diamond] = await Promise.all([LumerinDiamond.deploy()]); - - await diamond.__LumerinDiamond_init(); + diamond = await deployLumerinDiamond(); await reverter.snapshot(); }); afterEach(reverter.revert); - describe('Diamond functionality', () => { - describe('#__LumerinDiamond_init', () => { - it('should set correct data after creation', async () => { - expect(await diamond.owner()).to.eq(await OWNER.getAddress()); - }); - it('should revert if try to call init function twice', async () => { - const reason = 'Initializable: contract is already initialized'; - - await expect(diamond.__LumerinDiamond_init()).to.be.rejectedWith(reason); - }); + describe('#__LumerinDiamond_init', () => { + it('should set correct data after creation', async () => { + expect(await diamond.owner()).to.eq(await OWNER.getAddress()); + }); + it('should revert if try to call init function twice', async () => { + await expect(diamond.__LumerinDiamond_init()).to.be.rejectedWith( + 'Initializable: contract is already initialized', + ); }); }); }); + +// npx hardhat test "test/diamond/LumerinDiamond.test.ts" +// npx hardhat coverage --solcoverjs ./.solcover.ts --testfiles "test/diamond/LumerinDiamond.test.ts" diff --git a/smart-contracts/test/diamond/Marketplace.test.ts b/smart-contracts/test/diamond/Marketplace.test.ts deleted file mode 100644 index 5016f1fe..00000000 --- a/smart-contracts/test/diamond/Marketplace.test.ts +++ /dev/null @@ -1,476 +0,0 @@ -import { - IBidStorage, - IMarketplace__factory, - IModelRegistry__factory, - IModelStorage, - IProviderRegistry__factory, - IProviderStorage, - ISessionRouter__factory, - LumerinDiamond, - Marketplace, - ModelRegistry, - MorpheusToken, - ProviderRegistry, - SessionRouter, -} from '@ethers-v6'; -import { SignerWithAddress } from '@nomicfoundation/hardhat-ethers/signers'; -import { expect } from 'chai'; -import { Addressable, Fragment, resolveAddress } from 'ethers'; -import { ethers } from 'hardhat'; - -import { FacetAction } from '../helpers/enums'; -import { getDefaultPools } from '../helpers/pool-helper'; -import { Reverter } from '../helpers/reverter'; - -import { getHex, randomBytes32, wei } from '@/scripts/utils/utils'; -import { getCurrentBlockTime } from '@/utils/block-helper'; -import { HOUR, YEAR } from '@/utils/time'; - -describe('Marketplace', () => { - const reverter = new Reverter(); - - let OWNER: SignerWithAddress; - let SECOND: SignerWithAddress; - let THIRD: SignerWithAddress; - let PROVIDER: SignerWithAddress; - - let diamond: LumerinDiamond; - let marketplace: Marketplace; - let modelRegistry: ModelRegistry; - let providerRegistry: ProviderRegistry; - let sessionRouter: SessionRouter; - - let MOR: MorpheusToken; - - async function deployProvider(): Promise< - IProviderStorage.ProviderStruct & { - address: Addressable; - } - > { - const expectedProvider = { - endpoint: 'localhost:3334', - stake: wei(100), - createdAt: 0n, - limitPeriodEnd: 0n, - limitPeriodEarned: 0n, - isDeleted: false, - address: PROVIDER, - }; - - await MOR.transfer(PROVIDER, expectedProvider.stake * 100n); - await MOR.connect(PROVIDER).approve(sessionRouter, expectedProvider.stake); - - await providerRegistry - .connect(PROVIDER) - .providerRegister(expectedProvider.address, expectedProvider.stake, expectedProvider.endpoint); - expectedProvider.createdAt = await getCurrentBlockTime(); - expectedProvider.limitPeriodEnd = expectedProvider.createdAt + YEAR; - - return expectedProvider; - } - - async function deployModel(): Promise< - IModelStorage.ModelStruct & { - modelId: string; - } - > { - const expectedModel = { - modelId: randomBytes32(), - ipfsCID: getHex(Buffer.from('ipfs://ipfsaddress')), - fee: 100, - stake: 100, - owner: OWNER, - name: 'Llama 2.0', - tags: ['llama', 'animal', 'cute'], - createdAt: 0n, - isDeleted: false, - }; - - await MOR.approve(modelRegistry, expectedModel.stake); - - await modelRegistry.modelRegister( - expectedModel.modelId, - expectedModel.ipfsCID, - expectedModel.fee, - expectedModel.stake, - expectedModel.owner, - expectedModel.name, - expectedModel.tags, - ); - expectedModel.createdAt = await getCurrentBlockTime(); - - return expectedModel; - } - - async function deployBid(model: any): Promise< - IBidStorage.BidStruct & { - id: string; - modelId: string; - } - > { - let bid = { - id: '', - modelId: model.modelId, - pricePerSecond: wei(0.0001), - nonce: 0, - createdAt: 0n, - deletedAt: 0, - provider: PROVIDER, - }; - - await MOR.approve(modelRegistry, 10000n * 10n ** 18n); - - bid.id = await marketplace.connect(PROVIDER).postModelBid.staticCall(bid.provider, bid.modelId, bid.pricePerSecond); - await marketplace.connect(PROVIDER).postModelBid(bid.provider, bid.modelId, bid.pricePerSecond); - - bid.createdAt = await getCurrentBlockTime(); - - // generating data for sample session - const durationSeconds = HOUR; - const totalCost = bid.pricePerSecond * durationSeconds; - const totalSupply = await sessionRouter.totalMORSupply(await getCurrentBlockTime()); - const todaysBudget = await sessionRouter.getTodaysBudget(await getCurrentBlockTime()); - - const expectedSession = { - durationSeconds, - totalCost, - pricePerSecond: bid.pricePerSecond, - user: SECOND, - provider: bid.provider, - modelId: bid.modelId, - bidId: bid.id, - stake: (totalCost * totalSupply) / todaysBudget, - }; - - await MOR.transfer(SECOND, expectedSession.stake); - await MOR.connect(SECOND).approve(modelRegistry, expectedSession.stake); - - return bid; - } - before('setup', async () => { - [OWNER, SECOND, THIRD, PROVIDER] = await ethers.getSigners(); - - const LinearDistributionIntervalDecrease = await ethers.getContractFactory('LinearDistributionIntervalDecrease'); - const linearDistributionIntervalDecrease = await LinearDistributionIntervalDecrease.deploy(); - - const [LumerinDiamond, Marketplace, ModelRegistry, ProviderRegistry, SessionRouter, MorpheusToken] = - await Promise.all([ - ethers.getContractFactory('LumerinDiamond'), - ethers.getContractFactory('Marketplace'), - ethers.getContractFactory('ModelRegistry'), - ethers.getContractFactory('ProviderRegistry'), - ethers.getContractFactory('SessionRouter', { - libraries: { - LinearDistributionIntervalDecrease: linearDistributionIntervalDecrease, - }, - }), - ethers.getContractFactory('MorpheusToken'), - ]); - - [diamond, marketplace, modelRegistry, providerRegistry, sessionRouter, MOR] = await Promise.all([ - LumerinDiamond.deploy(), - Marketplace.deploy(), - ModelRegistry.deploy(), - ProviderRegistry.deploy(), - SessionRouter.deploy(), - MorpheusToken.deploy(), - ]); - - await diamond.__LumerinDiamond_init(); - - await diamond['diamondCut((address,uint8,bytes4[])[])']([ - { - facetAddress: marketplace, - action: FacetAction.Add, - functionSelectors: IMarketplace__factory.createInterface() - .fragments.filter(Fragment.isFunction) - .map((f) => f.selector), - }, - { - facetAddress: providerRegistry, - action: FacetAction.Add, - functionSelectors: IProviderRegistry__factory.createInterface() - .fragments.filter(Fragment.isFunction) - .map((f) => f.selector), - }, - { - facetAddress: sessionRouter, - action: FacetAction.Add, - functionSelectors: ISessionRouter__factory.createInterface() - .fragments.filter(Fragment.isFunction) - .map((f) => f.selector), - }, - { - facetAddress: modelRegistry, - action: FacetAction.Add, - functionSelectors: IModelRegistry__factory.createInterface() - .fragments.filter(Fragment.isFunction) - .map((f) => f.selector), - }, - ]); - - marketplace = marketplace.attach(diamond) as Marketplace; - providerRegistry = providerRegistry.attach(diamond) as ProviderRegistry; - modelRegistry = modelRegistry.attach(diamond) as ModelRegistry; - sessionRouter = sessionRouter.attach(diamond) as SessionRouter; - - await marketplace.__Marketplace_init(MOR); - await modelRegistry.__ModelRegistry_init(); - await providerRegistry.__ProviderRegistry_init(); - - await sessionRouter.__SessionRouter_init(OWNER, getDefaultPools()); - - await reverter.snapshot(); - }); - - afterEach(reverter.revert); - - describe('Diamond functionality', () => { - describe('#__Marketplace_init', () => { - it('should set correct data after creation', async () => { - expect(await marketplace.getToken()).to.eq(await MOR.getAddress()); - }); - it('should revert if try to call init function twice', async () => { - const reason = 'Initializable: contract is already initialized'; - - await expect(marketplace.__Marketplace_init(MOR)).to.be.rejectedWith(reason); - }); - }); - }); - - describe('bid actions', () => { - let provider: IProviderStorage.ProviderStruct; - let model: IModelStorage.ModelStruct & { - modelId: string; - }; - let bid: IBidStorage.BidStruct & { - id: string; - modelId: string; - }; - - beforeEach(async () => { - provider = await deployProvider(); - model = await deployModel(); - bid = await deployBid(model); - }); - - it('Should create a bid and query by id', async () => { - const data = await marketplace.bids(bid.id); - - expect(data).to.be.deep.equal([ - await resolveAddress(bid.provider), - bid.modelId, - bid.pricePerSecond, - bid.nonce, - bid.createdAt, - bid.deletedAt, - ]); - }); - - it("Should error if provider doesn't exist", async () => { - await expect( - marketplace.connect(SECOND).postModelBid(SECOND, bid.modelId, bid.pricePerSecond), - ).to.be.revertedWithCustomError(marketplace, 'ProviderNotFound'); - }); - - it("Should error if model doesn't exist", async () => { - const unknownModel = randomBytes32(); - - await expect( - marketplace.connect(PROVIDER).postModelBid(bid.provider, unknownModel, bid.pricePerSecond), - ).to.be.revertedWithCustomError(marketplace, 'ModelNotFound'); - }); - - it('should error if caller is not an owner or provider', async () => { - await expect( - marketplace.connect(SECOND).postModelBid(bid.provider, bid.modelId, bid.pricePerSecond), - ).to.be.revertedWithCustomError(marketplace, 'NotOwnerOrProvider'); - }); - - it('Should create second bid', async () => { - // create new bid with same provider and modelId - await marketplace.connect(PROVIDER).postModelBid(bid.provider, bid.modelId, bid.pricePerSecond); - const timestamp = await getCurrentBlockTime(); - - // check indexes are updated - const newBids1 = await marketplace.providerActiveBids(bid.provider, 0, 10); - const newBids2 = await marketplace.modelActiveBids(bid.modelId, 0, 10); - - expect(newBids1).to.be.deep.equal(newBids2); - expect(await marketplace.bids(newBids1[0])).to.be.deep.equal([ - await resolveAddress(bid.provider), - bid.modelId, - bid.pricePerSecond, - BigInt(bid.nonce) + 1n, - timestamp, - bid.deletedAt, - ]); - - // check old bid is deleted - const oldBid = await marketplace.bids(bid.id); - expect(oldBid).to.be.deep.equal([ - await resolveAddress(bid.provider), - bid.modelId, - bid.pricePerSecond, - bid.nonce, - bid.createdAt, - timestamp, - ]); - - // check old bid is still queried - const oldBids1 = await marketplace.providerBids(bid.provider, 0, 100); - const oldBids2 = await marketplace.modelBids(bid.modelId, 0, 100); - expect(oldBids1).to.be.deep.equal(oldBids2); - expect(oldBids1.length).to.be.equal(2); - expect(await marketplace.bids(oldBids1[0])).to.be.deep.equal([ - await resolveAddress(bid.provider), - bid.modelId, - bid.pricePerSecond, - bid.nonce, - bid.createdAt, - timestamp, - ]); - }); - - it('Should query by provider', async () => { - const activeBidIds = await marketplace.providerActiveBids(bid.provider, 0, 10); - - expect(activeBidIds.length).to.equal(1); - expect(activeBidIds[0]).to.equal(bid.id); - expect(await marketplace.bids(activeBidIds[0])).to.deep.equal([ - await resolveAddress(bid.provider), - bid.modelId, - bid.pricePerSecond, - bid.nonce, - bid.createdAt, - bid.deletedAt, - ]); - }); - - describe('delete bid', () => { - it('Should delete a bid', async () => { - // delete bid - await marketplace.connect(PROVIDER).deleteModelBid(bid.id); - - // check indexes are updated - const activeBidIds1 = await marketplace.providerActiveBids(bid.provider, 0, 10); - const activeBidIds2 = await marketplace.modelActiveBids(bid.modelId, 0, 10); - - expect(activeBidIds1.length).to.be.equal(0); - expect(activeBidIds2.length).to.be.equal(0); - - // check bid is deleted - const data = await marketplace.bids(bid.id); - expect(data).to.be.deep.equal([ - await resolveAddress(bid.provider), - bid.modelId, - bid.pricePerSecond, - bid.nonce, - bid.createdAt, - await getCurrentBlockTime(), - ]); - }); - - it("Should error if bid doesn't exist", async () => { - const unknownBid = randomBytes32(); - - await expect(marketplace.connect(PROVIDER).deleteModelBid(unknownBid)).to.be.revertedWithCustomError( - marketplace, - 'ActiveBidNotFound', - ); - }); - - it('Should error if not owner', async () => { - await expect(marketplace.connect(THIRD).deleteModelBid(bid.id)).to.be.revertedWithCustomError( - marketplace, - 'NotOwnerOrProvider', - ); - }); - - it('Should allow bid owner to delete bid', async () => { - // delete bid - await marketplace.connect(PROVIDER).deleteModelBid(bid.id); - }); - - it('Should allow contract owner to delete bid', async () => { - // delete bid - await marketplace.deleteModelBid(bid.id); - }); - - it('Should allow to create bid after it was deleted [H-1]', async () => { - // delete bid - await marketplace.connect(PROVIDER).deleteModelBid(bid.id); - - // create new bid with same provider and modelId - await marketplace.connect(PROVIDER).postModelBid(bid.provider, bid.modelId, bid.pricePerSecond); - }); - }); - - describe('bid fee', () => { - it('should set bid fee', async () => { - const newFee = 100; - await marketplace.setBidFee(newFee); - - const modelBidFee = await marketplace.getBidFee(); - expect(modelBidFee).to.be.equal(newFee); - }); - - it('should collect bid fee', async () => { - const newFee = 100; - await marketplace.setBidFee(newFee); - await MOR.transfer(bid.provider, 100); - - // check balance before - const balanceBefore = await MOR.balanceOf(marketplace); - // add bid - await MOR.connect(PROVIDER).approve(marketplace, Number(bid.pricePerSecond) + newFee); - await marketplace.connect(PROVIDER).postModelBid(bid.provider, bid.modelId, bid.pricePerSecond); - // check balance after - const balanceAfter = await MOR.balanceOf(marketplace); - expect(balanceAfter - balanceBefore).to.be.equal(newFee); - }); - - it('should allow withdrawal by owner', async () => { - const newFee = 100; - await marketplace.setBidFee(newFee); - await MOR.transfer(bid.provider, 100); - // add bid - await MOR.connect(PROVIDER).approve(marketplace, Number(bid.pricePerSecond) + newFee); - await marketplace.connect(PROVIDER).postModelBid(bid.provider, bid.modelId, bid.pricePerSecond); - // check balance after - const balanceBefore = await MOR.balanceOf(OWNER); - await marketplace.withdraw(OWNER, newFee); - const balanceAfter = await MOR.balanceOf(OWNER); - expect(balanceAfter - balanceBefore).to.be.equal(newFee); - }); - - it('should not allow withdrawal by any other account except owner', async () => { - const newFee = 100; - await marketplace.setBidFee(newFee); - await MOR.transfer(bid.provider, 100); - // add bid - await MOR.connect(PROVIDER).approve(marketplace, Number(bid.pricePerSecond) + newFee); - await marketplace.connect(PROVIDER).postModelBid(bid.provider, bid.modelId, bid.pricePerSecond); - // check balance after - await expect(marketplace.connect(PROVIDER).withdraw(bid.provider, newFee)).to.be.revertedWithCustomError( - diamond, - 'OwnableUnauthorizedAccount', - ); - }); - - it('should not allow withdrawal if not enough balance', async () => { - await expect(marketplace.withdraw(OWNER, 100000000)).to.be.revertedWithCustomError( - marketplace, - 'NotEnoughBalance', - ); - }); - - it('should revert if caller is not an owner', async () => { - await expect(marketplace.connect(SECOND).setBidFee(1)).to.be.revertedWithCustomError( - diamond, - 'OwnableUnauthorizedAccount', - ); - }); - }); - }); -}); diff --git a/smart-contracts/test/diamond/ModelRegistry.test.ts b/smart-contracts/test/diamond/ModelRegistry.test.ts deleted file mode 100644 index 5a43b624..00000000 --- a/smart-contracts/test/diamond/ModelRegistry.test.ts +++ /dev/null @@ -1,508 +0,0 @@ -import { - IBidStorage, - IMarketplace__factory, - IModelRegistry__factory, - IModelStorage, - IProviderRegistry__factory, - IProviderStorage, - ISessionRouter__factory, - LumerinDiamond, - Marketplace, - ModelRegistry, - MorpheusToken, - ProviderRegistry, - SessionRouter, -} from '@ethers-v6'; -import { SignerWithAddress } from '@nomicfoundation/hardhat-ethers/signers'; -import { expect } from 'chai'; -import { Addressable, Fragment, resolveAddress } from 'ethers'; -import { ethers } from 'hardhat'; - -import { FacetAction } from '../helpers/enums'; -import { getDefaultPools } from '../helpers/pool-helper'; -import { Reverter } from '../helpers/reverter'; - -import { getHex, randomBytes32, wei } from '@/scripts/utils/utils'; -import { getCurrentBlockTime } from '@/utils/block-helper'; -import { HOUR, YEAR } from '@/utils/time'; - -describe('Model registry', () => { - const reverter = new Reverter(); - - let OWNER: SignerWithAddress; - let SECOND: SignerWithAddress; - let THIRD: SignerWithAddress; - let PROVIDER: SignerWithAddress; - - let diamond: LumerinDiamond; - let marketplace: Marketplace; - let modelRegistry: ModelRegistry; - let providerRegistry: ProviderRegistry; - let sessionRouter: SessionRouter; - - let MOR: MorpheusToken; - - async function deployProvider(): Promise< - IProviderStorage.ProviderStruct & { - address: Addressable; - } - > { - const expectedProvider = { - endpoint: 'localhost:3334', - stake: wei(100), - createdAt: 0n, - limitPeriodEnd: 0n, - limitPeriodEarned: 0n, - isDeleted: false, - address: PROVIDER, - }; - - await MOR.transfer(PROVIDER, expectedProvider.stake * 100n); - await MOR.connect(PROVIDER).approve(sessionRouter, expectedProvider.stake); - - await providerRegistry - .connect(PROVIDER) - .providerRegister(expectedProvider.address, expectedProvider.stake, expectedProvider.endpoint); - expectedProvider.createdAt = await getCurrentBlockTime(); - expectedProvider.limitPeriodEnd = expectedProvider.createdAt + YEAR; - - return expectedProvider; - } - - async function deployModel(): Promise< - IModelStorage.ModelStruct & { - modelId: string; - } - > { - const model = { - modelId: randomBytes32(), - ipfsCID: getHex(Buffer.from('ipfs://ipfsaddress')), - fee: 100, - stake: 100, - owner: OWNER, - name: 'Llama 2.0', - tags: ['llama', 'animal', 'cute'], - createdAt: 0n, - isDeleted: false, - }; - - await MOR.approve(modelRegistry, model.stake); - - await modelRegistry.modelRegister( - model.modelId, - model.ipfsCID, - model.fee, - model.stake, - model.owner, - model.name, - model.tags, - ); - model.createdAt = await getCurrentBlockTime(); - - return model; - } - - async function deployBid(model: any): Promise< - IBidStorage.BidStruct & { - id: string; - modelId: string; - } - > { - let bid = { - id: '', - modelId: model.modelId, - pricePerSecond: wei(0.0001), - nonce: 0, - createdAt: 0n, - deletedAt: 0, - provider: PROVIDER, - }; - - await MOR.approve(modelRegistry, 10000n * 10n ** 18n); - - bid.id = await marketplace.connect(PROVIDER).postModelBid.staticCall(bid.provider, bid.modelId, bid.pricePerSecond); - await marketplace.connect(PROVIDER).postModelBid(bid.provider, bid.modelId, bid.pricePerSecond); - - bid.createdAt = await getCurrentBlockTime(); - - // generating data for sample session - const durationSeconds = HOUR; - const totalCost = bid.pricePerSecond * durationSeconds; - const totalSupply = await sessionRouter.totalMORSupply(await getCurrentBlockTime()); - const todaysBudget = await sessionRouter.getTodaysBudget(await getCurrentBlockTime()); - - const expectedSession = { - durationSeconds, - totalCost, - pricePerSecond: bid.pricePerSecond, - user: SECOND, - provider: bid.provider, - modelId: bid.modelId, - bidId: bid.id, - stake: (totalCost * totalSupply) / todaysBudget, - }; - - await MOR.transfer(SECOND, expectedSession.stake); - await MOR.connect(SECOND).approve(modelRegistry, expectedSession.stake); - - return bid; - } - - before('setup', async () => { - [OWNER, SECOND, THIRD, PROVIDER] = await ethers.getSigners(); - - const LinearDistributionIntervalDecrease = await ethers.getContractFactory('LinearDistributionIntervalDecrease'); - const linearDistributionIntervalDecrease = await LinearDistributionIntervalDecrease.deploy(); - - const [LumerinDiamond, Marketplace, ModelRegistry, ProviderRegistry, SessionRouter, MorpheusToken] = - await Promise.all([ - ethers.getContractFactory('LumerinDiamond'), - ethers.getContractFactory('Marketplace'), - ethers.getContractFactory('ModelRegistry'), - ethers.getContractFactory('ProviderRegistry'), - ethers.getContractFactory('SessionRouter', { - libraries: { - LinearDistributionIntervalDecrease: linearDistributionIntervalDecrease, - }, - }), - ethers.getContractFactory('MorpheusToken'), - ]); - - [diamond, marketplace, modelRegistry, providerRegistry, sessionRouter, MOR] = await Promise.all([ - LumerinDiamond.deploy(), - Marketplace.deploy(), - ModelRegistry.deploy(), - ProviderRegistry.deploy(), - SessionRouter.deploy(), - MorpheusToken.deploy(), - ]); - - await diamond.__LumerinDiamond_init(); - - await diamond['diamondCut((address,uint8,bytes4[])[])']([ - { - facetAddress: marketplace, - action: FacetAction.Add, - functionSelectors: IMarketplace__factory.createInterface() - .fragments.filter(Fragment.isFunction) - .map((f) => f.selector), - }, - { - facetAddress: providerRegistry, - action: FacetAction.Add, - functionSelectors: IProviderRegistry__factory.createInterface() - .fragments.filter(Fragment.isFunction) - .map((f) => f.selector), - }, - { - facetAddress: sessionRouter, - action: FacetAction.Add, - functionSelectors: ISessionRouter__factory.createInterface() - .fragments.filter(Fragment.isFunction) - .map((f) => f.selector), - }, - { - facetAddress: modelRegistry, - action: FacetAction.Add, - functionSelectors: IModelRegistry__factory.createInterface() - .fragments.filter(Fragment.isFunction) - .map((f) => f.selector), - }, - ]); - - marketplace = marketplace.attach(diamond.target) as Marketplace; - providerRegistry = providerRegistry.attach(diamond.target) as ProviderRegistry; - modelRegistry = modelRegistry.attach(diamond.target) as ModelRegistry; - sessionRouter = sessionRouter.attach(diamond.target) as SessionRouter; - - await marketplace.__Marketplace_init(MOR); - await modelRegistry.__ModelRegistry_init(); - await providerRegistry.__ProviderRegistry_init(); - - await sessionRouter.__SessionRouter_init(OWNER, getDefaultPools()); - - await reverter.snapshot(); - }); - - afterEach(reverter.revert); - - describe('Diamond functionality', () => { - describe('#__ModelRegistry_init', () => { - it('should revert if try to call init function twice', async () => { - const reason = 'Initializable: contract is already initialized'; - - await expect(modelRegistry.__ModelRegistry_init()).to.be.rejectedWith(reason); - }); - }); - }); - - describe('Actions', () => { - let provider: IProviderStorage.ProviderStruct; - let model: IModelStorage.ModelStruct & { - modelId: string; - }; - let bid: IBidStorage.BidStruct & { - id: string; - modelId: string; - }; - - beforeEach(async () => { - provider = await deployProvider(); - model = await deployModel(); - bid = await deployBid(model); - }); - - it('Should register', async () => { - const data = await modelRegistry.getModel(model.modelId); - - expect(await modelRegistry.models(0)).eq(model.modelId); - expect(data).deep.equal([ - model.ipfsCID, - model.fee, - model.stake, - await resolveAddress(model.owner), - model.name, - model.tags, - model.createdAt, - model.isDeleted, - ]); - }); - - it('Should error when registering with insufficient stake', async () => { - const minStake = 100n; - await modelRegistry.setModelMinimumStake(minStake); - - await expect( - modelRegistry.modelRegister(randomBytes32(), randomBytes32(), 0n, 0n, OWNER, 'a', []), - ).revertedWithCustomError(modelRegistry, 'StakeTooLow'); - }); - - it('Should error when registering with insufficient allowance', async () => { - await expect( - modelRegistry.connect(THIRD).modelRegister(randomBytes32(), randomBytes32(), 0n, 100n, THIRD, 'a', []), - ).to.rejectedWith('ERC20: insufficient allowance'); - }); - - it('Should error when register account doesnt match sender account', async () => { - await MOR.approve(modelRegistry, 100n); - - await expect( - modelRegistry.connect(THIRD).modelRegister(randomBytes32(), randomBytes32(), 0n, 100n, SECOND, 'a', []), - ).to.revertedWithCustomError(modelRegistry, 'NotOwnerOrModelOwner'); - }); - - it('Should deregister by owner', async () => { - await marketplace.connect(PROVIDER).deleteModelBid(bid.id); - - await modelRegistry.modelDeregister(model.modelId); - - expect((await modelRegistry.getModel(model.modelId)).isDeleted).to.equal(true); - expect(await modelRegistry.models(0n)).equals(model.modelId); - }); - - it('Should error if model not known by admin', async () => { - await expect(modelRegistry.modelDeregister(randomBytes32())).to.revertedWithCustomError( - modelRegistry, - 'ModelNotFound', - ); - }); - - it('Should error if caller is not owner or model owner', async () => { - await expect(modelRegistry.connect(SECOND).modelDeregister(model.modelId)).to.revertedWithCustomError( - modelRegistry, - 'NotOwnerOrModelOwner', - ); - }); - - it('Should return stake on deregister', async () => { - await marketplace.connect(PROVIDER).deleteModelBid(bid.id); - - const balanceBefore = await MOR.balanceOf(model.owner); - await modelRegistry.modelDeregister(model.modelId); - const balanceAfter = await MOR.balanceOf(model.owner); - - expect(balanceAfter - balanceBefore).eq(model.stake); - }); - - it('should error when deregistering a model that has bids', async () => { - // try deregistering model - await expect(modelRegistry.modelDeregister(model.modelId)).to.revertedWithCustomError( - modelRegistry, - 'ModelHasActiveBids', - ); - - // remove bid - await marketplace.connect(PROVIDER).deleteModelBid(bid.id); - - // deregister model - await modelRegistry.modelDeregister(model.modelId); - }); - - it('Should update existing model', async () => { - const updates = { - ipfsCID: getHex(Buffer.from('ipfs://new-ipfsaddress')), - fee: BigInt(model.fee) * 2n, - addStake: BigInt(model.stake) * 2n, - owner: PROVIDER, - name: 'Llama 3.0', - tags: ['llama', 'smart', 'angry'], - }; - await MOR.approve(modelRegistry, updates.addStake); - - await modelRegistry.modelRegister( - model.modelId, - updates.ipfsCID, - updates.fee, - updates.addStake, - updates.owner, - updates.name, - updates.tags, - ); - const providerData = await modelRegistry.getModel(model.modelId); - - expect(providerData).deep.equal([ - updates.ipfsCID, - updates.fee, - BigInt(model.stake) + updates.addStake, - await resolveAddress(updates.owner), - updates.name, - updates.tags, - model.createdAt, - model.isDeleted, - ]); - }); - - it('Should emit event on update', async () => { - const updates = { - ipfsCID: getHex(Buffer.from('ipfs://new-ipfsaddress')), - fee: BigInt(model.fee) * 2n, - addStake: BigInt(model.stake) * 2n, - owner: PROVIDER, - name: 'Llama 3.0', - tags: ['llama', 'smart', 'angry'], - }; - - await MOR.approve(modelRegistry, updates.addStake); - - await expect( - modelRegistry.modelRegister( - model.modelId, - updates.ipfsCID, - updates.fee, - updates.addStake, - updates.owner, - updates.name, - updates.tags, - ), - ).to.emit(modelRegistry, 'ModelRegisteredUpdated'); - }); - - it('should reregister model', async () => { - await marketplace.connect(PROVIDER).deleteModelBid(bid.id); - - // check indexes - expect(await modelRegistry.models(0)).eq(model.modelId); - - // deregister - await modelRegistry.modelDeregister(model.modelId); - - // check indexes - expect(await modelRegistry.models(0)).eq(model.modelId); - - // reregister - const modelId = model.modelId; - const model2 = { - ipfsCID: randomBytes32(), - fee: 100n, - stake: 100n, - owner: await resolveAddress(OWNER), - name: 'model2', - tags: ['model', '2'], - createdAt: model.createdAt, - }; - await MOR.transfer(OWNER, model2.stake); - await MOR.approve(modelRegistry, model2.stake); - await modelRegistry.modelRegister( - modelId, - model2.ipfsCID, - model2.fee, - model2.stake, - model2.owner, - model2.name, - model2.tags, - ); - // check indexes - expect(await modelRegistry.models(0)).eq(modelId); - expect(await modelRegistry.getModel(modelId)).deep.equal([ - model2.ipfsCID, - model2.fee, - model2.stake, - model2.owner, - model2.name, - model2.tags, - model2.createdAt, - false, - ]); - }); - - it('Should error if reregister model by caller is not owner or model owner', async () => { - await expect( - modelRegistry - .connect(SECOND) - .modelRegister(model.modelId, model.ipfsCID, model.fee, model.stake, model.owner, model.name, model.tags), - ).to.revertedWithCustomError(modelRegistry, 'NotOwnerOrModelOwner'); - }); - - describe('Getters', () => { - it('Should get by address', async () => { - const providerData = await modelRegistry.getModel(model.modelId); - expect(providerData).deep.equal([ - model.ipfsCID, - model.fee, - model.stake, - await resolveAddress(model.owner), - model.name, - model.tags, - model.createdAt, - model.isDeleted, - ]); - }); - }); - - describe('Min stake', () => { - it('Should set min stake', async () => { - const minStake = 100n; - await expect(modelRegistry.setModelMinimumStake(minStake)) - .to.emit(modelRegistry, 'ModelMinimumStakeSet') - .withArgs(minStake); - - expect(await modelRegistry.modelMinimumStake()).eq(minStake); - }); - it('Should error when not owner is setting min stake', async () => { - await expect(modelRegistry.connect(THIRD).setModelMinimumStake(0)).to.revertedWithCustomError( - diamond, - 'OwnableUnauthorizedAccount', - ); - }); - // it("Should get model stats", async () => { - // const stats = await modelRegistry.modelStats([model.modelId]); - - // expect(stats).deep.equal({ - // count: 0, - // totalDuration: { - // mean: 0n, - // sqSum: 0n, - // }, - // tpsScaled1000: { - // mean: 0n, - // sqSum: 0n, - // }, - // ttftMs: { - // mean: 0n, - // sqSum: 0n, - // }, - // }); - // }); - }); - }); -}); diff --git a/smart-contracts/test/diamond/ProviderRegistry.test.ts b/smart-contracts/test/diamond/ProviderRegistry.test.ts deleted file mode 100644 index 061f3dcd..00000000 --- a/smart-contracts/test/diamond/ProviderRegistry.test.ts +++ /dev/null @@ -1,444 +0,0 @@ -import { - IBidStorage, - IMarketplace__factory, - IModelRegistry__factory, - IModelStorage, - IProviderRegistry__factory, - IProviderStorage, - ISessionRouter__factory, - LumerinDiamond, - Marketplace, - ModelRegistry, - MorpheusToken, - ProviderRegistry, - SessionRouter, -} from '@ethers-v6'; -import { SignerWithAddress } from '@nomicfoundation/hardhat-ethers/signers'; -import { expect } from 'chai'; -import { Addressable, Fragment, resolveAddress } from 'ethers'; -import { ethers } from 'hardhat'; - -import { FacetAction } from '../helpers/enums'; -import { getDefaultPools } from '../helpers/pool-helper'; -import { Reverter } from '../helpers/reverter'; - -import { getHex, randomBytes32, wei } from '@/scripts/utils/utils'; -import { getCurrentBlockTime } from '@/utils/block-helper'; -import { HOUR, YEAR } from '@/utils/time'; - -describe('Provider registry', () => { - const reverter = new Reverter(); - - let OWNER: SignerWithAddress; - let SECOND: SignerWithAddress; - let THIRD: SignerWithAddress; - let PROVIDER: SignerWithAddress; - - let diamond: LumerinDiamond; - let marketplace: Marketplace; - let modelRegistry: ModelRegistry; - let providerRegistry: ProviderRegistry; - let sessionRouter: SessionRouter; - - let MOR: MorpheusToken; - - async function deployProvider(): Promise< - IProviderStorage.ProviderStruct & { - address: Addressable; - } - > { - const provider = { - endpoint: 'localhost:3334', - stake: wei(100), - createdAt: 0n, - limitPeriodEnd: 0n, - limitPeriodEarned: 0n, - isDeleted: false, - address: PROVIDER, - }; - - await MOR.transfer(PROVIDER, provider.stake * 100n); - await MOR.connect(PROVIDER).approve(sessionRouter, provider.stake); - - await providerRegistry.connect(PROVIDER).providerRegister(provider.address, provider.stake, provider.endpoint); - provider.createdAt = await getCurrentBlockTime(); - provider.limitPeriodEnd = provider.createdAt + YEAR; - - return provider; - } - - async function deployModel(): Promise< - IModelStorage.ModelStruct & { - modelId: string; - } - > { - const model = { - modelId: randomBytes32(), - ipfsCID: getHex(Buffer.from('ipfs://ipfsaddress')), - fee: 100, - stake: 100, - owner: OWNER, - name: 'Llama 2.0', - tags: ['llama', 'animal', 'cute'], - createdAt: 0n, - isDeleted: false, - }; - - await MOR.approve(modelRegistry, model.stake); - - await modelRegistry.modelRegister( - model.modelId, - model.ipfsCID, - model.fee, - model.stake, - model.owner, - model.name, - model.tags, - ); - model.createdAt = await getCurrentBlockTime(); - - return model; - } - - async function deployBid(model: any): Promise< - IBidStorage.BidStruct & { - id: string; - modelId: string; - } - > { - let bid = { - id: '', - modelId: model.modelId, - pricePerSecond: wei(0.0001), - nonce: 0, - createdAt: 0n, - deletedAt: 0, - provider: PROVIDER, - }; - - await MOR.approve(modelRegistry, 10000n * 10n ** 18n); - - bid.id = await marketplace.connect(PROVIDER).postModelBid.staticCall(bid.provider, bid.modelId, bid.pricePerSecond); - await marketplace.connect(PROVIDER).postModelBid(bid.provider, bid.modelId, bid.pricePerSecond); - - bid.createdAt = await getCurrentBlockTime(); - - // generating data for sample session - const durationSeconds = HOUR; - const totalCost = bid.pricePerSecond * durationSeconds; - const totalSupply = await sessionRouter.totalMORSupply(await getCurrentBlockTime()); - const todaysBudget = await sessionRouter.getTodaysBudget(await getCurrentBlockTime()); - - const expectedSession = { - durationSeconds, - totalCost, - pricePerSecond: bid.pricePerSecond, - user: SECOND, - provider: bid.provider, - modelId: bid.modelId, - bidId: bid.id, - stake: (totalCost * totalSupply) / todaysBudget, - }; - - await MOR.transfer(SECOND, expectedSession.stake); - await MOR.connect(SECOND).approve(modelRegistry, expectedSession.stake); - - return bid; - } - - before('setup', async () => { - [OWNER, SECOND, THIRD, PROVIDER] = await ethers.getSigners(); - - const LinearDistributionIntervalDecrease = await ethers.getContractFactory('LinearDistributionIntervalDecrease'); - const linearDistributionIntervalDecrease = await LinearDistributionIntervalDecrease.deploy(); - - const [LumerinDiamond, Marketplace, ModelRegistry, ProviderRegistry, SessionRouter, MorpheusToken] = - await Promise.all([ - ethers.getContractFactory('LumerinDiamond'), - ethers.getContractFactory('Marketplace'), - ethers.getContractFactory('ModelRegistry'), - ethers.getContractFactory('ProviderRegistry'), - ethers.getContractFactory('SessionRouter', { - libraries: { - LinearDistributionIntervalDecrease: linearDistributionIntervalDecrease, - }, - }), - ethers.getContractFactory('MorpheusToken'), - ]); - - [diamond, marketplace, modelRegistry, providerRegistry, sessionRouter, MOR] = await Promise.all([ - LumerinDiamond.deploy(), - Marketplace.deploy(), - ModelRegistry.deploy(), - ProviderRegistry.deploy(), - SessionRouter.deploy(), - MorpheusToken.deploy(), - ]); - - await diamond.__LumerinDiamond_init(); - - await diamond['diamondCut((address,uint8,bytes4[])[])']([ - { - facetAddress: marketplace, - action: FacetAction.Add, - functionSelectors: IMarketplace__factory.createInterface() - .fragments.filter(Fragment.isFunction) - .map((f) => f.selector), - }, - { - facetAddress: providerRegistry, - action: FacetAction.Add, - functionSelectors: IProviderRegistry__factory.createInterface() - .fragments.filter(Fragment.isFunction) - .map((f) => f.selector), - }, - { - facetAddress: sessionRouter, - action: FacetAction.Add, - functionSelectors: ISessionRouter__factory.createInterface() - .fragments.filter(Fragment.isFunction) - .map((f) => f.selector), - }, - { - facetAddress: modelRegistry, - action: FacetAction.Add, - functionSelectors: IModelRegistry__factory.createInterface() - .fragments.filter(Fragment.isFunction) - .map((f) => f.selector), - }, - ]); - - marketplace = marketplace.attach(diamond.target) as Marketplace; - providerRegistry = providerRegistry.attach(diamond.target) as ProviderRegistry; - modelRegistry = modelRegistry.attach(diamond.target) as ModelRegistry; - sessionRouter = sessionRouter.attach(diamond.target) as SessionRouter; - - await marketplace.__Marketplace_init(MOR); - await modelRegistry.__ModelRegistry_init(); - await providerRegistry.__ProviderRegistry_init(); - - await sessionRouter.__SessionRouter_init(OWNER, getDefaultPools()); - - await reverter.snapshot(); - }); - - afterEach(reverter.revert); - - describe('Diamond functionality', () => { - describe('#__ProviderRegistry_init', () => { - it('should revert if try to call init function twice', async () => { - const reason = 'Initializable: contract is already initialized'; - - await expect(providerRegistry.__ProviderRegistry_init()).to.be.rejectedWith(reason); - }); - }); - }); - - describe('Actions', () => { - let provider: IProviderStorage.ProviderStruct; - let model: IModelStorage.ModelStruct & { - modelId: string; - }; - let bid: IBidStorage.BidStruct & { - id: string; - modelId: string; - }; - - beforeEach(async () => { - provider = await deployProvider(); - model = await deployModel(); - bid = await deployBid(model); - }); - - it('Should register', async () => { - await providerRegistry.connect(SECOND).providerRegister(SECOND, provider.stake, provider.endpoint); - - const data = await providerRegistry.getProvider(SECOND); - - expect(data).deep.equal([ - provider.endpoint, - provider.stake, - await getCurrentBlockTime(), - (await getCurrentBlockTime()) + YEAR, - provider.limitPeriodEarned, - false, - ]); - }); - - it('Should error when registering with insufficient stake', async () => { - const minStake = 100; - await providerRegistry.providerSetMinStake(minStake); - - await expect(providerRegistry.providerRegister(SECOND, minStake - 1, 'endpoint')).to.be.revertedWithCustomError( - providerRegistry, - 'StakeTooLow', - ); - }); - - it('Should error when registering with insufficient allowance', async () => { - // await catchError(MOR.abi, "ERC20InsufficientAllowance", async () => { - // await providerRegistry.providerRegister([PROVIDER, 100n, "endpoint"]); - // }); - await expect(providerRegistry.connect(THIRD).providerRegister(THIRD, 100n, 'endpoint')).to.be.revertedWith( - 'ERC20: insufficient allowance', - ); - }); - - it('Should error when register account doesnt match sender account', async () => { - await expect( - providerRegistry.connect(PROVIDER).providerRegister(THIRD, 100n, 'endpoint'), - ).to.be.revertedWithCustomError(providerRegistry, 'NotOwnerOrProvider'); - }); - - describe('Deregister', () => { - it('Should deregister by provider', async () => { - await marketplace.connect(PROVIDER).deleteModelBid(bid.id); - - await expect(providerRegistry.connect(PROVIDER).providerDeregister(PROVIDER)) - .to.emit(providerRegistry, 'ProviderDeregistered') - .withArgs(PROVIDER); - - expect((await providerRegistry.getProvider(PROVIDER)).isDeleted).to.equal(true); - }); - - it('Should deregister by admin', async () => { - await marketplace.connect(PROVIDER).deleteModelBid(bid.id); - - await expect(providerRegistry.providerDeregister(PROVIDER)) - .to.emit(providerRegistry, 'ProviderDeregistered') - .withArgs(PROVIDER); - }); - - it('Should return stake on deregister', async () => { - await marketplace.connect(PROVIDER).deleteModelBid(bid.id); - - const balanceBefore = await MOR.balanceOf(PROVIDER); - await providerRegistry.connect(PROVIDER).providerDeregister(PROVIDER); - const balanceAfter = await MOR.balanceOf(PROVIDER); - expect(balanceAfter - balanceBefore).eq(provider.stake); - }); - - it('should error when deregistering a model that has bids', async () => { - // try deregistering model - await expect(providerRegistry.connect(PROVIDER).providerDeregister(PROVIDER)).to.be.revertedWithCustomError( - providerRegistry, - 'ProviderHasActiveBids', - ); - - // remove bid - await marketplace.connect(PROVIDER).deleteModelBid(bid.id); - // deregister model - await providerRegistry.connect(PROVIDER).providerDeregister(PROVIDER); - }); - - it('Should correctly reregister provider', async () => { - await marketplace.connect(PROVIDER).deleteModelBid(bid.id); - - // deregister - await providerRegistry.connect(PROVIDER).providerDeregister(PROVIDER); - // check indexes - const provider2 = { - endpoint: 'new-endpoint-2', - stake: 123n, - limitPeriodEarned: provider.limitPeriodEarned, - limitPeriodEnd: provider.limitPeriodEnd, - createdAt: provider.createdAt, - }; - // register again - await MOR.transfer(PROVIDER, provider2.stake); - await MOR.connect(PROVIDER).approve(providerRegistry, provider2.stake); - await providerRegistry.connect(PROVIDER).providerRegister(PROVIDER, provider2.stake, provider2.endpoint); - // check record - expect(await providerRegistry.getProvider(PROVIDER)).deep.equal([ - provider2.endpoint, - provider2.stake, - provider2.createdAt, - provider2.limitPeriodEnd, - provider2.limitPeriodEarned, - false, - ]); - }); - - it('should error if caller is not an owner or provider', async () => { - await marketplace.connect(PROVIDER).deleteModelBid(bid.id); - - await expect(providerRegistry.connect(SECOND).providerDeregister(PROVIDER)).to.revertedWithCustomError( - providerRegistry, - 'NotOwnerOrProvider', - ); - }); - - it('should error if provider is not exists', async () => { - await marketplace.connect(PROVIDER).deleteModelBid(bid.id); - - await expect(providerRegistry.connect(SECOND).providerDeregister(SECOND)).to.revertedWithCustomError( - providerRegistry, - 'ProviderNotFound', - ); - }); - }); - - it('Should update stake and url', async () => { - const updates = { - addStake: BigInt(provider.stake) * 2n, - endpoint: 'new-endpoint', - }; - await MOR.connect(PROVIDER).approve(providerRegistry, updates.addStake); - - await providerRegistry.connect(PROVIDER).providerRegister(PROVIDER, updates.addStake, updates.endpoint); - - const providerData = await providerRegistry.getProvider(PROVIDER); - expect(providerData).deep.equal([ - updates.endpoint, - BigInt(provider.stake) + updates.addStake, - provider.createdAt, - provider.limitPeriodEnd, - provider.limitPeriodEarned, - provider.isDeleted, - ]); - }); - - it('Should emit event on update', async () => { - const updates = { - addStake: BigInt(provider.stake) * 2n, - endpoint: 'new-endpoint', - }; - await MOR.connect(PROVIDER).approve(providerRegistry, updates.addStake); - await expect(providerRegistry.connect(PROVIDER).providerRegister(PROVIDER, updates.addStake, updates.endpoint)) - .to.emit(providerRegistry, 'ProviderRegisteredUpdated') - .withArgs(PROVIDER); - }); - - describe('Getters', () => { - it('Should get by address', async () => { - const providerData = await providerRegistry.getProvider(PROVIDER); - expect(providerData).deep.equal([ - provider.endpoint, - provider.stake, - provider.createdAt, - provider.limitPeriodEnd, - provider.limitPeriodEarned, - provider.isDeleted, - ]); - }); - }); - - describe('Min stake', () => { - it('Should set min stake', async () => { - const minStake = 100n; - await expect(providerRegistry.providerSetMinStake(minStake)) - .to.emit(providerRegistry, 'ProviderMinStakeUpdated') - .withArgs(minStake); - - expect(await providerRegistry.providerMinimumStake()).eq(minStake); - }); - - it('Should error when not owner is setting min stake', async () => { - await expect(providerRegistry.connect(SECOND).providerSetMinStake(100)).to.be.revertedWithCustomError( - diamond, - 'OwnableUnauthorizedAccount', - ); - }); - }); - }); -}); diff --git a/smart-contracts/test/diamond/SessionRouter/closeSession.test.ts b/smart-contracts/test/diamond/SessionRouter/closeSession.test.ts deleted file mode 100644 index eaf4414f..00000000 --- a/smart-contracts/test/diamond/SessionRouter/closeSession.test.ts +++ /dev/null @@ -1,512 +0,0 @@ -import { - IBidStorage, - IMarketplace__factory, - IModelRegistry__factory, - IModelStorage, - IProviderRegistry__factory, - IProviderStorage, - ISessionRouter__factory, - LumerinDiamond, - Marketplace, - ModelRegistry, - MorpheusToken, - ProviderRegistry, - SessionRouter, -} from '@ethers-v6'; -import { SignerWithAddress } from '@nomicfoundation/hardhat-ethers/signers'; -import { expect } from 'chai'; -import { Addressable, Fragment } from 'ethers'; -import { ethers } from 'hardhat'; - -import { MAX_UINT8 } from '@/scripts/utils/constants'; -import { getHex, randomBytes32, wei } from '@/scripts/utils/utils'; -import { FacetAction } from '@/test/helpers/enums'; -import { getDefaultPools } from '@/test/helpers/pool-helper'; -import { Reverter } from '@/test/helpers/reverter'; -import { getCurrentBlockTime, setTime } from '@/utils/block-helper'; -import { getProviderApproval, getReport } from '@/utils/provider-helper'; -import { DAY, HOUR, YEAR } from '@/utils/time'; - -describe('Session closeout', () => { - const reverter = new Reverter(); - - let OWNER: SignerWithAddress; - let SECOND: SignerWithAddress; - let THIRD: SignerWithAddress; - let PROVIDER: SignerWithAddress; - - let diamond: LumerinDiamond; - let marketplace: Marketplace; - let modelRegistry: ModelRegistry; - let providerRegistry: ProviderRegistry; - let sessionRouter: SessionRouter; - - let MOR: MorpheusToken; - - async function deployProvider(): Promise< - IProviderStorage.ProviderStruct & { - address: Addressable; - } - > { - const provider = { - endpoint: 'localhost:3334', - stake: wei(100), - createdAt: 0n, - limitPeriodEnd: 0n, - limitPeriodEarned: 0n, - isDeleted: false, - address: PROVIDER, - }; - - await MOR.transfer(PROVIDER, provider.stake * 100n); - await MOR.connect(PROVIDER).approve(sessionRouter, provider.stake); - - await providerRegistry.connect(PROVIDER).providerRegister(provider.address, provider.stake, provider.endpoint); - provider.createdAt = await getCurrentBlockTime(); - provider.limitPeriodEnd = provider.createdAt + YEAR; - - return provider; - } - - async function deployModel(): Promise< - IModelStorage.ModelStruct & { - modelId: string; - } - > { - const model = { - modelId: randomBytes32(), - ipfsCID: getHex(Buffer.from('ipfs://ipfsaddress')), - fee: 100, - stake: 100, - owner: OWNER, - name: 'Llama 2.0', - tags: ['llama', 'animal', 'cute'], - createdAt: 0n, - isDeleted: false, - }; - - await MOR.approve(modelRegistry, model.stake); - - await modelRegistry.modelRegister( - model.modelId, - model.ipfsCID, - model.fee, - model.stake, - model.owner, - model.name, - model.tags, - ); - model.createdAt = await getCurrentBlockTime(); - - return model; - } - - async function deployBid(model: any): Promise< - [ - IBidStorage.BidStruct & { - id: string; - modelId: string; - }, - { - durationSeconds: bigint; - totalCost: bigint; - pricePerSecond: bigint; - user: SignerWithAddress; - provider: SignerWithAddress; - modelId: any; - bidId: string; - stake: bigint; - }, - ] - > { - let bid = { - id: '', - modelId: model.modelId, - pricePerSecond: wei(0.0001), - nonce: 0, - createdAt: 0n, - deletedAt: 0, - provider: PROVIDER, - }; - - await MOR.approve(modelRegistry, 10000n * 10n ** 18n); - - bid.id = await marketplace.connect(PROVIDER).postModelBid.staticCall(bid.provider, bid.modelId, bid.pricePerSecond); - await marketplace.connect(PROVIDER).postModelBid(bid.provider, bid.modelId, bid.pricePerSecond); - - bid.createdAt = await getCurrentBlockTime(); - - // generating data for sample session - const durationSeconds = HOUR; - const totalCost = bid.pricePerSecond * durationSeconds; - const totalSupply = await sessionRouter.totalMORSupply(await getCurrentBlockTime()); - const todaysBudget = await sessionRouter.getTodaysBudget(await getCurrentBlockTime()); - - const expectedSession = { - durationSeconds, - totalCost, - pricePerSecond: bid.pricePerSecond, - user: SECOND, - provider: bid.provider, - modelId: bid.modelId, - bidId: bid.id, - stake: (totalCost * totalSupply) / todaysBudget, - }; - - await MOR.transfer(SECOND, expectedSession.stake); - await MOR.connect(SECOND).approve(modelRegistry, expectedSession.stake); - - return [bid, expectedSession]; - } - - before('setup', async () => { - [OWNER, SECOND, THIRD, PROVIDER] = await ethers.getSigners(); - - const LinearDistributionIntervalDecrease = await ethers.getContractFactory('LinearDistributionIntervalDecrease'); - const linearDistributionIntervalDecrease = await LinearDistributionIntervalDecrease.deploy(); - - const [LumerinDiamond, Marketplace, ModelRegistry, ProviderRegistry, SessionRouter, MorpheusToken] = - await Promise.all([ - ethers.getContractFactory('LumerinDiamond'), - ethers.getContractFactory('Marketplace'), - ethers.getContractFactory('ModelRegistry'), - ethers.getContractFactory('ProviderRegistry'), - ethers.getContractFactory('SessionRouter', { - libraries: { - LinearDistributionIntervalDecrease: linearDistributionIntervalDecrease, - }, - }), - ethers.getContractFactory('MorpheusToken'), - ]); - - [diamond, marketplace, modelRegistry, providerRegistry, sessionRouter, MOR] = await Promise.all([ - LumerinDiamond.deploy(), - Marketplace.deploy(), - ModelRegistry.deploy(), - ProviderRegistry.deploy(), - SessionRouter.deploy(), - MorpheusToken.deploy(), - ]); - - await diamond.__LumerinDiamond_init(); - - await diamond['diamondCut((address,uint8,bytes4[])[])']([ - { - facetAddress: marketplace, - action: FacetAction.Add, - functionSelectors: IMarketplace__factory.createInterface() - .fragments.filter(Fragment.isFunction) - .map((f) => f.selector), - }, - { - facetAddress: providerRegistry, - action: FacetAction.Add, - functionSelectors: IProviderRegistry__factory.createInterface() - .fragments.filter(Fragment.isFunction) - .map((f) => f.selector), - }, - { - facetAddress: sessionRouter, - action: FacetAction.Add, - functionSelectors: ISessionRouter__factory.createInterface() - .fragments.filter(Fragment.isFunction) - .map((f) => f.selector), - }, - { - facetAddress: modelRegistry, - action: FacetAction.Add, - functionSelectors: IModelRegistry__factory.createInterface() - .fragments.filter(Fragment.isFunction) - .map((f) => f.selector), - }, - ]); - - marketplace = marketplace.attach(diamond.target) as Marketplace; - providerRegistry = providerRegistry.attach(diamond.target) as ProviderRegistry; - modelRegistry = modelRegistry.attach(diamond.target) as ModelRegistry; - sessionRouter = sessionRouter.attach(diamond.target) as SessionRouter; - - await marketplace.__Marketplace_init(MOR); - await modelRegistry.__ModelRegistry_init(); - await providerRegistry.__ProviderRegistry_init(); - - await sessionRouter.__SessionRouter_init(OWNER, getDefaultPools()); - - await reverter.snapshot(); - }); - - afterEach(reverter.revert); - - describe('Actions', () => { - let provider: IProviderStorage.ProviderStruct; - let model: IModelStorage.ModelStruct & { - modelId: string; - }; - let bid: IBidStorage.BidStruct & { - id: string; - modelId: string; - }; - let session: { - durationSeconds: bigint; - totalCost: bigint; - pricePerSecond: bigint; - user: SignerWithAddress; - provider: SignerWithAddress; - modelId: any; - bidId: string; - stake: bigint; - }; - - beforeEach(async () => { - provider = await deployProvider(); - model = await deployModel(); - [bid, session] = await deployBid(model); - }); - - it('should open short (<1D) session and close after expiration', async () => { - // open session - const { msg, signature } = await getProviderApproval(PROVIDER, await SECOND.getAddress(), session.bidId); - - const sessionId = await sessionRouter.connect(SECOND).openSession.staticCall(session.stake, msg, signature); - await sessionRouter.connect(SECOND).openSession(session.stake, msg, signature); - - await setTime(Number((await getCurrentBlockTime()) + session.durationSeconds * 2n)); - - const userBalanceBefore = await MOR.balanceOf(SECOND); - const providerBalanceBefore = await MOR.balanceOf(PROVIDER); - - // close session - const report = await getReport(PROVIDER, sessionId, 10, 1000); - await sessionRouter.connect(SECOND).closeSession(report.msg, report.signature); - - // verify session is closed without dispute - const sessionData = await sessionRouter.sessions(sessionId); - expect(sessionData.closeoutType).to.equal(0n); - - // verify balances - const userBalanceAfter = await MOR.balanceOf(SECOND); - const providerBalanceAfter = await MOR.balanceOf(PROVIDER); - - const userStakeReturned = userBalanceAfter - userBalanceBefore; - const providerEarned = providerBalanceAfter - providerBalanceBefore; - - const totalPrice = (sessionData.endsAt - sessionData.openedAt) * session.pricePerSecond; - - expect(userStakeReturned).to.closeTo(0, Number(session.pricePerSecond) * 5); - expect(providerEarned).to.closeTo(totalPrice, 1); - }); - - it('should open short (<1D) session and close early', async () => { - // open session - const { msg, signature } = await getProviderApproval(PROVIDER, await SECOND.getAddress(), session.bidId); - const sessionId = await sessionRouter.connect(SECOND).openSession.staticCall(session.stake, msg, signature); - await sessionRouter.connect(SECOND).openSession(session.stake, msg, signature); - - await setTime(Number((await getCurrentBlockTime()) + session.durationSeconds / 2n - 1n)); - - const userBalanceBefore = await MOR.balanceOf(SECOND); - const providerBalanceBefore = await MOR.balanceOf(PROVIDER); - - // close session - const report = await getReport(PROVIDER, sessionId, 10, 1000); - await sessionRouter.connect(SECOND).closeSession(report.msg, report.signature); - - // verify session is closed without dispute - const sessionData = await sessionRouter.sessions(sessionId); - expect(sessionData.closeoutType).to.equal(0n); - - // verify balances - const userBalanceAfter = await MOR.balanceOf(SECOND); - const providerBalanceAfter = await MOR.balanceOf(PROVIDER); - - const userStakeReturned = userBalanceAfter - userBalanceBefore; - const providerEarned = providerBalanceAfter - providerBalanceBefore; - - expect(userStakeReturned).to.closeTo(session.stake / 2n, Number(session.pricePerSecond) * 5); - expect(providerEarned).to.closeTo(session.totalCost / 2n, 1); - }); - - it('should open and close early with user report - dispute', async () => { - // open session - const { msg, signature } = await getProviderApproval(PROVIDER, await SECOND.getAddress(), session.bidId); - const sessionId = await sessionRouter.connect(SECOND).openSession.staticCall(session.stake, msg, signature); - await sessionRouter.connect(SECOND).openSession(session.stake, msg, signature); - - // wait half of the session - await setTime(Number((await getCurrentBlockTime()) + session.durationSeconds / 2n - 1n)); - - const userBalanceBefore = await MOR.balanceOf(SECOND); - const providerBalanceBefore = await MOR.balanceOf(PROVIDER); - - // close session with user signature - const report = await getReport(SECOND, sessionId, 10, 1000); - await sessionRouter.connect(SECOND).closeSession(report.msg, report.signature); - - // verify session is closed with dispute - const sessionData = await sessionRouter.sessions(sessionId); - const totalCost = sessionData.pricePerSecond * (sessionData.closedAt - sessionData.openedAt); - - // verify balances - const userBalanceAfter = await MOR.balanceOf(SECOND); - const providerBalanceAfter = await MOR.balanceOf(PROVIDER); - - const claimableProvider = await sessionRouter.getProviderClaimableBalance(sessionData.id); - - const [userAvail, userHold] = await sessionRouter.withdrawableUserStake(sessionData.user, MAX_UINT8); - - expect(sessionData.closeoutType).to.equal(1n); - expect(providerBalanceAfter - providerBalanceBefore).to.equal(0n); - expect(claimableProvider).to.equal(0n); - expect(session.stake / 2n).to.closeTo(userBalanceAfter - userBalanceBefore, 1); - expect(userAvail).to.equal(0n); - expect(userHold).to.closeTo(userBalanceAfter - userBalanceBefore, 1); - - // verify provider balance after dispute is released - await setTime(Number((await getCurrentBlockTime()) + DAY)); - const claimableProvider2 = await sessionRouter.getProviderClaimableBalance(sessionId); - expect(claimableProvider2).to.equal(totalCost); - - // claim provider balance - await sessionRouter.claimProviderBalance(sessionId, claimableProvider2); - - // verify provider balance after claim - const providerBalanceAfterClaim = await MOR.balanceOf(PROVIDER); - const providerClaimed = providerBalanceAfterClaim - providerBalanceAfter; - expect(providerClaimed).to.equal(totalCost); - }); - - it('should error when not a user trying to close', async () => { - // open session - const { msg, signature } = await getProviderApproval(PROVIDER, await SECOND.getAddress(), session.bidId); - const sessionId = await sessionRouter.connect(SECOND).openSession.staticCall(session.stake, msg, signature); - await sessionRouter.connect(SECOND).openSession(session.stake, msg, signature); - - // wait half of the session - await setTime(Number((await getCurrentBlockTime()) + session.durationSeconds / 2n - 1n)); - - // close session with user signature - const report = await getReport(SECOND, sessionId, 10, 10); - - await expect( - sessionRouter.connect(THIRD).closeSession(report.msg, report.signature), - ).to.be.revertedWithCustomError(sessionRouter, 'NotOwnerOrUser'); - }); - - it('should limit reward by stake amount', async () => { - // expected bid - const expectedBid = { - id: '', - providerAddr: await PROVIDER.getAddress(), - modelId: model.modelId, - pricePerSecond: wei('0.1'), - nonce: 0n, - createdAt: 0n, - deletedAt: 0n, - }; - - // add single bid - const postBidId = await marketplace - .connect(PROVIDER) - .postModelBid.staticCall(expectedBid.providerAddr, expectedBid.modelId, expectedBid.pricePerSecond); - await marketplace - .connect(PROVIDER) - .postModelBid(expectedBid.providerAddr, expectedBid.modelId, expectedBid.pricePerSecond); - - expectedBid.id = postBidId; - expectedBid.createdAt = await getCurrentBlockTime(); - - // calculate data for session opening - const totalCost = BigInt(provider.stake) * 2n; - const durationSeconds = totalCost / expectedBid.pricePerSecond; - const totalSupply = await sessionRouter.totalMORSupply(await getCurrentBlockTime()); - const todaysBudget = await sessionRouter.getTodaysBudget(await getCurrentBlockTime()); - - const expectedSession = { - durationSeconds, - totalCost, - pricePerSecond: expectedBid.pricePerSecond, - user: await SECOND.getAddress(), - provider: expectedBid.providerAddr, - modelId: expectedBid.modelId, - bidId: expectedBid.id, - stake: (totalCost * totalSupply) / todaysBudget, - }; - - // set user balance and approve funds - await MOR.transfer(SECOND, expectedSession.stake); - await MOR.connect(SECOND).approve(modelRegistry, expectedSession.stake); - - // open session - const { msg, signature } = await getProviderApproval(PROVIDER, await SECOND.getAddress(), expectedSession.bidId); - const sessionId = await sessionRouter - .connect(SECOND) - .openSession.staticCall(expectedSession.stake, msg, signature); - await sessionRouter.connect(SECOND).openSession(expectedSession.stake, msg, signature); - - // wait till session ends - await setTime(Number((await getCurrentBlockTime()) + expectedSession.durationSeconds)); - - const providerBalanceBefore = await MOR.balanceOf(PROVIDER); - // close session without dispute - const report = await getReport(PROVIDER, sessionId, 10, 1000); - await sessionRouter.connect(SECOND).closeSession(report.msg, report.signature); - - const providerBalanceAfter = await MOR.balanceOf(PROVIDER); - - const providerEarned = providerBalanceAfter - providerBalanceBefore; - - expect(providerEarned).to.equal(provider.stake); - - // check provider record if earning was updated - const providerRecord = await providerRegistry.getProvider(PROVIDER); - expect(providerRecord.limitPeriodEarned).to.equal(provider.stake); - }); - - it('should error if session is not exists', async () => { - // close session - const report = await getReport(PROVIDER, randomBytes32(), 10, 1000); - - await expect(sessionRouter.connect(SECOND).closeSession(report.msg, report.signature)).to.revertedWithCustomError( - sessionRouter, - 'SessionNotFound', - ); - }); - - it('should error if session is already closed', async () => { - // open session - const { msg, signature } = await getProviderApproval(PROVIDER, await SECOND.getAddress(), session.bidId); - - const sessionId = await sessionRouter.connect(SECOND).openSession.staticCall(session.stake, msg, signature); - await sessionRouter.connect(SECOND).openSession(session.stake, msg, signature); - - await setTime(Number((await getCurrentBlockTime()) + session.durationSeconds * 2n)); - - // close session - const report = await getReport(PROVIDER, sessionId, 10, 1000); - await sessionRouter.connect(SECOND).closeSession(report.msg, report.signature); - - await expect(sessionRouter.connect(SECOND).closeSession(report.msg, report.signature)).to.revertedWithCustomError( - sessionRouter, - 'SessionAlreadyClosed', - ); - }); - - it('should error when approval expired', async () => { - const { msg, signature } = await getProviderApproval(PROVIDER, await SECOND.getAddress(), session.bidId); - // open session - const sessionId = await sessionRouter.connect(SECOND).openSession.staticCall(session.stake, msg, signature); - await sessionRouter.connect(SECOND).openSession(session.stake, msg, signature); - // close session - const report = await getReport(PROVIDER, sessionId, 10, 1000); - const ttl = await sessionRouter.SIGNATURE_TTL(); - await setTime(Number((await getCurrentBlockTime()) + ttl + 1n)); - await expect( - sessionRouter.connect(SECOND).closeSession(report.msg, report.signature), - ).to.be.revertedWithCustomError(sessionRouter, 'SignatureExpired'); - }); - - it('should reset provider limitPeriodEarned after period', async () => {}); - - it('should error with WithdrawableBalanceLimitByStakeReached() if claiming more that stake for a period', async () => {}); - }); -}); diff --git a/smart-contracts/test/diamond/SessionRouter/openSession.test.ts b/smart-contracts/test/diamond/SessionRouter/openSession.test.ts deleted file mode 100644 index 47d75d70..00000000 --- a/smart-contracts/test/diamond/SessionRouter/openSession.test.ts +++ /dev/null @@ -1,612 +0,0 @@ -import { - IBidStorage, - IMarketplace__factory, - IModelRegistry__factory, - IModelStorage, - IProviderRegistry__factory, - IProviderStorage, - ISessionRouter__factory, - LumerinDiamond, - Marketplace, - ModelRegistry, - MorpheusToken, - ProviderRegistry, - SessionRouter, -} from '@ethers-v6'; -import { SignerWithAddress } from '@nomicfoundation/hardhat-ethers/signers'; -import { expect } from 'chai'; -import { Addressable, Fragment, randomBytes, resolveAddress } from 'ethers'; -import { ethers } from 'hardhat'; - -import { getHex, randomBytes32, startOfTheDay, wei } from '@/scripts/utils/utils'; -import { FacetAction } from '@/test/helpers/enums'; -import { getDefaultPools } from '@/test/helpers/pool-helper'; -import { Reverter } from '@/test/helpers/reverter'; -import { getCurrentBlockTime, setTime } from '@/utils/block-helper'; -import { getProviderApproval, getReport } from '@/utils/provider-helper'; -import { DAY, HOUR, YEAR } from '@/utils/time'; - -describe('session actions', () => { - const reverter = new Reverter(); - - let OWNER: SignerWithAddress; - let SECOND: SignerWithAddress; - let THIRD: SignerWithAddress; - let PROVIDER: SignerWithAddress; - - let diamond: LumerinDiamond; - let marketplace: Marketplace; - let modelRegistry: ModelRegistry; - let providerRegistry: ProviderRegistry; - let sessionRouter: SessionRouter; - - let MOR: MorpheusToken; - - async function deployProvider(): Promise< - IProviderStorage.ProviderStruct & { - address: Addressable; - } - > { - const provider = { - endpoint: 'localhost:3334', - stake: wei(100), - createdAt: 0n, - limitPeriodEnd: 0n, - limitPeriodEarned: 0n, - isDeleted: false, - address: PROVIDER, - }; - - await MOR.transfer(PROVIDER, provider.stake * 100n); - await MOR.connect(PROVIDER).approve(sessionRouter, provider.stake); - - await providerRegistry.connect(PROVIDER).providerRegister(provider.address, provider.stake, provider.endpoint); - provider.createdAt = await getCurrentBlockTime(); - provider.limitPeriodEnd = provider.createdAt + YEAR; - - return provider; - } - - async function deployModel(): Promise< - IModelStorage.ModelStruct & { - modelId: string; - } - > { - const model = { - modelId: randomBytes32(), - ipfsCID: getHex(Buffer.from('ipfs://ipfsaddress')), - fee: 100, - stake: 100, - owner: OWNER, - name: 'Llama 2.0', - tags: ['llama', 'animal', 'cute'], - createdAt: 0n, - isDeleted: false, - }; - - await MOR.approve(modelRegistry, model.stake); - - await modelRegistry.modelRegister( - model.modelId, - model.ipfsCID, - model.fee, - model.stake, - model.owner, - model.name, - model.tags, - ); - model.createdAt = await getCurrentBlockTime(); - - return model; - } - - async function deployBid(model: any): Promise< - [ - IBidStorage.BidStruct & { - id: string; - modelId: string; - }, - { - durationSeconds: bigint; - totalCost: bigint; - pricePerSecond: bigint; - user: SignerWithAddress; - provider: SignerWithAddress; - modelId: any; - bidId: string; - stake: bigint; - }, - ] - > { - let bid = { - id: '', - modelId: model.modelId, - pricePerSecond: wei(0.0001), - nonce: 0, - createdAt: 0n, - deletedAt: 0, - provider: PROVIDER, - }; - - await MOR.approve(modelRegistry, 10000n * 10n ** 18n); - - bid.id = await marketplace.connect(PROVIDER).postModelBid.staticCall(bid.provider, bid.modelId, bid.pricePerSecond); - await marketplace.connect(PROVIDER).postModelBid(bid.provider, bid.modelId, bid.pricePerSecond); - - bid.createdAt = await getCurrentBlockTime(); - - // generating data for sample session - const durationSeconds = HOUR; - const totalCost = bid.pricePerSecond * durationSeconds; - const totalSupply = await sessionRouter.totalMORSupply(await getCurrentBlockTime()); - const todaysBudget = await sessionRouter.getTodaysBudget(await getCurrentBlockTime()); - - const expectedSession = { - durationSeconds, - totalCost, - pricePerSecond: bid.pricePerSecond, - user: SECOND, - provider: bid.provider, - modelId: bid.modelId, - bidId: bid.id, - stake: (totalCost * totalSupply) / todaysBudget, - }; - - await MOR.transfer(SECOND, expectedSession.stake); - await MOR.connect(SECOND).approve(modelRegistry, expectedSession.stake); - - return [bid, expectedSession]; - } - - async function openSession(session: any) { - // open session - const { msg, signature } = await getProviderApproval(PROVIDER, await SECOND.getAddress(), session.bidId); - const sessionId = await sessionRouter.connect(SECOND).openSession.staticCall(session.stake, msg, signature); - await sessionRouter.connect(SECOND).openSession(session.stake, msg, signature); - - return sessionId; - } - - before('setup', async () => { - [OWNER, SECOND, THIRD, PROVIDER] = await ethers.getSigners(); - - const LinearDistributionIntervalDecrease = await ethers.getContractFactory('LinearDistributionIntervalDecrease'); - const linearDistributionIntervalDecrease = await LinearDistributionIntervalDecrease.deploy(); - - const [LumerinDiamond, Marketplace, ModelRegistry, ProviderRegistry, SessionRouter, MorpheusToken] = - await Promise.all([ - ethers.getContractFactory('LumerinDiamond'), - ethers.getContractFactory('Marketplace'), - ethers.getContractFactory('ModelRegistry'), - ethers.getContractFactory('ProviderRegistry'), - ethers.getContractFactory('SessionRouter', { - libraries: { - LinearDistributionIntervalDecrease: linearDistributionIntervalDecrease, - }, - }), - ethers.getContractFactory('MorpheusToken'), - ]); - - [diamond, marketplace, modelRegistry, providerRegistry, sessionRouter, MOR] = await Promise.all([ - LumerinDiamond.deploy(), - Marketplace.deploy(), - ModelRegistry.deploy(), - ProviderRegistry.deploy(), - SessionRouter.deploy(), - MorpheusToken.deploy(), - ]); - - await diamond.__LumerinDiamond_init(); - - await diamond['diamondCut((address,uint8,bytes4[])[])']([ - { - facetAddress: marketplace, - action: FacetAction.Add, - functionSelectors: IMarketplace__factory.createInterface() - .fragments.filter(Fragment.isFunction) - .map((f) => f.selector), - }, - { - facetAddress: providerRegistry, - action: FacetAction.Add, - functionSelectors: IProviderRegistry__factory.createInterface() - .fragments.filter(Fragment.isFunction) - .map((f) => f.selector), - }, - { - facetAddress: sessionRouter, - action: FacetAction.Add, - functionSelectors: ISessionRouter__factory.createInterface() - .fragments.filter(Fragment.isFunction) - .map((f) => f.selector), - }, - { - facetAddress: modelRegistry, - action: FacetAction.Add, - functionSelectors: IModelRegistry__factory.createInterface() - .fragments.filter(Fragment.isFunction) - .map((f) => f.selector), - }, - ]); - - marketplace = marketplace.attach(diamond.target) as Marketplace; - providerRegistry = providerRegistry.attach(diamond.target) as ProviderRegistry; - modelRegistry = modelRegistry.attach(diamond.target) as ModelRegistry; - sessionRouter = sessionRouter.attach(diamond.target) as SessionRouter; - - await marketplace.__Marketplace_init(MOR); - await modelRegistry.__ModelRegistry_init(); - await providerRegistry.__ProviderRegistry_init(); - - await sessionRouter.__SessionRouter_init(OWNER, getDefaultPools()); - - await reverter.snapshot(); - }); - - afterEach(reverter.revert); - - describe('Actions', () => { - let provider: IProviderStorage.ProviderStruct; - let model: IModelStorage.ModelStruct & { - modelId: string; - }; - let bid: IBidStorage.BidStruct & { - id: string; - modelId: string; - }; - let session: { - durationSeconds: bigint; - totalCost: bigint; - pricePerSecond: bigint; - user: SignerWithAddress; - provider: SignerWithAddress; - modelId: any; - bidId: string; - stake: bigint; - }; - - beforeEach(async () => { - provider = await deployProvider(); - model = await deployModel(); - [bid, session] = await deployBid(model); - }); - - describe('positive cases', () => { - it('should open session without error', async () => { - const { msg, signature } = await getProviderApproval(PROVIDER, await SECOND.getAddress(), session.bidId); - const sessionId = await sessionRouter.connect(SECOND).openSession.staticCall(session.stake, msg, signature); - await sessionRouter.connect(SECOND).openSession(session.stake, msg, signature); - - expect(sessionId).to.be.a('string'); - }); - - it('should emit SessionOpened event', async () => { - const { msg, signature } = await getProviderApproval(PROVIDER, await SECOND.getAddress(), session.bidId); - - const sessionId = await sessionRouter.connect(SECOND).openSession.staticCall(session.stake, msg, signature); - await expect(sessionRouter.connect(SECOND).openSession(session.stake, msg, signature)) - .to.emit(sessionRouter, 'SessionOpened') - .withArgs(session.user, sessionId, session.provider); - }); - - it('should verify session fields after opening', async () => { - const { msg, signature } = await getProviderApproval(PROVIDER, await SECOND.getAddress(), session.bidId); - const sessionId = await sessionRouter.connect(SECOND).openSession.staticCall(session.stake, msg, signature); - await sessionRouter.connect(SECOND).openSession(session.stake, msg, signature); - - const sessionData = await sessionRouter.sessions(sessionId); - const createdAt = await getCurrentBlockTime(); - - expect(sessionData).to.deep.equal([ - sessionId, - await resolveAddress(session.user), - await resolveAddress(session.provider), - session.modelId, - session.bidId, - session.stake, - session.pricePerSecond, - getHex(Buffer.from(''), 0), - 0n, - 0n, - createdAt, - sessionData.endsAt, // skipped in this test - 0n, - ]); - }); - - it('should verify balances after opening', async () => { - const srBefore = await MOR.balanceOf(sessionRouter); - const userBefore = await MOR.balanceOf(SECOND); - - const { msg, signature } = await getProviderApproval(PROVIDER, await SECOND.getAddress(), session.bidId); - await sessionRouter.connect(SECOND).openSession(session.stake, msg, signature); - - const srAfter = await MOR.balanceOf(sessionRouter); - const userAfter = await MOR.balanceOf(SECOND); - - expect(srAfter - srBefore).to.equal(session.stake); - expect(userBefore - userAfter).to.equal(session.stake); - }); - - it('should allow opening two sessions in the same block', async () => { - await MOR.transfer(SECOND, session.stake * 2n); - await MOR.connect(SECOND).approve(sessionRouter, session.stake * 2n); - - const apprv1 = await getProviderApproval(PROVIDER, await SECOND.getAddress(), session.bidId); - await setTime(Number(await getCurrentBlockTime()) + 1); - const apprv2 = await getProviderApproval(PROVIDER, await SECOND.getAddress(), session.bidId); - - await ethers.provider.send('evm_setAutomine', [false]); - - await sessionRouter.connect(SECOND).openSession(session.stake, apprv1.msg, apprv1.signature); - await sessionRouter.connect(SECOND).openSession(session.stake, apprv2.msg, apprv2.signature); - - await ethers.provider.send('evm_setAutomine', [true]); - await ethers.provider.send('evm_mine', []); - - const sessionId1 = await sessionRouter.getSessionId(SECOND, PROVIDER, session.stake, 0); - const sessionId2 = await sessionRouter.getSessionId(SECOND, PROVIDER, session.stake, 1); - - expect(sessionId1).not.to.equal(sessionId2); - - const session1 = await sessionRouter.sessions(sessionId1); - const session2 = await sessionRouter.sessions(sessionId2); - - expect(session1.stake).to.equal(session.stake); - expect(session2.stake).to.equal(session.stake); - }); - - it('should partially use remaining staked tokens for the opening session', async () => { - const sessionId = await openSession(session); - await setTime(Number((await getCurrentBlockTime()) + session.durationSeconds / 2n)); - - // close session - const report = await getReport(PROVIDER, sessionId, 10, 10); - await sessionRouter.connect(SECOND).closeSession(report.msg, report.signature); - - await setTime(Number(startOfTheDay(await getCurrentBlockTime()) + DAY)); - - const [avail] = await sessionRouter.withdrawableUserStake(SECOND, 255); - expect(avail > 0).to.be.true; - - // reset allowance - await MOR.connect(SECOND).approve(sessionRouter, 0n); - - const stake = avail / 2n; - - const approval = await getProviderApproval(PROVIDER, await SECOND.getAddress(), session.bidId); - await sessionRouter.connect(SECOND).openSession(stake, approval.msg, approval.signature); - - const [avail2] = await sessionRouter.withdrawableUserStake(SECOND, 255); - expect(avail2).to.be.equal(stake); - }); - - it('should use all remaining staked tokens for the opening session', async () => { - const sessionId = await openSession(session); - await setTime(Number((await getCurrentBlockTime()) + session.durationSeconds / 2n)); - - // close session - const report = await getReport(PROVIDER, sessionId, 10, 10); - await sessionRouter.connect(SECOND).closeSession(report.msg, report.signature); - - await setTime(Number(startOfTheDay(await getCurrentBlockTime()) + DAY)); - - const [avail] = await sessionRouter.withdrawableUserStake(SECOND, 255); - expect(avail > 0).to.be.true; - - // reset allowance - await MOR.connect(SECOND).approve(sessionRouter, 0n); - - const approval = await getProviderApproval(PROVIDER, await SECOND.getAddress(), session.bidId); - await sessionRouter.connect(SECOND).openSession(avail, approval.msg, approval.signature); - - const [avail2] = await sessionRouter.withdrawableUserStake(SECOND, 255); - expect(avail2).to.be.equal(0n); - }); - - it('should use remaining staked tokens and allowance for opening session', async () => { - const sessionId = await openSession(session); - await setTime(Number((await getCurrentBlockTime()) + session.durationSeconds / 2n)); - - // close session - const report = await getReport(PROVIDER, sessionId, 10, 10); - await sessionRouter.connect(SECOND).closeSession(report.msg, report.signature); - - await setTime(Number(startOfTheDay(await getCurrentBlockTime()) + DAY)); - - const [avail] = await sessionRouter.withdrawableUserStake(SECOND, 255); - expect(avail > 0).to.be.true; - - const allowancePart = 1000n; - const balanceBefore = await MOR.balanceOf(SECOND); - - // reset allowance - await MOR.connect(SECOND).approve(sessionRouter, allowancePart); - - const approval = await getProviderApproval(PROVIDER, await SECOND.getAddress(), session.bidId); - await sessionRouter.connect(SECOND).openSession(avail + allowancePart, approval.msg, approval.signature); - - // check all onHold used - const [avail2] = await sessionRouter.withdrawableUserStake(SECOND, 255); - expect(avail2).to.be.equal(0n); - - // check allowance used - const balanceAfter = await MOR.balanceOf(SECOND); - expect(balanceBefore - balanceAfter).to.be.equal(allowancePart); - }); - }); - - describe('negative cases', () => { - it('should error when approval generated for a different user', async () => { - const { msg, signature } = await getProviderApproval(PROVIDER, await THIRD.getAddress(), session.bidId); - - await expect( - sessionRouter.connect(SECOND).openSession(session.stake, msg, signature), - ).to.be.revertedWithCustomError(sessionRouter, 'ApprovedForAnotherUser'); - }); - - it('should error when approval expired', async () => { - const { msg, signature } = await getProviderApproval(PROVIDER, await SECOND.getAddress(), session.bidId); - const ttl = await sessionRouter.SIGNATURE_TTL(); - await setTime(Number((await getCurrentBlockTime()) + ttl) + 1); - - await expect( - sessionRouter.connect(SECOND).openSession(session.stake, msg, signature), - ).to.be.revertedWithCustomError(sessionRouter, 'SignatureExpired'); - }); - - it('should error when bid not exist', async () => { - const { msg, signature } = await getProviderApproval(PROVIDER, await SECOND.getAddress(), randomBytes32()); - await expect( - sessionRouter.connect(SECOND).openSession(session.stake, msg, signature), - ).to.be.revertedWithCustomError(sessionRouter, 'BidNotFound'); - }); - - it('should error when bid is deleted', async () => { - await marketplace.connect(PROVIDER).deleteModelBid(session.bidId); - - const { msg, signature } = await getProviderApproval(PROVIDER, await SECOND.getAddress(), session.bidId); - await expect( - sessionRouter.connect(SECOND).openSession(session.stake, msg, signature), - ).to.be.revertedWithCustomError(sessionRouter, 'BidNotFound'); - }); - - it('should error when signature has invalid length', async () => { - const { msg } = await getProviderApproval(PROVIDER, await SECOND.getAddress(), session.bidId); - - await expect(sessionRouter.connect(SECOND).openSession(session.stake, msg, '0x00')).to.be.revertedWith( - 'ECDSA: invalid signature length', - ); - }); - - it('should error when signature is invalid', async () => { - const { msg } = await getProviderApproval(PROVIDER, await SECOND.getAddress(), session.bidId); - const sig = randomBytes(65); - - await expect(sessionRouter.connect(SECOND).openSession(session.stake, msg, sig)).to.be.reverted; - }); - - it('should error when opening two bids with same signature', async () => { - const { msg, signature } = await getProviderApproval(PROVIDER, await SECOND.getAddress(), session.bidId); - await sessionRouter.connect(SECOND).openSession(session.stake, msg, signature); - - await approveUserFunds(session.stake); - - await expect( - sessionRouter.connect(SECOND).openSession(session.stake, msg, signature), - ).to.be.revertedWithCustomError(sessionRouter, 'DuplicateApproval'); - }); - - it('should not error when opening two bids same time', async () => { - const appr1 = await getProviderApproval(PROVIDER, await SECOND.getAddress(), session.bidId); - await sessionRouter.connect(SECOND).openSession(session.stake, appr1.msg, appr1.signature); - - await approveUserFunds(session.stake); - const appr2 = await getProviderApproval(PROVIDER, await SECOND.getAddress(), session.bidId); - await sessionRouter.connect(SECOND).openSession(session.stake, appr2.msg, appr2.signature); - }); - - it('should error with insufficient allowance', async () => { - const { msg, signature } = await getProviderApproval(PROVIDER, await SECOND.getAddress(), session.bidId); - await expect(sessionRouter.connect(SECOND).openSession(session.stake * 2n, msg, signature)).to.be.revertedWith( - 'ERC20: insufficient allowance', - ); - }); - - it('should error with insufficient allowance', async () => { - const stake = (await MOR.balanceOf(SECOND)) + 1n; - await MOR.connect(SECOND).approve(sessionRouter, stake); - - const { msg, signature } = await getProviderApproval(PROVIDER, await SECOND.getAddress(), session.bidId); - await expect(sessionRouter.connect(SECOND).openSession(stake, msg, signature)).to.be.revertedWith( - 'ERC20: transfer amount exceeds balance', - ); - }); - - it('should error if session time too short', async () => { - const { msg, signature } = await getProviderApproval(PROVIDER, await SECOND.getAddress(), session.bidId); - await expect(sessionRouter.connect(SECOND).openSession(0, msg, signature)).to.be.revertedWithCustomError( - sessionRouter, - 'SessionTooShort', - ); - }); - - it('should error if chainId is invalid', async () => { - const { msg, signature } = await getProviderApproval(PROVIDER, await SECOND.getAddress(), session.bidId, 1n); - await expect(sessionRouter.connect(SECOND).openSession(0, msg, signature)).to.be.revertedWithCustomError( - sessionRouter, - 'WrongChainId', - ); - }); - }); - - describe('verify session end time', () => { - it("session that doesn't span across midnight (1h)", async () => { - const durationSeconds = HOUR; - const stake = await getStake(durationSeconds, session.pricePerSecond); - - const { msg, signature } = await getProviderApproval(PROVIDER, await SECOND.getAddress(), session.bidId); - const sessionId = await sessionRouter.connect(SECOND).openSession.staticCall(stake, msg, signature); - await sessionRouter.connect(SECOND).openSession(stake, msg, signature); - - const sessionData = await sessionRouter.sessions(sessionId); - - expect(sessionData.endsAt).to.equal((await getCurrentBlockTime()) + durationSeconds); - }); - - it('session that spans across midnight (6h) should last 6h', async () => { - const tomorrow9pm = startOfTheDay(await getCurrentBlockTime()) + DAY + 21n * HOUR; - await setTime(Number(tomorrow9pm)); - - // the stake is enough to cover the first day (3h till midnight) and the next day (< 6h) - const durationSeconds = 6n * HOUR; - const stake = await getStake(durationSeconds, session.pricePerSecond); - await approveUserFunds(stake); - - const { msg, signature } = await getProviderApproval(PROVIDER, await SECOND.getAddress(), session.bidId); - const sessionId = await sessionRouter.connect(SECOND).openSession.staticCall(stake, msg, signature); - await sessionRouter.connect(SECOND).openSession(stake, msg, signature); - - const expEndsAt = (await getCurrentBlockTime()) + durationSeconds; - const sessionData = await sessionRouter.sessions(sessionId); - - expect(sessionData.endsAt).closeTo(expEndsAt, 10); - }); - - it('session that lasts multiple days', async () => { - const midnight = startOfTheDay(await getCurrentBlockTime()) + DAY; - await setTime(Number(midnight)); - - // the stake is enough to cover the whole day + extra 1h - const durationSeconds = 25n * HOUR; - const stake = await sessionRouter.stipendToStake( - durationSeconds * session.pricePerSecond, - await getCurrentBlockTime(), - ); - - await approveUserFunds(stake); - - const { msg, signature } = await getProviderApproval(PROVIDER, await SECOND.getAddress(), session.bidId); - const sessionId = await sessionRouter.connect(SECOND).openSession.staticCall(stake, msg, signature); - await sessionRouter.connect(SECOND).openSession(stake, msg, signature); - - const sessionData = await sessionRouter.sessions(sessionId); - const durSeconds = Number(sessionData.endsAt - sessionData.openedAt); - - expect(durSeconds).to.equal(DAY); - }); - }); - }); - - async function approveUserFunds(amount: bigint) { - await MOR.transfer(SECOND, amount); - await MOR.connect(SECOND).approve(sessionRouter, amount); - } - - async function getStake(durationSeconds: bigint, pricePerSecond: bigint): Promise { - const totalCost = pricePerSecond * durationSeconds; - const totalSupply = await sessionRouter.totalMORSupply(await getCurrentBlockTime()); - const todaysBudget = await sessionRouter.getTodaysBudget(await getCurrentBlockTime()); - return (totalCost * totalSupply) / todaysBudget; - } -}); diff --git a/smart-contracts/test/diamond/SessionRouter/readFunctions.test.ts b/smart-contracts/test/diamond/SessionRouter/readFunctions.test.ts deleted file mode 100644 index b201d62c..00000000 --- a/smart-contracts/test/diamond/SessionRouter/readFunctions.test.ts +++ /dev/null @@ -1,341 +0,0 @@ -import { - IBidStorage, - IMarketplace__factory, - IModelRegistry__factory, - IModelStorage, - IProviderRegistry__factory, - IProviderStorage, - ISessionRouter__factory, - LumerinDiamond, - Marketplace, - ModelRegistry, - MorpheusToken, - ProviderRegistry, - SessionRouter, -} from '@ethers-v6'; -import { SignerWithAddress } from '@nomicfoundation/hardhat-ethers/signers'; -import { expect } from 'chai'; -import { Addressable, Fragment } from 'ethers'; -import { ethers } from 'hardhat'; - -import { getHex, randomBytes32, wei } from '@/scripts/utils/utils'; -import { FacetAction } from '@/test/helpers/enums'; -import { getDefaultPools } from '@/test/helpers/pool-helper'; -import { Reverter } from '@/test/helpers/reverter'; -import { getCurrentBlockTime, setTime } from '@/utils/block-helper'; -import { getProviderApproval, getReport } from '@/utils/provider-helper'; -import { HOUR, YEAR } from '@/utils/time'; - -describe('Session router', () => { - const reverter = new Reverter(); - - let OWNER: SignerWithAddress; - let SECOND: SignerWithAddress; - let THIRD: SignerWithAddress; - let PROVIDER: SignerWithAddress; - - let diamond: LumerinDiamond; - let marketplace: Marketplace; - let modelRegistry: ModelRegistry; - let providerRegistry: ProviderRegistry; - let sessionRouter: SessionRouter; - - let MOR: MorpheusToken; - - async function deployProvider(): Promise< - IProviderStorage.ProviderStruct & { - address: Addressable; - } - > { - const provider = { - endpoint: 'localhost:3334', - stake: wei(100), - createdAt: 0n, - limitPeriodEnd: 0n, - limitPeriodEarned: 0n, - isDeleted: false, - address: PROVIDER, - }; - - await MOR.transfer(PROVIDER, provider.stake * 100n); - await MOR.connect(PROVIDER).approve(sessionRouter, provider.stake); - - await providerRegistry.connect(PROVIDER).providerRegister(provider.address, provider.stake, provider.endpoint); - provider.createdAt = await getCurrentBlockTime(); - provider.limitPeriodEnd = provider.createdAt + YEAR; - - return provider; - } - - async function deployModel(): Promise< - IModelStorage.ModelStruct & { - modelId: string; - } - > { - const model = { - modelId: randomBytes32(), - ipfsCID: getHex(Buffer.from('ipfs://ipfsaddress')), - fee: 100, - stake: 100, - owner: OWNER, - name: 'Llama 2.0', - tags: ['llama', 'animal', 'cute'], - createdAt: 0n, - isDeleted: false, - }; - - await MOR.approve(modelRegistry, model.stake); - - await modelRegistry.modelRegister( - model.modelId, - model.ipfsCID, - model.fee, - model.stake, - model.owner, - model.name, - model.tags, - ); - model.createdAt = await getCurrentBlockTime(); - - return model; - } - - async function deployBid(model: any): Promise< - [ - IBidStorage.BidStruct & { - id: string; - modelId: string; - }, - { - durationSeconds: bigint; - totalCost: bigint; - pricePerSecond: bigint; - user: SignerWithAddress; - provider: SignerWithAddress; - modelId: any; - bidId: string; - stake: bigint; - }, - ] - > { - let bid = { - id: '', - modelId: model.modelId, - pricePerSecond: wei(0.0001), - nonce: 0, - createdAt: 0n, - deletedAt: 0, - provider: PROVIDER, - }; - - await MOR.approve(modelRegistry, 10000n * 10n ** 18n); - - bid.id = await marketplace.connect(PROVIDER).postModelBid.staticCall(bid.provider, bid.modelId, bid.pricePerSecond); - await marketplace.connect(PROVIDER).postModelBid(bid.provider, bid.modelId, bid.pricePerSecond); - - bid.createdAt = await getCurrentBlockTime(); - - // generating data for sample session - const durationSeconds = HOUR; - const totalCost = bid.pricePerSecond * durationSeconds; - const totalSupply = await sessionRouter.totalMORSupply(await getCurrentBlockTime()); - const todaysBudget = await sessionRouter.getTodaysBudget(await getCurrentBlockTime()); - - const expectedSession = { - durationSeconds, - totalCost, - pricePerSecond: bid.pricePerSecond, - user: SECOND, - provider: bid.provider, - modelId: bid.modelId, - bidId: bid.id, - stake: (totalCost * totalSupply) / todaysBudget, - }; - - await MOR.transfer(SECOND, expectedSession.stake); - await MOR.connect(SECOND).approve(modelRegistry, expectedSession.stake); - - return [bid, expectedSession]; - } - - before('setup', async () => { - [OWNER, SECOND, THIRD, PROVIDER] = await ethers.getSigners(); - - const LinearDistributionIntervalDecrease = await ethers.getContractFactory('LinearDistributionIntervalDecrease'); - const linearDistributionIntervalDecrease = await LinearDistributionIntervalDecrease.deploy(); - - const [LumerinDiamond, Marketplace, ModelRegistry, ProviderRegistry, SessionRouter, MorpheusToken] = - await Promise.all([ - ethers.getContractFactory('LumerinDiamond'), - ethers.getContractFactory('Marketplace'), - ethers.getContractFactory('ModelRegistry'), - ethers.getContractFactory('ProviderRegistry'), - ethers.getContractFactory('SessionRouter', { - libraries: { - LinearDistributionIntervalDecrease: linearDistributionIntervalDecrease, - }, - }), - ethers.getContractFactory('MorpheusToken'), - ]); - - [diamond, marketplace, modelRegistry, providerRegistry, sessionRouter, MOR] = await Promise.all([ - LumerinDiamond.deploy(), - Marketplace.deploy(), - ModelRegistry.deploy(), - ProviderRegistry.deploy(), - SessionRouter.deploy(), - MorpheusToken.deploy(), - ]); - - await diamond.__LumerinDiamond_init(); - - await diamond['diamondCut((address,uint8,bytes4[])[])']([ - { - facetAddress: marketplace, - action: FacetAction.Add, - functionSelectors: IMarketplace__factory.createInterface() - .fragments.filter(Fragment.isFunction) - .map((f) => f.selector), - }, - { - facetAddress: providerRegistry, - action: FacetAction.Add, - functionSelectors: IProviderRegistry__factory.createInterface() - .fragments.filter(Fragment.isFunction) - .map((f) => f.selector), - }, - { - facetAddress: sessionRouter, - action: FacetAction.Add, - functionSelectors: ISessionRouter__factory.createInterface() - .fragments.filter(Fragment.isFunction) - .map((f) => f.selector), - }, - { - facetAddress: modelRegistry, - action: FacetAction.Add, - functionSelectors: IModelRegistry__factory.createInterface() - .fragments.filter(Fragment.isFunction) - .map((f) => f.selector), - }, - ]); - - marketplace = marketplace.attach(diamond.target) as Marketplace; - providerRegistry = providerRegistry.attach(diamond.target) as ProviderRegistry; - modelRegistry = modelRegistry.attach(diamond.target) as ModelRegistry; - sessionRouter = sessionRouter.attach(diamond.target) as SessionRouter; - - await marketplace.__Marketplace_init(MOR); - await modelRegistry.__ModelRegistry_init(); - await providerRegistry.__ProviderRegistry_init(); - - await sessionRouter.__SessionRouter_init(OWNER, getDefaultPools()); - - await reverter.snapshot(); - }); - - afterEach(reverter.revert); - - describe('Actions', () => { - let provider: IProviderStorage.ProviderStruct; - let model: IModelStorage.ModelStruct & { - modelId: string; - }; - let bid: IBidStorage.BidStruct & { - id: string; - modelId: string; - }; - let session: { - durationSeconds: bigint; - totalCost: bigint; - pricePerSecond: bigint; - user: SignerWithAddress; - provider: SignerWithAddress; - modelId: any; - bidId: string; - stake: bigint; - }; - - beforeEach(async () => { - provider = await deployProvider(); - model = await deployModel(); - [bid, session] = await deployBid(model); - }); - - describe('session read functions', () => { - const exp = { - initialReward: 3456000000000000000000n, - rewardDecrease: 592558728240000000n, - payoutStart: 1707393600n, - decreaseInterval: 86400n, - blockTimeEpochSeconds: BigInt(new Date('2024-05-02T09:19:57Z').getTime()) / 1000n, - balance: 286534931460577320000000n, - }; - - it('should get compute balance equal to one on L1', async () => { - await sessionRouter.setPoolConfig(3n, { - initialReward: exp.initialReward, - rewardDecrease: exp.rewardDecrease, - payoutStart: exp.payoutStart, - decreaseInterval: exp.decreaseInterval, - }); - - const balance = await sessionRouter.getComputeBalance(exp.blockTimeEpochSeconds); - - expect(balance).to.equal(exp.balance); - }); - - it('should revert if caller is not an owner', async () => { - await expect( - sessionRouter.connect(SECOND).setPoolConfig(3, { - initialReward: exp.initialReward, - rewardDecrease: exp.rewardDecrease, - payoutStart: exp.payoutStart, - decreaseInterval: exp.decreaseInterval, - }), - ).to.revertedWithCustomError(diamond, 'OwnableUnauthorizedAccount'); - }); - - it('should revert if pool is not exists', async () => { - await expect( - sessionRouter.setPoolConfig(9999, { - initialReward: exp.initialReward, - rewardDecrease: exp.rewardDecrease, - payoutStart: exp.payoutStart, - decreaseInterval: exp.decreaseInterval, - }), - ).to.revertedWithCustomError(sessionRouter, 'PoolIndexOutOfBounds'); - }); - }); - - describe('getProviderClaimableBalance', () => { - it('should be correct for contract that closed early due to dispute [H-6]', async () => { - // open session - const { msg, signature } = await getProviderApproval(PROVIDER, await SECOND.getAddress(), session.bidId); - const sessionId = await sessionRouter.connect(SECOND).openSession.staticCall(session.stake, msg, signature); - await sessionRouter.connect(SECOND).openSession(session.stake, msg, signature); - - await setTime(Number((await getCurrentBlockTime()) + session.durationSeconds / 2n - 1n)); - - // close session with dispute / user report - const report = await getReport(SECOND, sessionId, 10, 10); - await sessionRouter.connect(SECOND).closeSession(report.msg, report.signature); - - // verify session is closed with dispute - const sessionData = await sessionRouter.sessions(sessionId); - expect(sessionData.closeoutType).to.equal(1n); - - const sessionCost = session.pricePerSecond * (sessionData.closedAt - sessionData.openedAt); - - // immediately after claimable balance should be 0 - const claimable = await sessionRouter.getProviderClaimableBalance(sessionId); - expect(claimable).to.equal(0n); - - // after 24 hours claimable balance should be correct - await setTime(Number((await getCurrentBlockTime()) + 24n * HOUR)); - const claimable2 = await sessionRouter.getProviderClaimableBalance(sessionId); - expect(claimable2).to.equal(sessionCost); - }); - }); - }); -}); diff --git a/smart-contracts/test/diamond/SessionRouter/stats.test.ts b/smart-contracts/test/diamond/SessionRouter/stats.test.ts deleted file mode 100644 index 6e8fe9af..00000000 --- a/smart-contracts/test/diamond/SessionRouter/stats.test.ts +++ /dev/null @@ -1,320 +0,0 @@ -import { - IBidStorage, - IMarketplace__factory, - IModelRegistry__factory, - IModelStorage, - IProviderRegistry__factory, - IProviderStorage, - ISessionRouter__factory, - LumerinDiamond, - Marketplace, - ModelRegistry, - MorpheusToken, - ProviderRegistry, - SessionRouter, -} from '@ethers-v6'; -import { SignerWithAddress } from '@nomicfoundation/hardhat-ethers/signers'; -import { expect } from 'chai'; -import { Addressable, Fragment, resolveAddress } from 'ethers'; -import { ethers } from 'hardhat'; - -import { getHex, randomBytes32, wei } from '@/scripts/utils/utils'; -import { FacetAction } from '@/test/helpers/enums'; -import { getDefaultPools } from '@/test/helpers/pool-helper'; -import { Reverter } from '@/test/helpers/reverter'; -import { getCurrentBlockTime, setTime } from '@/utils/block-helper'; -import { getProviderApproval, getReport } from '@/utils/provider-helper'; -import { HOUR, YEAR } from '@/utils/time'; - -describe('Session router - stats tests', () => { - const reverter = new Reverter(); - - let OWNER: SignerWithAddress; - let SECOND: SignerWithAddress; - let THIRD: SignerWithAddress; - let PROVIDER: SignerWithAddress; - - let diamond: LumerinDiamond; - let marketplace: Marketplace; - let modelRegistry: ModelRegistry; - let providerRegistry: ProviderRegistry; - let sessionRouter: SessionRouter; - - let MOR: MorpheusToken; - - async function deployProvider(): Promise< - IProviderStorage.ProviderStruct & { - address: Addressable; - } - > { - const provider = { - endpoint: 'localhost:3334', - stake: wei(100), - createdAt: 0n, - limitPeriodEnd: 0n, - limitPeriodEarned: 0n, - isDeleted: false, - address: PROVIDER, - }; - - await MOR.transfer(PROVIDER, provider.stake * 100n); - await MOR.connect(PROVIDER).approve(sessionRouter, provider.stake); - - await providerRegistry.connect(PROVIDER).providerRegister(provider.address, provider.stake, provider.endpoint); - provider.createdAt = await getCurrentBlockTime(); - provider.limitPeriodEnd = provider.createdAt + YEAR; - - return provider; - } - - async function deployModel(): Promise< - IModelStorage.ModelStruct & { - modelId: string; - } - > { - const model = { - modelId: randomBytes32(), - ipfsCID: getHex(Buffer.from('ipfs://ipfsaddress')), - fee: 100, - stake: 100, - owner: OWNER, - name: 'Llama 2.0', - tags: ['llama', 'animal', 'cute'], - createdAt: 0n, - isDeleted: false, - }; - - await MOR.approve(modelRegistry, model.stake); - - await modelRegistry.modelRegister( - model.modelId, - model.ipfsCID, - model.fee, - model.stake, - model.owner, - model.name, - model.tags, - ); - model.createdAt = await getCurrentBlockTime(); - - return model; - } - - async function deployBid(model: any): Promise< - [ - IBidStorage.BidStruct & { - id: string; - modelId: string; - }, - { - durationSeconds: bigint; - totalCost: bigint; - pricePerSecond: bigint; - user: SignerWithAddress; - provider: SignerWithAddress; - modelId: any; - bidId: string; - stake: bigint; - }, - ] - > { - let bid = { - id: '', - modelId: model.modelId, - pricePerSecond: wei(0.0001), - nonce: 0, - createdAt: 0n, - deletedAt: 0, - provider: PROVIDER, - }; - - await MOR.approve(modelRegistry, 10000n * 10n ** 18n); - - bid.id = await marketplace.connect(PROVIDER).postModelBid.staticCall(bid.provider, bid.modelId, bid.pricePerSecond); - await marketplace.connect(PROVIDER).postModelBid(bid.provider, bid.modelId, bid.pricePerSecond); - - bid.createdAt = await getCurrentBlockTime(); - - // generating data for sample session - const durationSeconds = HOUR; - const totalCost = bid.pricePerSecond * durationSeconds; - const totalSupply = await sessionRouter.totalMORSupply(await getCurrentBlockTime()); - const todaysBudget = await sessionRouter.getTodaysBudget(await getCurrentBlockTime()); - - const expectedSession = { - durationSeconds, - totalCost, - pricePerSecond: bid.pricePerSecond, - user: SECOND, - provider: bid.provider, - modelId: bid.modelId, - bidId: bid.id, - stake: (totalCost * totalSupply) / todaysBudget, - }; - - await MOR.transfer(SECOND, expectedSession.stake); - await MOR.connect(SECOND).approve(modelRegistry, expectedSession.stake); - - return [bid, expectedSession]; - } - before('setup', async () => { - [OWNER, SECOND, THIRD, PROVIDER] = await ethers.getSigners(); - - const LinearDistributionIntervalDecrease = await ethers.getContractFactory('LinearDistributionIntervalDecrease'); - const linearDistributionIntervalDecrease = await LinearDistributionIntervalDecrease.deploy(); - - const [LumerinDiamond, Marketplace, ModelRegistry, ProviderRegistry, SessionRouter, MorpheusToken] = - await Promise.all([ - ethers.getContractFactory('LumerinDiamond'), - ethers.getContractFactory('Marketplace'), - ethers.getContractFactory('ModelRegistry'), - ethers.getContractFactory('ProviderRegistry'), - ethers.getContractFactory('SessionRouter', { - libraries: { - LinearDistributionIntervalDecrease: linearDistributionIntervalDecrease, - }, - }), - ethers.getContractFactory('MorpheusToken'), - ]); - - [diamond, marketplace, modelRegistry, providerRegistry, sessionRouter, MOR] = await Promise.all([ - LumerinDiamond.deploy(), - Marketplace.deploy(), - ModelRegistry.deploy(), - ProviderRegistry.deploy(), - SessionRouter.deploy(), - MorpheusToken.deploy(), - ]); - - await diamond.__LumerinDiamond_init(); - - await diamond['diamondCut((address,uint8,bytes4[])[])']([ - { - facetAddress: marketplace, - action: FacetAction.Add, - functionSelectors: IMarketplace__factory.createInterface() - .fragments.filter(Fragment.isFunction) - .map((f) => f.selector), - }, - { - facetAddress: providerRegistry, - action: FacetAction.Add, - functionSelectors: IProviderRegistry__factory.createInterface() - .fragments.filter(Fragment.isFunction) - .map((f) => f.selector), - }, - { - facetAddress: sessionRouter, - action: FacetAction.Add, - functionSelectors: ISessionRouter__factory.createInterface() - .fragments.filter(Fragment.isFunction) - .map((f) => f.selector), - }, - { - facetAddress: modelRegistry, - action: FacetAction.Add, - functionSelectors: IModelRegistry__factory.createInterface() - .fragments.filter(Fragment.isFunction) - .map((f) => f.selector), - }, - ]); - - marketplace = marketplace.attach(diamond.target) as Marketplace; - providerRegistry = providerRegistry.attach(diamond.target) as ProviderRegistry; - modelRegistry = modelRegistry.attach(diamond.target) as ModelRegistry; - sessionRouter = sessionRouter.attach(diamond.target) as SessionRouter; - - await marketplace.__Marketplace_init(MOR); - await modelRegistry.__ModelRegistry_init(); - await providerRegistry.__ProviderRegistry_init(); - - await sessionRouter.__SessionRouter_init(OWNER, getDefaultPools()); - - await reverter.snapshot(); - }); - - afterEach(reverter.revert); - - describe('Actions', () => { - let provider: IProviderStorage.ProviderStruct; - let model: IModelStorage.ModelStruct & { - modelId: string; - }; - let bid: IBidStorage.BidStruct & { - id: string; - modelId: string; - }; - let session: { - durationSeconds: bigint; - totalCost: bigint; - pricePerSecond: bigint; - user: SignerWithAddress; - provider: SignerWithAddress; - modelId: any; - bidId: string; - stake: bigint; - }; - - beforeEach(async () => { - provider = await deployProvider(); - model = await deployModel(); - [bid, session] = await deployBid(model); - }); - - it('should update provider-model stats', async () => { - await openCloseSession(session.bidId, HOUR, session.pricePerSecond, 100, 1000, true); - - await openCloseSession(session.bidId, HOUR, session.pricePerSecond, 150, 2000, true); - - const [bidIds, bids, stats] = await sessionRouter.getActiveBidsRatingByModel(session.modelId, 0n, 100); - - expect(bidIds).to.deep.equal([session.bidId]); - expect(bids[0]).to.deep.equal([ - await resolveAddress(bid.provider), - bid.modelId, - bid.pricePerSecond, - bid.nonce, - bid.createdAt, - bid.deletedAt, - ]); - expect(stats[0].successCount).to.equal(2); - expect(stats[0].totalCount).to.equal(2); - expect(Number(stats[0].tpsScaled1000.mean)).to.greaterThan(0); - expect(Number(stats[0].ttftMs.mean)).to.greaterThan(0); - }); - }); - - async function openCloseSession( - bidId: string, - durationSeconds: bigint, - pricePerSecond: bigint, - tps: number, - ttft: number, - success = true, - ) { - // open session - const { msg, signature } = await getProviderApproval(PROVIDER, await SECOND.getAddress(), bidId); - const stake = await getStake(durationSeconds, pricePerSecond); - - await MOR.transfer(SECOND, stake); - await MOR.connect(SECOND).approve(sessionRouter, stake); - const sessionId = await sessionRouter.connect(SECOND).openSession.staticCall(stake, msg, signature); - await sessionRouter.connect(SECOND).openSession(stake, msg, signature); - - // wait till end of the session - await setTime(Number((await getCurrentBlockTime()) + durationSeconds)); - - // close session - const signer = success ? PROVIDER : SECOND; - const report = await getReport(signer, sessionId, tps, ttft); - - await sessionRouter.connect(SECOND).closeSession(report.msg, report.signature); - } - - async function getStake(durationSeconds: bigint, pricePerSecond: bigint): Promise { - const totalCost = pricePerSecond * durationSeconds; - const totalSupply = await sessionRouter.totalMORSupply(await getCurrentBlockTime()); - const todaysBudget = await sessionRouter.getTodaysBudget(await getCurrentBlockTime()); - return (totalCost * totalSupply) / todaysBudget; - } -}); diff --git a/smart-contracts/test/diamond/SessionRouter/userOnHold.test.ts b/smart-contracts/test/diamond/SessionRouter/userOnHold.test.ts deleted file mode 100644 index c00b2a36..00000000 --- a/smart-contracts/test/diamond/SessionRouter/userOnHold.test.ts +++ /dev/null @@ -1,337 +0,0 @@ -import { - IBidStorage, - IMarketplace__factory, - IModelRegistry__factory, - IModelStorage, - IProviderRegistry__factory, - IProviderStorage, - ISessionRouter__factory, - LumerinDiamond, - Marketplace, - ModelRegistry, - MorpheusToken, - ProviderRegistry, - SessionRouter, -} from '@ethers-v6'; -import { SignerWithAddress } from '@nomicfoundation/hardhat-ethers/signers'; -import { expect } from 'chai'; -import { Addressable, Fragment } from 'ethers'; -import { ethers } from 'hardhat'; - -import { MAX_UINT8 } from '@/scripts/utils/constants'; -import { getHex, randomBytes32, wei } from '@/scripts/utils/utils'; -import { FacetAction } from '@/test/helpers/enums'; -import { getDefaultPools } from '@/test/helpers/pool-helper'; -import { Reverter } from '@/test/helpers/reverter'; -import { getCurrentBlockTime, setTime } from '@/utils/block-helper'; -import { getProviderApproval, getReport } from '@/utils/provider-helper'; -import { DAY, HOUR, YEAR } from '@/utils/time'; - -describe('User on hold tests', () => { - const reverter = new Reverter(); - - let OWNER: SignerWithAddress; - let SECOND: SignerWithAddress; - let THIRD: SignerWithAddress; - let PROVIDER: SignerWithAddress; - - let diamond: LumerinDiamond; - let marketplace: Marketplace; - let modelRegistry: ModelRegistry; - let providerRegistry: ProviderRegistry; - let sessionRouter: SessionRouter; - - let MOR: MorpheusToken; - - async function deployProvider(): Promise< - IProviderStorage.ProviderStruct & { - address: Addressable; - } - > { - const provider = { - endpoint: 'localhost:3334', - stake: wei(100), - createdAt: 0n, - limitPeriodEnd: 0n, - limitPeriodEarned: 0n, - isDeleted: false, - address: PROVIDER, - }; - - await MOR.transfer(PROVIDER, provider.stake * 100n); - await MOR.connect(PROVIDER).approve(sessionRouter, provider.stake); - - await providerRegistry.connect(PROVIDER).providerRegister(provider.address, provider.stake, provider.endpoint); - provider.createdAt = await getCurrentBlockTime(); - provider.limitPeriodEnd = provider.createdAt + YEAR; - - return provider; - } - - async function deployModel(): Promise< - IModelStorage.ModelStruct & { - modelId: string; - } - > { - const model = { - modelId: randomBytes32(), - ipfsCID: getHex(Buffer.from('ipfs://ipfsaddress')), - fee: 100, - stake: 100, - owner: OWNER, - name: 'Llama 2.0', - tags: ['llama', 'animal', 'cute'], - createdAt: 0n, - isDeleted: false, - }; - - await MOR.approve(modelRegistry, model.stake); - - await modelRegistry.modelRegister( - model.modelId, - model.ipfsCID, - model.fee, - model.stake, - model.owner, - model.name, - model.tags, - ); - model.createdAt = await getCurrentBlockTime(); - - return model; - } - - async function deployBid(model: any): Promise< - [ - IBidStorage.BidStruct & { - id: string; - modelId: string; - }, - { - durationSeconds: bigint; - totalCost: bigint; - pricePerSecond: bigint; - user: SignerWithAddress; - provider: SignerWithAddress; - modelId: any; - bidId: string; - stake: bigint; - }, - ] - > { - let bid = { - id: '', - modelId: model.modelId, - pricePerSecond: wei(0.0001), - nonce: 0, - createdAt: 0n, - deletedAt: 0, - provider: PROVIDER, - }; - - await MOR.approve(modelRegistry, 10000n * 10n ** 18n); - - bid.id = await marketplace.connect(PROVIDER).postModelBid.staticCall(bid.provider, bid.modelId, bid.pricePerSecond); - await marketplace.connect(PROVIDER).postModelBid(bid.provider, bid.modelId, bid.pricePerSecond); - - bid.createdAt = await getCurrentBlockTime(); - - // generating data for sample session - const durationSeconds = HOUR; - const totalCost = bid.pricePerSecond * durationSeconds; - const totalSupply = await sessionRouter.totalMORSupply(await getCurrentBlockTime()); - const todaysBudget = await sessionRouter.getTodaysBudget(await getCurrentBlockTime()); - - const expectedSession = { - durationSeconds, - totalCost, - pricePerSecond: bid.pricePerSecond, - user: SECOND, - provider: bid.provider, - modelId: bid.modelId, - bidId: bid.id, - stake: (totalCost * totalSupply) / todaysBudget, - }; - - await MOR.transfer(SECOND, expectedSession.stake); - await MOR.connect(SECOND).approve(modelRegistry, expectedSession.stake); - - return [bid, expectedSession]; - } - - async function openSession(session: any) { - // open session - const { msg, signature } = await getProviderApproval(PROVIDER, await SECOND.getAddress(), session.bidId); - const sessionId = await sessionRouter.connect(SECOND).openSession.staticCall(session.stake, msg, signature); - await sessionRouter.connect(SECOND).openSession(session.stake, msg, signature); - - return sessionId; - } - - async function openEarlyCloseSession(session: any, sessionId: string) { - await setTime(Number(await getCurrentBlockTime()) + Number(session.durationSeconds) / 2); - - // close session - const report = await getReport(PROVIDER, sessionId, 10, 10); - await sessionRouter.connect(SECOND).closeSession(report.msg, report.signature); - - return session.stake / 2n; - } - before('setup', async () => { - [OWNER, SECOND, THIRD, PROVIDER] = await ethers.getSigners(); - - const LinearDistributionIntervalDecrease = await ethers.getContractFactory('LinearDistributionIntervalDecrease'); - const linearDistributionIntervalDecrease = await LinearDistributionIntervalDecrease.deploy(); - - const [LumerinDiamond, Marketplace, ModelRegistry, ProviderRegistry, SessionRouter, MorpheusToken] = - await Promise.all([ - ethers.getContractFactory('LumerinDiamond'), - ethers.getContractFactory('Marketplace'), - ethers.getContractFactory('ModelRegistry'), - ethers.getContractFactory('ProviderRegistry'), - ethers.getContractFactory('SessionRouter', { - libraries: { - LinearDistributionIntervalDecrease: linearDistributionIntervalDecrease, - }, - }), - ethers.getContractFactory('MorpheusToken'), - ]); - - [diamond, marketplace, modelRegistry, providerRegistry, sessionRouter, MOR] = await Promise.all([ - LumerinDiamond.deploy(), - Marketplace.deploy(), - ModelRegistry.deploy(), - ProviderRegistry.deploy(), - SessionRouter.deploy(), - MorpheusToken.deploy(), - ]); - - await diamond.__LumerinDiamond_init(); - - await diamond['diamondCut((address,uint8,bytes4[])[])']([ - { - facetAddress: marketplace, - action: FacetAction.Add, - functionSelectors: IMarketplace__factory.createInterface() - .fragments.filter(Fragment.isFunction) - .map((f) => f.selector), - }, - { - facetAddress: providerRegistry, - action: FacetAction.Add, - functionSelectors: IProviderRegistry__factory.createInterface() - .fragments.filter(Fragment.isFunction) - .map((f) => f.selector), - }, - { - facetAddress: sessionRouter, - action: FacetAction.Add, - functionSelectors: ISessionRouter__factory.createInterface() - .fragments.filter(Fragment.isFunction) - .map((f) => f.selector), - }, - { - facetAddress: modelRegistry, - action: FacetAction.Add, - functionSelectors: IModelRegistry__factory.createInterface() - .fragments.filter(Fragment.isFunction) - .map((f) => f.selector), - }, - ]); - - marketplace = marketplace.attach(diamond.target) as Marketplace; - providerRegistry = providerRegistry.attach(diamond.target) as ProviderRegistry; - modelRegistry = modelRegistry.attach(diamond.target) as ModelRegistry; - sessionRouter = sessionRouter.attach(diamond.target) as SessionRouter; - - await marketplace.__Marketplace_init(MOR); - await modelRegistry.__ModelRegistry_init(); - await providerRegistry.__ProviderRegistry_init(); - - await sessionRouter.__SessionRouter_init(OWNER, getDefaultPools()); - - await reverter.snapshot(); - }); - - afterEach(reverter.revert); - - describe('Actions', () => { - let provider: IProviderStorage.ProviderStruct; - let model: IModelStorage.ModelStruct & { - modelId: string; - }; - let bid: IBidStorage.BidStruct & { - id: string; - modelId: string; - }; - let session: { - durationSeconds: bigint; - totalCost: bigint; - pricePerSecond: bigint; - user: SignerWithAddress; - provider: SignerWithAddress; - modelId: any; - bidId: string; - stake: bigint; - }; - let sessionId: string; - let expectedOnHold: bigint; - - beforeEach(async () => { - provider = await deployProvider(); - model = await deployModel(); - [bid, session] = await deployBid(model); - sessionId = await openSession(session); - expectedOnHold = await openEarlyCloseSession(session, sessionId); - }); - - it('user stake should be locked right after closeout', async () => { - // right after closeout - const [available, onHold] = await sessionRouter.withdrawableUserStake(SECOND, MAX_UINT8); - expect(available).to.equal(0n); - expect(onHold).to.closeTo(expectedOnHold, BigInt(0.01 * Number(expectedOnHold))); - }); - - it('user stake should be locked before the next day', async () => { - // before next day - await setTime(startOfTomorrow(Number(await getCurrentBlockTime())) - Number(HOUR)); - const [available3, onHold3] = await sessionRouter.withdrawableUserStake(SECOND, Number(MAX_UINT8)); - expect(available3).to.equal(0n); - expect(onHold3).to.closeTo(expectedOnHold, BigInt(0.01 * Number(expectedOnHold))); - }); - - it('user stake should be available on the next day and withdrawable', async () => { - await setTime(startOfTomorrow(Number(await getCurrentBlockTime()))); - const [available2, onHold2] = await sessionRouter.withdrawableUserStake(SECOND, Number(MAX_UINT8)); - expect(available2).to.closeTo(expectedOnHold, BigInt(0.01 * Number(expectedOnHold))); - expect(onHold2).to.equal(0n); - - const balanceBefore = await MOR.balanceOf(SECOND); - await sessionRouter.connect(SECOND).withdrawUserStake(available2, Number(MAX_UINT8)); - const balanceAfter = await MOR.balanceOf(SECOND); - const balanceDelta = balanceAfter - balanceBefore; - expect(balanceDelta).to.closeTo(expectedOnHold, BigInt(0.01 * Number(expectedOnHold))); - }); - - it("user shouldn't be able to withdraw more than there is available stake", async () => { - await setTime(startOfTomorrow(Number(await getCurrentBlockTime()))); - const [available2] = await sessionRouter.withdrawableUserStake(SECOND, Number(MAX_UINT8)); - - await expect( - sessionRouter.connect(SECOND).withdrawUserStake(available2 * 2n, Number(MAX_UINT8)), - ).to.be.revertedWithCustomError(sessionRouter, 'NotEnoughWithdrawableBalance'); - }); - - it('should revert if amount to withdraw is 0', async () => { - await expect(sessionRouter.connect(SECOND).withdrawUserStake(0, Number(MAX_UINT8))).to.be.revertedWithCustomError( - sessionRouter, - 'AmountToWithdrawIsZero', - ); - }); - }); -}); - -function startOfTomorrow(epochSeconds: number): number { - const startOfToday = epochSeconds - (epochSeconds % Number(DAY)); - return startOfToday + Number(DAY); -} diff --git a/smart-contracts/test/diamond/SessionRouter/writeFunctions.test.ts b/smart-contracts/test/diamond/SessionRouter/writeFunctions.test.ts deleted file mode 100644 index cb3cbdc5..00000000 --- a/smart-contracts/test/diamond/SessionRouter/writeFunctions.test.ts +++ /dev/null @@ -1,348 +0,0 @@ -import { - IBidStorage, - IMarketplace__factory, - IModelRegistry__factory, - IModelStorage, - IProviderRegistry__factory, - IProviderStorage, - ISessionRouter__factory, - LumerinDiamond, - Marketplace, - ModelRegistry, - MorpheusToken, - ProviderRegistry, - SessionRouter, -} from '@ethers-v6'; -import { SignerWithAddress } from '@nomicfoundation/hardhat-ethers/signers'; -import { expect } from 'chai'; -import { Addressable, Fragment } from 'ethers'; -import { ethers } from 'hardhat'; - -import { ZERO_ADDR } from '@/scripts/utils/constants'; -import { getHex, randomBytes32, wei } from '@/scripts/utils/utils'; -import { FacetAction } from '@/test/helpers/enums'; -import { getDefaultPools } from '@/test/helpers/pool-helper'; -import { Reverter } from '@/test/helpers/reverter'; -import { getCurrentBlockTime, setTime } from '@/utils/block-helper'; -import { getProviderApproval, getReport } from '@/utils/provider-helper'; -import { HOUR, YEAR } from '@/utils/time'; - -describe('Session router', () => { - const reverter = new Reverter(); - - let OWNER: SignerWithAddress; - let SECOND: SignerWithAddress; - let THIRD: SignerWithAddress; - let PROVIDER: SignerWithAddress; - - let diamond: LumerinDiamond; - let marketplace: Marketplace; - let modelRegistry: ModelRegistry; - let providerRegistry: ProviderRegistry; - let sessionRouter: SessionRouter; - - let MOR: MorpheusToken; - - async function deployProvider(): Promise< - IProviderStorage.ProviderStruct & { - address: Addressable; - } - > { - const provider = { - endpoint: 'localhost:3334', - stake: wei(100), - createdAt: 0n, - limitPeriodEnd: 0n, - limitPeriodEarned: 0n, - isDeleted: false, - address: PROVIDER, - }; - - await MOR.transfer(PROVIDER, provider.stake * 100n); - await MOR.connect(PROVIDER).approve(sessionRouter, provider.stake); - - await providerRegistry.connect(PROVIDER).providerRegister(provider.address, provider.stake, provider.endpoint); - provider.createdAt = await getCurrentBlockTime(); - provider.limitPeriodEnd = provider.createdAt + YEAR; - - return provider; - } - - async function deployModel(): Promise< - IModelStorage.ModelStruct & { - modelId: string; - } - > { - const model = { - modelId: randomBytes32(), - ipfsCID: getHex(Buffer.from('ipfs://ipfsaddress')), - fee: 100, - stake: 100, - owner: OWNER, - name: 'Llama 2.0', - tags: ['llama', 'animal', 'cute'], - createdAt: 0n, - isDeleted: false, - }; - - await MOR.approve(modelRegistry, model.stake); - - await modelRegistry.modelRegister( - model.modelId, - model.ipfsCID, - model.fee, - model.stake, - model.owner, - model.name, - model.tags, - ); - model.createdAt = await getCurrentBlockTime(); - - return model; - } - - async function deployBid(model: any): Promise< - [ - IBidStorage.BidStruct & { - id: string; - modelId: string; - }, - { - durationSeconds: bigint; - totalCost: bigint; - pricePerSecond: bigint; - user: SignerWithAddress; - provider: SignerWithAddress; - modelId: any; - bidId: string; - stake: bigint; - }, - ] - > { - let bid = { - id: '', - modelId: model.modelId, - pricePerSecond: wei(0.0001), - nonce: 0, - createdAt: 0n, - deletedAt: 0, - provider: PROVIDER, - }; - - await MOR.approve(modelRegistry, 10000n * 10n ** 18n); - - bid.id = await marketplace.connect(PROVIDER).postModelBid.staticCall(bid.provider, bid.modelId, bid.pricePerSecond); - await marketplace.connect(PROVIDER).postModelBid(bid.provider, bid.modelId, bid.pricePerSecond); - - bid.createdAt = await getCurrentBlockTime(); - - // generating data for sample session - const durationSeconds = HOUR; - const totalCost = bid.pricePerSecond * durationSeconds; - const totalSupply = await sessionRouter.totalMORSupply(await getCurrentBlockTime()); - const todaysBudget = await sessionRouter.getTodaysBudget(await getCurrentBlockTime()); - - const expectedSession = { - durationSeconds, - totalCost, - pricePerSecond: bid.pricePerSecond, - user: SECOND, - provider: bid.provider, - modelId: bid.modelId, - bidId: bid.id, - stake: (totalCost * totalSupply) / todaysBudget, - }; - - await MOR.transfer(SECOND, expectedSession.stake); - await MOR.connect(SECOND).approve(modelRegistry, expectedSession.stake); - - return [bid, expectedSession]; - } - - async function openSession(session: any) { - // open session - const { msg, signature } = await getProviderApproval(PROVIDER, await SECOND.getAddress(), session.bidId); - const sessionId = await sessionRouter.connect(SECOND).openSession.staticCall(session.stake, msg, signature); - await sessionRouter.connect(SECOND).openSession(session.stake, msg, signature); - - return sessionId; - } - - before('setup', async () => { - [OWNER, SECOND, THIRD, PROVIDER] = await ethers.getSigners(); - - const LinearDistributionIntervalDecrease = await ethers.getContractFactory('LinearDistributionIntervalDecrease'); - const linearDistributionIntervalDecrease = await LinearDistributionIntervalDecrease.deploy(); - - const [LumerinDiamond, Marketplace, ModelRegistry, ProviderRegistry, SessionRouter, MorpheusToken] = - await Promise.all([ - ethers.getContractFactory('LumerinDiamond'), - ethers.getContractFactory('Marketplace'), - ethers.getContractFactory('ModelRegistry'), - ethers.getContractFactory('ProviderRegistry'), - ethers.getContractFactory('SessionRouter', { - libraries: { - LinearDistributionIntervalDecrease: linearDistributionIntervalDecrease, - }, - }), - ethers.getContractFactory('MorpheusToken'), - ]); - - [diamond, marketplace, modelRegistry, providerRegistry, sessionRouter, MOR] = await Promise.all([ - LumerinDiamond.deploy(), - Marketplace.deploy(), - ModelRegistry.deploy(), - ProviderRegistry.deploy(), - SessionRouter.deploy(), - MorpheusToken.deploy(), - ]); - - await diamond.__LumerinDiamond_init(); - - await diamond['diamondCut((address,uint8,bytes4[])[])']([ - { - facetAddress: marketplace, - action: FacetAction.Add, - functionSelectors: IMarketplace__factory.createInterface() - .fragments.filter(Fragment.isFunction) - .map((f) => f.selector), - }, - { - facetAddress: providerRegistry, - action: FacetAction.Add, - functionSelectors: IProviderRegistry__factory.createInterface() - .fragments.filter(Fragment.isFunction) - .map((f) => f.selector), - }, - { - facetAddress: sessionRouter, - action: FacetAction.Add, - functionSelectors: ISessionRouter__factory.createInterface() - .fragments.filter(Fragment.isFunction) - .map((f) => f.selector), - }, - { - facetAddress: modelRegistry, - action: FacetAction.Add, - functionSelectors: IModelRegistry__factory.createInterface() - .fragments.filter(Fragment.isFunction) - .map((f) => f.selector), - }, - ]); - - marketplace = marketplace.attach(diamond.target) as Marketplace; - providerRegistry = providerRegistry.attach(diamond.target) as ProviderRegistry; - modelRegistry = modelRegistry.attach(diamond.target) as ModelRegistry; - sessionRouter = sessionRouter.attach(diamond.target) as SessionRouter; - - await marketplace.__Marketplace_init(MOR); - await modelRegistry.__ModelRegistry_init(); - await providerRegistry.__ProviderRegistry_init(); - - await sessionRouter.__SessionRouter_init(OWNER, getDefaultPools()); - - await reverter.snapshot(); - }); - - afterEach(reverter.revert); - - describe('Diamond functionality', () => { - describe('#__SessionRouter_init', () => { - it('should set correct data after creation', async () => { - expect(await sessionRouter.getFundingAccount()).to.eq(await OWNER.getAddress()); - - const pools = await sessionRouter.pools(); - for (let i = 0; i < pools.length; i++) { - const pool = getDefaultPools()[i]; - expect(pools[i].payoutStart).to.eq(pool.payoutStart); - expect(pools[i].decreaseInterval).to.eq(pool.decreaseInterval); - expect(pools[i].initialReward).to.eq(pool.initialReward); - expect(pools[i].rewardDecrease).to.eq(pool.rewardDecrease); - } - }); - it('should revert if try to call init function twice', async () => { - const reason = 'Initializable: contract is already initialized'; - - await expect(sessionRouter.__SessionRouter_init(OWNER, getDefaultPools())).to.be.rejectedWith(reason); - }); - }); - }); - - describe('session write functions', () => { - let provider: IProviderStorage.ProviderStruct; - let model: IModelStorage.ModelStruct & { - modelId: string; - }; - let bid: IBidStorage.BidStruct & { - id: string; - modelId: string; - }; - let session: { - durationSeconds: bigint; - totalCost: bigint; - pricePerSecond: bigint; - user: SignerWithAddress; - provider: SignerWithAddress; - modelId: any; - bidId: string; - stake: bigint; - }; - let sessionId: string; - - beforeEach(async () => { - provider = await deployProvider(); - model = await deployModel(); - [bid, session] = await deployBid(model); - sessionId = await openSession(session); - }); - - it('should block erase if session not closed', async () => { - // check history - const sessionIds = await sessionRouter.getSessionsByUser(SECOND, 0, 10); - expect(sessionIds.length).to.equal(1); - expect((await sessionRouter.sessions(sessionIds[0])).id).to.equal(sessionId); - - // erase history fails - await expect(sessionRouter.connect(SECOND).deleteHistory(sessionId)).to.be.revertedWithCustomError( - sessionRouter, - 'SessionNotClosed', - ); - }); - - it('should block erase if caller is not an owner', async () => { - await expect(sessionRouter.connect(THIRD).deleteHistory(sessionId)).to.be.revertedWithCustomError( - sessionRouter, - 'NotOwnerOrUser', - ); - }); - - it('erase history', async () => { - // wait for half of the session - await setTime(Number((await getCurrentBlockTime()) + session.durationSeconds / 2n)); - - // close session - const report = await getReport(PROVIDER, sessionId, 10, 10); - await sessionRouter.connect(SECOND).closeSession(report.msg, report.signature); - - // check history - const sessionIds = await sessionRouter.getSessionsByUser(SECOND, 0, 10); - expect(sessionIds.length).to.equal(1); - expect((await sessionRouter.sessions(sessionIds[0])).id).to.equal(sessionId); - expect((await sessionRouter.sessions(sessionIds[0])).user).to.equal(await SECOND.getAddress()); - - // erase history - await sessionRouter.connect(SECOND).deleteHistory(sessionId); - - const sessionData = await sessionRouter.sessions(sessionId); - expect(sessionData.user).to.equal(ZERO_ADDR); - - // TODO: fix history so user is not exposed using getSessionsByUser - // const sessionIds2 = await sessionRouter.getSessionsByUser([ - // user.account.address, - // 0n, - // 10, - // ]); - // expect(sessionIds2.length).to.equal(0); - }); - }); -}); diff --git a/smart-contracts/test/diamond/facets/Marketplace.test.ts b/smart-contracts/test/diamond/facets/Marketplace.test.ts new file mode 100644 index 00000000..04172e7c --- /dev/null +++ b/smart-contracts/test/diamond/facets/Marketplace.test.ts @@ -0,0 +1,269 @@ +import { LumerinDiamond, Marketplace, ModelRegistry, MorpheusToken, ProviderRegistry } from '@ethers-v6'; +import { SignerWithAddress } from '@nomicfoundation/hardhat-ethers/signers'; +import { expect } from 'chai'; +import { ethers } from 'hardhat'; + +import { getHex, wei } from '@/scripts/utils/utils'; +import { + deployFacetMarketplace, + deployFacetModelRegistry, + deployFacetProviderRegistry, + deployFacetSessionRouter, + deployLumerinDiamond, + deployMORToken, +} from '@/test/helpers/deployers'; +import { Reverter } from '@/test/helpers/reverter'; +import { setNextTime } from '@/utils/block-helper'; + +describe('Marketplace', () => { + const reverter = new Reverter(); + + let OWNER: SignerWithAddress; + let SECOND: SignerWithAddress; + let PROVIDER: SignerWithAddress; + + let diamond: LumerinDiamond; + let marketplace: Marketplace; + let modelRegistry: ModelRegistry; + let providerRegistry: ProviderRegistry; + + let token: MorpheusToken; + + const modelId1 = getHex(Buffer.from('1')); + const modelId2 = getHex(Buffer.from('2')); + + before(async () => { + [OWNER, SECOND, PROVIDER] = await ethers.getSigners(); + + [diamond, token] = await Promise.all([deployLumerinDiamond(), deployMORToken()]); + + [providerRegistry, modelRegistry, , marketplace] = await Promise.all([ + deployFacetProviderRegistry(diamond), + deployFacetModelRegistry(diamond), + deployFacetSessionRouter(diamond, OWNER), + deployFacetMarketplace(diamond, token), + ]); + + await token.transfer(SECOND, wei(1000)); + await token.connect(SECOND).approve(providerRegistry, wei(1000)); + await token.approve(providerRegistry, wei(1000)); + await token.connect(SECOND).approve(modelRegistry, wei(1000)); + await token.approve(modelRegistry, wei(1000)); + await token.connect(SECOND).approve(marketplace, wei(1000)); + await token.approve(marketplace, wei(1000)); + + const ipfsCID = getHex(Buffer.from('ipfs://ipfsaddress')); + await providerRegistry.connect(SECOND).providerRegister(wei(100), 'test'); + await modelRegistry.connect(SECOND).modelRegister(modelId1, ipfsCID, 0, wei(100), 'name', ['tag_1']); + await modelRegistry.connect(SECOND).modelRegister(modelId2, ipfsCID, 0, wei(100), 'name', ['tag_1']); + + await reverter.snapshot(); + }); + + afterEach(reverter.revert); + + describe('#__Marketplace_init', () => { + it('should set correct data after creation', async () => { + expect(await marketplace.getToken()).to.eq(await token.getAddress()); + }); + it('should revert if try to call init function twice', async () => { + await expect(marketplace.__Marketplace_init(token)).to.be.rejectedWith( + 'Initializable: contract is already initialized', + ); + }); + }); + + describe('#setMarketplaceBidFee', async () => { + it('should set marketplace bid fee', async () => { + const fee = wei(100); + + await expect(marketplace.setMarketplaceBidFee(fee)).to.emit(marketplace, 'MaretplaceFeeUpdated').withArgs(fee); + + expect(await marketplace.getBidFee()).eq(fee); + }); + it('should throw error when caller is not an owner', async () => { + await expect(marketplace.connect(SECOND).setMarketplaceBidFee(100)).to.be.revertedWithCustomError( + diamond, + 'OwnableUnauthorizedAccount', + ); + }); + }); + + describe('#postModelBid', async () => { + beforeEach(async () => { + await marketplace.setMarketplaceBidFee(wei(1)); + }); + + it('should post a model bid', async () => { + await setNextTime(300); + await marketplace.connect(SECOND).postModelBid(modelId1, wei(10)); + + const bidId = await marketplace.getBidId(SECOND, modelId1, 0); + const data = await marketplace.getBid(bidId); + expect(data.provider).to.eq(SECOND); + expect(data.modelId).to.eq(modelId1); + expect(data.pricePerSecond).to.eq(wei(10)); + expect(data.nonce).to.eq(0); + expect(data.createdAt).to.eq(300); + expect(data.deletedAt).to.eq(0); + + expect(await token.balanceOf(marketplace)).to.eq(wei(301)); + expect(await token.balanceOf(SECOND)).to.eq(wei(699)); + + expect(await marketplace.getProviderBids(SECOND, 0, 10)).deep.eq([bidId]); + expect(await marketplace.getModelBids(modelId1, 0, 10)).deep.eq([bidId]); + expect(await marketplace.getProviderActiveBids(SECOND, 0, 10)).deep.eq([bidId]); + expect(await marketplace.getModelActiveBids(modelId1, 0, 10)).deep.eq([bidId]); + }); + it('should post few model bids', async () => { + await setNextTime(300); + await marketplace.connect(SECOND).postModelBid(modelId1, wei(10)); + await marketplace.connect(SECOND).postModelBid(modelId2, wei(20)); + + const bidId1 = await marketplace.getBidId(SECOND, modelId1, 0); + let data = await marketplace.getBid(bidId1); + expect(data.provider).to.eq(SECOND); + expect(data.modelId).to.eq(modelId1); + expect(data.pricePerSecond).to.eq(wei(10)); + expect(data.nonce).to.eq(0); + expect(data.createdAt).to.eq(300); + expect(data.deletedAt).to.eq(0); + + const bidId2 = await marketplace.getBidId(SECOND, modelId2, 0); + data = await marketplace.getBid(bidId2); + expect(data.provider).to.eq(SECOND); + expect(data.modelId).to.eq(modelId2); + expect(data.pricePerSecond).to.eq(wei(20)); + expect(data.nonce).to.eq(0); + expect(data.createdAt).to.eq(301); + expect(data.deletedAt).to.eq(0); + + expect(await token.balanceOf(marketplace)).to.eq(wei(302)); + expect(await token.balanceOf(SECOND)).to.eq(wei(698)); + + expect(await marketplace.getProviderBids(SECOND, 0, 10)).deep.eq([bidId1, bidId2]); + expect(await marketplace.getModelBids(modelId1, 0, 10)).deep.eq([bidId1]); + expect(await marketplace.getModelBids(modelId2, 0, 10)).deep.eq([bidId2]); + expect(await marketplace.getProviderActiveBids(SECOND, 0, 10)).deep.eq([bidId1, bidId2]); + expect(await marketplace.getModelActiveBids(modelId1, 0, 10)).deep.eq([bidId1]); + expect(await marketplace.getModelActiveBids(modelId2, 0, 10)).deep.eq([bidId2]); + }); + it('should post a new model bid and delete an old bid when an old bid is active', async () => { + await setNextTime(300); + await marketplace.connect(SECOND).postModelBid(modelId1, wei(10)); + await marketplace.connect(SECOND).postModelBid(modelId1, wei(20)); + + const bidId1 = await marketplace.getBidId(SECOND, modelId1, 0); + let data = await marketplace.getBid(bidId1); + expect(data.deletedAt).to.eq(301); + + const bidId2 = await marketplace.getBidId(SECOND, modelId1, 1); + data = await marketplace.getBid(bidId2); + expect(data.provider).to.eq(SECOND); + expect(data.modelId).to.eq(modelId1); + expect(data.pricePerSecond).to.eq(wei(20)); + expect(data.nonce).to.eq(1); + expect(data.createdAt).to.eq(301); + expect(data.deletedAt).to.eq(0); + + expect(await token.balanceOf(marketplace)).to.eq(wei(302)); + expect(await token.balanceOf(SECOND)).to.eq(wei(698)); + + expect(await marketplace.getProviderBids(SECOND, 0, 10)).deep.eq([bidId1, bidId2]); + expect(await marketplace.getModelBids(modelId1, 0, 10)).deep.eq([bidId1, bidId2]); + expect(await marketplace.getProviderActiveBids(SECOND, 0, 10)).deep.eq([bidId2]); + expect(await marketplace.getModelActiveBids(modelId1, 0, 10)).deep.eq([bidId2]); + }); + it('should post a new model bid and skip the old bid delete', async () => { + await setNextTime(300); + await marketplace.connect(SECOND).postModelBid(modelId1, wei(10)); + + const bidId1 = await marketplace.getBidId(SECOND, modelId1, 0); + await marketplace.connect(SECOND).deleteModelBid(bidId1); + await marketplace.connect(SECOND).postModelBid(modelId1, wei(20)); + }); + it('should throw error when the provider is deregistered', async () => { + await providerRegistry.connect(SECOND).providerDeregister(); + await expect(marketplace.connect(SECOND).postModelBid(modelId1, wei(10))).to.be.revertedWithCustomError( + marketplace, + 'MarketplaceProviderNotFound', + ); + }); + it('should throw error when the model is deregistered', async () => { + await modelRegistry.connect(SECOND).modelDeregister(modelId1); + await expect(marketplace.connect(SECOND).postModelBid(modelId1, wei(10))).to.be.revertedWithCustomError( + marketplace, + 'MarketplaceModelNotFound', + ); + }); + }); + + describe('#deleteModelBid', async () => { + it('should delete a bid', async () => { + await setNextTime(300); + await marketplace.connect(SECOND).postModelBid(modelId1, wei(10)); + + const bidId1 = await marketplace.getBidId(SECOND, modelId1, 0); + await marketplace.connect(SECOND).deleteModelBid(bidId1); + + const data = await marketplace.getBid(bidId1); + expect(data.deletedAt).to.eq(301); + expect(await marketplace.isBidActive(bidId1)).to.eq(false); + }); + it('should throw error when caller is not an owner', async () => { + await marketplace.connect(SECOND).postModelBid(modelId1, wei(10)); + + const bidId1 = await marketplace.getBidId(SECOND, modelId1, 0); + await expect(marketplace.connect(PROVIDER).deleteModelBid(bidId1)).to.be.revertedWithCustomError( + diamond, + 'OwnableUnauthorizedAccount', + ); + }); + it('should throw error when bid already deleted', async () => { + await marketplace.connect(SECOND).postModelBid(modelId1, wei(10)); + + const bidId1 = await marketplace.getBidId(SECOND, modelId1, 0); + await marketplace.connect(SECOND).deleteModelBid(bidId1); + await expect(marketplace.connect(SECOND).deleteModelBid(bidId1)).to.be.revertedWithCustomError( + marketplace, + 'MarketplaceActiveBidNotFound', + ); + }); + }); + + describe('#withdraw', async () => { + beforeEach(async () => { + await marketplace.setMarketplaceBidFee(wei(1)); + }); + + it('should withdraw fee, all fee balance', async () => { + await marketplace.connect(SECOND).postModelBid(modelId1, wei(10)); + expect(await marketplace.getFeeBalance()).to.eq(wei(1)); + + await marketplace.withdraw(PROVIDER, wei(999)); + + expect(await marketplace.getFeeBalance()).to.eq(wei(0)); + expect(await token.balanceOf(marketplace)).to.eq(wei(300)); + expect(await token.balanceOf(PROVIDER)).to.eq(wei(1)); + }); + it('should withdraw fee, part of fee balance', async () => { + await marketplace.connect(SECOND).postModelBid(modelId1, wei(10)); + expect(await marketplace.getFeeBalance()).to.eq(wei(1)); + + await marketplace.withdraw(PROVIDER, wei(0.1)); + + expect(await marketplace.getFeeBalance()).to.eq(wei(0.9)); + expect(await token.balanceOf(marketplace)).to.eq(wei(300.9)); + expect(await token.balanceOf(PROVIDER)).to.eq(wei(0.1)); + }); + it('should throw error when caller is not an owner', async () => { + await expect(marketplace.connect(SECOND).withdraw(PROVIDER, wei(1))).to.be.revertedWithCustomError( + diamond, + 'OwnableUnauthorizedAccount', + ); + }); + }); +}); + +// npm run generate-types && npx hardhat test "test/diamond/facets/Marketplace.test.ts" +// npx hardhat coverage --solcoverjs ./.solcover.ts --testfiles "test/diamond/facets/Marketplace.test.ts" diff --git a/smart-contracts/test/diamond/facets/ModelRegistry.test.ts b/smart-contracts/test/diamond/facets/ModelRegistry.test.ts new file mode 100644 index 00000000..12e74d88 --- /dev/null +++ b/smart-contracts/test/diamond/facets/ModelRegistry.test.ts @@ -0,0 +1,224 @@ +import { LumerinDiamond, Marketplace, ModelRegistry, MorpheusToken, ProviderRegistry } from '@ethers-v6'; +import { SignerWithAddress } from '@nomicfoundation/hardhat-ethers/signers'; +import { expect } from 'chai'; +import { ethers } from 'hardhat'; + +import { getHex, wei } from '@/scripts/utils/utils'; +import { + deployFacetMarketplace, + deployFacetModelRegistry, + deployFacetProviderRegistry, + deployFacetSessionRouter, + deployLumerinDiamond, + deployMORToken, +} from '@/test/helpers/deployers'; +import { Reverter } from '@/test/helpers/reverter'; +import { setNextTime } from '@/utils/block-helper'; + +describe('ModelRegistry', () => { + const reverter = new Reverter(); + + let OWNER: SignerWithAddress; + let SECOND: SignerWithAddress; + + let diamond: LumerinDiamond; + let providerRegistry: ProviderRegistry; + let modelRegistry: ModelRegistry; + let marketplace: Marketplace; + + let token: MorpheusToken; + + const modelId = getHex(Buffer.from('1')); + const ipfsCID = getHex(Buffer.from('ipfs://ipfsaddress')); + + before(async () => { + [OWNER, SECOND] = await ethers.getSigners(); + + [diamond, token] = await Promise.all([deployLumerinDiamond(), deployMORToken()]); + + [providerRegistry, modelRegistry, , marketplace] = await Promise.all([ + deployFacetProviderRegistry(diamond), + deployFacetModelRegistry(diamond), + deployFacetSessionRouter(diamond, OWNER), + deployFacetMarketplace(diamond, token), + ]); + + await token.transfer(SECOND, wei(1000)); + await token.connect(SECOND).approve(providerRegistry, wei(1000)); + await token.approve(providerRegistry, wei(1000)); + await token.connect(SECOND).approve(modelRegistry, wei(1000)); + await token.approve(modelRegistry, wei(1000)); + await token.connect(SECOND).approve(marketplace, wei(1000)); + await token.approve(marketplace, wei(1000)); + + await reverter.snapshot(); + }); + + afterEach(reverter.revert); + + describe('#__ModelRegistry_init', () => { + it('should revert if try to call init function twice', async () => { + await expect(modelRegistry.__ModelRegistry_init()).to.be.rejectedWith( + 'Initializable: contract is already initialized', + ); + }); + }); + + describe('#modelSetMinStake', async () => { + it('should set min stake', async () => { + const minStake = wei(100); + + await expect(modelRegistry.modelSetMinStake(minStake)) + .to.emit(modelRegistry, 'ModelMinStakeUpdated') + .withArgs(minStake); + + expect(await modelRegistry.getModelMinimumStake()).eq(minStake); + }); + + it('should throw error when caller is not an owner', async () => { + await expect(modelRegistry.connect(SECOND).modelSetMinStake(100)).to.be.revertedWithCustomError( + diamond, + 'OwnableUnauthorizedAccount', + ); + }); + }); + + describe('#getModelIds', async () => { + it('should set min stake', async () => { + await setNextTime(300); + await modelRegistry + .connect(SECOND) + .modelRegister(getHex(Buffer.from('1')), ipfsCID, 0, wei(100), 'name1', ['tag_1']); + await modelRegistry + .connect(SECOND) + .modelRegister(getHex(Buffer.from('2')), ipfsCID, 0, wei(100), 'name2', ['tag_1']); + + const modelIds = await modelRegistry.getModelIds(0, 10); + expect(modelIds.length).to.eq(2); + }); + + it('should throw error when caller is not an owner', async () => { + await expect(modelRegistry.connect(SECOND).modelSetMinStake(100)).to.be.revertedWithCustomError( + diamond, + 'OwnableUnauthorizedAccount', + ); + }); + }); + + describe('#modelRegister', async () => { + it('should register a new model', async () => { + await setNextTime(300); + await modelRegistry.connect(SECOND).modelRegister(modelId, ipfsCID, 0, wei(100), 'name', ['tag_1']); + + const data = await modelRegistry.getModel(modelId); + expect(data.ipfsCID).to.eq(ipfsCID); + expect(data.fee).to.eq(0); + expect(data.stake).to.eq(wei(100)); + expect(data.owner).to.eq(SECOND); + expect(data.name).to.eq('name'); + expect(data.tags).deep.eq(['tag_1']); + expect(data.createdAt).to.eq(300); + expect(data.isDeleted).to.eq(false); + expect(await modelRegistry.getIsModelActive(modelId)).to.eq(true); + + expect(await token.balanceOf(modelRegistry)).to.eq(wei(100)); + expect(await token.balanceOf(SECOND)).to.eq(wei(900)); + + await modelRegistry.connect(SECOND).modelRegister(modelId, ipfsCID, 0, wei(0), 'name', ['tag_1']); + }); + it('should add stake to existed model', async () => { + const ipfsCID2 = getHex(Buffer.from('ipfs://ipfsaddress/2')); + + await setNextTime(300); + await modelRegistry.connect(SECOND).modelRegister(modelId, ipfsCID, 0, wei(100), 'name', ['tag_1']); + await modelRegistry.connect(SECOND).modelRegister(modelId, ipfsCID2, 1, wei(300), 'name2', ['tag_1', 'tag_2']); + + const data = await modelRegistry.getModel(modelId); + expect(data.ipfsCID).to.eq(ipfsCID2); + expect(data.fee).to.eq(1); + expect(data.stake).to.eq(wei(400)); + expect(data.owner).to.eq(SECOND); + expect(data.name).to.eq('name2'); + expect(data.tags).deep.eq(['tag_1', 'tag_2']); + expect(data.createdAt).to.eq(300); + expect(data.isDeleted).to.eq(false); + expect(await modelRegistry.getIsModelActive(modelId)).to.eq(true); + + expect(await token.balanceOf(modelRegistry)).to.eq(wei(400)); + expect(await token.balanceOf(SECOND)).to.eq(wei(600)); + }); + it('should activate deregistered model', async () => { + await setNextTime(300); + await modelRegistry.connect(SECOND).modelRegister(modelId, ipfsCID, 0, wei(100), 'name', ['tag_1']); + await modelRegistry.connect(SECOND).modelDeregister(modelId); + + let data = await modelRegistry.getModel(modelId); + expect(data.isDeleted).to.eq(true); + expect(await modelRegistry.getIsModelActive(modelId)).to.eq(false); + + await modelRegistry.connect(SECOND).modelRegister(modelId, ipfsCID, 4, wei(200), 'name3', ['tag_3']); + + data = await modelRegistry.getModel(modelId); + expect(data.ipfsCID).to.eq(ipfsCID); + expect(data.fee).to.eq(4); + expect(data.stake).to.eq(wei(200)); + expect(data.owner).to.eq(SECOND); + expect(data.name).to.eq('name3'); + expect(data.tags).deep.eq(['tag_3']); + expect(data.createdAt).to.eq(300); + expect(data.isDeleted).to.eq(false); + expect(await modelRegistry.getIsModelActive(modelId)).to.eq(true); + }); + it('should throw error when the caller is not a model owner', async () => { + await modelRegistry.connect(SECOND).modelRegister(modelId, ipfsCID, 0, wei(100), 'name', ['tag_1']); + await expect( + modelRegistry.connect(OWNER).modelRegister(modelId, ipfsCID, 0, wei(100), 'name', ['tag_1']), + ).to.be.revertedWithCustomError(modelRegistry, 'OwnableUnauthorizedAccount'); + }); + it('should throw error when the stake is too low', async () => { + await modelRegistry.modelSetMinStake(wei(2)); + await expect( + modelRegistry.connect(SECOND).modelRegister(modelId, ipfsCID, 0, wei(1), 'name', ['tag_1']), + ).to.be.revertedWithCustomError(modelRegistry, 'ModelStakeTooLow'); + }); + }); + + describe('#modelDeregister', async () => { + it('should deregister the model', async () => { + await setNextTime(300); + await modelRegistry.connect(SECOND).modelRegister(modelId, ipfsCID, 0, wei(100), 'name', ['tag_1']); + await modelRegistry.connect(SECOND).modelDeregister(modelId); + + expect((await modelRegistry.getModel(modelId)).isDeleted).to.equal(true); + expect(await modelRegistry.getIsModelActive(modelId)).to.eq(false); + expect(await token.balanceOf(modelRegistry)).to.eq(0); + expect(await token.balanceOf(SECOND)).to.eq(wei(1000)); + }); + it('should throw error when the caller is not an owner or specified address', async () => { + await expect(modelRegistry.connect(SECOND).modelDeregister(modelId)).to.be.revertedWithCustomError( + modelRegistry, + 'OwnableUnauthorizedAccount', + ); + }); + it('should throw error when model has active bids', async () => { + await providerRegistry.connect(SECOND).providerRegister(wei(100), 'test'); + await modelRegistry.connect(SECOND).modelRegister(modelId, ipfsCID, 0, wei(100), 'name', ['tag_1']); + await marketplace.connect(SECOND).postModelBid(modelId, wei(10)); + await expect(modelRegistry.connect(SECOND).modelDeregister(modelId)).to.be.revertedWithCustomError( + modelRegistry, + 'ModelHasActiveBids', + ); + }); + it('should throw error when delete model few times', async () => { + await modelRegistry.connect(SECOND).modelRegister(modelId, ipfsCID, 0, wei(100), 'name', ['tag_1']); + await modelRegistry.connect(SECOND).modelDeregister(modelId); + await expect(modelRegistry.connect(SECOND).modelDeregister(modelId)).to.be.revertedWithCustomError( + modelRegistry, + 'ModelHasAlreadyDeregistered', + ); + }); + }); +}); + +// npm run generate-types && npx hardhat test "test/diamond/facets/ModelRegistry.test.ts" +// npx hardhat coverage --solcoverjs ./.solcover.ts --testfiles "test/diamond/facets/ModelRegistry.test.ts" diff --git a/smart-contracts/test/diamond/facets/ProviderRegistry.test.ts b/smart-contracts/test/diamond/facets/ProviderRegistry.test.ts new file mode 100644 index 00000000..aa095630 --- /dev/null +++ b/smart-contracts/test/diamond/facets/ProviderRegistry.test.ts @@ -0,0 +1,202 @@ +import { LumerinDiamond, Marketplace, ModelRegistry, MorpheusToken, ProviderRegistry } from '@ethers-v6'; +import { SignerWithAddress } from '@nomicfoundation/hardhat-ethers/signers'; +import { expect } from 'chai'; +import { ethers } from 'hardhat'; + +import { getHex, wei } from '@/scripts/utils/utils'; +import { + deployFacetMarketplace, + deployFacetModelRegistry, + deployFacetProviderRegistry, + deployFacetSessionRouter, + deployLumerinDiamond, + deployMORToken, +} from '@/test/helpers/deployers'; +import { Reverter } from '@/test/helpers/reverter'; +import { setNextTime } from '@/utils/block-helper'; +import { YEAR } from '@/utils/time'; + +describe('ProviderRegistry', () => { + const reverter = new Reverter(); + + let OWNER: SignerWithAddress; + let PROVIDER: SignerWithAddress; + + let diamond: LumerinDiamond; + let providerRegistry: ProviderRegistry; + let modelRegistry: ModelRegistry; + let marketplace: Marketplace; + + let token: MorpheusToken; + + const modelId = getHex(Buffer.from('1')); + const ipfsCID = getHex(Buffer.from('ipfs://ipfsaddress')); + + before(async () => { + [OWNER, PROVIDER] = await ethers.getSigners(); + + [diamond, token] = await Promise.all([deployLumerinDiamond(), deployMORToken()]); + + [providerRegistry, modelRegistry, , marketplace] = await Promise.all([ + deployFacetProviderRegistry(diamond), + deployFacetModelRegistry(diamond), + deployFacetSessionRouter(diamond, OWNER), + deployFacetMarketplace(diamond, token), + ]); + + await token.transfer(PROVIDER, wei(1000)); + await token.connect(PROVIDER).approve(providerRegistry, wei(1000)); + await token.approve(providerRegistry, wei(1000)); + + await reverter.snapshot(); + }); + + afterEach(reverter.revert); + + describe('#__ProviderRegistry_init', () => { + it('should revert if try to call init function twice', async () => { + await expect(providerRegistry.__ProviderRegistry_init()).to.be.rejectedWith( + 'Initializable: contract is already initialized', + ); + }); + }); + + describe('#providerSetMinStake', async () => { + it('should set min stake', async () => { + const minStake = wei(100); + + await expect(providerRegistry.providerSetMinStake(minStake)) + .to.emit(providerRegistry, 'ProviderMinStakeUpdated') + .withArgs(minStake); + + expect(await providerRegistry.getProviderMinimumStake()).eq(minStake); + }); + + it('should throw error when caller is not an owner', async () => { + await expect(providerRegistry.connect(PROVIDER).providerSetMinStake(100)).to.be.revertedWithCustomError( + diamond, + 'OwnableUnauthorizedAccount', + ); + }); + }); + + describe('#providerRegister', async () => { + it('should register a new provider', async () => { + await setNextTime(300); + await providerRegistry.connect(PROVIDER).providerRegister(wei(100), 'test'); + + const data = await providerRegistry.getProvider(PROVIDER); + + expect(data.endpoint).to.eq('test'); + expect(data.stake).to.eq(wei(100)); + expect(data.createdAt).to.eq(300); + expect(data.limitPeriodEnd).to.eq(YEAR + 300); + expect(data.limitPeriodEarned).to.eq(0); + expect(data.isDeleted).to.eq(false); + expect(await providerRegistry.getIsProviderActive(PROVIDER)).to.eq(true); + + expect(await token.balanceOf(providerRegistry)).to.eq(wei(100)); + expect(await token.balanceOf(PROVIDER)).to.eq(wei(900)); + + await providerRegistry.connect(PROVIDER).providerRegister(wei(0), 'test'); + }); + it('should add stake to existed provider', async () => { + await setNextTime(300); + await providerRegistry.connect(PROVIDER).providerRegister(wei(100), 'test'); + await providerRegistry.connect(PROVIDER).providerRegister(wei(300), 'test2'); + + const data = await providerRegistry.getProvider(PROVIDER); + + expect(data.endpoint).to.eq('test2'); + expect(data.stake).to.eq(wei(400)); + expect(data.createdAt).to.eq(300); + expect(data.limitPeriodEnd).to.eq(YEAR + 300); + expect(data.limitPeriodEarned).to.eq(0); + expect(data.isDeleted).to.eq(false); + expect(await providerRegistry.getIsProviderActive(PROVIDER)).to.eq(true); + + expect(await token.balanceOf(providerRegistry)).to.eq(wei(400)); + expect(await token.balanceOf(PROVIDER)).to.eq(wei(600)); + }); + it('should activate deregistered provider', async () => { + await setNextTime(300); + await providerRegistry.connect(PROVIDER).providerRegister(wei(100), 'test'); + await setNextTime(301 + YEAR); + await providerRegistry.connect(PROVIDER).providerDeregister(); + + let data = await providerRegistry.getProvider(PROVIDER); + expect(data.isDeleted).to.eq(true); + expect(await providerRegistry.getIsProviderActive(PROVIDER)).to.eq(false); + + await providerRegistry.connect(PROVIDER).providerRegister(wei(1), 'test2'); + data = await providerRegistry.getProvider(PROVIDER); + + expect(data.endpoint).to.eq('test2'); + expect(data.stake).to.eq(wei(1)); + expect(data.createdAt).to.eq(300); + expect(data.limitPeriodEnd).to.eq(YEAR + 300); + expect(data.limitPeriodEarned).to.eq(0); + expect(data.isDeleted).to.eq(false); + expect(await providerRegistry.getIsProviderActive(PROVIDER)).to.eq(true); + }); + it('should throw error when the stake is too low', async () => { + await providerRegistry.providerSetMinStake(wei(2)); + await expect(providerRegistry.connect(PROVIDER).providerRegister(wei(0), '')).to.be.revertedWithCustomError( + providerRegistry, + 'ProviderStakeTooLow', + ); + }); + }); + + describe('#providerDeregister', async () => { + it('should deregister the provider', async () => { + await setNextTime(300); + await providerRegistry.connect(PROVIDER).providerRegister(wei(100), 'test'); + await setNextTime(301 + YEAR); + await providerRegistry.connect(PROVIDER).providerDeregister(); + + expect((await providerRegistry.getProvider(PROVIDER)).isDeleted).to.equal(true); + expect(await providerRegistry.getIsProviderActive(PROVIDER)).to.eq(false); + expect(await token.balanceOf(providerRegistry)).to.eq(0); + expect(await token.balanceOf(PROVIDER)).to.eq(wei(1000)); + }); + it('should deregister the provider without transfer', async () => { + await providerRegistry.providerSetMinStake(0); + await setNextTime(300); + await providerRegistry.connect(PROVIDER).providerRegister(wei(0), 'test'); + await setNextTime(301 + YEAR); + await providerRegistry.connect(PROVIDER).providerDeregister(); + + expect((await providerRegistry.getProvider(PROVIDER)).isDeleted).to.equal(true); + expect(await providerRegistry.getIsProviderActive(PROVIDER)).to.eq(false); + expect(await token.balanceOf(providerRegistry)).to.eq(0); + expect(await token.balanceOf(PROVIDER)).to.eq(wei(1000)); + }); + it('should throw error when provider is not found', async () => { + await expect(providerRegistry.connect(OWNER).providerDeregister()).to.be.revertedWithCustomError( + providerRegistry, + 'ProviderNotFound', + ); + }); + it('should throw error when provider has active bids', async () => { + await providerRegistry.connect(PROVIDER).providerRegister(wei(100), 'test'); + await modelRegistry.connect(PROVIDER).modelRegister(modelId, ipfsCID, 0, wei(100), 'name', ['tag_1']); + await marketplace.connect(PROVIDER).postModelBid(modelId, wei(10)); + await expect(providerRegistry.connect(PROVIDER).providerDeregister()).to.be.revertedWithCustomError( + providerRegistry, + 'ProviderHasActiveBids', + ); + }); + it('should throw error when delete provider few times', async () => { + await providerRegistry.connect(OWNER).providerRegister(wei(100), 'test'); + await providerRegistry.connect(OWNER).providerDeregister(); + await expect(providerRegistry.connect(OWNER).providerDeregister()).to.be.revertedWithCustomError( + providerRegistry, + 'ProviderHasAlreadyDeregistered', + ); + }); + }); +}); + +// npm run generate-types && npx hardhat test "test/diamond/facets/ProviderRegistry.test.ts" +// npx hardhat coverage --solcoverjs ./.solcover.ts --testfiles "test/diamond/facets/ProviderRegistry.test.ts" diff --git a/smart-contracts/test/diamond/facets/SessionRouter.test.ts b/smart-contracts/test/diamond/facets/SessionRouter.test.ts new file mode 100644 index 00000000..366ccda5 --- /dev/null +++ b/smart-contracts/test/diamond/facets/SessionRouter.test.ts @@ -0,0 +1,830 @@ +import { LumerinDiamond, Marketplace, ModelRegistry, MorpheusToken, ProviderRegistry, SessionRouter } from '@ethers-v6'; +import { SignerWithAddress } from '@nomicfoundation/hardhat-ethers/signers'; +import { expect } from 'chai'; +import { ethers } from 'hardhat'; + +import { getHex, wei } from '@/scripts/utils/utils'; +import { + deployFacetMarketplace, + deployFacetModelRegistry, + deployFacetProviderRegistry, + deployFacetSessionRouter, + deployLumerinDiamond, + deployMORToken, +} from '@/test/helpers/deployers'; +import { payoutStart } from '@/test/helpers/pool-helper'; +import { Reverter } from '@/test/helpers/reverter'; +import { setTime } from '@/utils/block-helper'; +import { getProviderApproval, getReceipt } from '@/utils/provider-helper'; +import { DAY } from '@/utils/time'; + +describe('SessionRouter', () => { + const reverter = new Reverter(); + + let OWNER: SignerWithAddress; + let SECOND: SignerWithAddress; + let FUNDING: SignerWithAddress; + let PROVIDER: SignerWithAddress; + + let diamond: LumerinDiamond; + let marketplace: Marketplace; + let modelRegistry: ModelRegistry; + let providerRegistry: ProviderRegistry; + let sessionRouter: SessionRouter; + + let token: MorpheusToken; + + let bidId = ''; + const modelId = getHex(Buffer.from('1')); + const bidPricePerSecond = wei(0.0001); + + before(async () => { + [OWNER, SECOND, FUNDING, PROVIDER] = await ethers.getSigners(); + + [diamond, token] = await Promise.all([deployLumerinDiamond(), deployMORToken()]); + + [providerRegistry, modelRegistry, sessionRouter, marketplace] = await Promise.all([ + deployFacetProviderRegistry(diamond), + deployFacetModelRegistry(diamond), + deployFacetSessionRouter(diamond, FUNDING), + deployFacetMarketplace(diamond, token), + ]); + + await token.transfer(SECOND, wei(10000)); + await token.transfer(PROVIDER, wei(10000)); + await token.transfer(FUNDING, wei(10000)); + await token.connect(PROVIDER).approve(providerRegistry, wei(10000)); + await token.connect(PROVIDER).approve(modelRegistry, wei(10000)); + await token.connect(PROVIDER).approve(marketplace, wei(10000)); + await token.connect(SECOND).approve(sessionRouter, wei(10000)); + await token.connect(FUNDING).approve(sessionRouter, wei(10000)); + + const ipfsCID = getHex(Buffer.from('ipfs://ipfsaddress')); + await providerRegistry.connect(PROVIDER).providerRegister(wei(0.2), 'test'); + await modelRegistry.connect(PROVIDER).modelRegister(modelId, ipfsCID, 0, wei(100), 'name', ['tag_1']); + + await marketplace.connect(PROVIDER).postModelBid(modelId, bidPricePerSecond); + bidId = await marketplace.getBidId(PROVIDER, modelId, 0); + + await reverter.snapshot(); + }); + + afterEach(reverter.revert); + + describe('#__SessionRouter_init', () => { + it('should set correct data after creation', async () => { + expect(await sessionRouter.getFundingAccount()).to.eq(FUNDING); + expect((await sessionRouter.getPools()).length).to.eq(5); + }); + it('should revert if try to call init function twice', async () => { + await expect(sessionRouter.__SessionRouter_init(FUNDING, [])).to.be.rejectedWith( + 'Initializable: contract is already initialized', + ); + }); + }); + + describe('#getTodaysBudget', () => { + it('should return not zero amount', async () => { + expect(await sessionRouter.getTodaysBudget(payoutStart)).to.eq(0); + expect(await sessionRouter.getTodaysBudget(payoutStart + 10 * DAY)).to.greaterThan(0); + }); + }); + + describe('#setPoolConfig', () => { + it('should reset pool', async () => { + await sessionRouter.setPoolConfig(0, { + payoutStart: 0, + decreaseInterval: DAY, + initialReward: wei(1000), + rewardDecrease: wei(10), + }); + + const pool = await sessionRouter.getPool(0); + expect(pool.payoutStart).to.eq(0); + expect(pool.decreaseInterval).to.eq(DAY); + expect(pool.initialReward).to.eq(wei(1000)); + expect(pool.rewardDecrease).to.eq(wei(10)); + }); + it('should throw error when the pool index is invalid', async () => { + await expect( + sessionRouter.setPoolConfig(100, { + payoutStart: 0, + decreaseInterval: DAY, + initialReward: wei(1000), + rewardDecrease: wei(10), + }), + ).to.be.revertedWithCustomError(sessionRouter, 'SessionPoolIndexOutOfBounds'); + }); + it('should throw error when the caller is invalid', async () => { + await expect( + sessionRouter.connect(SECOND).setPoolConfig(0, { + payoutStart: 0, + decreaseInterval: DAY, + initialReward: wei(1000), + rewardDecrease: wei(10), + }), + ).to.be.revertedWithCustomError(sessionRouter, 'OwnableUnauthorizedAccount'); + }); + }); + + describe('#openSession', () => { + let tokenBalBefore = 0n; + let secondBalBefore = 0n; + + beforeEach(async () => { + tokenBalBefore = await token.balanceOf(sessionRouter); + secondBalBefore = await token.balanceOf(SECOND); + }); + it('should open session', async () => { + await setTime(payoutStart + 10 * DAY); + const { msg, signature } = await getProviderApproval(PROVIDER, SECOND, bidId); + await sessionRouter.connect(SECOND).openSession(wei(50), msg, signature); + + const sessionId = await sessionRouter.getSessionId(SECOND, PROVIDER, bidId, 0); + const data = await sessionRouter.getSession(sessionId); + expect(data.user).to.eq(SECOND); + expect(data.bidId).to.eq(bidId); + expect(data.stake).to.eq(wei(50)); + expect(data.closeoutReceipt).to.eq('0x'); + expect(data.closeoutType).to.eq(0); + expect(data.providerWithdrawnAmount).to.eq(0); + expect(data.openedAt).to.eq(payoutStart + 10 * DAY + 1); + expect(data.endsAt).to.greaterThan(data.openedAt); + expect(data.closedAt).to.eq(0); + expect(data.isActive).to.eq(true); + + const tokenBalAfter = await token.balanceOf(sessionRouter); + expect(tokenBalAfter - tokenBalBefore).to.eq(wei(50)); + const secondBalAfter = await token.balanceOf(SECOND); + expect(secondBalBefore - secondBalAfter).to.eq(wei(50)); + + expect(await sessionRouter.getIsProviderApprovalUsed(msg)).to.eq(true); + expect(await sessionRouter.getUserSessions(SECOND, 0, 10)).to.deep.eq([sessionId]); + expect(await sessionRouter.getProviderSessions(PROVIDER, 0, 10)).to.deep.eq([sessionId]); + expect(await sessionRouter.getModelSessions(modelId, 0, 10)).to.deep.eq([sessionId]); + }); + it('should open two different session wit the same input params', async () => { + await setTime(payoutStart + 10 * DAY); + const { msg: msg1, signature: signature1 } = await getProviderApproval(PROVIDER, SECOND, bidId); + await setTime(payoutStart + 10 * DAY + 1); + const { msg: msg2, signature: signature2 } = await getProviderApproval(PROVIDER, SECOND, bidId); + await sessionRouter.connect(SECOND).openSession(wei(50), msg1, signature1); + await sessionRouter.connect(SECOND).openSession(wei(50), msg2, signature2); + + const sessionId1 = await sessionRouter.getSessionId(SECOND, PROVIDER, bidId, 0); + const sessionId2 = await sessionRouter.getSessionId(SECOND, PROVIDER, bidId, 1); + + const tokenBalAfter = await token.balanceOf(sessionRouter); + expect(tokenBalAfter - tokenBalBefore).to.eq(wei(100)); + const secondBalAfter = await token.balanceOf(SECOND); + expect(secondBalBefore - secondBalAfter).to.eq(wei(100)); + + expect(await sessionRouter.getIsProviderApprovalUsed(msg1)).to.eq(true); + expect(await sessionRouter.getIsProviderApprovalUsed(msg2)).to.eq(true); + expect(await sessionRouter.getUserSessions(SECOND, 0, 10)).to.deep.eq([sessionId1, sessionId2]); + expect(await sessionRouter.getProviderSessions(PROVIDER, 0, 10)).to.deep.eq([sessionId1, sessionId2]); + expect(await sessionRouter.getModelSessions(modelId, 0, 10)).to.deep.eq([sessionId1, sessionId2]); + + expect(await sessionRouter.getTotalSessions(PROVIDER)).to.eq(2); + }); + it('should open session with max duration', async () => { + await setTime(payoutStart + 10 * DAY); + const { msg, signature } = await getProviderApproval(PROVIDER, SECOND, bidId); + await sessionRouter.connect(SECOND).openSession(wei(10000), msg, signature); + + const sessionId = await sessionRouter.getSessionId(SECOND, PROVIDER, bidId, 0); + const data = await sessionRouter.getSession(sessionId); + expect(data.endsAt).to.eq(Number(data.openedAt.toString()) + DAY); + }); + it('should throw error when the approval is for an another user', async () => { + const { msg, signature } = await getProviderApproval(PROVIDER, OWNER, bidId); + await expect(sessionRouter.connect(SECOND).openSession(wei(50), msg, signature)).to.be.revertedWithCustomError( + sessionRouter, + 'SessionApprovedForAnotherUser', + ); + }); + it('should throw error when the approval is for an another chain', async () => { + const { msg, signature } = await getProviderApproval(PROVIDER, SECOND, bidId, 1n); + await expect(sessionRouter.connect(SECOND).openSession(wei(50), msg, signature)).to.be.revertedWithCustomError( + sessionRouter, + 'SesssionApprovedForAnotherChainId', + ); + }); + it('should throw error when an aprrove expired', async () => { + await setTime(payoutStart); + const { msg, signature } = await getProviderApproval(PROVIDER, SECOND, bidId); + await setTime(payoutStart + 600); + await expect(sessionRouter.connect(SECOND).openSession(wei(50), msg, signature)).to.be.revertedWithCustomError( + sessionRouter, + 'SesssionApproveExpired', + ); + }); + it('should throw error when the bid is not found', async () => { + const { msg, signature } = await getProviderApproval(PROVIDER, SECOND, getHex(Buffer.from('1'))); + await expect(sessionRouter.connect(SECOND).openSession(wei(50), msg, signature)).to.be.revertedWithCustomError( + sessionRouter, + 'SessionBidNotFound', + ); + }); + it('should throw error when the signature mismatch', async () => { + const { msg, signature } = await getProviderApproval(OWNER, SECOND, bidId); + await expect(sessionRouter.connect(SECOND).openSession(wei(50), msg, signature)).to.be.revertedWithCustomError( + sessionRouter, + 'SessionProviderSignatureMismatch', + ); + }); + it('should throw error when an approval duplicated', async () => { + await setTime(payoutStart + 10 * DAY); + const { msg, signature } = await getProviderApproval(PROVIDER, SECOND, bidId); + await sessionRouter.connect(SECOND).openSession(wei(50), msg, signature); + await expect(sessionRouter.connect(SECOND).openSession(wei(50), msg, signature)).to.be.revertedWithCustomError( + sessionRouter, + 'SessionDuplicateApproval', + ); + }); + it('should throw error when session duration too short', async () => { + const { msg, signature } = await getProviderApproval(PROVIDER, SECOND, bidId); + await expect(sessionRouter.connect(SECOND).openSession(wei(50), msg, signature)).to.be.revertedWithCustomError( + sessionRouter, + 'SessionTooShort', + ); + }); + }); + + describe('#closeSession', () => { + it('should close session and send rewards for the provider, with dispute, late closure', async () => { + const { sessionId, openedAt } = await _createSession(); + + const providerBalBefore = await token.balanceOf(PROVIDER); + const fundingBalBefore = await token.balanceOf(FUNDING); + + await setTime(openedAt + 5 * DAY); + const { msg: receiptMsg } = await getReceipt(PROVIDER, sessionId, 0, 0); + const { signature: receiptSig } = await getReceipt(OWNER, sessionId, 0, 0); + await sessionRouter.connect(SECOND).closeSession(receiptMsg, receiptSig); + + const session = await sessionRouter.getSession(sessionId); + const duration = session.endsAt - session.openedAt; + + expect(session.closedAt).to.eq(openedAt + 5 * DAY + 1); + expect(session.isActive).to.eq(false); + expect(session.closeoutReceipt).to.eq(receiptMsg); + expect(session.providerWithdrawnAmount).to.eq(bidPricePerSecond * duration); + + expect((await sessionRouter.getProvider(PROVIDER)).limitPeriodEarned).to.eq(bidPricePerSecond * duration); + + const providerBalAfter = await token.balanceOf(PROVIDER); + expect(providerBalAfter - providerBalBefore).to.eq(bidPricePerSecond * duration); + const fundingBalAfter = await token.balanceOf(FUNDING); + expect(fundingBalBefore - fundingBalAfter).to.eq(bidPricePerSecond * duration); + }); + it('should close session and send rewards for the provider, no dispute, early closure', async () => { + const { sessionId, openedAt } = await _createSession(); + const providerBalBefore = await token.balanceOf(PROVIDER); + + const fundingBalBefore = await token.balanceOf(FUNDING); + await setTime(openedAt + 200); + const { msg: receiptMsg, signature: receiptSig } = await getReceipt(PROVIDER, sessionId, 0, 0); + await sessionRouter.connect(SECOND).closeSession(receiptMsg, receiptSig); + + const session = await sessionRouter.getSession(sessionId); + const duration = session.closedAt - session.openedAt; + + expect(session.closedAt).to.eq(openedAt + 201); + expect(session.isActive).to.eq(false); + expect(session.closeoutReceipt).to.eq(receiptMsg); + expect(session.providerWithdrawnAmount).to.eq(bidPricePerSecond * duration); + + expect((await sessionRouter.getProvider(PROVIDER)).limitPeriodEarned).to.eq(bidPricePerSecond * duration); + + const providerBalAfter = await token.balanceOf(PROVIDER); + expect(providerBalAfter - providerBalBefore).to.eq(bidPricePerSecond * duration); + const fundingBalAfter = await token.balanceOf(FUNDING); + expect(fundingBalBefore - fundingBalAfter).to.eq(bidPricePerSecond * duration); + }); + it('should close session and send rewards for the provider, with dispute, late closure before end', async () => { + const { sessionId, secondsToDayEnd, openedAt } = await _createSession(); + + const providerBalBefore = await token.balanceOf(PROVIDER); + const fundingBalBefore = await token.balanceOf(FUNDING); + + await setTime(openedAt + secondsToDayEnd + 1); + const { msg: receiptMsg } = await getReceipt(PROVIDER, sessionId, 0, 0); + const { signature: receiptSig } = await getReceipt(OWNER, sessionId, 0, 0); + await sessionRouter.connect(SECOND).closeSession(receiptMsg, receiptSig); + + const session = await sessionRouter.getSession(sessionId); + const duration = BigInt(secondsToDayEnd); + + expect(session.closedAt).to.eq(openedAt + secondsToDayEnd + 2); + expect(session.isActive).to.eq(false); + expect(session.closeoutReceipt).to.eq(receiptMsg); + expect(session.providerWithdrawnAmount).to.eq(bidPricePerSecond * duration); + + expect((await sessionRouter.getProvider(PROVIDER)).limitPeriodEarned).to.eq(bidPricePerSecond * duration); + + const providerBalAfter = await token.balanceOf(PROVIDER); + expect(providerBalAfter - providerBalBefore).to.eq(bidPricePerSecond * duration); + const fundingBalAfter = await token.balanceOf(FUNDING); + expect(fundingBalBefore - fundingBalAfter).to.eq(bidPricePerSecond * duration); + }); + it('should close session and do not send rewards for the provider, with dispute, same day closure', async () => { + const { sessionId, secondsToDayEnd, openedAt } = await _createSession(); + + const providerBalBefore = await token.balanceOf(PROVIDER); + const fundingBalBefore = await token.balanceOf(FUNDING); + + await setTime(openedAt + secondsToDayEnd - 50); + const { msg: receiptMsg } = await getReceipt(PROVIDER, sessionId, 0, 0); + const { signature: receiptSig } = await getReceipt(OWNER, sessionId, 0, 0); + await sessionRouter.connect(SECOND).closeSession(receiptMsg, receiptSig); + + const session = await sessionRouter.getSession(sessionId); + + const duration = 0n; + expect(session.closedAt).to.eq(openedAt + secondsToDayEnd - 49); + expect(session.isActive).to.eq(false); + expect(session.closeoutReceipt).to.eq(receiptMsg); + + expect(session.providerWithdrawnAmount).to.eq(bidPricePerSecond * duration); + + expect((await sessionRouter.getProvider(PROVIDER)).limitPeriodEarned).to.eq(bidPricePerSecond * duration); + const providerBalAfter = await token.balanceOf(PROVIDER); + expect(providerBalAfter - providerBalBefore).to.eq(bidPricePerSecond * duration); + const fundingBalAfter = await token.balanceOf(FUNDING); + expect(fundingBalBefore - fundingBalAfter).to.eq(bidPricePerSecond * duration); + }); + it('should close session and send rewards for the user, late closure', async () => { + const { sessionId, openedAt } = await _createSession(); + + const userBalBefore = await token.balanceOf(SECOND); + const contractBalBefore = await token.balanceOf(sessionRouter); + + await setTime(openedAt + 5 * DAY); + const { msg: receiptMsg, signature: receiptSig } = await getReceipt(PROVIDER, sessionId, 0, 0); + await sessionRouter.connect(SECOND).closeSession(receiptMsg, receiptSig); + + const stakesOnHold = await sessionRouter.getUserStakesOnHold(SECOND, 20); + expect(stakesOnHold[0]).to.eq(0); + expect(stakesOnHold[1]).to.eq(0); + + const userBalAfter = await token.balanceOf(SECOND); + expect(userBalAfter - userBalBefore).to.eq(wei(50)); + const contractBalAfter = await token.balanceOf(sessionRouter); + expect(contractBalBefore - contractBalAfter).to.eq(wei(50)); + }); + it('should close session and send rewards for the user, early closure', async () => { + const { sessionId, openedAt, secondsToDayEnd } = await _createSession(); + + const userBalBefore = await token.balanceOf(SECOND); + const contractBalBefore = await token.balanceOf(sessionRouter); + + await setTime(openedAt + secondsToDayEnd + 1); + const { msg: receiptMsg, signature: receiptSig } = await getReceipt(PROVIDER, sessionId, 0, 0); + await sessionRouter.connect(SECOND).closeSession(receiptMsg, receiptSig); + + const stakesOnHold = await sessionRouter.getUserStakesOnHold(SECOND, 1); + expect(stakesOnHold[0]).to.eq(0); + expect(stakesOnHold[1]).to.greaterThan(0); + + const userBalAfter = await token.balanceOf(SECOND); + expect(userBalAfter - userBalBefore).to.lessThan(wei(50)); + const contractBalAfter = await token.balanceOf(sessionRouter); + expect(contractBalBefore - contractBalAfter).to.lessThan(wei(50)); + + await sessionRouter.getProviderModelStats(modelId, PROVIDER); + await sessionRouter.getModelStats(modelId); + }); + it('should throw error when the caller is invalid', async () => { + const { msg: receiptMsg, signature: receiptSig } = await getReceipt(PROVIDER, getHex(Buffer.from('1')), 0, 0); + await expect(sessionRouter.connect(SECOND).closeSession(receiptMsg, receiptSig)).to.be.revertedWithCustomError( + sessionRouter, + 'OwnableUnauthorizedAccount', + ); + }); + it('should throw error when the session already closed', async () => { + const { sessionId, openedAt, secondsToDayEnd } = await _createSession(); + + await setTime(openedAt + secondsToDayEnd + 1); + const { msg: receiptMsg, signature: receiptSig } = await getReceipt(PROVIDER, sessionId, 0, 0); + await sessionRouter.connect(SECOND).closeSession(receiptMsg, receiptSig); + + await expect(sessionRouter.connect(SECOND).closeSession(receiptMsg, receiptSig)).to.be.revertedWithCustomError( + sessionRouter, + 'SessionAlreadyClosed', + ); + }); + it('should throw error when the provider receipt for another chain', async () => { + const { sessionId } = await _createSession(); + + const { msg: receiptMsg, signature: receiptSig } = await getReceipt(PROVIDER, sessionId, 0, 0, 1n); + await expect(sessionRouter.connect(SECOND).closeSession(receiptMsg, receiptSig)).to.be.revertedWithCustomError( + sessionRouter, + 'SesssionReceiptForAnotherChainId', + ); + }); + it('should throw error when the provider receipt expired', async () => { + const { sessionId, openedAt } = await _createSession(); + + await setTime(openedAt + 100); + const { msg: receiptMsg, signature: receiptSig } = await getReceipt(PROVIDER, sessionId, 0, 0); + + await setTime(openedAt + 10000); + await expect(sessionRouter.connect(SECOND).closeSession(receiptMsg, receiptSig)).to.be.revertedWithCustomError( + sessionRouter, + 'SesssionReceiptExpired', + ); + }); + }); + + describe('#claimForProvider', () => { + it('should claim provider rewards, remainder, session closed with dispute', async () => { + const { sessionId, secondsToDayEnd, openedAt } = await _createSession(); + + await setTime(openedAt + secondsToDayEnd + 1); + const { msg: receiptMsg } = await getReceipt(PROVIDER, sessionId, 0, 0); + const { signature: receiptSig } = await getReceipt(OWNER, sessionId, 0, 0); + await sessionRouter.connect(SECOND).closeSession(receiptMsg, receiptSig); + + let session = await sessionRouter.getSession(sessionId); + const fullDuration = BigInt(secondsToDayEnd + 1); + const duration = 1n; + + const providerBalBefore = await token.balanceOf(PROVIDER); + const fundingBalBefore = await token.balanceOf(FUNDING); + + await sessionRouter.connect(PROVIDER).claimForProvider(sessionId); + session = await sessionRouter.getSession(sessionId); + + expect(session.providerWithdrawnAmount).to.eq(bidPricePerSecond * fullDuration); + expect(await sessionRouter.getProvidersTotalClaimed()).to.eq(bidPricePerSecond * fullDuration); + expect((await sessionRouter.getProvider(PROVIDER)).limitPeriodEarned).to.eq(bidPricePerSecond * fullDuration); + + const providerBalAfter = await token.balanceOf(PROVIDER); + expect(providerBalAfter - providerBalBefore).to.eq(bidPricePerSecond * duration); + const fundingBalAfter = await token.balanceOf(FUNDING); + expect(fundingBalBefore - fundingBalAfter).to.eq(bidPricePerSecond * duration); + }); + it('should claim provider rewards, full', async () => { + const { sessionId, openedAt } = await _createSession(); + + let session = await sessionRouter.getSession(sessionId); + const duration = session.endsAt - session.openedAt; + + const providerBalBefore = await token.balanceOf(PROVIDER); + const fundingBalBefore = await token.balanceOf(FUNDING); + + await setTime(openedAt + 5 * DAY + 1); + await sessionRouter.connect(PROVIDER).claimForProvider(sessionId); + session = await sessionRouter.getSession(sessionId); + + expect(session.providerWithdrawnAmount).to.eq(bidPricePerSecond * duration); + expect(await sessionRouter.getProvidersTotalClaimed()).to.eq(bidPricePerSecond * duration); + expect((await sessionRouter.getProvider(PROVIDER)).limitPeriodEarned).to.eq(bidPricePerSecond * duration); + + const providerBalAfter = await token.balanceOf(PROVIDER); + expect(providerBalAfter - providerBalBefore).to.eq(bidPricePerSecond * duration); + const fundingBalAfter = await token.balanceOf(FUNDING); + expect(fundingBalBefore - fundingBalAfter).to.eq(bidPricePerSecond * duration); + }); + it('should claim provider rewards with reward limiter amount for the period', async () => { + const providerBalBefore = await token.balanceOf(PROVIDER); + const fundingBalBefore = await token.balanceOf(FUNDING); + + await setTime(payoutStart + 10 * DAY); + const { msg: msg1, signature: sig1 } = await getProviderApproval(PROVIDER, SECOND, bidId); + await sessionRouter.connect(SECOND).openSession(wei(50), msg1, sig1); + + const sessionId1 = await sessionRouter.getSessionId(SECOND, PROVIDER, bidId, 0); + await setTime(payoutStart + 20 * DAY); + await sessionRouter.connect(PROVIDER).claimForProvider(sessionId1); + + await setTime(payoutStart + 30 * DAY); + const { msg: msg2, signature: sig2 } = await getProviderApproval(PROVIDER, SECOND, bidId); + await sessionRouter.connect(SECOND).openSession(wei(50), msg2, sig2); + + const sessionId2 = await sessionRouter.getSessionId(SECOND, PROVIDER, bidId, 1); + await setTime(payoutStart + 40 * DAY); + await sessionRouter.connect(PROVIDER).claimForProvider(sessionId2); + + expect(await sessionRouter.getProvidersTotalClaimed()).to.eq(wei(0.2)); + expect((await sessionRouter.getProvider(PROVIDER)).limitPeriodEarned).to.eq(wei(0.2)); + + const providerBalAfter = await token.balanceOf(PROVIDER); + expect(providerBalAfter - providerBalBefore).to.eq(wei(0.2)); + const fundingBalAfter = await token.balanceOf(FUNDING); + expect(fundingBalBefore - fundingBalAfter).to.eq(wei(0.2)); + }); + it('should throw error when caller is not the session provider', async () => { + const { sessionId } = await _createSession(); + + await expect(sessionRouter.connect(SECOND).claimForProvider(sessionId)).to.be.revertedWithCustomError( + sessionRouter, + 'OwnableUnauthorizedAccount', + ); + }); + it('should throw error when session is not end', async () => { + const { sessionId, openedAt } = await _createSession(); + + await setTime(openedAt + 10); + await expect(sessionRouter.connect(PROVIDER).claimForProvider(sessionId)).to.be.revertedWithCustomError( + sessionRouter, + 'SessionNotEndedOrNotExist', + ); + }); + }); + + describe('#claimForProvider', () => { + it('should claim provider rewards, remainder, session closed with dispute', async () => { + const { sessionId, secondsToDayEnd, openedAt } = await _createSession(); + + await setTime(openedAt + secondsToDayEnd + 1); + const { msg: receiptMsg } = await getReceipt(PROVIDER, sessionId, 0, 0); + const { signature: receiptSig } = await getReceipt(OWNER, sessionId, 0, 0); + await sessionRouter.connect(SECOND).closeSession(receiptMsg, receiptSig); + + let session = await sessionRouter.getSession(sessionId); + const fullDuration = BigInt(secondsToDayEnd + 1); + const duration = 1n; + + const providerBalBefore = await token.balanceOf(PROVIDER); + const fundingBalBefore = await token.balanceOf(FUNDING); + + await sessionRouter.connect(PROVIDER).claimForProvider(sessionId); + session = await sessionRouter.getSession(sessionId); + + expect(session.providerWithdrawnAmount).to.eq(bidPricePerSecond * fullDuration); + expect(await sessionRouter.getProvidersTotalClaimed()).to.eq(bidPricePerSecond * fullDuration); + expect((await sessionRouter.getProvider(PROVIDER)).limitPeriodEarned).to.eq(bidPricePerSecond * fullDuration); + + const providerBalAfter = await token.balanceOf(PROVIDER); + expect(providerBalAfter - providerBalBefore).to.eq(bidPricePerSecond * duration); + const fundingBalAfter = await token.balanceOf(FUNDING); + expect(fundingBalBefore - fundingBalAfter).to.eq(bidPricePerSecond * duration); + }); + it('should claim provider rewards, full', async () => { + const { sessionId, openedAt } = await _createSession(); + + let session = await sessionRouter.getSession(sessionId); + const duration = session.endsAt - session.openedAt; + + const providerBalBefore = await token.balanceOf(PROVIDER); + const fundingBalBefore = await token.balanceOf(FUNDING); + + await setTime(openedAt + 5 * DAY + 1); + await sessionRouter.connect(PROVIDER).claimForProvider(sessionId); + session = await sessionRouter.getSession(sessionId); + + expect(session.providerWithdrawnAmount).to.eq(bidPricePerSecond * duration); + expect(await sessionRouter.getProvidersTotalClaimed()).to.eq(bidPricePerSecond * duration); + expect((await sessionRouter.getProvider(PROVIDER)).limitPeriodEarned).to.eq(bidPricePerSecond * duration); + + const providerBalAfter = await token.balanceOf(PROVIDER); + expect(providerBalAfter - providerBalBefore).to.eq(bidPricePerSecond * duration); + const fundingBalAfter = await token.balanceOf(FUNDING); + expect(fundingBalBefore - fundingBalAfter).to.eq(bidPricePerSecond * duration); + }); + it('should claim provider rewards with reward limiter amount for the period', async () => { + const providerBalBefore = await token.balanceOf(PROVIDER); + const fundingBalBefore = await token.balanceOf(FUNDING); + + await setTime(payoutStart + 10 * DAY); + const { msg: msg1, signature: sig1 } = await getProviderApproval(PROVIDER, SECOND, bidId); + await sessionRouter.connect(SECOND).openSession(wei(50), msg1, sig1); + + const sessionId1 = await sessionRouter.getSessionId(SECOND, PROVIDER, bidId, 0); + await setTime(payoutStart + 20 * DAY); + await sessionRouter.connect(PROVIDER).claimForProvider(sessionId1); + + await setTime(payoutStart + 30 * DAY); + const { msg: msg2, signature: sig2 } = await getProviderApproval(PROVIDER, SECOND, bidId); + await sessionRouter.connect(SECOND).openSession(wei(50), msg2, sig2); + + const sessionId2 = await sessionRouter.getSessionId(SECOND, PROVIDER, bidId, 1); + await setTime(payoutStart + 40 * DAY); + await sessionRouter.connect(PROVIDER).claimForProvider(sessionId2); + + expect(await sessionRouter.getProvidersTotalClaimed()).to.eq(wei(0.2)); + expect((await sessionRouter.getProvider(PROVIDER)).limitPeriodEarned).to.eq(wei(0.2)); + + const providerBalAfter = await token.balanceOf(PROVIDER); + expect(providerBalAfter - providerBalBefore).to.eq(wei(0.2)); + const fundingBalAfter = await token.balanceOf(FUNDING); + expect(fundingBalBefore - fundingBalAfter).to.eq(wei(0.2)); + }); + it('should throw error when caller is not the session provider', async () => { + const { sessionId } = await _createSession(); + + await expect(sessionRouter.connect(SECOND).claimForProvider(sessionId)).to.be.revertedWithCustomError( + sessionRouter, + 'OwnableUnauthorizedAccount', + ); + }); + it('should throw error when session is not end', async () => { + const { sessionId, openedAt } = await _createSession(); + + await setTime(openedAt + 10); + await expect(sessionRouter.connect(PROVIDER).claimForProvider(sessionId)).to.be.revertedWithCustomError( + sessionRouter, + 'SessionNotEndedOrNotExist', + ); + }); + }); + + describe('#withdrawUserStakes', () => { + it('should withdraw the user stake on hold, one entity', async () => { + const openedAt = payoutStart + (payoutStart % DAY) + 10 * DAY - 201; + + await setTime(openedAt + 1 * DAY); + const { msg, signature } = await getProviderApproval(PROVIDER, SECOND, bidId); + await sessionRouter.connect(SECOND).openSession(wei(50), msg, signature); + const sessionId1 = await sessionRouter.getSessionId(SECOND, PROVIDER, bidId, 0); + + await setTime(openedAt + 2 * DAY + 1); + const { msg: receiptMsg, signature: receiptSig } = await getReceipt(PROVIDER, sessionId1, 0, 0); + await sessionRouter.connect(SECOND).closeSession(receiptMsg, receiptSig); + + const userBalBefore = await token.balanceOf(SECOND); + const contractBalBefore = await token.balanceOf(sessionRouter); + + await setTime(openedAt + 3 * DAY + 2); + await sessionRouter.connect(SECOND).withdrawUserStakes(1); + + const stakesOnHold = await sessionRouter.getUserStakesOnHold(SECOND, 1); + expect(stakesOnHold[0]).to.eq(0); + expect(stakesOnHold[1]).to.eq(0); + + const userBalAfter = await token.balanceOf(SECOND); + expect(userBalAfter - userBalBefore).to.greaterThan(0); + const contractBalAfter = await token.balanceOf(sessionRouter); + expect(contractBalBefore - contractBalAfter).to.greaterThan(0); + }); + it('should withdraw the user stake on hold, few entities, with on hold', async () => { + const openedAt = payoutStart + (payoutStart % DAY) + 10 * DAY - 201; + + await setTime(openedAt + 1 * DAY); + const { msg: msg1, signature: sig1 } = await getProviderApproval(PROVIDER, SECOND, bidId); + await sessionRouter.connect(SECOND).openSession(wei(50), msg1, sig1); + const sessionId1 = await sessionRouter.getSessionId(SECOND, PROVIDER, bidId, 0); + + await setTime(openedAt + 1 * DAY + 500); + const { msg: receiptMsg1, signature: receiptSig1 } = await getReceipt(PROVIDER, sessionId1, 0, 0); + await sessionRouter.connect(SECOND).closeSession(receiptMsg1, receiptSig1); + + await setTime(openedAt + 3 * DAY); + const { msg: msg2, signature: sig2 } = await getProviderApproval(PROVIDER, SECOND, bidId); + await sessionRouter.connect(SECOND).openSession(wei(50), msg2, sig2); + const sessionId2 = await sessionRouter.getSessionId(SECOND, PROVIDER, bidId, 1); + + await setTime(openedAt + 3 * DAY + 500); + const { msg: receiptMsg2, signature: receiptSig2 } = await getReceipt(PROVIDER, sessionId2, 0, 0); + await sessionRouter.connect(SECOND).closeSession(receiptMsg2, receiptSig2); + + const stakesOnHold = await sessionRouter.getUserStakesOnHold(SECOND, 20); + expect(stakesOnHold[0]).to.greaterThan(0); + expect(stakesOnHold[1]).to.greaterThan(0); + + const userBalBefore = await token.balanceOf(SECOND); + const contractBalBefore = await token.balanceOf(sessionRouter); + + await setTime(openedAt + 4 * DAY); + await sessionRouter.connect(SECOND).withdrawUserStakes(20); + + const userBalAfter = await token.balanceOf(SECOND); + expect(userBalAfter - userBalBefore).to.eq(stakesOnHold[0]); + const contractBalAfter = await token.balanceOf(sessionRouter); + expect(contractBalBefore - contractBalAfter).to.eq(stakesOnHold[0]); + }); + it('should withdraw the user stake on hold, few entities, without on hold', async () => { + const openedAt = payoutStart + (payoutStart % DAY) + 10 * DAY - 201; + + await setTime(openedAt + 1 * DAY); + const { msg: msg1, signature: sig1 } = await getProviderApproval(PROVIDER, SECOND, bidId); + await sessionRouter.connect(SECOND).openSession(wei(50), msg1, sig1); + const sessionId1 = await sessionRouter.getSessionId(SECOND, PROVIDER, bidId, 0); + + await setTime(openedAt + 1 * DAY + 500); + const { msg: receiptMsg1, signature: receiptSig1 } = await getReceipt(PROVIDER, sessionId1, 0, 0); + await sessionRouter.connect(SECOND).closeSession(receiptMsg1, receiptSig1); + + await setTime(openedAt + 3 * DAY); + const { msg: msg2, signature: sig2 } = await getProviderApproval(PROVIDER, SECOND, bidId); + await sessionRouter.connect(SECOND).openSession(wei(50), msg2, sig2); + const sessionId2 = await sessionRouter.getSessionId(SECOND, PROVIDER, bidId, 1); + + await setTime(openedAt + 3 * DAY + 500); + const { msg: receiptMsg2, signature: receiptSig2 } = await getReceipt(PROVIDER, sessionId2, 0, 0); + await sessionRouter.connect(SECOND).closeSession(receiptMsg2, receiptSig2); + + await setTime(openedAt + 10 * DAY); + const stakesOnHold = await sessionRouter.getUserStakesOnHold(SECOND, 20); + expect(stakesOnHold[0]).to.greaterThan(0); + expect(stakesOnHold[1]).to.eq(0); + + const userBalBefore = await token.balanceOf(SECOND); + const contractBalBefore = await token.balanceOf(sessionRouter); + + await sessionRouter.connect(SECOND).withdrawUserStakes(20); + + const userBalAfter = await token.balanceOf(SECOND); + expect(userBalAfter - userBalBefore).to.eq(stakesOnHold[0]); + const contractBalAfter = await token.balanceOf(sessionRouter); + expect(contractBalBefore - contractBalAfter).to.eq(stakesOnHold[0]); + }); + it('should withdraw the user stake on hold, few entities, without on hold, partial withdraw', async () => { + const openedAt = payoutStart + (payoutStart % DAY) + 10 * DAY - 201; + + await setTime(openedAt + 1 * DAY); + const { msg: msg1, signature: sig1 } = await getProviderApproval(PROVIDER, SECOND, bidId); + await sessionRouter.connect(SECOND).openSession(wei(50), msg1, sig1); + const sessionId1 = await sessionRouter.getSessionId(SECOND, PROVIDER, bidId, 0); + + await setTime(openedAt + 1 * DAY + 500); + const { msg: receiptMsg1, signature: receiptSig1 } = await getReceipt(PROVIDER, sessionId1, 0, 0); + await sessionRouter.connect(SECOND).closeSession(receiptMsg1, receiptSig1); + + await setTime(openedAt + 3 * DAY); + const { msg: msg2, signature: sig2 } = await getProviderApproval(PROVIDER, SECOND, bidId); + await sessionRouter.connect(SECOND).openSession(wei(50), msg2, sig2); + const sessionId2 = await sessionRouter.getSessionId(SECOND, PROVIDER, bidId, 1); + + await setTime(openedAt + 3 * DAY + 500); + const { msg: receiptMsg2, signature: receiptSig2 } = await getReceipt(PROVIDER, sessionId2, 0, 0); + await sessionRouter.connect(SECOND).closeSession(receiptMsg2, receiptSig2); + + await setTime(openedAt + 10 * DAY); + // First withdraw + let stakesOnHold = await sessionRouter.getUserStakesOnHold(SECOND, 1); + expect(stakesOnHold[0]).to.greaterThan(0); + expect(stakesOnHold[1]).to.eq(0); + + let userBalBefore = await token.balanceOf(SECOND); + let contractBalBefore = await token.balanceOf(sessionRouter); + + await sessionRouter.connect(SECOND).withdrawUserStakes(1); + + let userBalAfter = await token.balanceOf(SECOND); + expect(userBalAfter - userBalBefore).to.lessThan(stakesOnHold[0]); + let contractBalAfter = await token.balanceOf(sessionRouter); + expect(contractBalBefore - contractBalAfter).to.lessThan(stakesOnHold[0]); + + // Second withdraw + stakesOnHold = await sessionRouter.getUserStakesOnHold(SECOND, 1); + expect(stakesOnHold[0]).to.greaterThan(0); + expect(stakesOnHold[1]).to.eq(0); + + userBalBefore = await token.balanceOf(SECOND); + contractBalBefore = await token.balanceOf(sessionRouter); + + await sessionRouter.connect(SECOND).withdrawUserStakes(1); + + userBalAfter = await token.balanceOf(SECOND); + expect(userBalAfter - userBalBefore).to.eq(stakesOnHold[0]); + contractBalAfter = await token.balanceOf(sessionRouter); + expect(contractBalBefore - contractBalAfter).to.eq(stakesOnHold[0]); + }); + it('should throw error when withdraw amount is zero', async () => { + const openedAt = payoutStart + (payoutStart % DAY) + 10 * DAY - 201; + + await setTime(openedAt + 1 * DAY); + const { msg: msg1, signature: sig1 } = await getProviderApproval(PROVIDER, SECOND, bidId); + await sessionRouter.connect(SECOND).openSession(wei(50), msg1, sig1); + const sessionId1 = await sessionRouter.getSessionId(SECOND, PROVIDER, bidId, 0); + + const { msg: msg2, signature: sig2 } = await getProviderApproval(PROVIDER, SECOND, bidId); + await sessionRouter.connect(SECOND).openSession(wei(50), msg2, sig2); + const sessionId2 = await sessionRouter.getSessionId(SECOND, PROVIDER, bidId, 1); + + await setTime(openedAt + 1 * DAY + 500); + const { msg: receiptMsg1, signature: receiptSig1 } = await getReceipt(PROVIDER, sessionId1, 0, 0); + await sessionRouter.connect(SECOND).closeSession(receiptMsg1, receiptSig1); + + await setTime(openedAt + 1 * DAY + 550); + const { msg: receiptMsg2, signature: receiptSig2 } = await getReceipt(PROVIDER, sessionId2, 0, 0); + await sessionRouter.connect(SECOND).closeSession(receiptMsg2, receiptSig2); + + await expect(sessionRouter.connect(SECOND).withdrawUserStakes(20)).to.be.revertedWithCustomError( + sessionRouter, + 'SessionUserAmountToWithdrawIsZero', + ); + }); + }); + + const _createSession = async () => { + const secondsToDayEnd = 600n; + const openedAt = payoutStart + (payoutStart % DAY) + 10 * DAY - Number(secondsToDayEnd) - 1; + + await setTime(openedAt); + const { msg, signature } = await getProviderApproval(PROVIDER, SECOND, bidId); + await sessionRouter.connect(SECOND).openSession(wei(50), msg, signature); + + return { + sessionId: await sessionRouter.getSessionId(SECOND, PROVIDER, bidId, 0), + secondsToDayEnd: Number(secondsToDayEnd), + openedAt, + }; + }; +}); + +// npm run generate-types && npx hardhat test "test/diamond/facets/SessionRouter.test.ts" +// npx hardhat coverage --solcoverjs ./.solcover.ts --testfiles "test/diamond/facets/SessionRouter.test.ts" diff --git a/smart-contracts/test/helpers/deployers/diamond/facets/marketplace.ts b/smart-contracts/test/helpers/deployers/diamond/facets/marketplace.ts new file mode 100644 index 00000000..49ab75e6 --- /dev/null +++ b/smart-contracts/test/helpers/deployers/diamond/facets/marketplace.ts @@ -0,0 +1,40 @@ +import { Fragment } from 'ethers'; +import { ethers } from 'hardhat'; + +import { + IBidStorage__factory, + IMarketplace__factory, + LumerinDiamond, + Marketplace, + MorpheusToken, +} from '@/generated-types/ethers'; +import { FacetAction } from '@/test/helpers/deployers/diamond/lumerin-diamond'; + +export const deployFacetMarketplace = async (diamond: LumerinDiamond, token: MorpheusToken): Promise => { + let facet: Marketplace; + + const factory = await ethers.getContractFactory('Marketplace'); + facet = await factory.deploy(); + + await diamond['diamondCut((address,uint8,bytes4[])[])']([ + { + facetAddress: facet, + action: FacetAction.Add, + functionSelectors: IMarketplace__factory.createInterface() + .fragments.filter(Fragment.isFunction) + .map((f) => f.selector), + }, + { + facetAddress: facet, + action: FacetAction.Add, + functionSelectors: IBidStorage__factory.createInterface() + .fragments.filter(Fragment.isFunction) + .map((f) => f.selector), + }, + ]); + + facet = facet.attach(diamond.target) as Marketplace; + await facet.__Marketplace_init(token); + + return facet; +}; diff --git a/smart-contracts/test/helpers/deployers/diamond/facets/model-registry.ts b/smart-contracts/test/helpers/deployers/diamond/facets/model-registry.ts new file mode 100644 index 00000000..aa198cc7 --- /dev/null +++ b/smart-contracts/test/helpers/deployers/diamond/facets/model-registry.ts @@ -0,0 +1,27 @@ +import { Fragment } from 'ethers'; +import { ethers } from 'hardhat'; + +import { IModelRegistry__factory, LumerinDiamond, ModelRegistry } from '@/generated-types/ethers'; +import { FacetAction } from '@/test/helpers/deployers/diamond/lumerin-diamond'; + +export const deployFacetModelRegistry = async (diamond: LumerinDiamond): Promise => { + let facet: ModelRegistry; + + const factory = await ethers.getContractFactory('ModelRegistry'); + facet = await factory.deploy(); + + await diamond['diamondCut((address,uint8,bytes4[])[])']([ + { + facetAddress: facet, + action: FacetAction.Add, + functionSelectors: IModelRegistry__factory.createInterface() + .fragments.filter(Fragment.isFunction) + .map((f) => f.selector), + }, + ]); + + facet = facet.attach(diamond.target) as ModelRegistry; + await facet.__ModelRegistry_init(); + + return facet; +}; diff --git a/smart-contracts/test/helpers/deployers/diamond/facets/provider-registry.ts b/smart-contracts/test/helpers/deployers/diamond/facets/provider-registry.ts new file mode 100644 index 00000000..12ee4165 --- /dev/null +++ b/smart-contracts/test/helpers/deployers/diamond/facets/provider-registry.ts @@ -0,0 +1,27 @@ +import { Fragment } from 'ethers'; +import { ethers } from 'hardhat'; + +import { IProviderRegistry__factory, LumerinDiamond, ProviderRegistry } from '@/generated-types/ethers'; +import { FacetAction } from '@/test/helpers/deployers/diamond/lumerin-diamond'; + +export const deployFacetProviderRegistry = async (diamond: LumerinDiamond): Promise => { + let facet: ProviderRegistry; + + const factory = await ethers.getContractFactory('ProviderRegistry'); + facet = await factory.deploy(); + + await diamond['diamondCut((address,uint8,bytes4[])[])']([ + { + facetAddress: facet, + action: FacetAction.Add, + functionSelectors: IProviderRegistry__factory.createInterface() + .fragments.filter(Fragment.isFunction) + .map((f) => f.selector), + }, + ]); + + facet = facet.attach(diamond.target) as ProviderRegistry; + await facet.__ProviderRegistry_init(); + + return facet; +}; diff --git a/smart-contracts/test/helpers/deployers/diamond/facets/session-router.ts b/smart-contracts/test/helpers/deployers/diamond/facets/session-router.ts new file mode 100644 index 00000000..e1764bd7 --- /dev/null +++ b/smart-contracts/test/helpers/deployers/diamond/facets/session-router.ts @@ -0,0 +1,51 @@ +import { SignerWithAddress } from '@nomicfoundation/hardhat-ethers/signers'; +import { Fragment } from 'ethers'; +import { ethers } from 'hardhat'; + +import { + ISessionRouter__factory, + IStatsStorage__factory, + LumerinDiamond, + SessionRouter, +} from '@/generated-types/ethers'; +import { FacetAction } from '@/test/helpers/deployers/diamond/lumerin-diamond'; +import { getDefaultPools } from '@/test/helpers/pool-helper'; + +export const deployFacetSessionRouter = async ( + diamond: LumerinDiamond, + fundingAccount: SignerWithAddress, +): Promise => { + let facet: SessionRouter; + + const LDIDFactory = await ethers.getContractFactory('LinearDistributionIntervalDecrease'); + const LDID = await LDIDFactory.deploy(); + + const factory = await ethers.getContractFactory('SessionRouter', { + libraries: { + LinearDistributionIntervalDecrease: LDID, + }, + }); + facet = await factory.deploy(); + + await diamond['diamondCut((address,uint8,bytes4[])[])']([ + { + facetAddress: facet, + action: FacetAction.Add, + functionSelectors: ISessionRouter__factory.createInterface() + .fragments.filter(Fragment.isFunction) + .map((f) => f.selector), + }, + { + facetAddress: facet, + action: FacetAction.Add, + functionSelectors: IStatsStorage__factory.createInterface() + .fragments.filter(Fragment.isFunction) + .map((f) => f.selector), + }, + ]); + + facet = facet.attach(diamond.target) as SessionRouter; + await facet.__SessionRouter_init(fundingAccount, getDefaultPools()); + + return facet; +}; diff --git a/smart-contracts/test/helpers/deployers/diamond/index.ts b/smart-contracts/test/helpers/deployers/diamond/index.ts new file mode 100644 index 00000000..36c74670 --- /dev/null +++ b/smart-contracts/test/helpers/deployers/diamond/index.ts @@ -0,0 +1,5 @@ +export * from './facets/model-registry'; +export * from './facets/marketplace'; +export * from './facets/provider-registry'; +export * from './facets/session-router'; +export * from './lumerin-diamond'; diff --git a/smart-contracts/test/helpers/deployers/diamond/lumerin-diamond.ts b/smart-contracts/test/helpers/deployers/diamond/lumerin-diamond.ts new file mode 100644 index 00000000..f2220ca2 --- /dev/null +++ b/smart-contracts/test/helpers/deployers/diamond/lumerin-diamond.ts @@ -0,0 +1,17 @@ +import { ethers } from 'hardhat'; + +import { LumerinDiamond } from '@/generated-types/ethers'; + +export enum FacetAction { + Add = 0, + Replace = 1, + Remove = 2, +} + +export const deployLumerinDiamond = async (): Promise => { + const factory = await ethers.getContractFactory('LumerinDiamond'); + const contract = await factory.deploy(); + await contract.__LumerinDiamond_init(); + + return contract; +}; diff --git a/smart-contracts/test/helpers/deployers/index.ts b/smart-contracts/test/helpers/deployers/index.ts new file mode 100644 index 00000000..6d1b8c98 --- /dev/null +++ b/smart-contracts/test/helpers/deployers/index.ts @@ -0,0 +1,2 @@ +export * from './diamond'; +export * from './mock'; diff --git a/smart-contracts/test/helpers/deployers/mock/index.ts b/smart-contracts/test/helpers/deployers/mock/index.ts new file mode 100644 index 00000000..cfcb186e --- /dev/null +++ b/smart-contracts/test/helpers/deployers/mock/index.ts @@ -0,0 +1 @@ +export * from './tokens/morpheus-token'; diff --git a/smart-contracts/test/helpers/deployers/mock/tokens/morpheus-token.ts b/smart-contracts/test/helpers/deployers/mock/tokens/morpheus-token.ts new file mode 100644 index 00000000..8ff21270 --- /dev/null +++ b/smart-contracts/test/helpers/deployers/mock/tokens/morpheus-token.ts @@ -0,0 +1,10 @@ +import { ethers } from 'hardhat'; + +import { MorpheusToken } from '@/generated-types/ethers'; + +export const deployMORToken = async (): Promise => { + const factory = await ethers.getContractFactory('MorpheusToken'); + const contract = await factory.deploy(); + + return contract; +}; diff --git a/smart-contracts/test/helpers/enums.ts b/smart-contracts/test/helpers/enums.ts deleted file mode 100644 index 89962ab6..00000000 --- a/smart-contracts/test/helpers/enums.ts +++ /dev/null @@ -1,5 +0,0 @@ -export enum FacetAction { - Add = 0, - Replace = 1, - Remove = 2, -} diff --git a/smart-contracts/test/helpers/pool-helper.ts b/smart-contracts/test/helpers/pool-helper.ts index c5117be6..914c3538 100644 --- a/smart-contracts/test/helpers/pool-helper.ts +++ b/smart-contracts/test/helpers/pool-helper.ts @@ -1,9 +1,11 @@ import { ISessionStorage } from '../../generated-types/ethers/contracts/interfaces/facets/ISessionRouter'; +export const payoutStart = 1707393600; + export function getDefaultPools(): ISessionStorage.PoolStruct[] { return [ { - payoutStart: 1707393600n, + payoutStart: payoutStart, decreaseInterval: 86400n, initialReward: 3456000000000000000000n, rewardDecrease: 592558728240000000n, @@ -15,19 +17,19 @@ export function getDefaultPools(): ISessionStorage.PoolStruct[] { rewardDecrease: 592558728240000000n, }, { - payoutStart: 1707393600n, + payoutStart: payoutStart, decreaseInterval: 86400n, initialReward: 3456000000000000000000n, rewardDecrease: 592558728240000000n, }, { - payoutStart: 1707393600n, + payoutStart: payoutStart, decreaseInterval: 86400n, initialReward: 3456000000000000000000n, rewardDecrease: 592558728240000000n, }, { - payoutStart: 1707393600n, + payoutStart: payoutStart, decreaseInterval: 86400n, initialReward: 576000000000000000000n, rewardDecrease: 98759788040000000n, diff --git a/smart-contracts/test/helpers/reverter.ts b/smart-contracts/test/helpers/reverter.ts index 207c36b5..c2ff7420 100644 --- a/smart-contracts/test/helpers/reverter.ts +++ b/smart-contracts/test/helpers/reverter.ts @@ -1,6 +1,7 @@ import { network } from 'hardhat'; export class Reverter { + // eslint-disable-next-line @typescript-eslint/no-explicit-any private snapshotId: any; revert = async () => { diff --git a/smart-contracts/test/helpers/time.ts b/smart-contracts/test/helpers/time.ts new file mode 100644 index 00000000..655e0f6d --- /dev/null +++ b/smart-contracts/test/helpers/time.ts @@ -0,0 +1,2 @@ +export const day = 86400; +export const year = 31536000; diff --git a/smart-contracts/utils/provider-helper.ts b/smart-contracts/utils/provider-helper.ts index 46d2bca5..ccfe5cd1 100644 --- a/smart-contracts/utils/provider-helper.ts +++ b/smart-contracts/utils/provider-helper.ts @@ -3,13 +3,20 @@ import { AbiCoder, getBytes, keccak256 } from 'ethers'; import { getChainId, getCurrentBlockTime } from './block-helper'; -export const getProviderApproval = async (provider: SignerWithAddress, user: string, bidId: string, chainId = 0n) => { +export const getProviderApproval = async ( + provider: SignerWithAddress, + user: SignerWithAddress, + bidId: string, + chainId = 0n, +) => { chainId = chainId || (await getChainId()); const timestamp = await getCurrentBlockTime(); + const msg = AbiCoder.defaultAbiCoder().encode( ['bytes32', 'uint256', 'address', 'uint128'], - [bidId, chainId, user, timestamp], + [bidId, chainId, user.address, timestamp], ); + const signature = await provider.signMessage(getBytes(keccak256(msg))); return { @@ -18,11 +25,19 @@ export const getProviderApproval = async (provider: SignerWithAddress, user: str }; }; -export const getReport = async (reporter: SignerWithAddress, sessionId: string, tps: number, ttftMs: number) => { +export const getReceipt = async ( + reporter: SignerWithAddress, + sessionId: string, + tps: number, + ttftMs: number, + chainId = 0n, +) => { + chainId = chainId || (await getChainId()); const timestamp = await getCurrentBlockTime(); + const msg = AbiCoder.defaultAbiCoder().encode( ['bytes32', 'uint256', 'uint128', 'uint32', 'uint32'], - [sessionId, await getChainId(), timestamp, tps * 1000, ttftMs], + [sessionId, chainId, timestamp, tps * 1000, ttftMs], ); const signature = await reporter.signMessage(getBytes(keccak256(msg))); diff --git a/smart-contracts/utils/time.ts b/smart-contracts/utils/time.ts index 79742b5d..c55ce8d7 100644 --- a/smart-contracts/utils/time.ts +++ b/smart-contracts/utils/time.ts @@ -1,5 +1,5 @@ -export const SECOND = 1n; -export const MINUTE = 60n * SECOND; -export const HOUR = 60n * MINUTE; -export const DAY = 24n * HOUR; -export const YEAR = 365n * DAY; +export const SECOND = 1; +export const MINUTE = 60 * SECOND; +export const HOUR = 60 * MINUTE; +export const DAY = 24 * HOUR; +export const YEAR = 365 * DAY; diff --git a/smart-contracts/wagmi.config.ts b/smart-contracts/wagmi.config.ts index a4838e9d..2f9ea2ce 100644 --- a/smart-contracts/wagmi.config.ts +++ b/smart-contracts/wagmi.config.ts @@ -1,16 +1,12 @@ -import { defineConfig } from "@wagmi/cli"; -import { hardhat } from "@wagmi/cli/plugins"; +import { defineConfig } from '@wagmi/cli'; +import { hardhat } from '@wagmi/cli/plugins'; export default defineConfig({ - out: "bindings/ts/abi.ts", + out: 'bindings/ts/abi.ts', plugins: [ hardhat({ - project: ".", - include: [ - "facets/**/*.json", - "MorpheusToken.sol/*.json", - "ERC20.sol/*.json", - ], + project: '.', + include: ['facets/**/*.json', 'MorpheusToken.sol/*.json', 'ERC20.sol/*.json'], }), ], }); From 7b171b2e07646776062c0280d2133b861cbfe067 Mon Sep 17 00:00:00 2001 From: Oleksandr Date: Wed, 16 Oct 2024 00:45:18 +0300 Subject: [PATCH 2/9] remove new time.ts --- smart-contracts/test/helpers/time.ts | 2 -- 1 file changed, 2 deletions(-) delete mode 100644 smart-contracts/test/helpers/time.ts diff --git a/smart-contracts/test/helpers/time.ts b/smart-contracts/test/helpers/time.ts deleted file mode 100644 index 655e0f6d..00000000 --- a/smart-contracts/test/helpers/time.ts +++ /dev/null @@ -1,2 +0,0 @@ -export const day = 86400; -export const year = 31536000; From f72be0055d593502bd220f3e054f3541a1b26409 Mon Sep 17 00:00:00 2001 From: Oleksandr Date: Mon, 21 Oct 2024 12:47:21 +0300 Subject: [PATCH 3/9] fix storage for facets, deploy to the testnet --- .../contracts/diamond/facets/Marketplace.sol | 51 ++++--- .../diamond/facets/ModelRegistry.sol | 29 ++-- .../diamond/facets/ProviderRegistry.sol | 28 ++-- .../diamond/facets/SessionRouter.sol | 78 ++++++----- .../contracts/diamond/storages/BidStorage.sol | 77 +++-------- .../diamond/storages/MarketplaceStorage.sol | 27 ++-- .../diamond/storages/ModelStorage.sol | 42 +++--- .../diamond/storages/ProviderStorage.sol | 38 +++--- .../diamond/storages/SessionStorage.sol | 128 +++++++++--------- .../interfaces/facets/IModelRegistry.sol | 2 +- .../interfaces/facets/IProviderRegistry.sol | 2 +- .../interfaces/storage/IBidStorage.sol | 4 +- .../interfaces/storage/IModelStorage.sol | 2 + .../interfaces/storage/IProviderStorage.sol | 2 + .../deploy/1_full_protocol.migration.ts | 113 ++++++++++++++++ .../deploy/data/config_arbitrum_sepolia.json | 39 ++++++ .../deploy/helpers/config-parser.ts | 18 +++ smart-contracts/hardhat.config.ts | 18 +++ .../test/diamond/facets/ModelRegistry.test.ts | 6 +- .../diamond/facets/ProviderRegistry.test.ts | 6 +- 20 files changed, 452 insertions(+), 258 deletions(-) create mode 100644 smart-contracts/deploy/1_full_protocol.migration.ts create mode 100644 smart-contracts/deploy/data/config_arbitrum_sepolia.json create mode 100644 smart-contracts/deploy/helpers/config-parser.ts diff --git a/smart-contracts/contracts/diamond/facets/Marketplace.sol b/smart-contracts/contracts/diamond/facets/Marketplace.sol index 0a190a46..b25b2837 100644 --- a/smart-contracts/contracts/diamond/facets/Marketplace.sol +++ b/smart-contracts/contracts/diamond/facets/Marketplace.sol @@ -1,6 +1,7 @@ // SPDX-License-Identifier: MIT pragma solidity ^0.8.24; +import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; import {SafeERC20, IERC20} from "@openzeppelin/contracts/token/ERC20/utils/SafeERC20.sol"; import {OwnableDiamondStorage} from "../presets/OwnableDiamondStorage.sol"; @@ -21,13 +22,16 @@ contract Marketplace is BidStorage { using SafeERC20 for IERC20; + using EnumerableSet for EnumerableSet.Bytes32Set; - function __Marketplace_init(address token_) external initializer(MARKETPLACE_STORAGE_SLOT) { - setToken(IERC20(token_)); + function __Marketplace_init(address token_) external initializer(BIDS_STORAGE_SLOT) { + BidsStorage storage bidsStorage = getBidsStorage(); + bidsStorage.token = token_; } function setMarketplaceBidFee(uint256 bidFee_) external onlyOwner { - setBidFee(bidFee_); + MarketStorage storage marketStorage = getMarketStorage(); + marketStorage.bidFee = bidFee_; emit MaretplaceFeeUpdated(bidFee_); } @@ -42,13 +46,15 @@ contract Marketplace is revert MarketplaceModelNotFound(); } - uint256 fee_ = getBidFee(); - getToken().safeTransferFrom(_msgSender(), address(this), fee_); + BidsStorage storage bidsStorage = getBidsStorage(); + MarketStorage storage marketStorage = getMarketStorage(); - setFeeBalance(getFeeBalance() + fee_); + // TODO: check it + IERC20(bidsStorage.token).safeTransferFrom(_msgSender(), address(this), marketStorage.bidFee); + marketStorage.feeBalance += marketStorage.bidFee; bytes32 providerModelId_ = getProviderModelId(provider_, modelId_); - uint256 providerModelNonce_ = incrementBidNonce(providerModelId_); + uint256 providerModelNonce_ = bidsStorage.providerModelNonce[providerModelId_]++; bytes32 bidId_ = getBidId(provider_, modelId_, providerModelNonce_); if (providerModelNonce_ != 0) { @@ -58,18 +64,17 @@ contract Marketplace is } } - Bid storage bid = bids(bidId_); + Bid storage bid = bidsStorage.bids[bidId_]; bid.provider = provider_; bid.modelId = modelId_; bid.pricePerSecond = pricePerSecond_; bid.nonce = providerModelNonce_; bid.createdAt = uint128(block.timestamp); - addProviderBid(provider_, bidId_); - addModelBid(modelId_, bidId_); - - addProviderActiveBids(provider_, bidId_); - addModelActiveBids(modelId_, bidId_); + bidsStorage.providerBids[provider_].add(bidId_); + bidsStorage.providerActiveBids[provider_].add(bidId_); + bidsStorage.modelBids[modelId_].add(bidId_); + bidsStorage.modelActiveBids[modelId_].add(bidId_); emit MarketplaceBidPosted(provider_, modelId_, providerModelNonce_); @@ -77,7 +82,8 @@ contract Marketplace is } function deleteModelBid(bytes32 bidId_) external { - _onlyAccount(bids(bidId_).provider); + BidsStorage storage bidsStorage = getBidsStorage(); + _onlyAccount(bidsStorage.bids[bidId_].provider); if (!isBidActive(bidId_)) { revert MarketplaceActiveBidNotFound(); @@ -87,21 +93,24 @@ contract Marketplace is } function withdraw(address recipient_, uint256 amount_) external onlyOwner { - uint256 feeBalance_ = getFeeBalance(); - amount_ = amount_ > feeBalance_ ? feeBalance_ : amount_; + BidsStorage storage bidsStorage = getBidsStorage(); + MarketStorage storage marketStorage = getMarketStorage(); + + amount_ = amount_ > marketStorage.feeBalance ? marketStorage.feeBalance : amount_; - setFeeBalance(getFeeBalance() - amount_); + marketStorage.feeBalance -= amount_; - getToken().safeTransfer(recipient_, amount_); + IERC20(bidsStorage.token).safeTransfer(recipient_, amount_); } function _deleteBid(bytes32 bidId_) private { - Bid storage bid = bids(bidId_); + BidsStorage storage bidsStorage = getBidsStorage(); + Bid storage bid = bidsStorage.bids[bidId_]; bid.deletedAt = uint128(block.timestamp); - removeProviderActiveBids(bid.provider, bidId_); - removeModelActiveBids(bid.modelId, bidId_); + bidsStorage.providerActiveBids[bid.provider].remove(bidId_); + bidsStorage.modelActiveBids[bid.modelId].remove(bidId_); emit MarketplaceBidDeleted(bid.provider, bid.modelId, bid.nonce); } diff --git a/smart-contracts/contracts/diamond/facets/ModelRegistry.sol b/smart-contracts/contracts/diamond/facets/ModelRegistry.sol index ae0c8be5..61e2ce7d 100644 --- a/smart-contracts/contracts/diamond/facets/ModelRegistry.sol +++ b/smart-contracts/contracts/diamond/facets/ModelRegistry.sol @@ -1,6 +1,7 @@ // SPDX-License-Identifier: MIT pragma solidity ^0.8.24; +import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; import {SafeERC20, IERC20} from "@openzeppelin/contracts/token/ERC20/utils/SafeERC20.sol"; import {OwnableDiamondStorage} from "../presets/OwnableDiamondStorage.sol"; @@ -11,14 +12,16 @@ import {ModelStorage} from "../storages/ModelStorage.sol"; import {IModelRegistry} from "../../interfaces/facets/IModelRegistry.sol"; contract ModelRegistry is IModelRegistry, OwnableDiamondStorage, ModelStorage, BidStorage { + using EnumerableSet for EnumerableSet.Bytes32Set; using SafeERC20 for IERC20; - function __ModelRegistry_init() external initializer(MODEL_STORAGE_SLOT) {} + function __ModelRegistry_init() external initializer(MODELS_STORAGE_SLOT) {} function modelSetMinStake(uint256 modelMinimumStake_) external onlyOwner { - setModelMinimumStake(modelMinimumStake_); + ModelsStorage storage modelsStorage = getModelsStorage(); + modelsStorage.modelMinimumStake = modelMinimumStake_; - emit ModelMinStakeUpdated(modelMinimumStake_); + emit ModelMinimumStakeUpdated(modelMinimumStake_); } function modelRegister( @@ -29,20 +32,22 @@ contract ModelRegistry is IModelRegistry, OwnableDiamondStorage, ModelStorage, B string calldata name_, string[] memory tags_ ) external { - Model storage model = models(modelId_); + ModelsStorage storage modelsStorage = getModelsStorage(); + Model storage model = modelsStorage.models[modelId_]; uint256 newStake_ = model.stake + amount_; - uint256 minStake_ = getModelMinimumStake(); + uint256 minStake_ = modelsStorage.modelMinimumStake; if (newStake_ < minStake_) { revert ModelStakeTooLow(newStake_, minStake_); } if (amount_ > 0) { - getToken().safeTransferFrom(_msgSender(), address(this), amount_); + BidsStorage storage bidsStorage = getBidsStorage(); + IERC20(bidsStorage.token).safeTransferFrom(_msgSender(), address(this), amount_); } if (model.createdAt == 0) { - addModelId(modelId_); + modelsStorage.modelIds.add(modelId_); model.createdAt = uint128(block.timestamp); model.owner = _msgSender(); @@ -57,11 +62,14 @@ contract ModelRegistry is IModelRegistry, OwnableDiamondStorage, ModelStorage, B model.tags = tags_; model.isDeleted = false; + modelsStorage.activeModels.add(modelId_); + emit ModelRegisteredUpdated(_msgSender(), modelId_); } function modelDeregister(bytes32 modelId_) external { - Model storage model = models(modelId_); + ModelsStorage storage modelsStorage = getModelsStorage(); + Model storage model = modelsStorage.models[modelId_]; _onlyAccount(model.owner); if (!isModelActiveBidsEmpty(modelId_)) { @@ -76,7 +84,10 @@ contract ModelRegistry is IModelRegistry, OwnableDiamondStorage, ModelStorage, B model.stake = 0; model.isDeleted = true; - getToken().safeTransfer(model.owner, withdrawAmount_); + modelsStorage.activeModels.remove(modelId_); + + BidsStorage storage bidsStorage = getBidsStorage(); + IERC20(bidsStorage.token).safeTransfer(model.owner, withdrawAmount_); emit ModelDeregistered(model.owner, modelId_); } diff --git a/smart-contracts/contracts/diamond/facets/ProviderRegistry.sol b/smart-contracts/contracts/diamond/facets/ProviderRegistry.sol index 996a0afd..7a8c3d67 100644 --- a/smart-contracts/contracts/diamond/facets/ProviderRegistry.sol +++ b/smart-contracts/contracts/diamond/facets/ProviderRegistry.sol @@ -2,6 +2,7 @@ pragma solidity ^0.8.24; import {SafeERC20, IERC20} from "@openzeppelin/contracts/token/ERC20/utils/SafeERC20.sol"; +import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; import {OwnableDiamondStorage} from "../presets/OwnableDiamondStorage.sol"; @@ -11,25 +12,30 @@ import {ProviderStorage} from "../storages/ProviderStorage.sol"; import {IProviderRegistry} from "../../interfaces/facets/IProviderRegistry.sol"; contract ProviderRegistry is IProviderRegistry, OwnableDiamondStorage, ProviderStorage, BidStorage { + using EnumerableSet for EnumerableSet.AddressSet; using SafeERC20 for IERC20; - function __ProviderRegistry_init() external initializer(PROVIDER_STORAGE_SLOT) {} + function __ProviderRegistry_init() external initializer(PROVIDERS_STORAGE_SLOT) {} function providerSetMinStake(uint256 providerMinimumStake_) external onlyOwner { - setProviderMinimumStake(providerMinimumStake_); + PovidersStorage storage providersStorage = getProvidersStorage(); + providersStorage.providerMinimumStake = providerMinimumStake_; - emit ProviderMinStakeUpdated(providerMinimumStake_); + emit ProviderMinimumStakeUpdated(providerMinimumStake_); } function providerRegister(uint256 amount_, string calldata endpoint_) external { + BidsStorage storage bidsStorage = getBidsStorage(); + if (amount_ > 0) { - getToken().safeTransferFrom(_msgSender(), address(this), amount_); + IERC20(bidsStorage.token).safeTransferFrom(_msgSender(), address(this), amount_); } - Provider storage provider = providers(_msgSender()); + PovidersStorage storage providersStorage = getProvidersStorage(); + Provider storage provider = providersStorage.providers[_msgSender()]; uint256 newStake_ = provider.stake + amount_; - uint256 minStake_ = getProviderMinimumStake(); + uint256 minStake_ = providersStorage.providerMinimumStake; if (newStake_ < minStake_) { revert ProviderStakeTooLow(newStake_, minStake_); } @@ -45,11 +51,14 @@ contract ProviderRegistry is IProviderRegistry, OwnableDiamondStorage, ProviderS provider.endpoint = endpoint_; provider.stake = newStake_; + providersStorage.activeProviders.add(_msgSender()); + emit ProviderRegistered(_msgSender()); } function providerDeregister() external { - Provider storage provider = providers(_msgSender()); + PovidersStorage storage providersStorage = getProvidersStorage(); + Provider storage provider = providersStorage.providers[_msgSender()]; if (provider.createdAt == 0) { revert ProviderNotFound(); @@ -66,8 +75,11 @@ contract ProviderRegistry is IProviderRegistry, OwnableDiamondStorage, ProviderS provider.stake -= withdrawAmount_; provider.isDeleted = true; + providersStorage.activeProviders.remove(_msgSender()); + if (withdrawAmount_ > 0) { - getToken().safeTransfer(_msgSender(), withdrawAmount_); + BidsStorage storage bidsStorage = getBidsStorage(); + IERC20(bidsStorage.token).safeTransfer(_msgSender(), withdrawAmount_); } emit ProviderDeregistered(_msgSender()); diff --git a/smart-contracts/contracts/diamond/facets/SessionRouter.sol b/smart-contracts/contracts/diamond/facets/SessionRouter.sol index 331c6beb..1b2cac61 100644 --- a/smart-contracts/contracts/diamond/facets/SessionRouter.sol +++ b/smart-contracts/contracts/diamond/facets/SessionRouter.sol @@ -29,13 +29,18 @@ contract SessionRouter is using Math for *; using LibSD for LibSD.SD; using SafeERC20 for IERC20; + using EnumerableSet for EnumerableSet.Bytes32Set; function __SessionRouter_init( address fundingAccount_, Pool[] calldata pools_ - ) external initializer(SESSION_STORAGE_SLOT) { - setFundingAccount(fundingAccount_); - setPools(pools_); + ) external initializer(SESSIONS_STORAGE_SLOT) { + SessionsStorage storage sessionsStorage = getSessionsStorage(); + + sessionsStorage.fundingAccount = fundingAccount_; + for (uint256 i = 0; i < pools_.length; i++) { + sessionsStorage.pools.push(pools_[i]); + } } //////////////////////////// @@ -48,11 +53,11 @@ contract SessionRouter is * @dev call 'Distribution.pools(3)' where '3' is a poolId */ function setPoolConfig(uint256 index_, Pool calldata pool_) external onlyOwner { - if (index_ >= pools().length) { + if (index_ >= getSessionsStorage().pools.length) { revert SessionPoolIndexOutOfBounds(); } - setPool(index_, pool_); + getSessionsStorage().pools[index_] = pool_; } //////////////////////// @@ -68,22 +73,25 @@ contract SessionRouter is revert SessionBidNotFound(); } - Bid storage bid = bids(bidId_); + BidsStorage storage bidsStorage = getBidsStorage(); + SessionsStorage storage sessionsStorage = getSessionsStorage(); + + Bid storage bid = bidsStorage.bids[bidId_]; if (!_isValidProviderReceipt(bid.provider, approvalEncoded_, signature_)) { revert SessionProviderSignatureMismatch(); } - if (getIsProviderApprovalUsed(approvalEncoded_)) { + if (sessionsStorage.isProviderApprovalUsed[approvalEncoded_]) { revert SessionDuplicateApproval(); } uint128 endsAt_ = getSessionEnd(amount_, bid.pricePerSecond, uint128(block.timestamp)); - bytes32 sessionId_ = getSessionId(_msgSender(), bid.provider, bidId_, incrementSessionNonce()); + bytes32 sessionId_ = getSessionId(_msgSender(), bid.provider, bidId_, sessionsStorage.sessionNonce++); if (endsAt_ - block.timestamp < MIN_SESSION_DURATION) { revert SessionTooShort(); } - Session storage session = sessions(sessionId_); + Session storage session = sessionsStorage.sessions[sessionId_]; session.user = _msgSender(); session.stake = amount_; @@ -92,12 +100,13 @@ contract SessionRouter is session.endsAt = endsAt_; session.isActive = true; - addUserSessionId(_msgSender(), sessionId_); - addProviderSessionId(bid.provider, sessionId_); - addModelSessionId(bid.modelId, sessionId_); - setIsProviderApprovalUsed(approvalEncoded_, true); + sessionsStorage.userSessions[_msgSender()].add(sessionId_); + sessionsStorage.providerSessions[bid.provider].add(sessionId_); + sessionsStorage.modelSessions[bid.modelId].add(sessionId_); - getToken().safeTransferFrom(_msgSender(), address(this), amount_); + sessionsStorage.isProviderApprovalUsed[approvalEncoded_] = true; + + IERC20(bidsStorage.token).safeTransferFrom(_msgSender(), address(this), amount_); emit SessionOpened(_msgSender(), sessionId_, bid.provider); @@ -161,8 +170,8 @@ contract SessionRouter is function closeSession(bytes calldata receiptEncoded_, bytes calldata signature_) external { (bytes32 sessionId_, uint32 tpsScaled1000_, uint32 ttftMs_) = _extractProviderReceipt(receiptEncoded_); - Session storage session = sessions(sessionId_); - Bid storage bid = bids(session.bidId); + Session storage session = getSessionsStorage().sessions[sessionId_]; + Bid storage bid = getBidsStorage().bids[session.bidId]; _onlyAccount(session.user); if (session.closedAt != 0) { @@ -201,10 +210,10 @@ contract SessionRouter is uint256 userInitialLock_ = userDuration_ * bid.pricePerSecond; userStakeToLock_ = session.stake.min(stipendToStake(userInitialLock_, startOfToday_)); - addUserStakeOnHold(session.user, OnHold(userStakeToLock_, uint128(startOfToday_ + 1 days))); + getSessionsStorage().userStakesOnHold[session.user].push(OnHold(userStakeToLock_, uint128(startOfToday_ + 1 days))); } uint256 userAmountToWithdraw_ = session.stake - userStakeToLock_; - getToken().safeTransfer(session.user, userAmountToWithdraw_); + IERC20(getBidsStorage().token).safeTransfer(session.user, userAmountToWithdraw_); //// END //// STATS @@ -261,8 +270,8 @@ contract SessionRouter is * @dev Allows providers to receive their funds after the end or closure of the session */ function claimForProvider(bytes32 sessionId_) external { - Session storage session = sessions(sessionId_); - Bid storage bid = bids(session.bidId); + Session storage session = getSessionsStorage().sessions[sessionId_]; + Bid storage bid = getBidsStorage().bids[session.bidId]; _onlyAccount(bid.provider); @@ -282,8 +291,12 @@ contract SessionRouter is * @param amount_ Amount of reward to send */ function _claimForProvider(Session storage session, uint256 amount_) private { - Bid storage bid = bids(session.bidId); - Provider storage provider = providers(bid.provider); + SessionsStorage storage sessionsStorage = getSessionsStorage(); + BidsStorage storage bidsStorage = getBidsStorage(); + PovidersStorage storage providersStorage = getProvidersStorage(); + + Bid storage bid = bidsStorage.bids[session.bidId]; + Provider storage provider = providersStorage.providers[bid.provider]; if (block.timestamp > provider.limitPeriodEnd) { provider.limitPeriodEnd = uint128(block.timestamp) + PROVIDER_REWARD_LIMITER_PERIOD; @@ -300,9 +313,9 @@ contract SessionRouter is session.providerWithdrawnAmount += amount_; provider.limitPeriodEarned += amount_; - increaseProvidersTotalClaimed(amount_); + sessionsStorage.providersTotalClaimed += amount_; - getToken().safeTransferFrom(getFundingAccount(), bid.provider, amount_); + IERC20(bidsStorage.token).safeTransferFrom(sessionsStorage.fundingAccount, bid.provider, amount_); } /** @@ -316,7 +329,7 @@ contract SessionRouter is address user_, uint8 iterations_ ) external view returns (uint256 available_, uint256 hold_) { - OnHold[] memory onHold = userStakesOnHold(user_); + OnHold[] memory onHold = getSessionsStorage().userStakesOnHold[user_]; iterations_ = iterations_ > onHold.length ? uint8(onHold.length) : iterations_; for (uint256 i = 0; i < onHold.length; i++) { @@ -333,7 +346,7 @@ contract SessionRouter is function withdrawUserStakes(uint8 iterations_) external { uint256 amount_ = 0; - OnHold[] storage onHoldEntries = userStakesOnHold(_msgSender()); + OnHold[] storage onHoldEntries = getSessionsStorage().userStakesOnHold[_msgSender()]; uint8 i = iterations_ >= onHoldEntries.length ? uint8(onHoldEntries.length) : iterations_; i--; @@ -356,7 +369,7 @@ contract SessionRouter is revert SessionUserAmountToWithdrawIsZero(); } - getToken().safeTransfer(_msgSender(), amount_); + IERC20(getBidsStorage().token).safeTransfer(_msgSender(), amount_); emit UserWithdrawn(_msgSender(), amount_); } @@ -376,7 +389,8 @@ contract SessionRouter is * @dev Returns today's compute balance in MOR without claimed amount */ function getComputeBalance(uint128 timestamp_) public view returns (uint256) { - Pool storage pool = pool(COMPUTE_POOL_INDEX); + SessionsStorage storage sessionsStorage = getSessionsStorage(); + Pool storage pool = sessionsStorage.pools[COMPUTE_POOL_INDEX]; uint256 periodReward = LinearDistributionIntervalDecrease.getPeriodReward( pool.initialReward, @@ -387,7 +401,7 @@ contract SessionRouter is uint128(startOfTheDay(timestamp_)) ); - return periodReward - getProvidersTotalClaimed(); + return periodReward - sessionsStorage.providersTotalClaimed; } /** @@ -398,7 +412,9 @@ contract SessionRouter is uint256 startOfTheDay_ = startOfTheDay(timestamp_); uint256 totalSupply_ = 0; - Pool[] storage pools = pools(); + SessionsStorage storage sessionsStorage = getSessionsStorage(); + Pool[] storage pools = sessionsStorage.pools; + for (uint256 i = 0; i < pools.length; i++) { if (i == COMPUTE_POOL_INDEX) continue; @@ -412,7 +428,7 @@ contract SessionRouter is ); } - return totalSupply_ + getProvidersTotalClaimed(); + return totalSupply_ + sessionsStorage.providersTotalClaimed; } function startOfTheDay(uint128 timestamp_) public pure returns (uint128) { diff --git a/smart-contracts/contracts/diamond/storages/BidStorage.sol b/smart-contracts/contracts/diamond/storages/BidStorage.sol index 667b1712..067a1103 100644 --- a/smart-contracts/contracts/diamond/storages/BidStorage.sol +++ b/smart-contracts/contracts/diamond/storages/BidStorage.sol @@ -3,7 +3,6 @@ pragma solidity ^0.8.24; import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; import {IERC20} from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; - import {Paginator} from "@solarity/solidity-lib/libs/arrays/Paginator.sol"; import {IBidStorage} from "../../interfaces/storage/IBidStorage.sol"; @@ -12,21 +11,21 @@ contract BidStorage is IBidStorage { using Paginator for *; using EnumerableSet for EnumerableSet.Bytes32Set; - struct BDStorage { - IERC20 token; // MOR token + struct BidsStorage { + address token; // MOR token mapping(bytes32 bidId => Bid) bids; // bidId = keccak256(provider, modelId, nonce) - mapping(bytes32 modelId => bytes32[]) modelBids; // keccak256(provider, modelId) => all bidIds + mapping(bytes32 modelId => EnumerableSet.Bytes32Set) modelBids; // keccak256(provider, modelId) => all bidIds mapping(bytes32 modelId => EnumerableSet.Bytes32Set) modelActiveBids; // modelId => active bidIds - mapping(address provider => bytes32[]) providerBids; // provider => all bidIds + mapping(address provider => EnumerableSet.Bytes32Set) providerBids; // provider => all bidIds mapping(address provider => EnumerableSet.Bytes32Set) providerActiveBids; // provider => active bidIds mapping(bytes32 providerModelId => uint256) providerModelNonce; // keccak256(provider, modelId) => last nonce } - bytes32 public constant BID_STORAGE_SLOT = keccak256("diamond.standard.bid.storage"); + bytes32 public constant BIDS_STORAGE_SLOT = keccak256("diamond.standard.bids.storage"); /** PUBLIC, GETTERS */ function getBid(bytes32 bidId_) external view returns (Bid memory) { - return _getBidStorage().bids[bidId_]; + return getBidsStorage().bids[bidId_]; } function getProviderActiveBids( @@ -34,15 +33,15 @@ contract BidStorage is IBidStorage { uint256 offset_, uint256 limit_ ) external view returns (bytes32[] memory) { - return _getBidStorage().providerActiveBids[provider_].part(offset_, limit_); + return getBidsStorage().providerActiveBids[provider_].part(offset_, limit_); } function getModelActiveBids( bytes32 modelId_, uint256 offset_, uint256 limit_ - ) public view returns (bytes32[] memory) { - return _getBidStorage().modelActiveBids[modelId_].part(offset_, limit_); + ) external view returns (bytes32[] memory) { + return getBidsStorage().modelActiveBids[modelId_].part(offset_, limit_); } function getProviderBids( @@ -50,72 +49,34 @@ contract BidStorage is IBidStorage { uint256 offset_, uint256 limit_ ) external view returns (bytes32[] memory) { - return _getBidStorage().providerBids[provider_].part(offset_, limit_); + return getBidsStorage().providerBids[provider_].part(offset_, limit_); } function getModelBids(bytes32 modelId_, uint256 offset_, uint256 limit_) external view returns (bytes32[] memory) { - return _getBidStorage().modelBids[modelId_].part(offset_, limit_); + return getBidsStorage().modelBids[modelId_].part(offset_, limit_); } - function getToken() public view returns (IERC20) { - return _getBidStorage().token; + function getToken() external view returns (address) { + return getBidsStorage().token; } function isBidActive(bytes32 bidId_) public view returns (bool) { - Bid storage bid = _getBidStorage().bids[bidId_]; + Bid storage bid = getBidsStorage().bids[bidId_]; return bid.createdAt != 0 && bid.deletedAt == 0; } - /** INTERNAL, GETTERS */ - function bids(bytes32 bidId_) internal view returns (Bid storage) { - return _getBidStorage().bids[bidId_]; - } - + /** INTERNAL */ function isModelActiveBidsEmpty(bytes32 modelId) internal view returns (bool) { - return _getBidStorage().modelActiveBids[modelId].length() == 0; + return getBidsStorage().modelActiveBids[modelId].length() == 0; } function isProviderActiveBidsEmpty(address provider) internal view returns (bool) { - return _getBidStorage().providerActiveBids[provider].length() == 0; - } - - /** INTERNAL, SETTERS */ - function addProviderActiveBids(address provider_, bytes32 bidId_) internal { - _getBidStorage().providerActiveBids[provider_].add(bidId_); - } - - function removeProviderActiveBids(address provider_, bytes32 bidId_) internal { - _getBidStorage().providerActiveBids[provider_].remove(bidId_); - } - - function addModelActiveBids(bytes32 modelId_, bytes32 bidId_) internal { - _getBidStorage().modelActiveBids[modelId_].add(bidId_); - } - - function removeModelActiveBids(bytes32 modelId_, bytes32 bidId_) internal { - _getBidStorage().modelActiveBids[modelId_].remove(bidId_); - } - - function addProviderBid(address provider_, bytes32 bidId_) internal { - _getBidStorage().providerBids[provider_].push(bidId_); - } - - function addModelBid(bytes32 modelId_, bytes32 bidId_) internal { - _getBidStorage().modelBids[modelId_].push(bidId_); - } - - function setToken(IERC20 token_) internal { - _getBidStorage().token = token_; - } - - function incrementBidNonce(bytes32 providerModelId_) internal returns (uint256) { - return _getBidStorage().providerModelNonce[providerModelId_]++; + return getBidsStorage().providerActiveBids[provider].length() == 0; } - /** PRIVATE */ - function _getBidStorage() private pure returns (BDStorage storage ds) { - bytes32 slot_ = BID_STORAGE_SLOT; + function getBidsStorage() internal pure returns (BidsStorage storage ds) { + bytes32 slot_ = BIDS_STORAGE_SLOT; assembly { ds.slot := slot_ diff --git a/smart-contracts/contracts/diamond/storages/MarketplaceStorage.sol b/smart-contracts/contracts/diamond/storages/MarketplaceStorage.sol index b5445cf6..d905fada 100644 --- a/smart-contracts/contracts/diamond/storages/MarketplaceStorage.sol +++ b/smart-contracts/contracts/diamond/storages/MarketplaceStorage.sol @@ -4,34 +4,25 @@ pragma solidity ^0.8.24; import {IMarketplaceStorage} from "../../interfaces/storage/IMarketplaceStorage.sol"; contract MarketplaceStorage is IMarketplaceStorage { - struct MPStorage { + struct MarketStorage { uint256 feeBalance; // Total fees balance of the contract uint256 bidFee; } - bytes32 public constant MARKETPLACE_STORAGE_SLOT = keccak256("diamond.standard.marketplace.storage"); + bytes32 public constant MARKET_STORAGE_SLOT = keccak256("diamond.standard.market.storage"); /** PUBLIC, GETTERS */ - function getBidFee() public view returns (uint256) { - return _getMarketplaceStorage().bidFee; + function getBidFee() external view returns (uint256) { + return getMarketStorage().bidFee; } - function getFeeBalance() public view returns (uint256) { - return _getMarketplaceStorage().feeBalance; + function getFeeBalance() external view returns (uint256) { + return getMarketStorage().feeBalance; } - /** INTERNAL, SETTERS */ - function setBidFee(uint256 bidFee_) internal { - _getMarketplaceStorage().bidFee = bidFee_; - } - - function setFeeBalance(uint256 feeBalance_) internal { - _getMarketplaceStorage().feeBalance = feeBalance_; - } - - /** PRIVATE */ - function _getMarketplaceStorage() private pure returns (MPStorage storage ds) { - bytes32 slot_ = MARKETPLACE_STORAGE_SLOT; + /** INTERNAL */ + function getMarketStorage() internal pure returns (MarketStorage storage ds) { + bytes32 slot_ = MARKET_STORAGE_SLOT; assembly { ds.slot := slot_ diff --git a/smart-contracts/contracts/diamond/storages/ModelStorage.sol b/smart-contracts/contracts/diamond/storages/ModelStorage.sol index 9600e900..344d3392 100644 --- a/smart-contracts/contracts/diamond/storages/ModelStorage.sol +++ b/smart-contracts/contracts/diamond/storages/ModelStorage.sol @@ -1,55 +1,49 @@ // SPDX-License-Identifier: MIT pragma solidity ^0.8.24; +import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; import {Paginator} from "@solarity/solidity-lib/libs/arrays/Paginator.sol"; import {IModelStorage} from "../../interfaces/storage/IModelStorage.sol"; contract ModelStorage is IModelStorage { using Paginator for *; + using EnumerableSet for EnumerableSet.Bytes32Set; - struct MDLStorage { + struct ModelsStorage { uint256 modelMinimumStake; - bytes32[] modelIds; + EnumerableSet.Bytes32Set modelIds; mapping(bytes32 modelId => Model) models; + // TODO: move vars below to the graph in the future + EnumerableSet.Bytes32Set activeModels; } - bytes32 public constant MODEL_STORAGE_SLOT = keccak256("diamond.standard.model.storage"); + bytes32 public constant MODELS_STORAGE_SLOT = keccak256("diamond.standard.models.storage"); /** PUBLIC, GETTERS */ function getModel(bytes32 modelId_) external view returns (Model memory) { - return _getModelStorage().models[modelId_]; + return getModelsStorage().models[modelId_]; } function getModelIds(uint256 offset_, uint256 limit_) external view returns (bytes32[] memory) { - return _getModelStorage().modelIds.part(offset_, limit_); + return getModelsStorage().modelIds.part(offset_, limit_); } - function getModelMinimumStake() public view returns (uint256) { - return _getModelStorage().modelMinimumStake; + function getModelMinimumStake() external view returns (uint256) { + return getModelsStorage().modelMinimumStake; } - function getIsModelActive(bytes32 modelId_) public view returns (bool) { - return !_getModelStorage().models[modelId_].isDeleted; - } - - /** INTERNAL, GETTERS */ - function models(bytes32 modelId_) internal view returns (Model storage) { - return _getModelStorage().models[modelId_]; + function getActiveModels(uint256 offset_, uint256 limit_) external view returns (bytes32[] memory) { + return getModelsStorage().activeModels.part(offset_, limit_); } - /** INTERNAL, SETTERS */ - function addModelId(bytes32 modelId_) internal { - _getModelStorage().modelIds.push(modelId_); - } - - function setModelMinimumStake(uint256 modelMinimumStake_) internal { - _getModelStorage().modelMinimumStake = modelMinimumStake_; + function getIsModelActive(bytes32 modelId_) public view returns (bool) { + return !getModelsStorage().models[modelId_].isDeleted; } - /** PRIVATE */ - function _getModelStorage() private pure returns (MDLStorage storage ds) { - bytes32 slot_ = MODEL_STORAGE_SLOT; + /** INTERNAL */ + function getModelsStorage() internal pure returns (ModelsStorage storage ds) { + bytes32 slot_ = MODELS_STORAGE_SLOT; assembly { ds.slot := slot_ diff --git a/smart-contracts/contracts/diamond/storages/ProviderStorage.sol b/smart-contracts/contracts/diamond/storages/ProviderStorage.sol index 4318bd29..ebcf5f2e 100644 --- a/smart-contracts/contracts/diamond/storages/ProviderStorage.sol +++ b/smart-contracts/contracts/diamond/storages/ProviderStorage.sol @@ -1,45 +1,47 @@ // SPDX-License-Identifier: MIT pragma solidity ^0.8.24; +import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; +import {Paginator} from "@solarity/solidity-lib/libs/arrays/Paginator.sol"; + import {IProviderStorage} from "../../interfaces/storage/IProviderStorage.sol"; contract ProviderStorage is IProviderStorage { - struct PRVDRStorage { + using Paginator for *; + using EnumerableSet for EnumerableSet.AddressSet; + + struct PovidersStorage { uint256 providerMinimumStake; mapping(address => Provider) providers; + // TODO: move vars below to the graph in the future + EnumerableSet.AddressSet activeProviders; } // Reward for this period will be limited by the stake uint128 constant PROVIDER_REWARD_LIMITER_PERIOD = 365 days; - bytes32 public constant PROVIDER_STORAGE_SLOT = keccak256("diamond.standard.provider.storage"); + bytes32 public constant PROVIDERS_STORAGE_SLOT = keccak256("diamond.standard.providers.storage"); /** PUBLIC, GETTERS */ function getProvider(address provider_) external view returns (Provider memory) { - return providers(provider_); - } - - function getProviderMinimumStake() public view returns (uint256) { - return _getProviderStorage().providerMinimumStake; + return getProvidersStorage().providers[provider_]; } - function getIsProviderActive(address provider_) public view returns (bool) { - return !providers(provider_).isDeleted; + function getProviderMinimumStake() external view returns (uint256) { + return getProvidersStorage().providerMinimumStake; } - /** INTERNAL, GETTERS */ - function providers(address provider_) internal view returns (Provider storage) { - return _getProviderStorage().providers[provider_]; + function getActiveProviders(uint256 offset_, uint256 limit_) external view returns (address[] memory) { + return getProvidersStorage().activeProviders.part(offset_, limit_); } - /** INTERNAL, SETTERS */ - function setProviderMinimumStake(uint256 providerMinimumStake_) internal { - _getProviderStorage().providerMinimumStake = providerMinimumStake_; + function getIsProviderActive(address provider_) public view returns (bool) { + return !getProvidersStorage().providers[provider_].isDeleted; } - /** PRIVATE */ - function _getProviderStorage() private pure returns (PRVDRStorage storage ds) { - bytes32 slot_ = PROVIDER_STORAGE_SLOT; + /** INTERNAL */ + function getProvidersStorage() internal pure returns (PovidersStorage storage ds) { + bytes32 slot_ = PROVIDERS_STORAGE_SLOT; assembly { ds.slot := slot_ diff --git a/smart-contracts/contracts/diamond/storages/SessionStorage.sol b/smart-contracts/contracts/diamond/storages/SessionStorage.sol index 20ab07e5..0823570e 100644 --- a/smart-contracts/contracts/diamond/storages/SessionStorage.sol +++ b/smart-contracts/contracts/diamond/storages/SessionStorage.sol @@ -10,7 +10,7 @@ contract SessionStorage is ISessionStorage { using Paginator for *; using EnumerableSet for EnumerableSet.Bytes32Set; - struct SNStorage { + struct SessionsStorage { // Account which stores the MOR tokens with infinite allowance for this contract address fundingAccount; // Distribution pools configuration that mirrors L1 contract @@ -28,7 +28,7 @@ contract SessionStorage is ISessionStorage { mapping(bytes providerApproval => bool) isProviderApprovalUsed; } - bytes32 public constant SESSION_STORAGE_SLOT = keccak256("diamond.standard.session.storage"); + bytes32 public constant SESSIONS_STORAGE_SLOT = keccak256("diamond.standard.sessions.storage"); uint32 public constant MIN_SESSION_DURATION = 5 minutes; uint32 public constant MAX_SESSION_DURATION = 1 days; uint32 public constant SIGNATURE_TTL = 10 minutes; @@ -36,11 +36,11 @@ contract SessionStorage is ISessionStorage { /** PUBLIC, GETTERS */ function getSession(bytes32 sessionId_) external view returns (Session memory) { - return _getSessionStorage().sessions[sessionId_]; + return getSessionsStorage().sessions[sessionId_]; } function getUserSessions(address user_, uint256 offset_, uint256 limit_) external view returns (bytes32[] memory) { - return _getSessionStorage().userSessions[user_].part(offset_, limit_); + return getSessionsStorage().userSessions[user_].part(offset_, limit_); } function getProviderSessions( @@ -48,7 +48,7 @@ contract SessionStorage is ISessionStorage { uint256 offset_, uint256 limit_ ) external view returns (bytes32[] memory) { - return _getSessionStorage().providerSessions[provider_].part(offset_, limit_); + return getSessionsStorage().providerSessions[provider_].part(offset_, limit_); } function getModelSessions( @@ -56,98 +56,98 @@ contract SessionStorage is ISessionStorage { uint256 offset_, uint256 limit_ ) external view returns (bytes32[] memory) { - return _getSessionStorage().modelSessions[modelId_].part(offset_, limit_); + return getSessionsStorage().modelSessions[modelId_].part(offset_, limit_); } function getPools() external view returns (Pool[] memory) { - return _getSessionStorage().pools; + return getSessionsStorage().pools; } function getPool(uint256 index_) external view returns (Pool memory) { - return _getSessionStorage().pools[index_]; + return getSessionsStorage().pools[index_]; } - function getFundingAccount() public view returns (address) { - return _getSessionStorage().fundingAccount; + function getFundingAccount() external view returns (address) { + return getSessionsStorage().fundingAccount; } function getTotalSessions(address providerAddr_) public view returns (uint256) { - return _getSessionStorage().providerSessions[providerAddr_].length(); + return getSessionsStorage().providerSessions[providerAddr_].length(); } - function getProvidersTotalClaimed() public view returns (uint256) { - return _getSessionStorage().providersTotalClaimed; + function getProvidersTotalClaimed() external view returns (uint256) { + return getSessionsStorage().providersTotalClaimed; } - function getIsProviderApprovalUsed(bytes memory approval_) public view returns (bool) { - return _getSessionStorage().isProviderApprovalUsed[approval_]; + function getIsProviderApprovalUsed(bytes memory approval_) external view returns (bool) { + return getSessionsStorage().isProviderApprovalUsed[approval_]; } - /** INTERNAL, GETTERS */ - function pools() internal view returns (Pool[] storage) { - return _getSessionStorage().pools; - } + // /** INTERNAL, GETTERS */ + // function pools() internal view returns (Pool[] storage) { + // return _getSessionStorage().pools; + // } - function pool(uint256 poolIndex_) internal view returns (Pool storage) { - return _getSessionStorage().pools[poolIndex_]; - } + // function pool(uint256 poolIndex_) internal view returns (Pool storage) { + // return _getSessionStorage().pools[poolIndex_]; + // } - function userStakesOnHold(address user_) internal view returns (OnHold[] storage) { - return _getSessionStorage().userStakesOnHold[user_]; - } + // function userStakesOnHold(address user_) internal view returns (OnHold[] storage) { + // return _getSessionStorage().userStakesOnHold[user_]; + // } - function sessions(bytes32 sessionId_) internal view returns (Session storage) { - return _getSessionStorage().sessions[sessionId_]; - } + // function sessions(bytes32 sessionId_) internal view returns (Session storage) { + // return _getSessionStorage().sessions[sessionId_]; + // } /** INTERNAL, SETTERS */ - function setFundingAccount(address fundingAccount_) internal { - _getSessionStorage().fundingAccount = fundingAccount_; - } + // function setFundingAccount(address fundingAccount_) internal { + // _getSessionStorage().fundingAccount = fundingAccount_; + // } - function setPools(Pool[] calldata pools_) internal { - SNStorage storage s = _getSessionStorage(); + // function setPools(Pool[] calldata pools_) internal { + // SNStorage storage s = _getSessionStorage(); - for (uint256 i = 0; i < pools_.length; i++) { - s.pools.push(pools_[i]); - } - } + // for (uint256 i = 0; i < pools_.length; i++) { + // s.pools.push(pools_[i]); + // } + // } - function setPool(uint256 index_, Pool calldata pool_) internal { - _getSessionStorage().pools[index_] = pool_; - } + // function setPool(uint256 index_, Pool calldata pool_) internal { + // _getSessionStorage().pools[index_] = pool_; + // } - function addUserSessionId(address user_, bytes32 sessionId_) internal { - _getSessionStorage().userSessions[user_].add(sessionId_); - } + // function addUserSessionId(address user_, bytes32 sessionId_) internal { + // _getSessionStorage().userSessions[user_].add(sessionId_); + // } - function addProviderSessionId(address provider_, bytes32 sessionId_) internal { - _getSessionStorage().providerSessions[provider_].add(sessionId_); - } + // function addProviderSessionId(address provider_, bytes32 sessionId_) internal { + // _getSessionStorage().providerSessions[provider_].add(sessionId_); + // } - function addModelSessionId(bytes32 modelId, bytes32 sessionId) internal { - _getSessionStorage().modelSessions[modelId].add(sessionId); - } + // function addModelSessionId(bytes32 modelId, bytes32 sessionId) internal { + // _getSessionStorage().modelSessions[modelId].add(sessionId); + // } - function addUserStakeOnHold(address user, OnHold memory onHold) internal { - _getSessionStorage().userStakesOnHold[user].push(onHold); - } + // function addUserStakeOnHold(address user, OnHold memory onHold) internal { + // _getSessionStorage().userStakesOnHold[user].push(onHold); + // } - function increaseProvidersTotalClaimed(uint256 amount) internal { - _getSessionStorage().providersTotalClaimed += amount; - } + // function increaseProvidersTotalClaimed(uint256 amount) internal { + // _getSessionStorage().providersTotalClaimed += amount; + // } - function incrementSessionNonce() internal returns (uint256) { - return _getSessionStorage().sessionNonce++; - } + // function incrementSessionNonce() internal returns (uint256) { + // return _getSessionStorage().sessionNonce++; + // } - function setIsProviderApprovalUsed(bytes memory approval_, bool isUsed_) internal { - _getSessionStorage().isProviderApprovalUsed[approval_] = isUsed_; - } + // function setIsProviderApprovalUsed(bytes memory approval_, bool isUsed_) internal { + // _getSessionStorage().isProviderApprovalUsed[approval_] = isUsed_; + // } - /** PRIVATE */ - function _getSessionStorage() private pure returns (SNStorage storage ds) { - bytes32 slot_ = SESSION_STORAGE_SLOT; + /** INTERNAL */ + function getSessionsStorage() internal pure returns (SessionsStorage storage ds) { + bytes32 slot_ = SESSIONS_STORAGE_SLOT; assembly { ds.slot := slot_ diff --git a/smart-contracts/contracts/interfaces/facets/IModelRegistry.sol b/smart-contracts/contracts/interfaces/facets/IModelRegistry.sol index de19f793..5f2e1f53 100644 --- a/smart-contracts/contracts/interfaces/facets/IModelRegistry.sol +++ b/smart-contracts/contracts/interfaces/facets/IModelRegistry.sol @@ -6,7 +6,7 @@ import {IModelStorage} from "../storage/IModelStorage.sol"; interface IModelRegistry is IModelStorage { event ModelRegisteredUpdated(address indexed owner, bytes32 indexed modelId); event ModelDeregistered(address indexed owner, bytes32 indexed modelId); - event ModelMinStakeUpdated(uint256 newStake); + event ModelMinimumStakeUpdated(uint256 modelMinimumStake); error ModelStakeTooLow(uint256 amount, uint256 minAmount); error ModelHasAlreadyDeregistered(); error ModelNotFound(); diff --git a/smart-contracts/contracts/interfaces/facets/IProviderRegistry.sol b/smart-contracts/contracts/interfaces/facets/IProviderRegistry.sol index 35c3e392..20cf7d84 100644 --- a/smart-contracts/contracts/interfaces/facets/IProviderRegistry.sol +++ b/smart-contracts/contracts/interfaces/facets/IProviderRegistry.sol @@ -6,7 +6,7 @@ import {IProviderStorage} from "../storage/IProviderStorage.sol"; interface IProviderRegistry is IProviderStorage { event ProviderRegistered(address indexed provider); event ProviderDeregistered(address indexed provider); - event ProviderMinStakeUpdated(uint256 newStake); + event ProviderMinimumStakeUpdated(uint256 providerMinimumStake); event ProviderWithdrawn(address indexed provider, uint256 amount); error ProviderStakeTooLow(uint256 amount, uint256 minAmount); error ProviderNotDeregistered(); diff --git a/smart-contracts/contracts/interfaces/storage/IBidStorage.sol b/smart-contracts/contracts/interfaces/storage/IBidStorage.sol index d70d056d..3815f663 100644 --- a/smart-contracts/contracts/interfaces/storage/IBidStorage.sol +++ b/smart-contracts/contracts/interfaces/storage/IBidStorage.sol @@ -1,8 +1,6 @@ // SPDX-License-Identifier: MIT pragma solidity ^0.8.24; -import {IERC20} from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; - interface IBidStorage { struct Bid { address provider; @@ -35,7 +33,7 @@ interface IBidStorage { function getModelBids(bytes32 modelId_, uint256 offset_, uint256 limit_) external view returns (bytes32[] memory); - function getToken() external view returns (IERC20); + function getToken() external view returns (address); function isBidActive(bytes32 bidId_) external view returns (bool); } diff --git a/smart-contracts/contracts/interfaces/storage/IModelStorage.sol b/smart-contracts/contracts/interfaces/storage/IModelStorage.sol index ec76fef7..92085d3a 100644 --- a/smart-contracts/contracts/interfaces/storage/IModelStorage.sol +++ b/smart-contracts/contracts/interfaces/storage/IModelStorage.sol @@ -19,5 +19,7 @@ interface IModelStorage { function getModelMinimumStake() external view returns (uint256); + function getActiveModels(uint256 offset_, uint256 limit_) external view returns (bytes32[] memory); + function getIsModelActive(bytes32 modelId_) external view returns (bool); } diff --git a/smart-contracts/contracts/interfaces/storage/IProviderStorage.sol b/smart-contracts/contracts/interfaces/storage/IProviderStorage.sol index 75935a00..3121101f 100644 --- a/smart-contracts/contracts/interfaces/storage/IProviderStorage.sol +++ b/smart-contracts/contracts/interfaces/storage/IProviderStorage.sol @@ -15,5 +15,7 @@ interface IProviderStorage { function getProviderMinimumStake() external view returns (uint256); + function getActiveProviders(uint256 offset_, uint256 limit_) external view returns (address[] memory); + function getIsProviderActive(address provider_) external view returns (bool); } diff --git a/smart-contracts/deploy/1_full_protocol.migration.ts b/smart-contracts/deploy/1_full_protocol.migration.ts new file mode 100644 index 00000000..01d77cc9 --- /dev/null +++ b/smart-contracts/deploy/1_full_protocol.migration.ts @@ -0,0 +1,113 @@ +import { Deployer, Reporter } from '@solarity/hardhat-migrate'; +import { Fragment } from 'ethers'; + +import { parseConfig } from './helpers/config-parser'; + +import { + IBidStorage__factory, + IMarketplace__factory, + IModelRegistry__factory, + IProviderRegistry__factory, + ISessionRouter__factory, + IStatsStorage__factory, + LinearDistributionIntervalDecrease__factory, + LumerinDiamond__factory, + Marketplace, + Marketplace__factory, + ModelRegistry, + ModelRegistry__factory, + ProviderRegistry, + ProviderRegistry__factory, + SessionRouter, + SessionRouter__factory, +} from '@/generated-types/ethers'; +import { FacetAction } from '@/test/helpers/deployers'; + +module.exports = async function (deployer: Deployer) { + const config = parseConfig(); + + const lumerinDiamond = await deployer.deploy(LumerinDiamond__factory); + await lumerinDiamond.__LumerinDiamond_init(); + + const ldid = await deployer.deploy(LinearDistributionIntervalDecrease__factory); + + let providerRegistryFacet = await deployer.deploy(ProviderRegistry__factory); + let modelRegistryFacet = await deployer.deploy(ModelRegistry__factory); + let marketplaceFacet = await deployer.deploy(Marketplace__factory); + let sessionRouterFacet = await deployer.deploy(SessionRouter__factory, { + libraries: { + LinearDistributionIntervalDecrease: ldid, + }, + }); + + await lumerinDiamond['diamondCut((address,uint8,bytes4[])[])']([ + { + facetAddress: providerRegistryFacet, + action: FacetAction.Add, + functionSelectors: IProviderRegistry__factory.createInterface() + .fragments.filter(Fragment.isFunction) + .map((f) => f.selector), + }, + { + facetAddress: modelRegistryFacet, + action: FacetAction.Add, + functionSelectors: IModelRegistry__factory.createInterface() + .fragments.filter(Fragment.isFunction) + .map((f) => f.selector), + }, + { + facetAddress: marketplaceFacet, + action: FacetAction.Add, + functionSelectors: IMarketplace__factory.createInterface() + .fragments.filter(Fragment.isFunction) + .map((f) => f.selector), + }, + { + facetAddress: marketplaceFacet, + action: FacetAction.Add, + functionSelectors: IBidStorage__factory.createInterface() + .fragments.filter(Fragment.isFunction) + .map((f) => f.selector), + }, + { + facetAddress: sessionRouterFacet, + action: FacetAction.Add, + functionSelectors: ISessionRouter__factory.createInterface() + .fragments.filter(Fragment.isFunction) + .map((f) => f.selector), + }, + { + facetAddress: sessionRouterFacet, + action: FacetAction.Add, + functionSelectors: IStatsStorage__factory.createInterface() + .fragments.filter(Fragment.isFunction) + .map((f) => f.selector), + }, + ]); + + providerRegistryFacet = providerRegistryFacet.attach(lumerinDiamond.target) as ProviderRegistry; + await providerRegistryFacet.__ProviderRegistry_init(); + modelRegistryFacet = modelRegistryFacet.attach(lumerinDiamond.target) as ModelRegistry; + await modelRegistryFacet.__ModelRegistry_init(); + marketplaceFacet = marketplaceFacet.attach(lumerinDiamond.target) as Marketplace; + await marketplaceFacet.__Marketplace_init(config.MOR); + sessionRouterFacet = sessionRouterFacet.attach(lumerinDiamond.target) as SessionRouter; + await sessionRouterFacet.__SessionRouter_init(config.fundingAccount, config.pools); + + await providerRegistryFacet.providerSetMinStake(config.providerMinStake); + await modelRegistryFacet.modelSetMinStake(config.modelMinStake); + await marketplaceFacet.setMarketplaceBidFee(config.marketplaceBidFee); + + Reporter.reportContracts( + ['Lumerin Diamond', await lumerinDiamond.getAddress()], + ['Linear Distribution Interval Decrease Library', await ldid.getAddress()], + ); +}; + +// npx hardhat migrate --only 1 + +// npx hardhat migrate --network arbitrum_sepolia --only 1 --verify +// npx hardhat migrate --network arbitrum_sepolia --only 1 --verify --continue + +// npx hardhat migrate --network arbitrum --only 1 --verify +// npx hardhat migrate --network arbitrum --only 1 --verify --continue diff --git a/smart-contracts/deploy/data/config_arbitrum_sepolia.json b/smart-contracts/deploy/data/config_arbitrum_sepolia.json new file mode 100644 index 00000000..5d63728a --- /dev/null +++ b/smart-contracts/deploy/data/config_arbitrum_sepolia.json @@ -0,0 +1,39 @@ +{ + "MOR": "0x34a285a1b1c166420df5b6630132542923b5b27e", + "fundingAccount": "0x19ec1E4b714990620edf41fE28e9a1552953a7F4", + "providerMinStake": "200000000000000000", + "modelMinStake": "100000000000000000", + "marketplaceBidFee": "300000000000000000", + "pools": [ + { + "payoutStart": 1707393600, + "decreaseInterval": 86400, + "initialReward": "3456000000000000000000", + "rewardDecrease": "592558728240000000" + }, + { + "payoutStart": 1707393600, + "decreaseInterval": 86400, + "initialReward": "3456000000000000000000", + "rewardDecrease": "592558728240000000" + }, + { + "payoutStart": 1707393600, + "decreaseInterval": 86400, + "initialReward": "3456000000000000000000", + "rewardDecrease": "592558728240000000" + }, + { + "payoutStart": 1707393600, + "decreaseInterval": 86400, + "initialReward": "3456000000000000000000", + "rewardDecrease": "592558728240000000" + }, + { + "payoutStart": 1707393600, + "decreaseInterval": 86400, + "initialReward": "576000000000000000000", + "rewardDecrease": "98759788040000000" + } + ] +} diff --git a/smart-contracts/deploy/helpers/config-parser.ts b/smart-contracts/deploy/helpers/config-parser.ts new file mode 100644 index 00000000..8c7bc0fb --- /dev/null +++ b/smart-contracts/deploy/helpers/config-parser.ts @@ -0,0 +1,18 @@ +import { readFileSync } from 'fs'; + +import { ISessionStorage } from '@/generated-types/ethers'; + +export type Config = { + MOR: string; + fundingAccount: string; + pools: ISessionStorage.PoolStruct[]; + providerMinStake: string; + modelMinStake: string; + marketplaceBidFee: string; +}; + +export function parseConfig(): Config { + const configPath = `deploy/data/config_arbitrum_sepolia.json`; + + return JSON.parse(readFileSync(configPath, 'utf-8')) as Config; +} diff --git a/smart-contracts/hardhat.config.ts b/smart-contracts/hardhat.config.ts index a040ce23..256700e3 100644 --- a/smart-contracts/hardhat.config.ts +++ b/smart-contracts/hardhat.config.ts @@ -12,6 +12,10 @@ import 'tsconfig-paths/register'; dotenv.config(); +function privateKey() { + return process.env.PRIVATE_KEY !== undefined ? [process.env.PRIVATE_KEY] : []; +} + function typechainTarget() { const target = process.env.TYPECHAIN_TARGET; @@ -32,6 +36,9 @@ const config: HardhatUserConfig = { // auto: true, // interval: 10_000, // }, + // forking: { + // url: `https://arbitrum-sepolia.infura.io/v3/${process.env.INFURA_KEY}`, + // }, }, localhost: { url: 'http://127.0.0.1:8545', @@ -39,6 +46,11 @@ const config: HardhatUserConfig = { gasMultiplier: 1.2, timeout: 1000000000000000, }, + arbitrum_sepolia: { + url: `https://arbitrum-sepolia.infura.io/v3/${process.env.INFURA_KEY}`, + accounts: privateKey(), + gasMultiplier: 1.1, + }, }, solidity: { version: '0.8.24', @@ -78,6 +90,12 @@ const config: HardhatUserConfig = { discriminateTypes: true, dontOverrideCompile: forceTypechain(), }, + etherscan: { + apiKey: { + mainnet: `${process.env.ETHERSCAN_KEY}`, + arbitrumSepolia: `${process.env.ARBITRUM_KEY}`, + }, + }, }; export default config; diff --git a/smart-contracts/test/diamond/facets/ModelRegistry.test.ts b/smart-contracts/test/diamond/facets/ModelRegistry.test.ts index 12e74d88..a4ea9b17 100644 --- a/smart-contracts/test/diamond/facets/ModelRegistry.test.ts +++ b/smart-contracts/test/diamond/facets/ModelRegistry.test.ts @@ -69,7 +69,7 @@ describe('ModelRegistry', () => { const minStake = wei(100); await expect(modelRegistry.modelSetMinStake(minStake)) - .to.emit(modelRegistry, 'ModelMinStakeUpdated') + .to.emit(modelRegistry, 'ModelMinimumStakeUpdated') .withArgs(minStake); expect(await modelRegistry.getModelMinimumStake()).eq(minStake); @@ -124,6 +124,8 @@ describe('ModelRegistry', () => { expect(await token.balanceOf(modelRegistry)).to.eq(wei(100)); expect(await token.balanceOf(SECOND)).to.eq(wei(900)); + expect(await modelRegistry.getActiveModels(0, 10)).to.deep.eq([modelId]); + await modelRegistry.connect(SECOND).modelRegister(modelId, ipfsCID, 0, wei(0), 'name', ['tag_1']); }); it('should add stake to existed model', async () => { @@ -193,6 +195,8 @@ describe('ModelRegistry', () => { expect(await modelRegistry.getIsModelActive(modelId)).to.eq(false); expect(await token.balanceOf(modelRegistry)).to.eq(0); expect(await token.balanceOf(SECOND)).to.eq(wei(1000)); + + expect(await modelRegistry.getActiveModels(0, 10)).to.deep.eq([]); }); it('should throw error when the caller is not an owner or specified address', async () => { await expect(modelRegistry.connect(SECOND).modelDeregister(modelId)).to.be.revertedWithCustomError( diff --git a/smart-contracts/test/diamond/facets/ProviderRegistry.test.ts b/smart-contracts/test/diamond/facets/ProviderRegistry.test.ts index aa095630..bfc82b92 100644 --- a/smart-contracts/test/diamond/facets/ProviderRegistry.test.ts +++ b/smart-contracts/test/diamond/facets/ProviderRegistry.test.ts @@ -66,7 +66,7 @@ describe('ProviderRegistry', () => { const minStake = wei(100); await expect(providerRegistry.providerSetMinStake(minStake)) - .to.emit(providerRegistry, 'ProviderMinStakeUpdated') + .to.emit(providerRegistry, 'ProviderMinimumStakeUpdated') .withArgs(minStake); expect(await providerRegistry.getProviderMinimumStake()).eq(minStake); @@ -98,6 +98,8 @@ describe('ProviderRegistry', () => { expect(await token.balanceOf(providerRegistry)).to.eq(wei(100)); expect(await token.balanceOf(PROVIDER)).to.eq(wei(900)); + expect(await providerRegistry.getActiveProviders(0, 10)).to.deep.eq([PROVIDER.address]); + await providerRegistry.connect(PROVIDER).providerRegister(wei(0), 'test'); }); it('should add stake to existed provider', async () => { @@ -159,6 +161,8 @@ describe('ProviderRegistry', () => { expect(await providerRegistry.getIsProviderActive(PROVIDER)).to.eq(false); expect(await token.balanceOf(providerRegistry)).to.eq(0); expect(await token.balanceOf(PROVIDER)).to.eq(wei(1000)); + + expect(await providerRegistry.getActiveProviders(0, 10)).to.deep.eq([]); }); it('should deregister the provider without transfer', async () => { await providerRegistry.providerSetMinStake(0); From b2f1064b48cde3eb63032df43d02490f0a6202e0 Mon Sep 17 00:00:00 2001 From: Oleksandr Date: Mon, 21 Oct 2024 12:47:40 +0300 Subject: [PATCH 4/9] fix .enf and config file --- smart-contracts/.env.example | 4 +++- smart-contracts/hardhat.config.ts | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/smart-contracts/.env.example b/smart-contracts/.env.example index 54ead7ba..aa0dc9d2 100644 --- a/smart-contracts/.env.example +++ b/smart-contracts/.env.example @@ -1,5 +1,7 @@ COINMARKETCAP_API_KEY= # for coverage report to estimate deployment and calls price ETH_NODE_ADDRESS=https://arb-sepolia.g.alchemy.com/v2/SOME_API_KEY ETHERSCAN_API_KEY= # for coverage report to estimate deployment and calls price +ARBITRUM_API_KEY= # for coverage report to estimate deployment and calls price MOR_TOKEN_ADDRESS= # MOR token address -OWNER_PRIVATE_KEY= # contract owner private key \ No newline at end of file +OWNER_PRIVATE_KEY= # contract owner private key +PRIVATE_KEY= # deployer private key \ No newline at end of file diff --git a/smart-contracts/hardhat.config.ts b/smart-contracts/hardhat.config.ts index 256700e3..953250c8 100644 --- a/smart-contracts/hardhat.config.ts +++ b/smart-contracts/hardhat.config.ts @@ -92,8 +92,8 @@ const config: HardhatUserConfig = { }, etherscan: { apiKey: { - mainnet: `${process.env.ETHERSCAN_KEY}`, - arbitrumSepolia: `${process.env.ARBITRUM_KEY}`, + mainnet: `${process.env.ETHERSCAN_API_KEY}`, + arbitrumSepolia: `${process.env.ARBITRUM_API_KEY}`, }, }, }; From 48975bd4524d4693dccd37c1daf905701409be2d Mon Sep 17 00:00:00 2001 From: Oleksandr Date: Thu, 24 Oct 2024 12:59:56 +0300 Subject: [PATCH 5/9] add direct MOR payment for the session. Add docs to the interfaces --- .../contracts/diamond/facets/Marketplace.sol | 3 +- .../diamond/facets/SessionRouter.sol | 230 +++++++----- .../diamond/storages/ModelStorage.sol | 2 +- .../diamond/storages/SessionStorage.sol | 73 +--- .../interfaces/facets/IMarketplace.sol | 36 +- .../interfaces/facets/IModelRegistry.sol | 20 ++ .../interfaces/facets/IProviderRegistry.sol | 9 +- .../interfaces/facets/ISessionRouter.sol | 85 ++++- .../interfaces/storage/IBidStorage.sol | 48 ++- .../storage/IMarketplaceStorage.sol | 6 + .../interfaces/storage/IModelStorage.sol | 36 +- .../interfaces/storage/IProviderStorage.sol | 35 +- .../interfaces/storage/ISessionStorage.sol | 78 ++++- .../interfaces/storage/IStatsStorage.sol | 23 +- .../deploy/1_full_protocol.migration.ts | 3 +- .../test/diamond/facets/ModelRegistry.test.ts | 4 +- .../test/diamond/facets/SessionRouter.test.ts | 331 ++++++++++-------- .../diamond/facets/session-router.ts | 3 +- 18 files changed, 686 insertions(+), 339 deletions(-) diff --git a/smart-contracts/contracts/diamond/facets/Marketplace.sol b/smart-contracts/contracts/diamond/facets/Marketplace.sol index b25b2837..c3e20659 100644 --- a/smart-contracts/contracts/diamond/facets/Marketplace.sol +++ b/smart-contracts/contracts/diamond/facets/Marketplace.sol @@ -49,9 +49,8 @@ contract Marketplace is BidsStorage storage bidsStorage = getBidsStorage(); MarketStorage storage marketStorage = getMarketStorage(); - // TODO: check it IERC20(bidsStorage.token).safeTransferFrom(_msgSender(), address(this), marketStorage.bidFee); - marketStorage.feeBalance += marketStorage.bidFee; + marketStorage.feeBalance += marketStorage.bidFee; bytes32 providerModelId_ = getProviderModelId(provider_, modelId_); uint256 providerModelNonce_ = bidsStorage.providerModelNonce[providerModelId_]++; diff --git a/smart-contracts/contracts/diamond/facets/SessionRouter.sol b/smart-contracts/contracts/diamond/facets/SessionRouter.sol index 1b2cac61..e6510345 100644 --- a/smart-contracts/contracts/diamond/facets/SessionRouter.sol +++ b/smart-contracts/contracts/diamond/facets/SessionRouter.sol @@ -18,6 +18,8 @@ import {LibSD} from "../../libs/LibSD.sol"; import {ISessionRouter} from "../../interfaces/facets/ISessionRouter.sol"; +import "hardhat/console.sol"; + contract SessionRouter is ISessionRouter, OwnableDiamondStorage, @@ -33,10 +35,12 @@ contract SessionRouter is function __SessionRouter_init( address fundingAccount_, + uint128 maxSessionDuration_, Pool[] calldata pools_ ) external initializer(SESSIONS_STORAGE_SLOT) { SessionsStorage storage sessionsStorage = getSessionsStorage(); + setMaxSessionDuration(maxSessionDuration_); sessionsStorage.fundingAccount = fundingAccount_; for (uint256 i = 0; i < pools_.length; i++) { sessionsStorage.pools.push(pools_[i]); @@ -60,45 +64,40 @@ contract SessionRouter is getSessionsStorage().pools[index_] = pool_; } + function setMaxSessionDuration(uint128 maxSessionDuration_) public onlyOwner { + if (maxSessionDuration_ <= MIN_SESSION_DURATION) { + revert SessionMaxDurationTooShort(); + } + + getSessionsStorage().maxSessionDuration = maxSessionDuration_; + } + //////////////////////// /// OPEN SESSION /// //////////////////////// function openSession( uint256 amount_, + bool isDirectPaymentFromUser_, bytes calldata approvalEncoded_, bytes calldata signature_ ) external returns (bytes32) { - bytes32 bidId_ = _extractProviderApproval(approvalEncoded_); - if (!isBidActive(bidId_)) { - revert SessionBidNotFound(); - } - - BidsStorage storage bidsStorage = getBidsStorage(); SessionsStorage storage sessionsStorage = getSessionsStorage(); - Bid storage bid = bidsStorage.bids[bidId_]; - if (!_isValidProviderReceipt(bid.provider, approvalEncoded_, signature_)) { - revert SessionProviderSignatureMismatch(); - } - if (sessionsStorage.isProviderApprovalUsed[approvalEncoded_]) { - revert SessionDuplicateApproval(); - } + bytes32 bidId_ = _extractProviderApproval(approvalEncoded_); + Bid storage bid = getBidsStorage().bids[bidId_]; - uint128 endsAt_ = getSessionEnd(amount_, bid.pricePerSecond, uint128(block.timestamp)); bytes32 sessionId_ = getSessionId(_msgSender(), bid.provider, bidId_, sessionsStorage.sessionNonce++); - - if (endsAt_ - block.timestamp < MIN_SESSION_DURATION) { - revert SessionTooShort(); - } - Session storage session = sessionsStorage.sessions[sessionId_]; + uint128 endsAt_ = _validateSession(bidId_, amount_, isDirectPaymentFromUser_, approvalEncoded_, signature_); + session.user = _msgSender(); session.stake = amount_; session.bidId = bidId_; session.openedAt = uint128(block.timestamp); session.endsAt = endsAt_; session.isActive = true; + session.isDirectPaymentFromUser = isDirectPaymentFromUser_; sessionsStorage.userSessions[_msgSender()].add(sessionId_); sessionsStorage.providerSessions[bid.provider].add(sessionId_); @@ -106,12 +105,46 @@ contract SessionRouter is sessionsStorage.isProviderApprovalUsed[approvalEncoded_] = true; - IERC20(bidsStorage.token).safeTransferFrom(_msgSender(), address(this), amount_); + IERC20(getBidsStorage().token).safeTransferFrom(_msgSender(), address(this), amount_); emit SessionOpened(_msgSender(), sessionId_, bid.provider); return sessionId_; } + + function _validateSession( + bytes32 bidId_, + uint256 amount_, + bool isDirectPaymentFromUser_, + bytes calldata approvalEncoded_, + bytes calldata signature_ + ) private view returns (uint128) { + if (!isBidActive(bidId_)) { + revert SessionBidNotFound(); + } + + Bid storage bid = getBidsStorage().bids[bidId_]; + if (!_isValidProviderReceipt(bid.provider, approvalEncoded_, signature_)) { + revert SessionProviderSignatureMismatch(); + } + if (getSessionsStorage().isProviderApprovalUsed[approvalEncoded_]) { + revert SessionDuplicateApproval(); + } + + uint128 endsAt_ = getSessionEnd(amount_, bid.pricePerSecond, uint128(block.timestamp)); + uint128 duration_ = endsAt_ - uint128(block.timestamp); + + if (duration_ < MIN_SESSION_DURATION) { + revert SessionTooShort(); + } + + // This situation cannot be achieved in theory, but just in case, I'll leave it at that + if (isDirectPaymentFromUser_ && (duration_ * bid.pricePerSecond) > amount_) { + revert SessionStakeTooLow(); + } + + return endsAt_; + } function getSessionId( address user_, @@ -125,8 +158,8 @@ contract SessionRouter is function getSessionEnd(uint256 amount_, uint256 pricePerSecond_, uint128 openedAt_) public view returns (uint128) { uint128 duration_ = uint128(stakeToStipend(amount_, openedAt_) / pricePerSecond_); - if (duration_ > MAX_SESSION_DURATION) { - duration_ = MAX_SESSION_DURATION; + if (duration_ > getSessionsStorage().maxSessionDuration) { + duration_ = getSessionsStorage().maxSessionDuration; } return openedAt_ + duration_; @@ -135,6 +168,7 @@ contract SessionRouter is /** * @dev Returns stipend of user based on their stake * (User session stake amount / MOR Supply without Compute) * (MOR Compute Supply / 100) + * (User share) * (Rewards for all computes) */ function stakeToStipend(uint256 amount_, uint128 timestamp_) public view returns (uint256) { uint256 totalMorSupply_ = totalMORSupply(timestamp_); @@ -182,41 +216,86 @@ contract SessionRouter is session.closeoutReceipt = receiptEncoded_; // TODO: Remove that field in favor of tps and ttftMs session.closedAt = uint128(block.timestamp); - //// PROVIDER REWARDS - uint128 startOfToday_ = startOfTheDay(uint128(block.timestamp)); - // The session should be closed the day after the end of the session to prevent provider rewards locking - bool isClosingLate_ = startOfToday_ > startOfTheDay(session.endsAt); bool noDispute_ = _isValidProviderReceipt(bid.provider, receiptEncoded_, signature_); - uint128 duration_; - if (noDispute_ || isClosingLate_) { - // Session was closed without dispute or next day after it expected to end - duration_ = uint128(session.endsAt.min(session.closedAt)) - session.openedAt; - } else { - // Session was closed on the same day or earlier with dispute - // withdraw all funds except for today's session cost - duration_ = startOfToday_ - uint128(session.openedAt.min(uint256(startOfToday_))); + _rewardUserAfterClose(session, bid); + _rewardProviderAfterClose(noDispute_, session, bid); + _setStats(noDispute_, ttftMs_, tpsScaled1000_, session, bid); + + emit SessionClosed(session.user, sessionId_, bid.provider); + } + + function _extractProviderReceipt(bytes calldata receiptEncoded_) private view returns (bytes32, uint32, uint32) { + (bytes32 sessionId_, uint256 chainId_, uint128 timestamp_, uint32 tpsScaled1000_, uint32 ttftMs_) = abi.decode( + receiptEncoded_, + (bytes32, uint256, uint128, uint32, uint32) + ); + + if (chainId_ != block.chainid) { + revert SesssionReceiptForAnotherChainId(); + } + if (block.timestamp > timestamp_ + SIGNATURE_TTL) { + revert SesssionReceiptExpired(); } - uint256 providerAmountToWithdraw_ = (duration_ * bid.pricePerSecond) - session.providerWithdrawnAmount; + + return (sessionId_, tpsScaled1000_, ttftMs_); + } + + function _getProviderRewards(Session storage session, Bid storage bid, bool isIncludeWithdrawnAmount_) private view returns (uint256) { + uint256 sessionEnd_ = session.closedAt == 0 ? session.endsAt : session.closedAt.min(session.endsAt); + if (block.timestamp < sessionEnd_) { + return 0; + } + + uint256 withdrawnAmount = isIncludeWithdrawnAmount_ ? session.providerWithdrawnAmount : 0; + + return (sessionEnd_ - session.openedAt) * bid.pricePerSecond - withdrawnAmount; + } + + function _rewardProviderAfterClose( + bool noDispute_, + Session storage session, + Bid storage bid + ) internal { + uint128 startOfToday_ = startOfTheDay(uint128(block.timestamp)); + bool isClosingLate_ = uint128(block.timestamp) > session.endsAt; + + uint256 providerAmountToWithdraw_ = _getProviderRewards(session, bid, true); + uint256 providerOnHoldAmount = 0; + if (!noDispute_ && !isClosingLate_) { + providerOnHoldAmount = (session.endsAt.min(session.closedAt) - startOfToday_.max(session.openedAt)) * bid.pricePerSecond; + } + providerAmountToWithdraw_ -= providerOnHoldAmount; + _claimForProvider(session, providerAmountToWithdraw_); - //// END + } + + function _rewardUserAfterClose(Session storage session, Bid storage bid) private { + uint128 startOfToday_ = startOfTheDay(uint128(block.timestamp)); + bool isClosingLate_ = uint128(block.timestamp) > session.endsAt; - //// USER REWARDS - // We have to lock today's stake so the user won't get the reward twice + uint256 userStakeToProvider = session.isDirectPaymentFromUser ? _getProviderRewards(session, bid, false) : 0; + uint256 userStake = session.stake - userStakeToProvider; uint256 userStakeToLock_ = 0; if (!isClosingLate_) { // Session was closed on the same day, lock today's stake uint256 userDuration_ = session.endsAt.min(session.closedAt) - session.openedAt.max(startOfToday_); uint256 userInitialLock_ = userDuration_ * bid.pricePerSecond; - userStakeToLock_ = session.stake.min(stipendToStake(userInitialLock_, startOfToday_)); + userStakeToLock_ = userStake.min(stipendToStake(userInitialLock_, startOfToday_)); getSessionsStorage().userStakesOnHold[session.user].push(OnHold(userStakeToLock_, uint128(startOfToday_ + 1 days))); } - uint256 userAmountToWithdraw_ = session.stake - userStakeToLock_; + uint256 userAmountToWithdraw_ = userStake - userStakeToLock_; IERC20(getBidsStorage().token).safeTransfer(session.user, userAmountToWithdraw_); - //// END + } - //// STATS + function _setStats( + bool noDispute_, + uint32 ttftMs_, + uint32 tpsScaled1000_, + Session storage session, + Bid storage bid + ) internal { ProviderModelStats storage prStats = providerModelStats(bid.modelId, bid.provider); ModelStats storage modelStats = modelStats(bid.modelId); @@ -245,25 +324,6 @@ contract SessionRouter is } else { session.closeoutType = 1; } - //// END - - emit SessionClosed(session.user, sessionId_, bid.provider); - } - - function _extractProviderReceipt(bytes calldata receiptEncoded_) private view returns (bytes32, uint32, uint32) { - (bytes32 sessionId_, uint256 chainId_, uint128 timestamp_, uint32 tpsScaled1000_, uint32 ttftMs_) = abi.decode( - receiptEncoded_, - (bytes32, uint256, uint128, uint32, uint32) - ); - - if (chainId_ != block.chainid) { - revert SesssionReceiptForAnotherChainId(); - } - if (block.timestamp > timestamp_ + SIGNATURE_TTL) { - revert SesssionReceiptExpired(); - } - - return (sessionId_, tpsScaled1000_, ttftMs_); } /** @@ -274,15 +334,7 @@ contract SessionRouter is Bid storage bid = getBidsStorage().bids[session.bidId]; _onlyAccount(bid.provider); - - uint256 sessionEnd_ = session.closedAt == 0 ? session.endsAt : session.closedAt; - if (sessionEnd_ > block.timestamp) { - revert SessionNotEndedOrNotExist(); - } - - uint256 amount_ = (sessionEnd_ - session.openedAt) * bid.pricePerSecond - session.providerWithdrawnAmount; - - _claimForProvider(session, amount_); + _claimForProvider(session, _getProviderRewards(session, bid, true)); } /** @@ -291,12 +343,8 @@ contract SessionRouter is * @param amount_ Amount of reward to send */ function _claimForProvider(Session storage session, uint256 amount_) private { - SessionsStorage storage sessionsStorage = getSessionsStorage(); - BidsStorage storage bidsStorage = getBidsStorage(); - PovidersStorage storage providersStorage = getProvidersStorage(); - - Bid storage bid = bidsStorage.bids[session.bidId]; - Provider storage provider = providersStorage.providers[bid.provider]; + Bid storage bid = getBidsStorage().bids[session.bidId]; + Provider storage provider = getProvidersStorage().providers[bid.provider]; if (block.timestamp > provider.limitPeriodEnd) { provider.limitPeriodEnd = uint128(block.timestamp) + PROVIDER_REWARD_LIMITER_PERIOD; @@ -306,16 +354,19 @@ contract SessionRouter is uint256 providerClaimLimit_ = provider.stake - provider.limitPeriodEarned; amount_ = amount_.min(providerClaimLimit_); - if (amount_ == 0) { return; } session.providerWithdrawnAmount += amount_; provider.limitPeriodEarned += amount_; - sessionsStorage.providersTotalClaimed += amount_; + getSessionsStorage().providersTotalClaimed += amount_; - IERC20(bidsStorage.token).safeTransferFrom(sessionsStorage.fundingAccount, bid.provider, amount_); + if (session.isDirectPaymentFromUser) { + IERC20(getBidsStorage().token).safeTransfer(bid.provider, amount_); + } else { + IERC20(getBidsStorage().token).safeTransferFrom(getSessionsStorage().fundingAccount, bid.provider, amount_); + } } /** @@ -347,22 +398,25 @@ contract SessionRouter is uint256 amount_ = 0; OnHold[] storage onHoldEntries = getSessionsStorage().userStakesOnHold[_msgSender()]; - uint8 i = iterations_ >= onHoldEntries.length ? uint8(onHoldEntries.length) : iterations_; - i--; + uint8 i_ = iterations_ >= onHoldEntries.length ? uint8(onHoldEntries.length) : iterations_; + if (i_ == 0) { + revert SessionUserAmountToWithdrawIsZero(); + } + i_--; - while (i >= 0) { - if (block.timestamp < onHoldEntries[i].releaseAt) { - if (i == 0) break; - i--; + while (i_ >= 0) { + if (block.timestamp < onHoldEntries[i_].releaseAt) { + if (i_ == 0) break; + i_--; continue; } - amount_ += onHoldEntries[i].amount; + amount_ += onHoldEntries[i_].amount; onHoldEntries.pop(); - if (i == 0) break; - i--; + if (i_ == 0) break; + i_--; } if (amount_ == 0) { diff --git a/smart-contracts/contracts/diamond/storages/ModelStorage.sol b/smart-contracts/contracts/diamond/storages/ModelStorage.sol index 344d3392..42b5f2f9 100644 --- a/smart-contracts/contracts/diamond/storages/ModelStorage.sol +++ b/smart-contracts/contracts/diamond/storages/ModelStorage.sol @@ -33,7 +33,7 @@ contract ModelStorage is IModelStorage { return getModelsStorage().modelMinimumStake; } - function getActiveModels(uint256 offset_, uint256 limit_) external view returns (bytes32[] memory) { + function getActiveModelIds(uint256 offset_, uint256 limit_) external view returns (bytes32[] memory) { return getModelsStorage().activeModels.part(offset_, limit_); } diff --git a/smart-contracts/contracts/diamond/storages/SessionStorage.sol b/smart-contracts/contracts/diamond/storages/SessionStorage.sol index 0823570e..fa9817cc 100644 --- a/smart-contracts/contracts/diamond/storages/SessionStorage.sol +++ b/smart-contracts/contracts/diamond/storages/SessionStorage.sol @@ -20,6 +20,8 @@ contract SessionStorage is ISessionStorage { // Used to generate unique session ID uint256 sessionNonce; mapping(bytes32 sessionId => Session) sessions; + // Max ession duration + uint128 maxSessionDuration; // Session registry for providers, users and models mapping(address user => EnumerableSet.Bytes32Set) userSessions; mapping(address provider => EnumerableSet.Bytes32Set) providerSessions; @@ -30,7 +32,6 @@ contract SessionStorage is ISessionStorage { bytes32 public constant SESSIONS_STORAGE_SLOT = keccak256("diamond.standard.sessions.storage"); uint32 public constant MIN_SESSION_DURATION = 5 minutes; - uint32 public constant MAX_SESSION_DURATION = 1 days; uint32 public constant SIGNATURE_TTL = 10 minutes; uint256 public constant COMPUTE_POOL_INDEX = 3; @@ -71,8 +72,8 @@ contract SessionStorage is ISessionStorage { return getSessionsStorage().fundingAccount; } - function getTotalSessions(address providerAddr_) public view returns (uint256) { - return getSessionsStorage().providerSessions[providerAddr_].length(); + function getTotalSessions(address provider_) public view returns (uint256) { + return getSessionsStorage().providerSessions[provider_].length(); } function getProvidersTotalClaimed() external view returns (uint256) { @@ -82,68 +83,10 @@ contract SessionStorage is ISessionStorage { function getIsProviderApprovalUsed(bytes memory approval_) external view returns (bool) { return getSessionsStorage().isProviderApprovalUsed[approval_]; } - - // /** INTERNAL, GETTERS */ - // function pools() internal view returns (Pool[] storage) { - // return _getSessionStorage().pools; - // } - - // function pool(uint256 poolIndex_) internal view returns (Pool storage) { - // return _getSessionStorage().pools[poolIndex_]; - // } - - // function userStakesOnHold(address user_) internal view returns (OnHold[] storage) { - // return _getSessionStorage().userStakesOnHold[user_]; - // } - - // function sessions(bytes32 sessionId_) internal view returns (Session storage) { - // return _getSessionStorage().sessions[sessionId_]; - // } - - /** INTERNAL, SETTERS */ - // function setFundingAccount(address fundingAccount_) internal { - // _getSessionStorage().fundingAccount = fundingAccount_; - // } - - // function setPools(Pool[] calldata pools_) internal { - // SNStorage storage s = _getSessionStorage(); - - // for (uint256 i = 0; i < pools_.length; i++) { - // s.pools.push(pools_[i]); - // } - // } - - // function setPool(uint256 index_, Pool calldata pool_) internal { - // _getSessionStorage().pools[index_] = pool_; - // } - - // function addUserSessionId(address user_, bytes32 sessionId_) internal { - // _getSessionStorage().userSessions[user_].add(sessionId_); - // } - - // function addProviderSessionId(address provider_, bytes32 sessionId_) internal { - // _getSessionStorage().providerSessions[provider_].add(sessionId_); - // } - - // function addModelSessionId(bytes32 modelId, bytes32 sessionId) internal { - // _getSessionStorage().modelSessions[modelId].add(sessionId); - // } - - // function addUserStakeOnHold(address user, OnHold memory onHold) internal { - // _getSessionStorage().userStakesOnHold[user].push(onHold); - // } - - // function increaseProvidersTotalClaimed(uint256 amount) internal { - // _getSessionStorage().providersTotalClaimed += amount; - // } - - // function incrementSessionNonce() internal returns (uint256) { - // return _getSessionStorage().sessionNonce++; - // } - - // function setIsProviderApprovalUsed(bytes memory approval_, bool isUsed_) internal { - // _getSessionStorage().isProviderApprovalUsed[approval_] = isUsed_; - // } + + function getMaxSessionDuration() external view returns (uint128) { + return getSessionsStorage().maxSessionDuration; + } /** INTERNAL */ function getSessionsStorage() internal pure returns (SessionsStorage storage ds) { diff --git a/smart-contracts/contracts/interfaces/facets/IMarketplace.sol b/smart-contracts/contracts/interfaces/facets/IMarketplace.sol index ab4ab154..e24723d6 100644 --- a/smart-contracts/contracts/interfaces/facets/IMarketplace.sol +++ b/smart-contracts/contracts/interfaces/facets/IMarketplace.sol @@ -4,25 +4,57 @@ pragma solidity ^0.8.24; import {IMarketplaceStorage} from "../storage/IMarketplaceStorage.sol"; interface IMarketplace is IMarketplaceStorage { + event MaretplaceFeeUpdated(uint256 bidFee); event MarketplaceBidPosted(address indexed provider, bytes32 indexed modelId, uint256 nonce); event MarketplaceBidDeleted(address indexed provider, bytes32 indexed modelId, uint256 nonce); - event MaretplaceFeeUpdated(uint256 bidFee); - error MarketplaceProviderNotFound(); error MarketplaceModelNotFound(); error MarketplaceActiveBidNotFound(); + /** + * The function to initialize the facet. + * @param token_ Stake token (MOR) + */ function __Marketplace_init(address token_) external; + /** + * The function to set the bidFee. + * @param bidFee_ Amount of tokens + */ function setMarketplaceBidFee(uint256 bidFee_) external; + /** + * The function to create the bid. + * @param modelId_ The mode ID + * @param pricePerSecond_ The price per second + */ function postModelBid(bytes32 modelId_, uint256 pricePerSecond_) external returns (bytes32); + /** + * The function to delete the bid. + * @param bidId_ The bid ID + */ function deleteModelBid(bytes32 bidId_) external; + /** + * The function to withdraw the stake amount. + * @param recipient_ The recipient address. + * @param amount_ The amount. + */ function withdraw(address recipient_, uint256 amount_) external; + /** + * The function to get bid ID. + * @param provider_ The provider address. + * @param modelId_ The model ID. + * @param nonce_ The nonce. + */ function getBidId(address provider_, bytes32 modelId_, uint256 nonce_) external view returns (bytes32); + /** + * The function to returns provider model ID + * @param provider_ The provider address. + * @param modelId_ The model ID. + */ function getProviderModelId(address provider_, bytes32 modelId_) external view returns (bytes32); } diff --git a/smart-contracts/contracts/interfaces/facets/IModelRegistry.sol b/smart-contracts/contracts/interfaces/facets/IModelRegistry.sol index 5f2e1f53..c842a04d 100644 --- a/smart-contracts/contracts/interfaces/facets/IModelRegistry.sol +++ b/smart-contracts/contracts/interfaces/facets/IModelRegistry.sol @@ -12,10 +12,26 @@ interface IModelRegistry is IModelStorage { error ModelNotFound(); error ModelHasActiveBids(); + /** + * The function to initialize the facet. + */ function __ModelRegistry_init() external; + /** + * The function to set the minimal stake for models. + * @param modelMinimumStake_ Amount of tokens + */ function modelSetMinStake(uint256 modelMinimumStake_) external; + /** + * The function to register the model. + * @param modelId_ The model ID. + * @param ipfsCID_ The model IPFS CID. + * @param fee_ The model fee. + * @param amount_ The model stake amount. + * @param name_ The model name. + * @param tags_ The model tags. + */ function modelRegister( bytes32 modelId_, bytes32 ipfsCID_, @@ -25,5 +41,9 @@ interface IModelRegistry is IModelStorage { string[] calldata tags_ ) external; + /** + * The function to deregister the model. + * @param modelId_ The model ID. + */ function modelDeregister(bytes32 modelId_) external; } diff --git a/smart-contracts/contracts/interfaces/facets/IProviderRegistry.sol b/smart-contracts/contracts/interfaces/facets/IProviderRegistry.sol index 20cf7d84..3c05429d 100644 --- a/smart-contracts/contracts/interfaces/facets/IProviderRegistry.sol +++ b/smart-contracts/contracts/interfaces/facets/IProviderRegistry.sol @@ -16,23 +16,26 @@ interface IProviderRegistry is IProviderStorage { error ProviderNotFound(); error ProviderHasAlreadyDeregistered(); + /** + * The function to initialize the facet. + */ function __ProviderRegistry_init() external; /** - * @notice Sets the minimum stake required for a provider + * @notice The function to the minimum stake required for a provider * @param providerMinimumStake_ The minimal stake */ function providerSetMinStake(uint256 providerMinimumStake_) external; /** - * @notice Register a provider + * @notice The function to register the provider * @param amount_ The amount of stake to add * @param endpoint_ The provider endpoint (host.com:1234) */ function providerRegister(uint256 amount_, string calldata endpoint_) external; /** - * @notice Deregister a provider + * @notice The function to deregister the provider */ function providerDeregister() external; } diff --git a/smart-contracts/contracts/interfaces/facets/ISessionRouter.sol b/smart-contracts/contracts/interfaces/facets/ISessionRouter.sol index 9a342996..48bed3dc 100644 --- a/smart-contracts/contracts/interfaces/facets/ISessionRouter.sol +++ b/smart-contracts/contracts/interfaces/facets/ISessionRouter.sol @@ -13,7 +13,7 @@ interface ISessionRouter is ISessionStorage { error SesssionApproveExpired(); error SesssionApprovedForAnotherChainId(); error SessionDuplicateApproval(); - error SessionApprovedForAnotherUser(); // Means that approval generated for another user address, protection from front-running + error SessionApprovedForAnotherUser(); error SesssionReceiptForAnotherChainId(); error SesssionReceiptExpired(); error SessionTooShort(); @@ -23,23 +23,54 @@ interface ISessionRouter is ISessionStorage { error SessionBidNotFound(); error SessionPoolIndexOutOfBounds(); error SessionUserAmountToWithdrawIsZero(); + error SessionMaxDurationTooShort(); + error SessionStakeTooLow(); - function __SessionRouter_init(address fundingAccount_, Pool[] calldata pools_) external; + /** + * The function to initialize the facet. + * @param fundingAccount_ The funding address (treaasury) + * @param maxSessionDuration_ The max session duration + * @param pools_ The pools data + */ + function __SessionRouter_init(address fundingAccount_, uint128 maxSessionDuration_, Pool[] calldata pools_) external; /** * @notice Sets distibution pool configuration * @dev parameters should be the same as in Ethereum L1 Distribution contract * @dev at address 0x47176B2Af9885dC6C4575d4eFd63895f7Aaa4790 - * @dev call 'Distribution.pools(3)' where '3' is a poolId + * @dev call 'Distribution.pools(3)' where '3' is a poolId. + * @param index_ The pool index. + * @param pool_ The pool data. */ function setPoolConfig(uint256 index_, Pool calldata pool_) external; + /** + * The function to set the max session duration. + * @param maxSessionDuration_ The max session duration. + */ + function setMaxSessionDuration(uint128 maxSessionDuration_) external; + + /** + * The function to open the session. + * @param amount_ The stake amount. + * @param isDirectPaymentFromUser_ If active, provider rewarded from the user stake. + * @param approvalEncoded_ Provider approval. + * @param signature_ Provider signature. + */ function openSession( uint256 amount_, + bool isDirectPaymentFromUser_, bytes calldata approvalEncoded_, bytes calldata signature_ ) external returns (bytes32); + /** + * The function to get session ID/ + * @param user_ The user address. + * @param provider_ The provider address. + * @param bidId_ The bid ID. + * @param sessionNonce_ The session nounce. + */ function getSessionId( address user_, address provider_, @@ -47,48 +78,82 @@ interface ISessionRouter is ISessionStorage { uint256 sessionNonce_ ) external pure returns (bytes32); + /** + * The function to returns the session end timestamp + * @param amount_ The stake amount. + * @param pricePerSecond_ The price per second. + * @param openedAt_ The opened at timestamp. + */ function getSessionEnd(uint256 amount_, uint256 pricePerSecond_, uint128 openedAt_) external view returns (uint128); /** - * @dev Returns stipend of user based on their stake + * Returns stipend of user based on their stake * (User session stake amount / MOR Supply without Compute) * (MOR Compute Supply / 100) + * @param amount_ The amount of tokens. + * @param timestamp_ The timestamp when the TX executes. */ function stakeToStipend(uint256 amount_, uint128 timestamp_) external view returns (uint256); + /** + * The function to close session. + * @param receiptEncoded_ Provider receipt + * @param signature_ Provider signature + */ function closeSession(bytes calldata receiptEncoded_, bytes calldata signature_) external; /** - * @dev Allows providers to receive their funds after the end or closure of the session + * Allows providers to receive their funds after the end or closure of the session. + * @param sessionId_ The session ID. */ function claimForProvider(bytes32 sessionId_) external; /** - * @notice Returns stake of user based on their stipend + * Returns stake of user based on their stipend. + * @param stipend_ The stake amount. + * @param timestamp_ The timestamp when the TX executed. */ function stipendToStake(uint256 stipend_, uint128 timestamp_) external view returns (uint256); + /** + * The function to return available and locked amount of the user tokens. + * @param user_ The user address. + * @param iterations_ The loop interaction amount. + * @return available_ The available to withdraw. + * @return hold_ The locked amount. + */ function getUserStakesOnHold( address user_, uint8 iterations_ ) external view returns (uint256 available_, uint256 hold_); + /** + * The function to withdraw user stakes. + * @param iterations_ The loop interaction amount. + */ function withdrawUserStakes(uint8 iterations_) external; /** - * @dev Returns today's budget in MOR. 1% + * Returns today's budget in MOR. 1%. + * @param timestamp_ The timestamp when the TX executed. */ function getTodaysBudget(uint128 timestamp_) external view returns (uint256); /** - * @dev Returns today's compute balance in MOR without claimed amount + * Returns today's compute balance in MOR without claimed amount. + * @param timestamp_ The timestamp when the TX executed. */ function getComputeBalance(uint128 timestamp_) external view returns (uint256); /** - * @dev Total amount of MOR tokens that were distributed across all pools - * without compute pool rewards and with compute claimed rewards + * Total amount of MOR tokens that were distributed across all pools + * without compute pool rewards and with compute claimed rewards. + * @param timestamp_ The timestamp when the TX executed. */ function totalMORSupply(uint128 timestamp_) external view returns (uint256); + /** + * The function to return the timestamp on the start of day + * @param timestamp_ The timestamp when the TX executed. + */ function startOfTheDay(uint128 timestamp_) external pure returns (uint128); } diff --git a/smart-contracts/contracts/interfaces/storage/IBidStorage.sol b/smart-contracts/contracts/interfaces/storage/IBidStorage.sol index 3815f663..559d100d 100644 --- a/smart-contracts/contracts/interfaces/storage/IBidStorage.sol +++ b/smart-contracts/contracts/interfaces/storage/IBidStorage.sol @@ -2,38 +2,82 @@ pragma solidity ^0.8.24; interface IBidStorage { + /** + * The structure that stores the bid data. + * @param provider The provider addres. + * @param modelId The model ID. + * @param pricePerSecond The price per second. + * @param nonce The bid creates with this nounce (related to provider nounce). + * @param createdAt The timestamp when the bid is created. + * @param deletedAt The timestamp when the bid is deleted. + */ struct Bid { address provider; bytes32 modelId; - uint256 pricePerSecond; // Hourly price + uint256 pricePerSecond; uint256 nonce; uint128 createdAt; uint128 deletedAt; } - + + /** + * The function returns the bid structure. + * @param bidId_ Bid ID. + */ function getBid(bytes32 bidId_) external view returns (Bid memory); + /** + * The function returns active provider bids. + * @param provider_ Provider address. + * @param offset_ Offset for the pagination. + * @param limit_ Number of entities to return. + */ function getProviderActiveBids( address provider_, uint256 offset_, uint256 limit_ ) external view returns (bytes32[] memory); + /** + * The function returns active model bids. + * @param modelId_ Model ID. + * @param offset_ Offset for the pagination. + * @param limit_ Number of entities to return. + */ function getModelActiveBids( bytes32 modelId_, uint256 offset_, uint256 limit_ ) external view returns (bytes32[] memory); + /** + * The function returns provider bids. + * @param provider_ Provider address. + * @param offset_ Offset for the pagination. + * @param limit_ Number of entities to return. + */ function getProviderBids( address provider_, uint256 offset_, uint256 limit_ ) external view returns (bytes32[] memory); + /** + * The function returns model bids. + * @param modelId_ Model ID. + * @param offset_ Offset for the pagination. + * @param limit_ Number of entities to return. + */ function getModelBids(bytes32 modelId_, uint256 offset_, uint256 limit_) external view returns (bytes32[] memory); + /** + * The function returns stake token (MOR). + */ function getToken() external view returns (address); + /** + * The function returns bid status, active or not. + * @param bidId_ Bid ID. + */ function isBidActive(bytes32 bidId_) external view returns (bool); } diff --git a/smart-contracts/contracts/interfaces/storage/IMarketplaceStorage.sol b/smart-contracts/contracts/interfaces/storage/IMarketplaceStorage.sol index 197000ee..36ca7c70 100644 --- a/smart-contracts/contracts/interfaces/storage/IMarketplaceStorage.sol +++ b/smart-contracts/contracts/interfaces/storage/IMarketplaceStorage.sol @@ -2,7 +2,13 @@ pragma solidity ^0.8.24; interface IMarketplaceStorage { + /** + * The function returns bid fee on creation. + */ function getBidFee() external view returns (uint256); + /** + * The function returns fee balance. + */ function getFeeBalance() external view returns (uint256); } diff --git a/smart-contracts/contracts/interfaces/storage/IModelStorage.sol b/smart-contracts/contracts/interfaces/storage/IModelStorage.sol index 92085d3a..64529112 100644 --- a/smart-contracts/contracts/interfaces/storage/IModelStorage.sol +++ b/smart-contracts/contracts/interfaces/storage/IModelStorage.sol @@ -2,8 +2,19 @@ pragma solidity ^0.8.24; interface IModelStorage { + /** + * The structure that stores the model data. + * @param ipfsCID https://docs.ipfs.tech/concepts/content-addressing/#what-is-a-cid. Up to the model maintainer to keep up to date. + * @param fee The model fee. Readonly for now. + * @param stake The stake amount. + * @param owner The owner. + * @param name The name. Readonly for now. + * @param tags The tags. Readonly for now. + * @param createdAt The timestamp when the model is created. + * @param isDeleted The model status. + */ struct Model { - bytes32 ipfsCID; // https://docs.ipfs.tech/concepts/content-addressing/#what-is-a-cid. Up to the model maintainer to keep up to date + bytes32 ipfsCID; uint256 fee; // The fee is a royalty placeholder that isn't currently used uint256 stake; address owner; @@ -13,13 +24,34 @@ interface IModelStorage { bool isDeleted; } + /** + * The function returns the model structure. + * @param modelId_ Model ID. + */ function getModel(bytes32 modelId_) external view returns (Model memory); + /** + * The function returns the model IDs. + * @param offset_ Offset for the pagination. + * @param limit_ Number of entities to return. + */ function getModelIds(uint256 offset_, uint256 limit_) external view returns (bytes32[] memory); + /** + * The function returns the model minimal stake. + */ function getModelMinimumStake() external view returns (uint256); - function getActiveModels(uint256 offset_, uint256 limit_) external view returns (bytes32[] memory); + /** + * The function returns active model IDs. + * @param offset_ Offset for the pagination. + * @param limit_ Number of entities to return. + */ + function getActiveModelIds(uint256 offset_, uint256 limit_) external view returns (bytes32[] memory); + /** + * The function returns the model status, active or not. + * @param modelId_ Model ID. + */ function getIsModelActive(bytes32 modelId_) external view returns (bool); } diff --git a/smart-contracts/contracts/interfaces/storage/IProviderStorage.sol b/smart-contracts/contracts/interfaces/storage/IProviderStorage.sol index 3121101f..74fc8e1d 100644 --- a/smart-contracts/contracts/interfaces/storage/IProviderStorage.sol +++ b/smart-contracts/contracts/interfaces/storage/IProviderStorage.sol @@ -2,20 +2,45 @@ pragma solidity ^0.8.24; interface IProviderStorage { + /** + * The structure that stores the provider data. + * @param endpoint Example 'domain.com:1234'. Readonly for now. + * @param stake The stake amount. + * @param createdAt The timestamp when the provider is created. + * @param limitPeriodEnd Timestamp that indicate limit period end for provider rewards. + * @param limitPeriodEarned The amount of tokens that provider can receive before `limitPeriodEnd`. + * @param isDeleted The provider status. + */ struct Provider { - string endpoint; // Example 'domain.com:1234' - uint256 stake; // Stake amount, which also server as a reward limiter - uint128 createdAt; // Timestamp of the registration - uint128 limitPeriodEnd; // Timestamp of the limiter period end - uint256 limitPeriodEarned; // Total earned during the last limiter period + string endpoint; + uint256 stake; + uint128 createdAt; + uint128 limitPeriodEnd; + uint256 limitPeriodEarned; bool isDeleted; } + /** + * The function returns provider structure. + * @param provider_ Provider address + */ function getProvider(address provider_) external view returns (Provider memory); + /** + * The function returns provider minimal stake. + */ function getProviderMinimumStake() external view returns (uint256); + /** + * The function returns list of active providers. + * @param offset_ Offset for the pagination. + * @param limit_ Number of entities to return. + */ function getActiveProviders(uint256 offset_, uint256 limit_) external view returns (address[] memory); + /** + * The function returns provider status, active or not. + * @param provider_ Provider address + */ function getIsProviderActive(address provider_) external view returns (bool); } diff --git a/smart-contracts/contracts/interfaces/storage/ISessionStorage.sol b/smart-contracts/contracts/interfaces/storage/ISessionStorage.sol index 62810691..b2eb97ed 100644 --- a/smart-contracts/contracts/interfaces/storage/ISessionStorage.sol +++ b/smart-contracts/contracts/interfaces/storage/ISessionStorage.sol @@ -2,28 +2,48 @@ pragma solidity ^0.8.24; interface ISessionStorage { + /** + * The structure that stores the bid data. + * @param user The user address. User opens the session + * @param bidId The bid ID. + * @param stake The stake amount. + * @param closeoutReceipt The receipt from provider when session closed. + * @param closeoutType The closeout type. + * @param providerWithdrawnAmount Provider withdrawn amount fot this session. + * @param openedAt The timestamp when the session is opened. Setted on creation. + * @param endsAt The timestamp when the session is ends. Setted on creation. + * @param closedAt The timestamp when the session is closed. Setted on close. + * @param isActive The session status. + * @param isDirectPaymentFromUser If active, user pay for provider from his stake. + */ struct Session { address user; bytes32 bidId; uint256 stake; bytes closeoutReceipt; - // TODO: Use enum? uint256 closeoutType; - // Amount of funds that was already withdrawn by provider (we allow to withdraw for the previous day) uint256 providerWithdrawnAmount; uint128 openedAt; - // Expected end time considering the stake provided uint128 endsAt; uint128 closedAt; bool isActive; + bool isDirectPaymentFromUser; } + /** + * The structure that stores information about locked user funds. + * @param amount The locked amount. + * @param releaseAt The timestampl when funds will be available. + */ struct OnHold { uint256 amount; // In epoch seconds. TODO: consider using hours to reduce storage cost uint128 releaseAt; } + /** + * The structure that stores the Pool data. Should be the same with 0x47176B2Af9885dC6C4575d4eFd63895f7Aaa4790 on the Eth mainnet + */ struct Pool { uint256 initialReward; uint256 rewardDecrease; @@ -31,31 +51,79 @@ interface ISessionStorage { uint128 decreaseInterval; } + /** + * The function returns the session structure. + * @param sessionId_ Session ID + */ function getSession(bytes32 sessionId_) external view returns (Session memory); - function getUserSessions(address user, uint256 offset_, uint256 limit_) external view returns (bytes32[] memory); + /** + * The function returns the user session IDs. + * @param user_ The user address + * @param offset_ Offset for the pagination. + * @param limit_ Number of entities to return. + */ + function getUserSessions(address user_, uint256 offset_, uint256 limit_) external view returns (bytes32[] memory); + /** + * The function returns the provider session IDs. + * @param provider_ The provider address + * @param offset_ Offset for the pagination. + * @param limit_ Number of entities to return. + */ function getProviderSessions( address provider_, uint256 offset_, uint256 limit_ ) external view returns (bytes32[] memory); + /** + * The function returns the model session IDs. + * @param modelId_ The model ID + * @param offset_ Offset for the pagination. + * @param limit_ Number of entities to return. + */ function getModelSessions( bytes32 modelId_, uint256 offset_, uint256 limit_ ) external view returns (bytes32[] memory); + /** + * The function returns the pools info. + */ function getPools() external view returns (Pool[] memory); + /** + * The function returns the pools info. + * @param index_ Pool index + */ function getPool(uint256 index_) external view returns (Pool memory); + /** + * The function returns the funcding (treasury) address for providers payments. + */ function getFundingAccount() external view returns (address); - function getTotalSessions(address providerAddr_) external view returns (uint256); + /** + * The function returns total amount of sessions for the provider. + * @param provider_ Provider address + */ + function getTotalSessions(address provider_) external view returns (uint256); + /** + * The function returns total amount of claimed token by providers. + */ function getProvidersTotalClaimed() external view returns (uint256); + /** + * Check the approval for usage. + * @param approval_ Approval from provider + */ function getIsProviderApprovalUsed(bytes memory approval_) external view returns (bool); + + /** + * The function returns max session duration. + */ + function getMaxSessionDuration() external view returns (uint128); } diff --git a/smart-contracts/contracts/interfaces/storage/IStatsStorage.sol b/smart-contracts/contracts/interfaces/storage/IStatsStorage.sol index 1eeb6bde..6d2b15d0 100644 --- a/smart-contracts/contracts/interfaces/storage/IStatsStorage.sol +++ b/smart-contracts/contracts/interfaces/storage/IStatsStorage.sol @@ -11,19 +11,34 @@ interface IStatsStorage { uint32 count; } + /** + * The structure that stores the provider model stats. + * @param tpsScaled1000 Tokens per second running average + * @param ttftMs Time to first token running average in milliseconds + * @param totalDuration Total duration of sessions + * @param successCount Number of observations + * @param totalCount + */ struct ProviderModelStats { - LibSD.SD tpsScaled1000; // Tokens per second running average - LibSD.SD ttftMs; // Time to first token running average in milliseconds - uint32 totalDuration; // Total duration of sessions - uint32 successCount; // Number of observations + LibSD.SD tpsScaled1000; + LibSD.SD ttftMs; + uint32 totalDuration; + uint32 successCount; uint32 totalCount; // TODO: consider adding SD with weldford algorithm } + /** + * @param modelId_ The model ID. + * @param provider_ The provider address. + */ function getProviderModelStats( bytes32 modelId_, address provider_ ) external view returns (ProviderModelStats memory); + /** + * @param modelId_ The model ID. + */ function getModelStats(bytes32 modelId_) external view returns (ModelStats memory); } diff --git a/smart-contracts/deploy/1_full_protocol.migration.ts b/smart-contracts/deploy/1_full_protocol.migration.ts index 01d77cc9..55bb994d 100644 --- a/smart-contracts/deploy/1_full_protocol.migration.ts +++ b/smart-contracts/deploy/1_full_protocol.migration.ts @@ -22,6 +22,7 @@ import { SessionRouter__factory, } from '@/generated-types/ethers'; import { FacetAction } from '@/test/helpers/deployers'; +import { DAY } from '@/utils/time'; module.exports = async function (deployer: Deployer) { const config = parseConfig(); @@ -92,7 +93,7 @@ module.exports = async function (deployer: Deployer) { marketplaceFacet = marketplaceFacet.attach(lumerinDiamond.target) as Marketplace; await marketplaceFacet.__Marketplace_init(config.MOR); sessionRouterFacet = sessionRouterFacet.attach(lumerinDiamond.target) as SessionRouter; - await sessionRouterFacet.__SessionRouter_init(config.fundingAccount, config.pools); + await sessionRouterFacet.__SessionRouter_init(config.fundingAccount, 7 * DAY, config.pools); await providerRegistryFacet.providerSetMinStake(config.providerMinStake); await modelRegistryFacet.modelSetMinStake(config.modelMinStake); diff --git a/smart-contracts/test/diamond/facets/ModelRegistry.test.ts b/smart-contracts/test/diamond/facets/ModelRegistry.test.ts index a4ea9b17..01fe0f26 100644 --- a/smart-contracts/test/diamond/facets/ModelRegistry.test.ts +++ b/smart-contracts/test/diamond/facets/ModelRegistry.test.ts @@ -124,7 +124,7 @@ describe('ModelRegistry', () => { expect(await token.balanceOf(modelRegistry)).to.eq(wei(100)); expect(await token.balanceOf(SECOND)).to.eq(wei(900)); - expect(await modelRegistry.getActiveModels(0, 10)).to.deep.eq([modelId]); + expect(await modelRegistry.getActiveModelIds(0, 10)).to.deep.eq([modelId]); await modelRegistry.connect(SECOND).modelRegister(modelId, ipfsCID, 0, wei(0), 'name', ['tag_1']); }); @@ -196,7 +196,7 @@ describe('ModelRegistry', () => { expect(await token.balanceOf(modelRegistry)).to.eq(0); expect(await token.balanceOf(SECOND)).to.eq(wei(1000)); - expect(await modelRegistry.getActiveModels(0, 10)).to.deep.eq([]); + expect(await modelRegistry.getActiveModelIds(0, 10)).to.deep.eq([]); }); it('should throw error when the caller is not an owner or specified address', async () => { await expect(modelRegistry.connect(SECOND).modelDeregister(modelId)).to.be.revertedWithCustomError( diff --git a/smart-contracts/test/diamond/facets/SessionRouter.test.ts b/smart-contracts/test/diamond/facets/SessionRouter.test.ts index 366ccda5..d5b6f0da 100644 --- a/smart-contracts/test/diamond/facets/SessionRouter.test.ts +++ b/smart-contracts/test/diamond/facets/SessionRouter.test.ts @@ -77,7 +77,7 @@ describe('SessionRouter', () => { expect((await sessionRouter.getPools()).length).to.eq(5); }); it('should revert if try to call init function twice', async () => { - await expect(sessionRouter.__SessionRouter_init(FUNDING, [])).to.be.rejectedWith( + await expect(sessionRouter.__SessionRouter_init(FUNDING, DAY, [])).to.be.rejectedWith( 'Initializable: contract is already initialized', ); }); @@ -127,6 +127,28 @@ describe('SessionRouter', () => { }); }); + describe('#setMaxSessionDuration', () => { + it('should set max session duration', async () => { + await sessionRouter.setMaxSessionDuration(7 * DAY); + expect(await sessionRouter.getMaxSessionDuration()).to.eq(7 * DAY); + + await sessionRouter.setMaxSessionDuration(8 * DAY); + expect(await sessionRouter.getMaxSessionDuration()).to.eq(8 * DAY); + }); + it('should throw error when max session duration too low', async () => { + await expect(sessionRouter.setMaxSessionDuration(1)).to.be.revertedWithCustomError( + sessionRouter, + 'SessionMaxDurationTooShort', + ); + }); + it('should throw error when the caller is invalid', async () => { + await expect(sessionRouter.connect(SECOND).setMaxSessionDuration(1)).to.be.revertedWithCustomError( + sessionRouter, + 'OwnableUnauthorizedAccount', + ); + }); + }); + describe('#openSession', () => { let tokenBalBefore = 0n; let secondBalBefore = 0n; @@ -138,7 +160,7 @@ describe('SessionRouter', () => { it('should open session', async () => { await setTime(payoutStart + 10 * DAY); const { msg, signature } = await getProviderApproval(PROVIDER, SECOND, bidId); - await sessionRouter.connect(SECOND).openSession(wei(50), msg, signature); + await sessionRouter.connect(SECOND).openSession(wei(50), false, msg, signature); const sessionId = await sessionRouter.getSessionId(SECOND, PROVIDER, bidId, 0); const data = await sessionRouter.getSession(sessionId); @@ -152,6 +174,7 @@ describe('SessionRouter', () => { expect(data.endsAt).to.greaterThan(data.openedAt); expect(data.closedAt).to.eq(0); expect(data.isActive).to.eq(true); + expect(data.isDirectPaymentFromUser).to.eq(false); const tokenBalAfter = await token.balanceOf(sessionRouter); expect(tokenBalAfter - tokenBalBefore).to.eq(wei(50)); @@ -168,8 +191,8 @@ describe('SessionRouter', () => { const { msg: msg1, signature: signature1 } = await getProviderApproval(PROVIDER, SECOND, bidId); await setTime(payoutStart + 10 * DAY + 1); const { msg: msg2, signature: signature2 } = await getProviderApproval(PROVIDER, SECOND, bidId); - await sessionRouter.connect(SECOND).openSession(wei(50), msg1, signature1); - await sessionRouter.connect(SECOND).openSession(wei(50), msg2, signature2); + await sessionRouter.connect(SECOND).openSession(wei(50), false, msg1, signature1); + await sessionRouter.connect(SECOND).openSession(wei(50), false, msg2, signature2); const sessionId1 = await sessionRouter.getSessionId(SECOND, PROVIDER, bidId, 0); const sessionId2 = await sessionRouter.getSessionId(SECOND, PROVIDER, bidId, 1); @@ -190,69 +213,91 @@ describe('SessionRouter', () => { it('should open session with max duration', async () => { await setTime(payoutStart + 10 * DAY); const { msg, signature } = await getProviderApproval(PROVIDER, SECOND, bidId); - await sessionRouter.connect(SECOND).openSession(wei(10000), msg, signature); + await sessionRouter.connect(SECOND).openSession(wei(10000), false, msg, signature); const sessionId = await sessionRouter.getSessionId(SECOND, PROVIDER, bidId, 0); const data = await sessionRouter.getSession(sessionId); expect(data.endsAt).to.eq(Number(data.openedAt.toString()) + DAY); }); + it('should open session with valid amount for direct user payment', async () => { + await setTime(payoutStart + 10 * DAY); + const { msg, signature } = await getProviderApproval(PROVIDER, SECOND, bidId); + await sessionRouter.connect(SECOND).openSession(wei(50), true, msg, signature); + + const sessionId = await sessionRouter.getSessionId(SECOND, PROVIDER, bidId, 0); + const data = await sessionRouter.getSession(sessionId); + expect(data.user).to.eq(SECOND); + expect(data.bidId).to.eq(bidId); + expect(data.stake).to.eq(wei(50)); + expect(data.closeoutReceipt).to.eq('0x'); + expect(data.closeoutType).to.eq(0); + expect(data.providerWithdrawnAmount).to.eq(0); + expect(data.openedAt).to.eq(payoutStart + 10 * DAY + 1); + expect(data.endsAt).to.greaterThan(data.openedAt); + expect(data.closedAt).to.eq(0); + expect(data.isActive).to.eq(true); + expect(data.isDirectPaymentFromUser).to.eq(true); + + const tokenBalAfter = await token.balanceOf(sessionRouter); + expect(tokenBalAfter - tokenBalBefore).to.eq(wei(50)); + const secondBalAfter = await token.balanceOf(SECOND); + expect(secondBalBefore - secondBalAfter).to.eq(wei(50)); + + expect(await sessionRouter.getIsProviderApprovalUsed(msg)).to.eq(true); + expect(await sessionRouter.getUserSessions(SECOND, 0, 10)).to.deep.eq([sessionId]); + expect(await sessionRouter.getProviderSessions(PROVIDER, 0, 10)).to.deep.eq([sessionId]); + expect(await sessionRouter.getModelSessions(modelId, 0, 10)).to.deep.eq([sessionId]); + }); it('should throw error when the approval is for an another user', async () => { const { msg, signature } = await getProviderApproval(PROVIDER, OWNER, bidId); - await expect(sessionRouter.connect(SECOND).openSession(wei(50), msg, signature)).to.be.revertedWithCustomError( - sessionRouter, - 'SessionApprovedForAnotherUser', - ); + await expect( + sessionRouter.connect(SECOND).openSession(wei(50), false, msg, signature), + ).to.be.revertedWithCustomError(sessionRouter, 'SessionApprovedForAnotherUser'); }); it('should throw error when the approval is for an another chain', async () => { const { msg, signature } = await getProviderApproval(PROVIDER, SECOND, bidId, 1n); - await expect(sessionRouter.connect(SECOND).openSession(wei(50), msg, signature)).to.be.revertedWithCustomError( - sessionRouter, - 'SesssionApprovedForAnotherChainId', - ); + await expect( + sessionRouter.connect(SECOND).openSession(wei(50), false, msg, signature), + ).to.be.revertedWithCustomError(sessionRouter, 'SesssionApprovedForAnotherChainId'); }); it('should throw error when an aprrove expired', async () => { await setTime(payoutStart); const { msg, signature } = await getProviderApproval(PROVIDER, SECOND, bidId); await setTime(payoutStart + 600); - await expect(sessionRouter.connect(SECOND).openSession(wei(50), msg, signature)).to.be.revertedWithCustomError( - sessionRouter, - 'SesssionApproveExpired', - ); + await expect( + sessionRouter.connect(SECOND).openSession(wei(50), false, msg, signature), + ).to.be.revertedWithCustomError(sessionRouter, 'SesssionApproveExpired'); }); it('should throw error when the bid is not found', async () => { const { msg, signature } = await getProviderApproval(PROVIDER, SECOND, getHex(Buffer.from('1'))); - await expect(sessionRouter.connect(SECOND).openSession(wei(50), msg, signature)).to.be.revertedWithCustomError( - sessionRouter, - 'SessionBidNotFound', - ); + await expect( + sessionRouter.connect(SECOND).openSession(wei(50), false, msg, signature), + ).to.be.revertedWithCustomError(sessionRouter, 'SessionBidNotFound'); }); it('should throw error when the signature mismatch', async () => { const { msg, signature } = await getProviderApproval(OWNER, SECOND, bidId); - await expect(sessionRouter.connect(SECOND).openSession(wei(50), msg, signature)).to.be.revertedWithCustomError( - sessionRouter, - 'SessionProviderSignatureMismatch', - ); + await expect( + sessionRouter.connect(SECOND).openSession(wei(50), false, msg, signature), + ).to.be.revertedWithCustomError(sessionRouter, 'SessionProviderSignatureMismatch'); }); it('should throw error when an approval duplicated', async () => { await setTime(payoutStart + 10 * DAY); const { msg, signature } = await getProviderApproval(PROVIDER, SECOND, bidId); - await sessionRouter.connect(SECOND).openSession(wei(50), msg, signature); - await expect(sessionRouter.connect(SECOND).openSession(wei(50), msg, signature)).to.be.revertedWithCustomError( - sessionRouter, - 'SessionDuplicateApproval', - ); + await sessionRouter.connect(SECOND).openSession(wei(50), false, msg, signature); + await expect( + sessionRouter.connect(SECOND).openSession(wei(50), false, msg, signature), + ).to.be.revertedWithCustomError(sessionRouter, 'SessionDuplicateApproval'); }); it('should throw error when session duration too short', async () => { const { msg, signature } = await getProviderApproval(PROVIDER, SECOND, bidId); - await expect(sessionRouter.connect(SECOND).openSession(wei(50), msg, signature)).to.be.revertedWithCustomError( - sessionRouter, - 'SessionTooShort', - ); + await expect( + sessionRouter.connect(SECOND).openSession(wei(50), false, msg, signature), + ).to.be.revertedWithCustomError(sessionRouter, 'SessionTooShort'); }); }); describe('#closeSession', () => { - it('should close session and send rewards for the provider, with dispute, late closure', async () => { + it('should close session and send rewards for the provider, late closure', async () => { const { sessionId, openedAt } = await _createSession(); const providerBalBefore = await token.balanceOf(PROVIDER); @@ -302,7 +347,7 @@ describe('SessionRouter', () => { const fundingBalAfter = await token.balanceOf(FUNDING); expect(fundingBalBefore - fundingBalAfter).to.eq(bidPricePerSecond * duration); }); - it('should close session and send rewards for the provider, with dispute, late closure before end', async () => { + it('should close session and send rewards for the provider, late closure before end', async () => { const { sessionId, secondsToDayEnd, openedAt } = await _createSession(); const providerBalBefore = await token.balanceOf(PROVIDER); @@ -354,6 +399,39 @@ describe('SessionRouter', () => { const fundingBalAfter = await token.balanceOf(FUNDING); expect(fundingBalBefore - fundingBalAfter).to.eq(bidPricePerSecond * duration); }); + it('should close session and send rewards for the provider, early closure', async () => { + const { sessionId, openedAt, secondsToDayEnd } = await _createSession(true); + + const providerBalBefore = await token.balanceOf(PROVIDER); + const fundingBalBefore = await token.balanceOf(FUNDING); + const contractBalBefore = await token.balanceOf(sessionRouter); + const secondBalBefore = await token.balanceOf(SECOND); + + await setTime(openedAt + secondsToDayEnd + 100); + const { msg: receiptMsg } = await getReceipt(PROVIDER, sessionId, 0, 0); + const { signature: receiptSig } = await getReceipt(OWNER, sessionId, 0, 0); + await sessionRouter.connect(SECOND).closeSession(receiptMsg, receiptSig); + + const session = await sessionRouter.getSession(sessionId); + const duration = BigInt(secondsToDayEnd); + + expect(session.closedAt).to.eq(openedAt + secondsToDayEnd + 100 + 1); + expect(session.isActive).to.eq(false); + expect(session.closeoutReceipt).to.eq(receiptMsg); + expect(session.providerWithdrawnAmount).to.eq(bidPricePerSecond * duration); + + expect((await sessionRouter.getProvider(PROVIDER)).limitPeriodEarned).to.eq(bidPricePerSecond * duration); + + const providerBalAfter = await token.balanceOf(PROVIDER); + expect(providerBalAfter - providerBalBefore).to.eq(bidPricePerSecond * duration); + const fundingBalAfter = await token.balanceOf(FUNDING); + expect(fundingBalBefore - fundingBalAfter).to.eq(0); + const contractBalAfter = await token.balanceOf(sessionRouter); + const secondBalAfter = await token.balanceOf(SECOND); + expect(contractBalBefore - contractBalAfter).to.eq( + bidPricePerSecond * duration + secondBalAfter - secondBalBefore, + ); + }); it('should close session and send rewards for the user, late closure', async () => { const { sessionId, openedAt } = await _createSession(); @@ -395,6 +473,53 @@ describe('SessionRouter', () => { await sessionRouter.getProviderModelStats(modelId, PROVIDER); await sessionRouter.getModelStats(modelId); }); + it('should claim provider rewards and close session, late closure', async () => { + const { sessionId, openedAt } = await _createSession(true); + + let userBalBefore = await token.balanceOf(SECOND); + let providerBalBefore = await token.balanceOf(PROVIDER); + let contractBalBefore = await token.balanceOf(sessionRouter); + + // Claim for Provider + await setTime(openedAt + 5 * DAY); + await sessionRouter.connect(PROVIDER).claimForProvider(sessionId); + + let session = await sessionRouter.getSession(sessionId); + const duration = session.endsAt - session.openedAt; + + expect(session.providerWithdrawnAmount).to.eq(bidPricePerSecond * duration); + expect(await sessionRouter.getProvidersTotalClaimed()).to.eq(bidPricePerSecond * duration); + expect((await sessionRouter.getProvider(PROVIDER)).limitPeriodEarned).to.eq(bidPricePerSecond * duration); + + const userBalAfter = await token.balanceOf(SECOND); + expect(userBalAfter - userBalBefore).to.eq(0); + const providerBalAfter = await token.balanceOf(PROVIDER); + expect(providerBalAfter - providerBalBefore).to.eq(bidPricePerSecond * duration); + const contractBalAfter = await token.balanceOf(sessionRouter); + expect(contractBalBefore - contractBalAfter).to.eq(bidPricePerSecond * duration); + + // Close session + userBalBefore = userBalAfter; + providerBalBefore = providerBalAfter; + contractBalBefore = contractBalAfter; + + await setTime(openedAt + 6 * DAY); + const { msg: receiptMsg, signature: receiptSig } = await getReceipt(PROVIDER, sessionId, 0, 0); + await sessionRouter.connect(SECOND).closeSession(receiptMsg, receiptSig); + + const userBalAfterClose = await token.balanceOf(SECOND); + expect(userBalAfterClose - userBalBefore).to.eq(wei(50) - bidPricePerSecond * duration); + const providerBalAfterClose = await token.balanceOf(PROVIDER); + expect(providerBalAfterClose - providerBalBefore).to.eq(0); + const contractBalAfterClose = await token.balanceOf(sessionRouter); + expect(contractBalBefore - contractBalAfterClose).to.eq(wei(50) - bidPricePerSecond * duration); + + session = await sessionRouter.getSession(sessionId); + expect(session.closedAt).to.eq(openedAt + 6 * DAY + 1); + expect(session.isActive).to.eq(false); + expect(session.closeoutReceipt).to.eq(receiptMsg); + expect(session.providerWithdrawnAmount).to.eq(bidPricePerSecond * duration); + }); it('should throw error when the caller is invalid', async () => { const { msg: receiptMsg, signature: receiptSig } = await getReceipt(PROVIDER, getHex(Buffer.from('1')), 0, 0); await expect(sessionRouter.connect(SECOND).closeSession(receiptMsg, receiptSig)).to.be.revertedWithCustomError( @@ -493,7 +618,7 @@ describe('SessionRouter', () => { await setTime(payoutStart + 10 * DAY); const { msg: msg1, signature: sig1 } = await getProviderApproval(PROVIDER, SECOND, bidId); - await sessionRouter.connect(SECOND).openSession(wei(50), msg1, sig1); + await sessionRouter.connect(SECOND).openSession(wei(50), false, msg1, sig1); const sessionId1 = await sessionRouter.getSessionId(SECOND, PROVIDER, bidId, 0); await setTime(payoutStart + 20 * DAY); @@ -501,7 +626,7 @@ describe('SessionRouter', () => { await setTime(payoutStart + 30 * DAY); const { msg: msg2, signature: sig2 } = await getProviderApproval(PROVIDER, SECOND, bidId); - await sessionRouter.connect(SECOND).openSession(wei(50), msg2, sig2); + await sessionRouter.connect(SECOND).openSession(wei(50), false, msg2, sig2); const sessionId2 = await sessionRouter.getSessionId(SECOND, PROVIDER, bidId, 1); await setTime(payoutStart + 40 * DAY); @@ -515,102 +640,19 @@ describe('SessionRouter', () => { const fundingBalAfter = await token.balanceOf(FUNDING); expect(fundingBalBefore - fundingBalAfter).to.eq(wei(0.2)); }); - it('should throw error when caller is not the session provider', async () => { - const { sessionId } = await _createSession(); - - await expect(sessionRouter.connect(SECOND).claimForProvider(sessionId)).to.be.revertedWithCustomError( - sessionRouter, - 'OwnableUnauthorizedAccount', - ); - }); - it('should throw error when session is not end', async () => { + it('should claim zero when session is not end', async () => { const { sessionId, openedAt } = await _createSession(); - await setTime(openedAt + 10); - await expect(sessionRouter.connect(PROVIDER).claimForProvider(sessionId)).to.be.revertedWithCustomError( - sessionRouter, - 'SessionNotEndedOrNotExist', - ); - }); - }); - - describe('#claimForProvider', () => { - it('should claim provider rewards, remainder, session closed with dispute', async () => { - const { sessionId, secondsToDayEnd, openedAt } = await _createSession(); - - await setTime(openedAt + secondsToDayEnd + 1); - const { msg: receiptMsg } = await getReceipt(PROVIDER, sessionId, 0, 0); - const { signature: receiptSig } = await getReceipt(OWNER, sessionId, 0, 0); - await sessionRouter.connect(SECOND).closeSession(receiptMsg, receiptSig); - - let session = await sessionRouter.getSession(sessionId); - const fullDuration = BigInt(secondsToDayEnd + 1); - const duration = 1n; - const providerBalBefore = await token.balanceOf(PROVIDER); const fundingBalBefore = await token.balanceOf(FUNDING); + await setTime(openedAt + 10); await sessionRouter.connect(PROVIDER).claimForProvider(sessionId); - session = await sessionRouter.getSession(sessionId); - - expect(session.providerWithdrawnAmount).to.eq(bidPricePerSecond * fullDuration); - expect(await sessionRouter.getProvidersTotalClaimed()).to.eq(bidPricePerSecond * fullDuration); - expect((await sessionRouter.getProvider(PROVIDER)).limitPeriodEarned).to.eq(bidPricePerSecond * fullDuration); - - const providerBalAfter = await token.balanceOf(PROVIDER); - expect(providerBalAfter - providerBalBefore).to.eq(bidPricePerSecond * duration); - const fundingBalAfter = await token.balanceOf(FUNDING); - expect(fundingBalBefore - fundingBalAfter).to.eq(bidPricePerSecond * duration); - }); - it('should claim provider rewards, full', async () => { - const { sessionId, openedAt } = await _createSession(); - - let session = await sessionRouter.getSession(sessionId); - const duration = session.endsAt - session.openedAt; - - const providerBalBefore = await token.balanceOf(PROVIDER); - const fundingBalBefore = await token.balanceOf(FUNDING); - - await setTime(openedAt + 5 * DAY + 1); - await sessionRouter.connect(PROVIDER).claimForProvider(sessionId); - session = await sessionRouter.getSession(sessionId); - - expect(session.providerWithdrawnAmount).to.eq(bidPricePerSecond * duration); - expect(await sessionRouter.getProvidersTotalClaimed()).to.eq(bidPricePerSecond * duration); - expect((await sessionRouter.getProvider(PROVIDER)).limitPeriodEarned).to.eq(bidPricePerSecond * duration); const providerBalAfter = await token.balanceOf(PROVIDER); - expect(providerBalAfter - providerBalBefore).to.eq(bidPricePerSecond * duration); + expect(providerBalAfter - providerBalBefore).to.eq(0); const fundingBalAfter = await token.balanceOf(FUNDING); - expect(fundingBalBefore - fundingBalAfter).to.eq(bidPricePerSecond * duration); - }); - it('should claim provider rewards with reward limiter amount for the period', async () => { - const providerBalBefore = await token.balanceOf(PROVIDER); - const fundingBalBefore = await token.balanceOf(FUNDING); - - await setTime(payoutStart + 10 * DAY); - const { msg: msg1, signature: sig1 } = await getProviderApproval(PROVIDER, SECOND, bidId); - await sessionRouter.connect(SECOND).openSession(wei(50), msg1, sig1); - - const sessionId1 = await sessionRouter.getSessionId(SECOND, PROVIDER, bidId, 0); - await setTime(payoutStart + 20 * DAY); - await sessionRouter.connect(PROVIDER).claimForProvider(sessionId1); - - await setTime(payoutStart + 30 * DAY); - const { msg: msg2, signature: sig2 } = await getProviderApproval(PROVIDER, SECOND, bidId); - await sessionRouter.connect(SECOND).openSession(wei(50), msg2, sig2); - - const sessionId2 = await sessionRouter.getSessionId(SECOND, PROVIDER, bidId, 1); - await setTime(payoutStart + 40 * DAY); - await sessionRouter.connect(PROVIDER).claimForProvider(sessionId2); - - expect(await sessionRouter.getProvidersTotalClaimed()).to.eq(wei(0.2)); - expect((await sessionRouter.getProvider(PROVIDER)).limitPeriodEarned).to.eq(wei(0.2)); - - const providerBalAfter = await token.balanceOf(PROVIDER); - expect(providerBalAfter - providerBalBefore).to.eq(wei(0.2)); - const fundingBalAfter = await token.balanceOf(FUNDING); - expect(fundingBalBefore - fundingBalAfter).to.eq(wei(0.2)); + expect(fundingBalBefore - fundingBalAfter).to.eq(0); }); it('should throw error when caller is not the session provider', async () => { const { sessionId } = await _createSession(); @@ -620,15 +662,6 @@ describe('SessionRouter', () => { 'OwnableUnauthorizedAccount', ); }); - it('should throw error when session is not end', async () => { - const { sessionId, openedAt } = await _createSession(); - - await setTime(openedAt + 10); - await expect(sessionRouter.connect(PROVIDER).claimForProvider(sessionId)).to.be.revertedWithCustomError( - sessionRouter, - 'SessionNotEndedOrNotExist', - ); - }); }); describe('#withdrawUserStakes', () => { @@ -637,10 +670,10 @@ describe('SessionRouter', () => { await setTime(openedAt + 1 * DAY); const { msg, signature } = await getProviderApproval(PROVIDER, SECOND, bidId); - await sessionRouter.connect(SECOND).openSession(wei(50), msg, signature); + await sessionRouter.connect(SECOND).openSession(wei(50), false, msg, signature); const sessionId1 = await sessionRouter.getSessionId(SECOND, PROVIDER, bidId, 0); - await setTime(openedAt + 2 * DAY + 1); + await setTime(openedAt + 1 * DAY + 100); const { msg: receiptMsg, signature: receiptSig } = await getReceipt(PROVIDER, sessionId1, 0, 0); await sessionRouter.connect(SECOND).closeSession(receiptMsg, receiptSig); @@ -664,7 +697,7 @@ describe('SessionRouter', () => { await setTime(openedAt + 1 * DAY); const { msg: msg1, signature: sig1 } = await getProviderApproval(PROVIDER, SECOND, bidId); - await sessionRouter.connect(SECOND).openSession(wei(50), msg1, sig1); + await sessionRouter.connect(SECOND).openSession(wei(50), false, msg1, sig1); const sessionId1 = await sessionRouter.getSessionId(SECOND, PROVIDER, bidId, 0); await setTime(openedAt + 1 * DAY + 500); @@ -673,7 +706,7 @@ describe('SessionRouter', () => { await setTime(openedAt + 3 * DAY); const { msg: msg2, signature: sig2 } = await getProviderApproval(PROVIDER, SECOND, bidId); - await sessionRouter.connect(SECOND).openSession(wei(50), msg2, sig2); + await sessionRouter.connect(SECOND).openSession(wei(50), false, msg2, sig2); const sessionId2 = await sessionRouter.getSessionId(SECOND, PROVIDER, bidId, 1); await setTime(openedAt + 3 * DAY + 500); @@ -700,7 +733,7 @@ describe('SessionRouter', () => { await setTime(openedAt + 1 * DAY); const { msg: msg1, signature: sig1 } = await getProviderApproval(PROVIDER, SECOND, bidId); - await sessionRouter.connect(SECOND).openSession(wei(50), msg1, sig1); + await sessionRouter.connect(SECOND).openSession(wei(50), false, msg1, sig1); const sessionId1 = await sessionRouter.getSessionId(SECOND, PROVIDER, bidId, 0); await setTime(openedAt + 1 * DAY + 500); @@ -709,7 +742,7 @@ describe('SessionRouter', () => { await setTime(openedAt + 3 * DAY); const { msg: msg2, signature: sig2 } = await getProviderApproval(PROVIDER, SECOND, bidId); - await sessionRouter.connect(SECOND).openSession(wei(50), msg2, sig2); + await sessionRouter.connect(SECOND).openSession(wei(50), false, msg2, sig2); const sessionId2 = await sessionRouter.getSessionId(SECOND, PROVIDER, bidId, 1); await setTime(openedAt + 3 * DAY + 500); @@ -736,7 +769,7 @@ describe('SessionRouter', () => { await setTime(openedAt + 1 * DAY); const { msg: msg1, signature: sig1 } = await getProviderApproval(PROVIDER, SECOND, bidId); - await sessionRouter.connect(SECOND).openSession(wei(50), msg1, sig1); + await sessionRouter.connect(SECOND).openSession(wei(50), false, msg1, sig1); const sessionId1 = await sessionRouter.getSessionId(SECOND, PROVIDER, bidId, 0); await setTime(openedAt + 1 * DAY + 500); @@ -745,7 +778,7 @@ describe('SessionRouter', () => { await setTime(openedAt + 3 * DAY); const { msg: msg2, signature: sig2 } = await getProviderApproval(PROVIDER, SECOND, bidId); - await sessionRouter.connect(SECOND).openSession(wei(50), msg2, sig2); + await sessionRouter.connect(SECOND).openSession(wei(50), false, msg2, sig2); const sessionId2 = await sessionRouter.getSessionId(SECOND, PROVIDER, bidId, 1); await setTime(openedAt + 3 * DAY + 500); @@ -788,11 +821,11 @@ describe('SessionRouter', () => { await setTime(openedAt + 1 * DAY); const { msg: msg1, signature: sig1 } = await getProviderApproval(PROVIDER, SECOND, bidId); - await sessionRouter.connect(SECOND).openSession(wei(50), msg1, sig1); + await sessionRouter.connect(SECOND).openSession(wei(50), false, msg1, sig1); const sessionId1 = await sessionRouter.getSessionId(SECOND, PROVIDER, bidId, 0); const { msg: msg2, signature: sig2 } = await getProviderApproval(PROVIDER, SECOND, bidId); - await sessionRouter.connect(SECOND).openSession(wei(50), msg2, sig2); + await sessionRouter.connect(SECOND).openSession(wei(50), false, msg2, sig2); const sessionId2 = await sessionRouter.getSessionId(SECOND, PROVIDER, bidId, 1); await setTime(openedAt + 1 * DAY + 500); @@ -808,15 +841,21 @@ describe('SessionRouter', () => { 'SessionUserAmountToWithdrawIsZero', ); }); + it('should throw error when amount of itterations are zero', async () => { + await expect(sessionRouter.connect(SECOND).withdrawUserStakes(0)).to.be.revertedWithCustomError( + sessionRouter, + 'SessionUserAmountToWithdrawIsZero', + ); + }); }); - const _createSession = async () => { + const _createSession = async (isDirectPaymentFromUser = false) => { const secondsToDayEnd = 600n; const openedAt = payoutStart + (payoutStart % DAY) + 10 * DAY - Number(secondsToDayEnd) - 1; await setTime(openedAt); const { msg, signature } = await getProviderApproval(PROVIDER, SECOND, bidId); - await sessionRouter.connect(SECOND).openSession(wei(50), msg, signature); + await sessionRouter.connect(SECOND).openSession(wei(50), isDirectPaymentFromUser, msg, signature); return { sessionId: await sessionRouter.getSessionId(SECOND, PROVIDER, bidId, 0), diff --git a/smart-contracts/test/helpers/deployers/diamond/facets/session-router.ts b/smart-contracts/test/helpers/deployers/diamond/facets/session-router.ts index e1764bd7..39f63f5f 100644 --- a/smart-contracts/test/helpers/deployers/diamond/facets/session-router.ts +++ b/smart-contracts/test/helpers/deployers/diamond/facets/session-router.ts @@ -10,6 +10,7 @@ import { } from '@/generated-types/ethers'; import { FacetAction } from '@/test/helpers/deployers/diamond/lumerin-diamond'; import { getDefaultPools } from '@/test/helpers/pool-helper'; +import { DAY } from '@/utils/time'; export const deployFacetSessionRouter = async ( diamond: LumerinDiamond, @@ -45,7 +46,7 @@ export const deployFacetSessionRouter = async ( ]); facet = facet.attach(diamond.target) as SessionRouter; - await facet.__SessionRouter_init(fundingAccount, getDefaultPools()); + await facet.__SessionRouter_init(fundingAccount, DAY, getDefaultPools()); return facet; }; From a0ef50640f60725abc0acd22e84b6dacbbcd74fa Mon Sep 17 00:00:00 2001 From: Oleksandr Date: Thu, 24 Oct 2024 13:01:34 +0300 Subject: [PATCH 6/9] remove dev imports --- smart-contracts/contracts/diamond/facets/SessionRouter.sol | 2 -- 1 file changed, 2 deletions(-) diff --git a/smart-contracts/contracts/diamond/facets/SessionRouter.sol b/smart-contracts/contracts/diamond/facets/SessionRouter.sol index e6510345..1995fc54 100644 --- a/smart-contracts/contracts/diamond/facets/SessionRouter.sol +++ b/smart-contracts/contracts/diamond/facets/SessionRouter.sol @@ -18,8 +18,6 @@ import {LibSD} from "../../libs/LibSD.sol"; import {ISessionRouter} from "../../interfaces/facets/ISessionRouter.sol"; -import "hardhat/console.sol"; - contract SessionRouter is ISessionRouter, OwnableDiamondStorage, From e1cb781ace93e137d41964ed1276a6bfad67d2d8 Mon Sep 17 00:00:00 2001 From: Oleksandr Date: Thu, 24 Oct 2024 13:18:36 +0300 Subject: [PATCH 7/9] add linter fixes --- .../diamond/facets/ProviderRegistry.sol | 2 +- .../diamond/facets/SessionRouter.sol | 34 +++++++++++-------- .../diamond/storages/ModelStorage.sol | 4 +-- .../diamond/storages/ProviderStorage.sol | 2 +- .../diamond/storages/SessionStorage.sol | 2 +- .../interfaces/facets/ISessionRouter.sol | 6 +++- .../interfaces/storage/IBidStorage.sol | 2 +- 7 files changed, 30 insertions(+), 22 deletions(-) diff --git a/smart-contracts/contracts/diamond/facets/ProviderRegistry.sol b/smart-contracts/contracts/diamond/facets/ProviderRegistry.sol index 7a8c3d67..da00404e 100644 --- a/smart-contracts/contracts/diamond/facets/ProviderRegistry.sol +++ b/smart-contracts/contracts/diamond/facets/ProviderRegistry.sol @@ -28,7 +28,7 @@ contract ProviderRegistry is IProviderRegistry, OwnableDiamondStorage, ProviderS BidsStorage storage bidsStorage = getBidsStorage(); if (amount_ > 0) { - IERC20(bidsStorage.token).safeTransferFrom(_msgSender(), address(this), amount_); + IERC20(bidsStorage.token).safeTransferFrom(_msgSender(), address(this), amount_); } PovidersStorage storage providersStorage = getProvidersStorage(); diff --git a/smart-contracts/contracts/diamond/facets/SessionRouter.sol b/smart-contracts/contracts/diamond/facets/SessionRouter.sol index 1995fc54..98dc5ae1 100644 --- a/smart-contracts/contracts/diamond/facets/SessionRouter.sol +++ b/smart-contracts/contracts/diamond/facets/SessionRouter.sol @@ -109,7 +109,7 @@ contract SessionRouter is return sessionId_; } - + function _validateSession( bytes32 bidId_, uint256 amount_, @@ -239,29 +239,31 @@ contract SessionRouter is return (sessionId_, tpsScaled1000_, ttftMs_); } - function _getProviderRewards(Session storage session, Bid storage bid, bool isIncludeWithdrawnAmount_) private view returns (uint256) { + function _getProviderRewards( + Session storage session, + Bid storage bid, + bool isIncludeWithdrawnAmount_ + ) private view returns (uint256) { uint256 sessionEnd_ = session.closedAt == 0 ? session.endsAt : session.closedAt.min(session.endsAt); if (block.timestamp < sessionEnd_) { return 0; } - uint256 withdrawnAmount = isIncludeWithdrawnAmount_ ? session.providerWithdrawnAmount : 0; + uint256 withdrawnAmount = isIncludeWithdrawnAmount_ ? session.providerWithdrawnAmount : 0; return (sessionEnd_ - session.openedAt) * bid.pricePerSecond - withdrawnAmount; } - function _rewardProviderAfterClose( - bool noDispute_, - Session storage session, - Bid storage bid - ) internal { + function _rewardProviderAfterClose(bool noDispute_, Session storage session, Bid storage bid) internal { uint128 startOfToday_ = startOfTheDay(uint128(block.timestamp)); bool isClosingLate_ = uint128(block.timestamp) > session.endsAt; uint256 providerAmountToWithdraw_ = _getProviderRewards(session, bid, true); uint256 providerOnHoldAmount = 0; if (!noDispute_ && !isClosingLate_) { - providerOnHoldAmount = (session.endsAt.min(session.closedAt) - startOfToday_.max(session.openedAt)) * bid.pricePerSecond; + providerOnHoldAmount = + (session.endsAt.min(session.closedAt) - startOfToday_.max(session.openedAt)) * + bid.pricePerSecond; } providerAmountToWithdraw_ -= providerOnHoldAmount; @@ -281,17 +283,19 @@ contract SessionRouter is uint256 userInitialLock_ = userDuration_ * bid.pricePerSecond; userStakeToLock_ = userStake.min(stipendToStake(userInitialLock_, startOfToday_)); - getSessionsStorage().userStakesOnHold[session.user].push(OnHold(userStakeToLock_, uint128(startOfToday_ + 1 days))); + getSessionsStorage().userStakesOnHold[session.user].push( + OnHold(userStakeToLock_, uint128(startOfToday_ + 1 days)) + ); } uint256 userAmountToWithdraw_ = userStake - userStakeToLock_; IERC20(getBidsStorage().token).safeTransfer(session.user, userAmountToWithdraw_); } function _setStats( - bool noDispute_, - uint32 ttftMs_, - uint32 tpsScaled1000_, - Session storage session, + bool noDispute_, + uint32 ttftMs_, + uint32 tpsScaled1000_, + Session storage session, Bid storage bid ) internal { ProviderModelStats storage prStats = providerModelStats(bid.modelId, bid.provider); @@ -398,7 +402,7 @@ contract SessionRouter is OnHold[] storage onHoldEntries = getSessionsStorage().userStakesOnHold[_msgSender()]; uint8 i_ = iterations_ >= onHoldEntries.length ? uint8(onHoldEntries.length) : iterations_; if (i_ == 0) { - revert SessionUserAmountToWithdrawIsZero(); + revert SessionUserAmountToWithdrawIsZero(); } i_--; diff --git a/smart-contracts/contracts/diamond/storages/ModelStorage.sol b/smart-contracts/contracts/diamond/storages/ModelStorage.sol index 42b5f2f9..37c8ba00 100644 --- a/smart-contracts/contracts/diamond/storages/ModelStorage.sol +++ b/smart-contracts/contracts/diamond/storages/ModelStorage.sol @@ -14,7 +14,7 @@ contract ModelStorage is IModelStorage { uint256 modelMinimumStake; EnumerableSet.Bytes32Set modelIds; mapping(bytes32 modelId => Model) models; - // TODO: move vars below to the graph in the future + // TODO: move vars below to the graph in the future EnumerableSet.Bytes32Set activeModels; } @@ -34,7 +34,7 @@ contract ModelStorage is IModelStorage { } function getActiveModelIds(uint256 offset_, uint256 limit_) external view returns (bytes32[] memory) { - return getModelsStorage().activeModels.part(offset_, limit_); + return getModelsStorage().activeModels.part(offset_, limit_); } function getIsModelActive(bytes32 modelId_) public view returns (bool) { diff --git a/smart-contracts/contracts/diamond/storages/ProviderStorage.sol b/smart-contracts/contracts/diamond/storages/ProviderStorage.sol index ebcf5f2e..94cbeb78 100644 --- a/smart-contracts/contracts/diamond/storages/ProviderStorage.sol +++ b/smart-contracts/contracts/diamond/storages/ProviderStorage.sol @@ -32,7 +32,7 @@ contract ProviderStorage is IProviderStorage { } function getActiveProviders(uint256 offset_, uint256 limit_) external view returns (address[] memory) { - return getProvidersStorage().activeProviders.part(offset_, limit_); + return getProvidersStorage().activeProviders.part(offset_, limit_); } function getIsProviderActive(address provider_) public view returns (bool) { diff --git a/smart-contracts/contracts/diamond/storages/SessionStorage.sol b/smart-contracts/contracts/diamond/storages/SessionStorage.sol index fa9817cc..cbe96682 100644 --- a/smart-contracts/contracts/diamond/storages/SessionStorage.sol +++ b/smart-contracts/contracts/diamond/storages/SessionStorage.sol @@ -83,7 +83,7 @@ contract SessionStorage is ISessionStorage { function getIsProviderApprovalUsed(bytes memory approval_) external view returns (bool) { return getSessionsStorage().isProviderApprovalUsed[approval_]; } - + function getMaxSessionDuration() external view returns (uint128) { return getSessionsStorage().maxSessionDuration; } diff --git a/smart-contracts/contracts/interfaces/facets/ISessionRouter.sol b/smart-contracts/contracts/interfaces/facets/ISessionRouter.sol index 48bed3dc..f58439bd 100644 --- a/smart-contracts/contracts/interfaces/facets/ISessionRouter.sol +++ b/smart-contracts/contracts/interfaces/facets/ISessionRouter.sol @@ -32,7 +32,11 @@ interface ISessionRouter is ISessionStorage { * @param maxSessionDuration_ The max session duration * @param pools_ The pools data */ - function __SessionRouter_init(address fundingAccount_, uint128 maxSessionDuration_, Pool[] calldata pools_) external; + function __SessionRouter_init( + address fundingAccount_, + uint128 maxSessionDuration_, + Pool[] calldata pools_ + ) external; /** * @notice Sets distibution pool configuration diff --git a/smart-contracts/contracts/interfaces/storage/IBidStorage.sol b/smart-contracts/contracts/interfaces/storage/IBidStorage.sol index 559d100d..708bf7e8 100644 --- a/smart-contracts/contracts/interfaces/storage/IBidStorage.sol +++ b/smart-contracts/contracts/interfaces/storage/IBidStorage.sol @@ -19,7 +19,7 @@ interface IBidStorage { uint128 createdAt; uint128 deletedAt; } - + /** * The function returns the bid structure. * @param bidId_ Bid ID. From 7574f97c792da630517d3b17c4c5518a4d534f28 Mon Sep 17 00:00:00 2001 From: Oleksandr Date: Fri, 25 Oct 2024 13:27:52 +0300 Subject: [PATCH 8/9] add min and max price fro the bid price per second --- .../contracts/diamond/facets/Marketplace.sol | 31 ++++++++++++++++++- .../diamond/storages/MarketplaceStorage.sol | 6 ++++ .../interfaces/facets/IMarketplace.sol | 19 +++++++++++- .../storage/IMarketplaceStorage.sol | 7 +++++ .../test/diamond/facets/Marketplace.test.ts | 31 +++++++++++++++++-- .../test/diamond/facets/ModelRegistry.test.ts | 2 +- .../diamond/facets/ProviderRegistry.test.ts | 2 +- .../test/diamond/facets/SessionRouter.test.ts | 2 +- .../deployers/diamond/facets/marketplace.ts | 9 ++++-- 9 files changed, 100 insertions(+), 9 deletions(-) diff --git a/smart-contracts/contracts/diamond/facets/Marketplace.sol b/smart-contracts/contracts/diamond/facets/Marketplace.sol index c3e20659..81be7d44 100644 --- a/smart-contracts/contracts/diamond/facets/Marketplace.sol +++ b/smart-contracts/contracts/diamond/facets/Marketplace.sol @@ -24,9 +24,15 @@ contract Marketplace is using SafeERC20 for IERC20; using EnumerableSet for EnumerableSet.Bytes32Set; - function __Marketplace_init(address token_) external initializer(BIDS_STORAGE_SLOT) { + function __Marketplace_init( + address token_, + uint256 bidMinPricePerSecond_, + uint256 bidMaxPricePerSecond_ + ) external initializer(BIDS_STORAGE_SLOT) { BidsStorage storage bidsStorage = getBidsStorage(); bidsStorage.token = token_; + + setMinMaxBidPricePerSecond(bidMinPricePerSecond_, bidMaxPricePerSecond_); } function setMarketplaceBidFee(uint256 bidFee_) external onlyOwner { @@ -36,6 +42,25 @@ contract Marketplace is emit MaretplaceFeeUpdated(bidFee_); } + function setMinMaxBidPricePerSecond( + uint256 bidMinPricePerSecond_, + uint256 bidMaxPricePerSecond_ + ) public onlyOwner { + if (bidMinPricePerSecond_ == 0) { + revert MarketplaceBidMinPricePerSecondIsZero(); + } + + if (bidMinPricePerSecond_ > bidMaxPricePerSecond_) { + revert MarketplaceBidMinPricePerSecondIsInvalid(); + } + + MarketStorage storage marketStorage = getMarketStorage(); + marketStorage.bidMinPricePerSecond = bidMinPricePerSecond_; + marketStorage.bidMaxPricePerSecond = bidMaxPricePerSecond_; + + emit MarketplaceBidMinMaxPriceUpdated(bidMinPricePerSecond_, bidMaxPricePerSecond_); + } + function postModelBid(bytes32 modelId_, uint256 pricePerSecond_) external returns (bytes32 bidId) { address provider_ = _msgSender(); @@ -49,6 +74,10 @@ contract Marketplace is BidsStorage storage bidsStorage = getBidsStorage(); MarketStorage storage marketStorage = getMarketStorage(); + if (pricePerSecond_ < marketStorage.bidMinPricePerSecond || pricePerSecond_ > marketStorage.bidMaxPricePerSecond) { + revert MarketplaceBidPricePerSecondInvalid(); + } + IERC20(bidsStorage.token).safeTransferFrom(_msgSender(), address(this), marketStorage.bidFee); marketStorage.feeBalance += marketStorage.bidFee; diff --git a/smart-contracts/contracts/diamond/storages/MarketplaceStorage.sol b/smart-contracts/contracts/diamond/storages/MarketplaceStorage.sol index d905fada..91ad4158 100644 --- a/smart-contracts/contracts/diamond/storages/MarketplaceStorage.sol +++ b/smart-contracts/contracts/diamond/storages/MarketplaceStorage.sol @@ -7,6 +7,8 @@ contract MarketplaceStorage is IMarketplaceStorage { struct MarketStorage { uint256 feeBalance; // Total fees balance of the contract uint256 bidFee; + uint256 bidMinPricePerSecond; + uint256 bidMaxPricePerSecond; } bytes32 public constant MARKET_STORAGE_SLOT = keccak256("diamond.standard.market.storage"); @@ -20,6 +22,10 @@ contract MarketplaceStorage is IMarketplaceStorage { return getMarketStorage().feeBalance; } + function getMinMaxBidPricePerSecond() external view returns (uint256, uint256) { + return (getMarketStorage().bidMinPricePerSecond, getMarketStorage().bidMaxPricePerSecond); + } + /** INTERNAL */ function getMarketStorage() internal pure returns (MarketStorage storage ds) { bytes32 slot_ = MARKET_STORAGE_SLOT; diff --git a/smart-contracts/contracts/interfaces/facets/IMarketplace.sol b/smart-contracts/contracts/interfaces/facets/IMarketplace.sol index e24723d6..04b72b7d 100644 --- a/smart-contracts/contracts/interfaces/facets/IMarketplace.sol +++ b/smart-contracts/contracts/interfaces/facets/IMarketplace.sol @@ -7,15 +7,22 @@ interface IMarketplace is IMarketplaceStorage { event MaretplaceFeeUpdated(uint256 bidFee); event MarketplaceBidPosted(address indexed provider, bytes32 indexed modelId, uint256 nonce); event MarketplaceBidDeleted(address indexed provider, bytes32 indexed modelId, uint256 nonce); + event MarketplaceBidMinMaxPriceUpdated(uint256 bidMinPricePerSecond, uint256 bidMaxPricePerSecond); + error MarketplaceProviderNotFound(); error MarketplaceModelNotFound(); error MarketplaceActiveBidNotFound(); + error MarketplaceBidMinPricePerSecondIsZero(); + error MarketplaceBidMinPricePerSecondIsInvalid(); + error MarketplaceBidPricePerSecondInvalid(); /** * The function to initialize the facet. * @param token_ Stake token (MOR) + * @param bidMinPricePerSecond_ Min price per second for bid + * @param bidMaxPricePerSecond_ Max price per second for bid */ - function __Marketplace_init(address token_) external; + function __Marketplace_init(address token_, uint256 bidMinPricePerSecond_, uint256 bidMaxPricePerSecond_) external; /** * The function to set the bidFee. @@ -23,6 +30,16 @@ interface IMarketplace is IMarketplaceStorage { */ function setMarketplaceBidFee(uint256 bidFee_) external; + /** + * The function to set the min and max price per second for bid. + * @param bidMinPricePerSecond_ Min price per second for bid + * @param bidMaxPricePerSecond_ Max price per second for bid + */ + function setMinMaxBidPricePerSecond( + uint256 bidMinPricePerSecond_, + uint256 bidMaxPricePerSecond_ + ) external; + /** * The function to create the bid. * @param modelId_ The mode ID diff --git a/smart-contracts/contracts/interfaces/storage/IMarketplaceStorage.sol b/smart-contracts/contracts/interfaces/storage/IMarketplaceStorage.sol index 36ca7c70..bbf05153 100644 --- a/smart-contracts/contracts/interfaces/storage/IMarketplaceStorage.sol +++ b/smart-contracts/contracts/interfaces/storage/IMarketplaceStorage.sol @@ -11,4 +11,11 @@ interface IMarketplaceStorage { * The function returns fee balance. */ function getFeeBalance() external view returns (uint256); + + /** + * The function returns min and max price per second for bid. + * @return Min bid price per second + * @return Max bid price per second + */ + function getMinMaxBidPricePerSecond() external view returns (uint256, uint256); } diff --git a/smart-contracts/test/diamond/facets/Marketplace.test.ts b/smart-contracts/test/diamond/facets/Marketplace.test.ts index 04172e7c..90762b05 100644 --- a/smart-contracts/test/diamond/facets/Marketplace.test.ts +++ b/smart-contracts/test/diamond/facets/Marketplace.test.ts @@ -41,7 +41,7 @@ describe('Marketplace', () => { deployFacetProviderRegistry(diamond), deployFacetModelRegistry(diamond), deployFacetSessionRouter(diamond, OWNER), - deployFacetMarketplace(diamond, token), + deployFacetMarketplace(diamond, token, wei(0.0001), wei(900)), ]); await token.transfer(SECOND, wei(1000)); @@ -67,7 +67,7 @@ describe('Marketplace', () => { expect(await marketplace.getToken()).to.eq(await token.getAddress()); }); it('should revert if try to call init function twice', async () => { - await expect(marketplace.__Marketplace_init(token)).to.be.rejectedWith( + await expect(marketplace.__Marketplace_init(token, wei(0.001), wei(0.002))).to.be.rejectedWith( 'Initializable: contract is already initialized', ); }); @@ -89,6 +89,33 @@ describe('Marketplace', () => { }); }); + describe('#setMinMaxBidPricePerSecond', async () => { + it('should set min and max price per second', async () => { + await expect(marketplace.setMinMaxBidPricePerSecond(wei(1), wei(2))) + .to.emit(marketplace, 'MarketplaceBidMinMaxPriceUpdated') + .withArgs(wei(1), wei(2)); + + expect(await marketplace.getMinMaxBidPricePerSecond()).deep.eq([wei(1), wei(2)]); + }); + it('should throw error when caller is not an owner', async () => { + await expect( + marketplace.connect(SECOND).setMinMaxBidPricePerSecond(wei(1), wei(2)), + ).to.be.revertedWithCustomError(diamond, 'OwnableUnauthorizedAccount'); + }); + it('should throw error when min price is zero', async () => { + await expect(marketplace.setMinMaxBidPricePerSecond(wei(0), wei(2))).to.be.revertedWithCustomError( + marketplace, + 'MarketplaceBidMinPricePerSecondIsZero', + ); + }); + it('should throw error when min price greater then max price', async () => { + await expect(marketplace.setMinMaxBidPricePerSecond(wei(3), wei(2))).to.be.revertedWithCustomError( + marketplace, + 'MarketplaceBidMinPricePerSecondIsInvalid', + ); + }); + }); + describe('#postModelBid', async () => { beforeEach(async () => { await marketplace.setMarketplaceBidFee(wei(1)); diff --git a/smart-contracts/test/diamond/facets/ModelRegistry.test.ts b/smart-contracts/test/diamond/facets/ModelRegistry.test.ts index 01fe0f26..04bddbc0 100644 --- a/smart-contracts/test/diamond/facets/ModelRegistry.test.ts +++ b/smart-contracts/test/diamond/facets/ModelRegistry.test.ts @@ -40,7 +40,7 @@ describe('ModelRegistry', () => { deployFacetProviderRegistry(diamond), deployFacetModelRegistry(diamond), deployFacetSessionRouter(diamond, OWNER), - deployFacetMarketplace(diamond, token), + deployFacetMarketplace(diamond, token, wei(0.0001), wei(900)), ]); await token.transfer(SECOND, wei(1000)); diff --git a/smart-contracts/test/diamond/facets/ProviderRegistry.test.ts b/smart-contracts/test/diamond/facets/ProviderRegistry.test.ts index bfc82b92..9a392db2 100644 --- a/smart-contracts/test/diamond/facets/ProviderRegistry.test.ts +++ b/smart-contracts/test/diamond/facets/ProviderRegistry.test.ts @@ -41,7 +41,7 @@ describe('ProviderRegistry', () => { deployFacetProviderRegistry(diamond), deployFacetModelRegistry(diamond), deployFacetSessionRouter(diamond, OWNER), - deployFacetMarketplace(diamond, token), + deployFacetMarketplace(diamond, token, wei(0.0001), wei(900)), ]); await token.transfer(PROVIDER, wei(1000)); diff --git a/smart-contracts/test/diamond/facets/SessionRouter.test.ts b/smart-contracts/test/diamond/facets/SessionRouter.test.ts index d5b6f0da..cb994f01 100644 --- a/smart-contracts/test/diamond/facets/SessionRouter.test.ts +++ b/smart-contracts/test/diamond/facets/SessionRouter.test.ts @@ -47,7 +47,7 @@ describe('SessionRouter', () => { deployFacetProviderRegistry(diamond), deployFacetModelRegistry(diamond), deployFacetSessionRouter(diamond, FUNDING), - deployFacetMarketplace(diamond, token), + deployFacetMarketplace(diamond, token, wei(0.0001), wei(900)), ]); await token.transfer(SECOND, wei(10000)); diff --git a/smart-contracts/test/helpers/deployers/diamond/facets/marketplace.ts b/smart-contracts/test/helpers/deployers/diamond/facets/marketplace.ts index 49ab75e6..33229561 100644 --- a/smart-contracts/test/helpers/deployers/diamond/facets/marketplace.ts +++ b/smart-contracts/test/helpers/deployers/diamond/facets/marketplace.ts @@ -10,7 +10,12 @@ import { } from '@/generated-types/ethers'; import { FacetAction } from '@/test/helpers/deployers/diamond/lumerin-diamond'; -export const deployFacetMarketplace = async (diamond: LumerinDiamond, token: MorpheusToken): Promise => { +export const deployFacetMarketplace = async ( + diamond: LumerinDiamond, + token: MorpheusToken, + bidMinPrice: bigint, + bidMaxPrice: bigint, +): Promise => { let facet: Marketplace; const factory = await ethers.getContractFactory('Marketplace'); @@ -34,7 +39,7 @@ export const deployFacetMarketplace = async (diamond: LumerinDiamond, token: Mor ]); facet = facet.attach(diamond.target) as Marketplace; - await facet.__Marketplace_init(token); + await facet.__Marketplace_init(token, bidMinPrice, bidMaxPrice); return facet; }; From 5f58b393a9a5c8099deef8f627bb2d2f92fd906a Mon Sep 17 00:00:00 2001 From: Oleksandr Date: Fri, 25 Oct 2024 13:33:31 +0300 Subject: [PATCH 9/9] fix deployment --- smart-contracts/deploy/1_full_protocol.migration.ts | 8 +++++++- smart-contracts/deploy/data/config_arbitrum_sepolia.json | 3 +++ smart-contracts/deploy/helpers/config-parser.ts | 2 ++ smart-contracts/hardhat.config.ts | 6 +++--- 4 files changed, 15 insertions(+), 4 deletions(-) diff --git a/smart-contracts/deploy/1_full_protocol.migration.ts b/smart-contracts/deploy/1_full_protocol.migration.ts index 55bb994d..ac3b6c87 100644 --- a/smart-contracts/deploy/1_full_protocol.migration.ts +++ b/smart-contracts/deploy/1_full_protocol.migration.ts @@ -91,7 +91,11 @@ module.exports = async function (deployer: Deployer) { modelRegistryFacet = modelRegistryFacet.attach(lumerinDiamond.target) as ModelRegistry; await modelRegistryFacet.__ModelRegistry_init(); marketplaceFacet = marketplaceFacet.attach(lumerinDiamond.target) as Marketplace; - await marketplaceFacet.__Marketplace_init(config.MOR); + await marketplaceFacet.__Marketplace_init( + config.MOR, + config.marketplaceMinBidPricePerSecond, + config.marketplaceMaxBidPricePerSecond, + ); sessionRouterFacet = sessionRouterFacet.attach(lumerinDiamond.target) as SessionRouter; await sessionRouterFacet.__SessionRouter_init(config.fundingAccount, 7 * DAY, config.pools); @@ -99,6 +103,8 @@ module.exports = async function (deployer: Deployer) { await modelRegistryFacet.modelSetMinStake(config.modelMinStake); await marketplaceFacet.setMarketplaceBidFee(config.marketplaceBidFee); + // TODO: add allowance from the treasury + Reporter.reportContracts( ['Lumerin Diamond', await lumerinDiamond.getAddress()], ['Linear Distribution Interval Decrease Library', await ldid.getAddress()], diff --git a/smart-contracts/deploy/data/config_arbitrum_sepolia.json b/smart-contracts/deploy/data/config_arbitrum_sepolia.json index 5d63728a..2436772e 100644 --- a/smart-contracts/deploy/data/config_arbitrum_sepolia.json +++ b/smart-contracts/deploy/data/config_arbitrum_sepolia.json @@ -4,6 +4,9 @@ "providerMinStake": "200000000000000000", "modelMinStake": "100000000000000000", "marketplaceBidFee": "300000000000000000", + "marketplaceMinBidPricePerSecond": "5000000000000000", + "marketplaceMaxBidPricePerSecond": "20000000000000000000", + "pools": [ { "payoutStart": 1707393600, diff --git a/smart-contracts/deploy/helpers/config-parser.ts b/smart-contracts/deploy/helpers/config-parser.ts index 8c7bc0fb..a17a6f71 100644 --- a/smart-contracts/deploy/helpers/config-parser.ts +++ b/smart-contracts/deploy/helpers/config-parser.ts @@ -9,6 +9,8 @@ export type Config = { providerMinStake: string; modelMinStake: string; marketplaceBidFee: string; + marketplaceMinBidPricePerSecond: string; + marketplaceMaxBidPricePerSecond: string; }; export function parseConfig(): Config { diff --git a/smart-contracts/hardhat.config.ts b/smart-contracts/hardhat.config.ts index 953250c8..abef6980 100644 --- a/smart-contracts/hardhat.config.ts +++ b/smart-contracts/hardhat.config.ts @@ -36,9 +36,9 @@ const config: HardhatUserConfig = { // auto: true, // interval: 10_000, // }, - // forking: { - // url: `https://arbitrum-sepolia.infura.io/v3/${process.env.INFURA_KEY}`, - // }, + forking: { + url: `https://arbitrum-sepolia.infura.io/v3/${process.env.INFURA_KEY}`, + }, }, localhost: { url: 'http://127.0.0.1:8545',