From 1430d56258b4e814b388e497320fd76354bfb478 Mon Sep 17 00:00:00 2001 From: quaq <56312047+0x0aa0@users.noreply.github.com> Date: Thu, 12 Dec 2024 11:10:22 -0600 Subject: [PATCH 1/8] Payments storage packing (#942) --- contracts/script/EigenDADeployer.s.sol | 38 +++++++- contracts/src/core/EigenDAServiceManager.sol | 6 +- .../src/core/EigenDAServiceManagerStorage.sol | 8 +- contracts/src/interfaces/IPaymentVault.sol | 40 ++++---- contracts/src/payments/PaymentVault.sol | 69 ++++++------- .../src/payments/PaymentVaultStorage.sol | 26 ++--- contracts/test/rollup/MockRollup.t.sol | 4 +- contracts/test/unit/EigenDABlobUtils.t.sol | 6 +- .../test/unit/EigenDAServiceManagerUnit.t.sol | 4 +- contracts/test/unit/PaymentVaultUnit.t.sol | 96 ++++++++++--------- 10 files changed, 179 insertions(+), 118 deletions(-) diff --git a/contracts/script/EigenDADeployer.s.sol b/contracts/script/EigenDADeployer.s.sol index 9c63cf4c31..1b24709282 100644 --- a/contracts/script/EigenDADeployer.s.sol +++ b/contracts/script/EigenDADeployer.s.sol @@ -22,9 +22,10 @@ import {IEigenDAThresholdRegistry} from "../src/interfaces/IEigenDAThresholdRegi import {IEigenDABatchMetadataStorage} from "../src/interfaces/IEigenDABatchMetadataStorage.sol"; import {IEigenDASignatureVerifier} from "../src/interfaces/IEigenDASignatureVerifier.sol"; import {IEigenDARelayRegistry} from "../src/interfaces/IEigenDARelayRegistry.sol"; +import {IPaymentVault} from "../src/interfaces/IPaymentVault.sol"; +import {PaymentVault} from "../src/payments/PaymentVault.sol"; import {EigenDARelayRegistry} from "../src/core/EigenDARelayRegistry.sol"; import {ISocketRegistry, SocketRegistry} from "eigenlayer-middleware/SocketRegistry.sol"; - import {DeployOpenEigenLayer, ProxyAdmin, ERC20PresetFixedSupply, TransparentUpgradeableProxy, IPauserRegistry} from "./DeployOpenEigenLayer.s.sol"; import "forge-std/Test.sol"; import "forge-std/Script.sol"; @@ -49,6 +50,7 @@ contract EigenDADeployer is DeployOpenEigenLayer { IStakeRegistry public stakeRegistry; ISocketRegistry public socketRegistry; OperatorStateRetriever public operatorStateRetriever; + IPaymentVault public paymentVault; EigenDARelayRegistry public eigenDARelayRegistry; BLSApkRegistry public apkRegistryImplementation; @@ -59,6 +61,14 @@ contract EigenDADeployer is DeployOpenEigenLayer { EigenDAThresholdRegistry public eigenDAThresholdRegistryImplementation; EigenDARelayRegistry public eigenDARelayRegistryImplementation; ISocketRegistry public socketRegistryImplementation; + IPaymentVault public paymentVaultImplementation; + + uint64 _minNumSymbols = 4096; + uint64 _pricePerSymbol = 0.4470 gwei; + uint64 _priceUpdateCooldown = 1; + uint64 _globalSymbolsPerPeriod = 131072; + uint64 _reservationPeriodInterval = 300; + uint64 _globalRatePeriodInterval = 30; struct AddressConfig { address eigenLayerCommunityMultisig; @@ -135,6 +145,29 @@ contract EigenDADeployer is DeployOpenEigenLayer { address(new TransparentUpgradeableProxy(address(emptyContract), address(eigenDAProxyAdmin), "")) ); + { + paymentVault = IPaymentVault( + address(new TransparentUpgradeableProxy(address(emptyContract), address(eigenDAProxyAdmin), "")) + ); + + paymentVaultImplementation = new PaymentVault(); + + eigenDAProxyAdmin.upgradeAndCall( + TransparentUpgradeableProxy(payable(address(paymentVault))), + address(paymentVaultImplementation), + abi.encodeWithSelector( + PaymentVault.initialize.selector, + addressConfig.eigenDACommunityMultisig, + _minNumSymbols, + _pricePerSymbol, + _priceUpdateCooldown, + _globalSymbolsPerPeriod, + _reservationPeriodInterval, + _globalRatePeriodInterval + ) + ); + } + indexRegistryImplementation = new IndexRegistry( registryCoordinator ); @@ -222,7 +255,8 @@ contract EigenDADeployer is DeployOpenEigenLayer { registryCoordinator, stakeRegistry, eigenDAThresholdRegistry, - eigenDARelayRegistry + eigenDARelayRegistry, + paymentVault ); address[] memory confirmers = new address[](1); diff --git a/contracts/src/core/EigenDAServiceManager.sol b/contracts/src/core/EigenDAServiceManager.sol index fcbfb2895a..029785eb7c 100644 --- a/contracts/src/core/EigenDAServiceManager.sol +++ b/contracts/src/core/EigenDAServiceManager.sol @@ -10,6 +10,7 @@ import {IRegistryCoordinator} from "eigenlayer-middleware/interfaces/IRegistryCo import {IStakeRegistry} from "eigenlayer-middleware/interfaces/IStakeRegistry.sol"; import {IEigenDAThresholdRegistry} from "../interfaces/IEigenDAThresholdRegistry.sol"; import {IEigenDARelayRegistry} from "../interfaces/IEigenDARelayRegistry.sol"; +import {IPaymentVault} from "../interfaces/IPaymentVault.sol"; import {EigenDAServiceManagerStorage} from "./EigenDAServiceManagerStorage.sol"; import {EigenDAHasher} from "../libraries/EigenDAHasher.sol"; import "../interfaces/IEigenDAStructs.sol"; @@ -40,11 +41,12 @@ contract EigenDAServiceManager is EigenDAServiceManagerStorage, ServiceManagerBa IRegistryCoordinator __registryCoordinator, IStakeRegistry __stakeRegistry, IEigenDAThresholdRegistry __eigenDAThresholdRegistry, - IEigenDARelayRegistry __eigenDARelayRegistry + IEigenDARelayRegistry __eigenDARelayRegistry, + IPaymentVault __paymentVault ) BLSSignatureChecker(__registryCoordinator) ServiceManagerBase(__avsDirectory, __rewardsCoordinator, __registryCoordinator, __stakeRegistry) - EigenDAServiceManagerStorage(__eigenDAThresholdRegistry, __eigenDARelayRegistry) + EigenDAServiceManagerStorage(__eigenDAThresholdRegistry, __eigenDARelayRegistry, __paymentVault) { _disableInitializers(); } diff --git a/contracts/src/core/EigenDAServiceManagerStorage.sol b/contracts/src/core/EigenDAServiceManagerStorage.sol index d98cdbbdb5..3a04661d99 100644 --- a/contracts/src/core/EigenDAServiceManagerStorage.sol +++ b/contracts/src/core/EigenDAServiceManagerStorage.sol @@ -4,6 +4,7 @@ pragma solidity ^0.8.9; import {IEigenDAServiceManager} from "../interfaces/IEigenDAServiceManager.sol"; import {IEigenDAThresholdRegistry} from "../interfaces/IEigenDAThresholdRegistry.sol"; import {IEigenDARelayRegistry} from "../interfaces/IEigenDARelayRegistry.sol"; +import {IPaymentVault} from "../interfaces/IPaymentVault.sol"; /** * @title Storage variables for the `EigenDAServiceManager` contract. @@ -39,13 +40,16 @@ abstract contract EigenDAServiceManagerStorage is IEigenDAServiceManager { IEigenDAThresholdRegistry public immutable eigenDAThresholdRegistry; IEigenDARelayRegistry public immutable eigenDARelayRegistry; - + IPaymentVault public immutable paymentVault; + constructor( IEigenDAThresholdRegistry _eigenDAThresholdRegistry, - IEigenDARelayRegistry _eigenDARelayRegistry + IEigenDARelayRegistry _eigenDARelayRegistry, + IPaymentVault _paymentVault ) { eigenDAThresholdRegistry = _eigenDAThresholdRegistry; eigenDARelayRegistry = _eigenDARelayRegistry; + paymentVault = _paymentVault; } /// @notice The current batchId diff --git a/contracts/src/interfaces/IPaymentVault.sol b/contracts/src/interfaces/IPaymentVault.sol index 399fe3f6ce..0fcfc32a93 100644 --- a/contracts/src/interfaces/IPaymentVault.sol +++ b/contracts/src/interfaces/IPaymentVault.sol @@ -7,28 +7,32 @@ interface IPaymentVault { uint64 symbolsPerSecond; // Number of symbols reserved per second uint64 startTimestamp; // timestamp of epoch where reservation begins uint64 endTimestamp; // timestamp of epoch where reservation ends - bytes quorumNumbers; // quorum numbers in an ordered bytes array - bytes quorumSplits; // quorum splits in a bytes array that correspond to the quorum numbers + bytes quorumNumbers; // quorum numbers in an ordered bytes array + bytes quorumSplits; // quorum splits in a bytes array that correspond to the quorum numbers + } + + struct OnDemandPayment { + uint80 totalDeposit; } /// @notice Emitted when a reservation is created or updated event ReservationUpdated(address indexed account, Reservation reservation); /// @notice Emitted when an on-demand payment is created or updated - event OnDemandPaymentUpdated(address indexed account, uint256 onDemandPayment, uint256 totalDeposit); - /// @notice Emitted when globalSymbolsPerBin is updated - event GlobalSymbolsPerBinUpdated(uint256 previousValue, uint256 newValue); - /// @notice Emitted when reservationBinInterval is updated - event ReservationBinIntervalUpdated(uint256 previousValue, uint256 newValue); - /// @notice Emitted when globalRateBinInterval is updated - event GlobalRateBinIntervalUpdated(uint256 previousValue, uint256 newValue); + event OnDemandPaymentUpdated(address indexed account, uint80 onDemandPayment, uint80 totalDeposit); + /// @notice Emitted when globalSymbolsPerPeriod is updated + event GlobalSymbolsPerPeriodUpdated(uint64 previousValue, uint64 newValue); + /// @notice Emitted when reservationPeriodInterval is updated + event ReservationPeriodIntervalUpdated(uint64 previousValue, uint64 newValue); + /// @notice Emitted when globalRatePeriodInterval is updated + event GlobalRatePeriodIntervalUpdated(uint64 previousValue, uint64 newValue); /// @notice Emitted when priceParams are updated event PriceParamsUpdated( - uint256 previousMinNumSymbols, - uint256 newMinNumSymbols, - uint256 previousPricePerSymbol, - uint256 newPricePerSymbol, - uint256 previousPriceUpdateCooldown, - uint256 newPriceUpdateCooldown + uint64 previousMinNumSymbols, + uint64 newMinNumSymbols, + uint64 previousPricePerSymbol, + uint64 newPricePerSymbol, + uint64 previousPriceUpdateCooldown, + uint64 newPriceUpdateCooldown ); /** @@ -54,8 +58,8 @@ interface IPaymentVault { function getReservations(address[] memory _accounts) external view returns (Reservation[] memory _reservations); /// @notice Fetches the current total on demand balance of an account - function getOnDemandAmount(address _account) external view returns (uint256); + function getOnDemandTotalDeposit(address _account) external view returns (uint80); /// @notice Fetches the current total on demand balances for a set of accounts - function getOnDemandAmounts(address[] memory _accounts) external view returns (uint256[] memory _payments); -} \ No newline at end of file + function getOnDemandTotalDeposits(address[] memory _accounts) external view returns (uint80[] memory _payments); +} diff --git a/contracts/src/payments/PaymentVault.sol b/contracts/src/payments/PaymentVault.sol index 23bd0205b6..9dae3cd17c 100644 --- a/contracts/src/payments/PaymentVault.sol +++ b/contracts/src/payments/PaymentVault.sol @@ -9,7 +9,7 @@ import {IERC20} from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; * @title Entrypoint for making reservations and on demand payments for EigenDA. * @author Layr Labs, Inc. **/ -contract PaymentVault is PaymentVaultStorage, OwnableUpgradeable { +contract PaymentVault is OwnableUpgradeable, PaymentVaultStorage { constructor() { _disableInitializers(); @@ -25,23 +25,23 @@ contract PaymentVault is PaymentVaultStorage, OwnableUpgradeable { function initialize( address _initialOwner, - uint256 _minNumSymbols, - uint256 _globalSymbolsPerBin, - uint256 _pricePerSymbol, - uint256 _reservationBinInterval, - uint256 _priceUpdateCooldown, - uint256 _globalRateBinInterval + uint64 _minNumSymbols, + uint64 _pricePerSymbol, + uint64 _priceUpdateCooldown, + uint64 _globalSymbolsPerPeriod, + uint64 _reservationPeriodInterval, + uint64 _globalRatePeriodInterval ) public initializer { _transferOwnership(_initialOwner); minNumSymbols = _minNumSymbols; - globalSymbolsPerBin = _globalSymbolsPerBin; pricePerSymbol = _pricePerSymbol; - reservationBinInterval = _reservationBinInterval; priceUpdateCooldown = _priceUpdateCooldown; - globalRateBinInterval = _globalRateBinInterval; + lastPriceUpdateTime = uint64(block.timestamp); - lastPriceUpdateTime = block.timestamp; + globalSymbolsPerPeriod = _globalSymbolsPerPeriod; + reservationPeriodInterval = _reservationPeriodInterval; + globalRatePeriodInterval = _globalRatePeriodInterval; } /** @@ -53,7 +53,7 @@ contract PaymentVault is PaymentVaultStorage, OwnableUpgradeable { address _account, Reservation memory _reservation ) external onlyOwner { - _checkQuorumSplit(_reservation.quorumNumbers, _reservation.quorumSplits); + _checkQuorumSplit(_reservation.quorumNumbers, _reservation.quorumSplits); require(_reservation.endTimestamp > _reservation.startTimestamp, "end timestamp must be greater than start timestamp"); reservations[_account] = _reservation; emit ReservationUpdated(_account, _reservation); @@ -64,39 +64,41 @@ contract PaymentVault is PaymentVaultStorage, OwnableUpgradeable { * @param _account is the address to deposit the funds for */ function depositOnDemand(address _account) external payable { - _deposit(_account, msg.value); + _deposit(_account, msg.value); } function setPriceParams( - uint256 _minNumSymbols, - uint256 _pricePerSymbol, - uint256 _priceUpdateCooldown + uint64 _minNumSymbols, + uint64 _pricePerSymbol, + uint64 _priceUpdateCooldown ) external onlyOwner { require(block.timestamp >= lastPriceUpdateTime + priceUpdateCooldown, "price update cooldown not surpassed"); + emit PriceParamsUpdated( minNumSymbols, _minNumSymbols, pricePerSymbol, _pricePerSymbol, priceUpdateCooldown, _priceUpdateCooldown ); + pricePerSymbol = _pricePerSymbol; minNumSymbols = _minNumSymbols; priceUpdateCooldown = _priceUpdateCooldown; - lastPriceUpdateTime = block.timestamp; + lastPriceUpdateTime = uint64(block.timestamp); } - function setGlobalSymbolsPerBin(uint256 _globalSymbolsPerBin) external onlyOwner { - emit GlobalSymbolsPerBinUpdated(globalSymbolsPerBin, _globalSymbolsPerBin); - globalSymbolsPerBin = _globalSymbolsPerBin; + function setGlobalSymbolsPerPeriod(uint64 _globalSymbolsPerPeriod) external onlyOwner { + emit GlobalSymbolsPerPeriodUpdated(globalSymbolsPerPeriod, _globalSymbolsPerPeriod); + globalSymbolsPerPeriod = _globalSymbolsPerPeriod; } - function setReservationBinInterval(uint256 _reservationBinInterval) external onlyOwner { - emit ReservationBinIntervalUpdated(reservationBinInterval, _reservationBinInterval); - reservationBinInterval = _reservationBinInterval; + function setReservationPeriodInterval(uint64 _reservationPeriodInterval) external onlyOwner { + emit ReservationPeriodIntervalUpdated(reservationPeriodInterval, _reservationPeriodInterval); + reservationPeriodInterval = _reservationPeriodInterval; } - function setGlobalRateBinInterval(uint256 _globalRateBinInterval) external onlyOwner { - emit GlobalRateBinIntervalUpdated(globalRateBinInterval, _globalRateBinInterval); - globalRateBinInterval = _globalRateBinInterval; + function setGlobalRatePeriodInterval(uint64 _globalRatePeriodInterval) external onlyOwner { + emit GlobalRatePeriodIntervalUpdated(globalRatePeriodInterval, _globalRatePeriodInterval); + globalRatePeriodInterval = _globalRatePeriodInterval; } function withdraw(uint256 _amount) external onlyOwner { @@ -116,8 +118,9 @@ contract PaymentVault is PaymentVaultStorage, OwnableUpgradeable { } function _deposit(address _account, uint256 _amount) internal { - onDemandPayments[_account] += _amount; - emit OnDemandPaymentUpdated(_account, _amount, onDemandPayments[_account]); + require(_amount <= type(uint80).max, "amount must be less than or equal to 80 bits"); + onDemandPayments[_account].totalDeposit += uint80(_amount); + emit OnDemandPaymentUpdated(_account, uint80(_amount), onDemandPayments[_account].totalDeposit); } /// @notice Fetches the current reservation for an account @@ -134,15 +137,15 @@ contract PaymentVault is PaymentVaultStorage, OwnableUpgradeable { } /// @notice Fetches the current total on demand balance of an account - function getOnDemandAmount(address _account) external view returns (uint256) { - return onDemandPayments[_account]; + function getOnDemandTotalDeposit(address _account) external view returns (uint80) { + return onDemandPayments[_account].totalDeposit; } /// @notice Fetches the current total on demand balances for a set of accounts - function getOnDemandAmounts(address[] memory _accounts) external view returns (uint256[] memory _payments) { - _payments = new uint256[](_accounts.length); + function getOnDemandTotalDeposits(address[] memory _accounts) external view returns (uint80[] memory _payments) { + _payments = new uint80[](_accounts.length); for(uint256 i; i < _accounts.length; ++i){ - _payments[i] = onDemandPayments[_accounts[i]]; + _payments[i] = onDemandPayments[_accounts[i]].totalDeposit; } } } \ No newline at end of file diff --git a/contracts/src/payments/PaymentVaultStorage.sol b/contracts/src/payments/PaymentVaultStorage.sol index 4c364c3c13..c8aae718df 100644 --- a/contracts/src/payments/PaymentVaultStorage.sol +++ b/contracts/src/payments/PaymentVaultStorage.sol @@ -6,25 +6,25 @@ import {IPaymentVault} from "../interfaces/IPaymentVault.sol"; abstract contract PaymentVaultStorage is IPaymentVault { /// @notice minimum chargeable size for on-demand payments - uint256 public minNumSymbols; + uint64 public minNumSymbols; /// @notice price per symbol in wei - uint256 public pricePerSymbol; + uint64 public pricePerSymbol; /// @notice cooldown period before the price can be updated again - uint256 public priceUpdateCooldown; - /// @notice maximum number of symbols to disperse per second network-wide for on-demand payments (applied to only ETH and EIGEN) - uint256 public globalSymbolsPerBin; - /// @notice reservation bin duration - uint256 public reservationBinInterval; - /// @notice global rate bin size - uint256 public globalRateBinInterval; - + uint64 public priceUpdateCooldown; /// @notice timestamp of the last price update - uint256 public lastPriceUpdateTime; + uint64 public lastPriceUpdateTime; + + /// @notice maximum number of symbols to disperse per second network-wide for on-demand payments (applied to only ETH and EIGEN) + uint64 public globalSymbolsPerPeriod; + /// @notice reservation period interval + uint64 public reservationPeriodInterval; + /// @notice global rate period interval + uint64 public globalRatePeriodInterval; /// @notice mapping from user address to current reservation mapping(address => Reservation) public reservations; /// @notice mapping from user address to current on-demand payment - mapping(address => uint256) public onDemandPayments; + mapping(address => OnDemandPayment) public onDemandPayments; - uint256[42] private __GAP; + uint256[46] private __GAP; } \ No newline at end of file diff --git a/contracts/test/rollup/MockRollup.t.sol b/contracts/test/rollup/MockRollup.t.sol index bb676e1a1b..aa37291b2e 100644 --- a/contracts/test/rollup/MockRollup.t.sol +++ b/contracts/test/rollup/MockRollup.t.sol @@ -17,6 +17,7 @@ import {IEigenDABatchMetadataStorage} from "../../src/interfaces/IEigenDABatchMe import {IEigenDASignatureVerifier} from "../../src/interfaces/IEigenDASignatureVerifier.sol"; import {OperatorStateRetriever} from "../../lib/eigenlayer-middleware/src/OperatorStateRetriever.sol"; import {IEigenDARelayRegistry} from "../../src/interfaces/IEigenDARelayRegistry.sol"; +import {IPaymentVault} from "../../src/interfaces/IPaymentVault.sol"; import {EigenDARelayRegistry} from "../../src/core/EigenDARelayRegistry.sol"; import {IRegistryCoordinator} from "../../lib/eigenlayer-middleware/src/interfaces/IRegistryCoordinator.sol"; import "../../src/interfaces/IEigenDAStructs.sol"; @@ -94,7 +95,8 @@ contract MockRollupTest is BLSMockAVSDeployer { registryCoordinator, stakeRegistry, eigenDAThresholdRegistry, - eigenDARelayRegistry + eigenDARelayRegistry, + IPaymentVault(address(0)) ); eigenDAThresholdRegistryImplementation = new EigenDAThresholdRegistry(); diff --git a/contracts/test/unit/EigenDABlobUtils.t.sol b/contracts/test/unit/EigenDABlobUtils.t.sol index 3eca831bc4..5e5d326c69 100644 --- a/contracts/test/unit/EigenDABlobUtils.t.sol +++ b/contracts/test/unit/EigenDABlobUtils.t.sol @@ -17,6 +17,9 @@ import {IEigenDASignatureVerifier} from "../../src/interfaces/IEigenDASignatureV import {IRegistryCoordinator} from "../../lib/eigenlayer-middleware/src/interfaces/IRegistryCoordinator.sol"; import {IEigenDARelayRegistry} from "../../src/interfaces/IEigenDARelayRegistry.sol"; import {EigenDARelayRegistry} from "../../src/core/EigenDARelayRegistry.sol"; +import {IPaymentVault} from "../../src/interfaces/IPaymentVault.sol"; +import {PaymentVault} from "../../src/payments/PaymentVault.sol"; + import "../../src/interfaces/IEigenDAStructs.sol"; import "forge-std/StdStorage.sol"; @@ -78,7 +81,8 @@ contract EigenDABlobUtilsUnit is BLSMockAVSDeployer { registryCoordinator, stakeRegistry, eigenDAThresholdRegistry, - eigenDARelayRegistry + eigenDARelayRegistry, + IPaymentVault(address(0)) ); eigenDAThresholdRegistryImplementation = new EigenDAThresholdRegistry(); diff --git a/contracts/test/unit/EigenDAServiceManagerUnit.t.sol b/contracts/test/unit/EigenDAServiceManagerUnit.t.sol index 669edf49f6..f550889bcb 100644 --- a/contracts/test/unit/EigenDAServiceManagerUnit.t.sol +++ b/contracts/test/unit/EigenDAServiceManagerUnit.t.sol @@ -14,6 +14,7 @@ import {IEigenDABatchMetadataStorage} from "../../src/interfaces/IEigenDABatchMe import {IEigenDASignatureVerifier} from "../../src/interfaces/IEigenDASignatureVerifier.sol"; import {IRegistryCoordinator} from "../../lib/eigenlayer-middleware/src/interfaces/IRegistryCoordinator.sol"; import {IEigenDARelayRegistry} from "../../src/interfaces/IEigenDARelayRegistry.sol"; +import {IPaymentVault} from "../../src/interfaces/IPaymentVault.sol"; import {EigenDARelayRegistry} from "../../src/core/EigenDARelayRegistry.sol"; import "../../src/interfaces/IEigenDAStructs.sol"; @@ -75,7 +76,8 @@ contract EigenDAServiceManagerUnit is BLSMockAVSDeployer { registryCoordinator, stakeRegistry, eigenDAThresholdRegistry, - eigenDARelayRegistry + eigenDARelayRegistry, + IPaymentVault(address(0)) ); address[] memory confirmers = new address[](1); diff --git a/contracts/test/unit/PaymentVaultUnit.t.sol b/contracts/test/unit/PaymentVaultUnit.t.sol index 7d0f386358..740c2e7a55 100644 --- a/contracts/test/unit/PaymentVaultUnit.t.sol +++ b/contracts/test/unit/PaymentVaultUnit.t.sol @@ -12,17 +12,17 @@ contract PaymentVaultUnit is Test { using stdStorage for StdStorage; event ReservationUpdated(address indexed account, IPaymentVault.Reservation reservation); - event OnDemandPaymentUpdated(address indexed account, uint256 onDemandPayment, uint256 totalDeposit); - event GlobalSymbolsPerBinUpdated(uint256 previousValue, uint256 newValue); - event ReservationBinIntervalUpdated(uint256 previousValue, uint256 newValue); - event GlobalRateBinIntervalUpdated(uint256 previousValue, uint256 newValue); + event OnDemandPaymentUpdated(address indexed account, uint80 onDemandPayment, uint80 totalDeposit); + event GlobalSymbolsPerPeriodUpdated(uint64 previousValue, uint64 newValue); + event ReservationPeriodIntervalUpdated(uint64 previousValue, uint64 newValue); + event GlobalRatePeriodIntervalUpdated(uint64 previousValue, uint64 newValue); event PriceParamsUpdated( - uint256 previousMinNumSymbols, - uint256 newMinNumSymbols, - uint256 previousPricePerSymbol, - uint256 newPricePerSymbol, - uint256 previousPriceUpdateCooldown, - uint256 newPriceUpdateCooldown + uint64 previousMinNumSymbols, + uint64 newMinNumSymbols, + uint64 previousPricePerSymbol, + uint64 newPricePerSymbol, + uint64 previousPriceUpdateCooldown, + uint64 newPriceUpdateCooldown ); PaymentVault paymentVault; @@ -34,13 +34,12 @@ contract PaymentVaultUnit is Test { address user = address(uint160(uint256(keccak256(abi.encodePacked("user"))))); address user2 = address(uint160(uint256(keccak256(abi.encodePacked("user2"))))); - uint256 minNumSymbols = 1; - uint256 globalSymbolsPerBin = 2; - uint256 pricePerSymbol = 3; - uint256 reservationBinInterval = 4; - uint256 globalRateBinInterval = 5; - - uint256 priceUpdateCooldown = 6 days; + uint64 minNumSymbols = 1; + uint64 globalSymbolsPerPeriod = 2; + uint64 pricePerSymbol = 3; + uint64 reservationPeriodInterval = 4; + uint64 globalRatePeriodInterval = 5; + uint64 priceUpdateCooldown = 6 days; bytes quorumNumbers = hex"0001"; bytes quorumSplits = hex"3232"; @@ -58,11 +57,11 @@ contract PaymentVaultUnit is Test { PaymentVault.initialize.selector, initialOwner, minNumSymbols, - globalSymbolsPerBin, pricePerSymbol, - reservationBinInterval, priceUpdateCooldown, - globalRateBinInterval + globalSymbolsPerPeriod, + reservationPeriodInterval, + globalRatePeriodInterval ) ) ) @@ -75,11 +74,11 @@ contract PaymentVaultUnit is Test { function test_initialize() public { require(paymentVault.owner() == initialOwner, "Owner is not set"); assertEq(paymentVault.minNumSymbols(), minNumSymbols); - assertEq(paymentVault.globalSymbolsPerBin(), globalSymbolsPerBin); + assertEq(paymentVault.globalSymbolsPerPeriod(), globalSymbolsPerPeriod); assertEq(paymentVault.pricePerSymbol(), pricePerSymbol); - assertEq(paymentVault.reservationBinInterval(), reservationBinInterval); + assertEq(paymentVault.reservationPeriodInterval(), reservationPeriodInterval); assertEq(paymentVault.priceUpdateCooldown(), priceUpdateCooldown); - assertEq(paymentVault.globalRateBinInterval(), globalRateBinInterval); + assertEq(paymentVault.globalRatePeriodInterval(), globalRatePeriodInterval); vm.expectRevert("Initializable: contract is already initialized"); paymentVault.initialize(address(0), 0, 0, 0, 0, 0, 0); @@ -167,13 +166,13 @@ contract PaymentVaultUnit is Test { emit OnDemandPaymentUpdated(user, 100 ether, 100 ether); vm.prank(user); paymentVault.depositOnDemand{value: 100 ether}(user); - assertEq(paymentVault.onDemandPayments(user), 100 ether); + assertEq(paymentVault.getOnDemandTotalDeposit(user), 100 ether); vm.expectEmit(address(paymentVault)); emit OnDemandPaymentUpdated(user, 100 ether, 200 ether); vm.prank(user); paymentVault.depositOnDemand{value: 100 ether}(user); - assertEq(paymentVault.onDemandPayments(user), 200 ether); + assertEq(paymentVault.getOnDemandTotalDeposit(user), 200 ether); } function test_depositOnDemand_forOtherUser() public { @@ -184,8 +183,8 @@ contract PaymentVaultUnit is Test { emit OnDemandPaymentUpdated(user2, 100 ether, 100 ether); vm.prank(user); paymentVault.depositOnDemand{value: 100 ether}(user2); - assertEq(paymentVault.onDemandPayments(user2), 100 ether); - assertEq(paymentVault.onDemandPayments(user), 0); + assertEq(paymentVault.getOnDemandTotalDeposit(user2), 100 ether); + assertEq(paymentVault.getOnDemandTotalDeposit(user), 0); } function test_depositOnDemand_fallback() public { @@ -195,7 +194,7 @@ contract PaymentVaultUnit is Test { emit OnDemandPaymentUpdated(user, 100 ether, 100 ether); vm.prank(user); payable(paymentVault).call{value: 100 ether}(hex"69"); - assertEq(paymentVault.onDemandPayments(user), 100 ether); + assertEq(paymentVault.getOnDemandTotalDeposit(user), 100 ether); } function test_depositOnDemand_recieve() public { @@ -205,7 +204,14 @@ contract PaymentVaultUnit is Test { emit OnDemandPaymentUpdated(user, 100 ether, 100 ether); vm.prank(user); payable(paymentVault).call{value: 100 ether}(""); - assertEq(paymentVault.onDemandPayments(user), 100 ether); + assertEq(paymentVault.getOnDemandTotalDeposit(user), 100 ether); + } + + function test_depositOnDemand_revertUint80Overflow() public { + vm.deal(user, uint256(type(uint80).max) + 1); + vm.expectRevert("amount must be less than or equal to 80 bits"); + vm.prank(user); + paymentVault.depositOnDemand{value: uint256(type(uint80).max) + 1}(user); } function test_setPriceParams() public { @@ -230,28 +236,28 @@ contract PaymentVaultUnit is Test { paymentVault.setPriceParams(minNumSymbols + 1, pricePerSymbol + 1, priceUpdateCooldown + 1); } - function test_setGlobalRateBinInterval() public { + function test_setGlobalRatePeriodInterval() public { vm.expectEmit(address(paymentVault)); - emit GlobalRateBinIntervalUpdated(globalRateBinInterval, globalRateBinInterval + 1); + emit GlobalRatePeriodIntervalUpdated(globalRatePeriodInterval, globalRatePeriodInterval + 1); vm.prank(initialOwner); - paymentVault.setGlobalRateBinInterval(globalRateBinInterval + 1); - assertEq(paymentVault.globalRateBinInterval(), globalRateBinInterval + 1); + paymentVault.setGlobalRatePeriodInterval(globalRatePeriodInterval + 1); + assertEq(paymentVault.globalRatePeriodInterval(), globalRatePeriodInterval + 1); } - function test_setGlobalSymbolsPerBin() public { + function test_setGlobalSymbolsPerPeriod() public { vm.expectEmit(address(paymentVault)); - emit GlobalSymbolsPerBinUpdated(globalSymbolsPerBin, globalSymbolsPerBin + 1); + emit GlobalSymbolsPerPeriodUpdated(globalSymbolsPerPeriod, globalSymbolsPerPeriod + 1); vm.prank(initialOwner); - paymentVault.setGlobalSymbolsPerBin(globalSymbolsPerBin + 1); - assertEq(paymentVault.globalSymbolsPerBin(), globalSymbolsPerBin + 1); + paymentVault.setGlobalSymbolsPerPeriod(globalSymbolsPerPeriod + 1); + assertEq(paymentVault.globalSymbolsPerPeriod(), globalSymbolsPerPeriod + 1); } - function test_setReservationBinInterval() public { + function test_setReservationPeriodInterval() public { vm.expectEmit(address(paymentVault)); - emit ReservationBinIntervalUpdated(reservationBinInterval, reservationBinInterval + 1); + emit ReservationPeriodIntervalUpdated(reservationPeriodInterval, reservationPeriodInterval + 1); vm.prank(initialOwner); - paymentVault.setReservationBinInterval(reservationBinInterval + 1); - assertEq(paymentVault.reservationBinInterval(), reservationBinInterval + 1); + paymentVault.setReservationPeriodInterval(reservationPeriodInterval + 1); + assertEq(paymentVault.reservationPeriodInterval(), reservationPeriodInterval + 1); } function test_withdraw() public { @@ -286,11 +292,11 @@ contract PaymentVaultUnit is Test { vm.expectRevert("Ownable: caller is not the owner"); paymentVault.setPriceParams(minNumSymbols + 1, pricePerSymbol + 1, priceUpdateCooldown + 1); vm.expectRevert("Ownable: caller is not the owner"); - paymentVault.setGlobalRateBinInterval(globalRateBinInterval + 1); + paymentVault.setGlobalRatePeriodInterval(globalRatePeriodInterval + 1); vm.expectRevert("Ownable: caller is not the owner"); - paymentVault.setGlobalSymbolsPerBin(globalSymbolsPerBin + 1); + paymentVault.setGlobalSymbolsPerPeriod(globalSymbolsPerPeriod + 1); vm.expectRevert("Ownable: caller is not the owner"); - paymentVault.setReservationBinInterval(reservationBinInterval + 1); + paymentVault.setReservationPeriodInterval(reservationPeriodInterval + 1); } function test_getReservations() public { @@ -335,7 +341,7 @@ contract PaymentVaultUnit is Test { accounts[0] = user; accounts[1] = user2; - uint256[] memory payments = paymentVault.getOnDemandAmounts(accounts); + uint80[] memory payments = paymentVault.getOnDemandTotalDeposits(accounts); assertEq(payments[0], 100 ether); assertEq(payments[1], 200 ether); } From be47a6c79d8d9060af8751a245203bb88475a050 Mon Sep 17 00:00:00 2001 From: hopeyen <60078528+hopeyen@users.noreply.github.com> Date: Fri, 13 Dec 2024 00:17:21 +0700 Subject: [PATCH 2/8] fix: cumulative payment dynamo db unit conversion (#979) --- core/meterer/meterer.go | 11 +++-- core/meterer/meterer_test.go | 52 +++++++++++----------- core/meterer/offchain_store.go | 81 +++++++++++++++++++--------------- 3 files changed, 78 insertions(+), 66 deletions(-) diff --git a/core/meterer/meterer.go b/core/meterer/meterer.go index 7195972c49..a27e7853e0 100644 --- a/core/meterer/meterer.go +++ b/core/meterer/meterer.go @@ -3,6 +3,7 @@ package meterer import ( "context" "fmt" + "math/big" "slices" "time" @@ -230,11 +231,11 @@ func (m *Meterer) ValidatePayment(ctx context.Context, header core.PaymentMetada return fmt.Errorf("failed to get relevant on-demand records: %w", err) } // the current request must increment cumulative payment by a magnitude sufficient to cover the blob size - if prevPmt+m.PaymentCharged(numSymbols) > header.CumulativePayment.Uint64() { + if prevPmt.Add(prevPmt, m.PaymentCharged(numSymbols)).Cmp(header.CumulativePayment) > 0 { return fmt.Errorf("insufficient cumulative payment increment") } // the current request must not break the payment magnitude for the next payment if the two requests were delivered out-of-order - if nextPmt != 0 && header.CumulativePayment.Uint64()+m.PaymentCharged(uint(nextPmtnumSymbols)) > nextPmt { + if nextPmt.Cmp(big.NewInt(0)) != 0 && header.CumulativePayment.Add(header.CumulativePayment, m.PaymentCharged(uint(nextPmtnumSymbols))).Cmp(nextPmt) > 0 { return fmt.Errorf("breaking cumulative payment invariants") } // check passed: blob can be safely inserted into the set of payments @@ -242,8 +243,10 @@ func (m *Meterer) ValidatePayment(ctx context.Context, header core.PaymentMetada } // PaymentCharged returns the chargeable price for a given data length -func (m *Meterer) PaymentCharged(numSymbols uint) uint64 { - return uint64(m.SymbolsCharged(numSymbols)) * uint64(m.ChainPaymentState.GetPricePerSymbol()) +func (m *Meterer) PaymentCharged(numSymbols uint) *big.Int { + symbolsCharged := big.NewInt(int64(m.SymbolsCharged(numSymbols))) + pricePerSymbol := big.NewInt(int64(m.ChainPaymentState.GetPricePerSymbol())) + return symbolsCharged.Mul(symbolsCharged, pricePerSymbol) } // SymbolsCharged returns the number of symbols charged for a given data length diff --git a/core/meterer/meterer_test.go b/core/meterer/meterer_test.go index 6e3f980db3..38132596b9 100644 --- a/core/meterer/meterer_test.go +++ b/core/meterer/meterer_test.go @@ -186,16 +186,16 @@ func TestMetererReservations(t *testing.T) { paymentChainState.On("GetActiveReservationByAccount", testifymock.Anything, testifymock.Anything).Return(&core.ActiveReservation{}, fmt.Errorf("reservation not found")) // test invalid quorom ID - header := createPaymentHeader(1, 0, accountID1) + header := createPaymentHeader(1, big.NewInt(0), accountID1) err := mt.MeterRequest(ctx, *header, 1000, []uint8{0, 1, 2}) assert.ErrorContains(t, err, "quorum number mismatch") // overwhelming bin overflow for empty bins - header = createPaymentHeader(reservationPeriod-1, 0, accountID2) + header = createPaymentHeader(reservationPeriod-1, big.NewInt(0), accountID2) err = mt.MeterRequest(ctx, *header, 10, quoromNumbers) assert.NoError(t, err) // overwhelming bin overflow for empty bins - header = createPaymentHeader(reservationPeriod-1, 0, accountID2) + header = createPaymentHeader(reservationPeriod-1, big.NewInt(0), accountID2) err = mt.MeterRequest(ctx, *header, 1000, quoromNumbers) assert.ErrorContains(t, err, "overflow usage exceeds bin limit") @@ -204,13 +204,13 @@ func TestMetererReservations(t *testing.T) { if err != nil { t.Fatalf("Failed to generate key: %v", err) } - header = createPaymentHeader(1, 0, crypto.PubkeyToAddress(unregisteredUser.PublicKey)) + header = createPaymentHeader(1, big.NewInt(0), crypto.PubkeyToAddress(unregisteredUser.PublicKey)) assert.NoError(t, err) err = mt.MeterRequest(ctx, *header, 1000, []uint8{0, 1, 2}) assert.ErrorContains(t, err, "failed to get active reservation by account: reservation not found") // test invalid bin index - header = createPaymentHeader(reservationPeriod, 0, accountID1) + header = createPaymentHeader(reservationPeriod, big.NewInt(0), accountID1) err = mt.MeterRequest(ctx, *header, 2000, quoromNumbers) assert.ErrorContains(t, err, "invalid bin index for reservation") @@ -218,7 +218,7 @@ func TestMetererReservations(t *testing.T) { symbolLength := uint(20) requiredLength := uint(21) // 21 should be charged for length of 20 since minNumSymbols is 3 for i := 0; i < 9; i++ { - header = createPaymentHeader(reservationPeriod, 0, accountID2) + header = createPaymentHeader(reservationPeriod, big.NewInt(0), accountID2) err = mt.MeterRequest(ctx, *header, symbolLength, quoromNumbers) assert.NoError(t, err) item, err := dynamoClient.GetItem(ctx, reservationTableName, commondynamodb.Key{ @@ -232,7 +232,7 @@ func TestMetererReservations(t *testing.T) { } // first over flow is allowed - header = createPaymentHeader(reservationPeriod, 0, accountID2) + header = createPaymentHeader(reservationPeriod, big.NewInt(0), accountID2) assert.NoError(t, err) err = mt.MeterRequest(ctx, *header, 25, quoromNumbers) assert.NoError(t, err) @@ -248,7 +248,7 @@ func TestMetererReservations(t *testing.T) { assert.Equal(t, strconv.Itoa(int(16)), item["BinUsage"].(*types.AttributeValueMemberN).Value) // second over flow - header = createPaymentHeader(reservationPeriod, 0, accountID2) + header = createPaymentHeader(reservationPeriod, big.NewInt(0), accountID2) assert.NoError(t, err) err = mt.MeterRequest(ctx, *header, 1, quoromNumbers) assert.ErrorContains(t, err, "bin has already been filled") @@ -275,18 +275,18 @@ func TestMetererOnDemand(t *testing.T) { if err != nil { t.Fatalf("Failed to generate key: %v", err) } - header := createPaymentHeader(reservationPeriod, 2, crypto.PubkeyToAddress(unregisteredUser.PublicKey)) + header := createPaymentHeader(reservationPeriod, big.NewInt(2), crypto.PubkeyToAddress(unregisteredUser.PublicKey)) assert.NoError(t, err) err = mt.MeterRequest(ctx, *header, 1000, quorumNumbers) assert.ErrorContains(t, err, "failed to get on-demand payment by account: payment not found") // test invalid quorom ID - header = createPaymentHeader(reservationPeriod, 1, accountID1) + header = createPaymentHeader(reservationPeriod, big.NewInt(2), accountID1) err = mt.MeterRequest(ctx, *header, 1000, []uint8{0, 1, 2}) assert.ErrorContains(t, err, "invalid quorum for On-Demand Request") // test insufficient cumulative payment - header = createPaymentHeader(reservationPeriod, 1, accountID1) + header = createPaymentHeader(reservationPeriod, big.NewInt(1), accountID1) err = mt.MeterRequest(ctx, *header, 1000, quorumNumbers) assert.ErrorContains(t, err, "insufficient cumulative payment increment") // No rollback after meter request @@ -300,7 +300,7 @@ func TestMetererOnDemand(t *testing.T) { // test duplicated cumulative payments symbolLength := uint(100) priceCharged := mt.PaymentCharged(symbolLength) - assert.Equal(t, uint64(102*mt.ChainPaymentState.GetPricePerSymbol()), priceCharged) + assert.Equal(t, big.NewInt(int64(102*mt.ChainPaymentState.GetPricePerSymbol())), priceCharged) header = createPaymentHeader(reservationPeriod, priceCharged, accountID2) err = mt.MeterRequest(ctx, *header, symbolLength, quorumNumbers) assert.NoError(t, err) @@ -310,24 +310,24 @@ func TestMetererOnDemand(t *testing.T) { // test valid payments for i := 1; i < 9; i++ { - header = createPaymentHeader(reservationPeriod, uint64(priceCharged)*uint64(i+1), accountID2) + header = createPaymentHeader(reservationPeriod, new(big.Int).Mul(priceCharged, big.NewInt(int64(i+1))), accountID2) err = mt.MeterRequest(ctx, *header, symbolLength, quorumNumbers) assert.NoError(t, err) } // test cumulative payment on-chain constraint - header = createPaymentHeader(reservationPeriod, 2023, accountID2) + header = createPaymentHeader(reservationPeriod, big.NewInt(2023), accountID2) err = mt.MeterRequest(ctx, *header, 1, quorumNumbers) assert.ErrorContains(t, err, "invalid on-demand payment: request claims a cumulative payment greater than the on-chain deposit") // test insufficient increment in cumulative payment - previousCumulativePayment := uint64(priceCharged) * uint64(9) + previousCumulativePayment := priceCharged.Mul(priceCharged, big.NewInt(9)) symbolLength = uint(2) priceCharged = mt.PaymentCharged(symbolLength) - header = createPaymentHeader(reservationPeriod, previousCumulativePayment+priceCharged-1, accountID2) + header = createPaymentHeader(reservationPeriod, big.NewInt(0).Add(previousCumulativePayment, big.NewInt(0).Sub(priceCharged, big.NewInt(1))), accountID2) err = mt.MeterRequest(ctx, *header, symbolLength, quorumNumbers) assert.ErrorContains(t, err, "invalid on-demand payment: insufficient cumulative payment increment") - previousCumulativePayment = previousCumulativePayment + priceCharged + previousCumulativePayment = big.NewInt(0).Add(previousCumulativePayment, priceCharged) // test cannot insert cumulative payment in out of order header = createPaymentHeader(reservationPeriod, mt.PaymentCharged(50), accountID2) @@ -342,7 +342,7 @@ func TestMetererOnDemand(t *testing.T) { assert.NoError(t, err) assert.Equal(t, numPrevRecords, len(result)) // test failed global rate limit (previously payment recorded: 2, global limit: 1009) - header = createPaymentHeader(reservationPeriod, previousCumulativePayment+mt.PaymentCharged(1010), accountID1) + header = createPaymentHeader(reservationPeriod, big.NewInt(0).Add(previousCumulativePayment, mt.PaymentCharged(1010)), accountID1) err = mt.MeterRequest(ctx, *header, 1010, quorumNumbers) assert.ErrorContains(t, err, "failed global rate limiting") // Correct rollback @@ -360,42 +360,42 @@ func TestMeterer_paymentCharged(t *testing.T) { symbolLength uint pricePerSymbol uint32 minNumSymbols uint32 - expected uint64 + expected *big.Int }{ { name: "Data length equal to min chargeable size", symbolLength: 1024, pricePerSymbol: 1, minNumSymbols: 1024, - expected: 1024, + expected: big.NewInt(1024), }, { name: "Data length less than min chargeable size", symbolLength: 512, pricePerSymbol: 1, minNumSymbols: 1024, - expected: 1024, + expected: big.NewInt(1024), }, { name: "Data length greater than min chargeable size", symbolLength: 2048, pricePerSymbol: 1, minNumSymbols: 1024, - expected: 2048, + expected: big.NewInt(2048), }, { name: "Large data length", symbolLength: 1 << 20, // 1 MB pricePerSymbol: 1, minNumSymbols: 1024, - expected: 1 << 20, + expected: big.NewInt(1 << 20), }, { name: "Price not evenly divisible by min chargeable size", symbolLength: 1536, pricePerSymbol: 1, minNumSymbols: 1024, - expected: 2048, + expected: big.NewInt(2048), }, } @@ -465,10 +465,10 @@ func TestMeterer_symbolsCharged(t *testing.T) { } } -func createPaymentHeader(reservationPeriod uint32, cumulativePayment uint64, accountID gethcommon.Address) *core.PaymentMetadata { +func createPaymentHeader(reservationPeriod uint32, cumulativePayment *big.Int, accountID gethcommon.Address) *core.PaymentMetadata { return &core.PaymentMetadata{ AccountID: accountID.Hex(), ReservationPeriod: reservationPeriod, - CumulativePayment: big.NewInt(int64(cumulativePayment)), + CumulativePayment: cumulativePayment, } } diff --git a/core/meterer/offchain_store.go b/core/meterer/offchain_store.go index f80ddd8910..3c3116f1b7 100644 --- a/core/meterer/offchain_store.go +++ b/core/meterer/offchain_store.go @@ -6,7 +6,6 @@ import ( "fmt" "math/big" "strconv" - "time" pb "github.com/Layr-Labs/eigenda/api/grpc/disperser/v2" commonaws "github.com/Layr-Labs/eigenda/common/aws" @@ -65,24 +64,6 @@ func NewOffchainStore( }, nil } -type ReservationBin struct { - AccountID string - ReservationPeriod uint32 - BinUsage uint32 - UpdatedAt time.Time -} - -type PaymentTuple struct { - CumulativePayment uint64 - DataLength uint32 -} - -type GlobalBin struct { - ReservationPeriod uint32 - BinUsage uint64 - UpdatedAt time.Time -} - func (s *OffchainStore) UpdateReservationBin(ctx context.Context, accountID string, reservationPeriod uint64, size uint64) (uint64, error) { key := map[string]types.AttributeValue{ "AccountID": &types.AttributeValueMemberS{Value: accountID}, @@ -185,7 +166,7 @@ func (s *OffchainStore) RemoveOnDemandPayment(ctx context.Context, accountID str // GetRelevantOnDemandRecords gets previous cumulative payment, next cumulative payment, blob size of next payment // The queries are done sequentially instead of one-go for efficient querying and would not cause race condition errors for honest requests -func (s *OffchainStore) GetRelevantOnDemandRecords(ctx context.Context, accountID string, cumulativePayment *big.Int) (uint64, uint64, uint32, error) { +func (s *OffchainStore) GetRelevantOnDemandRecords(ctx context.Context, accountID string, cumulativePayment *big.Int) (*big.Int, *big.Int, uint32, error) { // Fetch the largest entry smaller than the given cumulativePayment queryInput := &dynamodb.QueryInput{ TableName: aws.String(s.onDemandTableName), @@ -199,14 +180,23 @@ func (s *OffchainStore) GetRelevantOnDemandRecords(ctx context.Context, accountI } smallerResult, err := s.dynamoClient.QueryWithInput(ctx, queryInput) if err != nil { - return 0, 0, 0, fmt.Errorf("failed to query smaller payments for account: %w", err) + return nil, nil, 0, fmt.Errorf("failed to query smaller payments for account: %w", err) } - var prevPayment uint64 + prevPayment := big.NewInt(0) if len(smallerResult) > 0 { - prevPayment, err = strconv.ParseUint(smallerResult[0]["CumulativePayments"].(*types.AttributeValueMemberN).Value, 10, 64) - if err != nil { - return 0, 0, 0, fmt.Errorf("failed to parse previous payment: %w", err) + cumulativePaymentsAttr, ok := smallerResult[0]["CumulativePayments"] + if !ok { + return nil, nil, 0, fmt.Errorf("CumulativePayments field not found in result") + } + cumulativePaymentsNum, ok := cumulativePaymentsAttr.(*types.AttributeValueMemberN) + if !ok { + return nil, nil, 0, fmt.Errorf("CumulativePayments has invalid type") } + setPrevPayment, success := prevPayment.SetString(cumulativePaymentsNum.Value, 10) + if !success { + return nil, nil, 0, fmt.Errorf("failed to parse previous payment: %w", err) + } + prevPayment = setPrevPayment } // Fetch the smallest entry larger than the given cumulativePayment @@ -222,18 +212,36 @@ func (s *OffchainStore) GetRelevantOnDemandRecords(ctx context.Context, accountI } largerResult, err := s.dynamoClient.QueryWithInput(ctx, queryInput) if err != nil { - return 0, 0, 0, fmt.Errorf("failed to query the next payment for account: %w", err) + return nil, nil, 0, fmt.Errorf("failed to query the next payment for account: %w", err) } - var nextPayment uint64 - var nextDataLength uint32 + nextPayment := big.NewInt(0) + nextDataLength := uint32(0) if len(largerResult) > 0 { - nextPayment, err = strconv.ParseUint(largerResult[0]["CumulativePayments"].(*types.AttributeValueMemberN).Value, 10, 64) - if err != nil { - return 0, 0, 0, fmt.Errorf("failed to parse next payment: %w", err) + cumulativePaymentsAttr, ok := largerResult[0]["CumulativePayments"] + if !ok { + return nil, nil, 0, fmt.Errorf("CumulativePayments field not found in result") + } + cumulativePaymentsNum, ok := cumulativePaymentsAttr.(*types.AttributeValueMemberN) + if !ok { + return nil, nil, 0, fmt.Errorf("CumulativePayments has invalid type") + } + setNextPayment, success := nextPayment.SetString(cumulativePaymentsNum.Value, 10) + if !success { + return nil, nil, 0, fmt.Errorf("failed to parse previous payment: %w", err) } - dataLength, err := strconv.ParseUint(largerResult[0]["DataLength"].(*types.AttributeValueMemberN).Value, 10, 32) + nextPayment = setNextPayment + + dataLengthAttr, ok := largerResult[0]["DataLength"] + if !ok { + return nil, nil, 0, fmt.Errorf("DataLength field not found in result") + } + dataLengthNum, ok := dataLengthAttr.(*types.AttributeValueMemberN) + if !ok { + return nil, nil, 0, fmt.Errorf("DataLength has invalid type") + } + dataLength, err := strconv.ParseUint(dataLengthNum.Value, 10, 32) if err != nil { - return 0, 0, 0, fmt.Errorf("failed to parse blob size: %w", err) + return nil, nil, 0, fmt.Errorf("failed to parse data length: %w", err) } nextDataLength = uint32(dataLength) } @@ -290,12 +298,13 @@ func (s *OffchainStore) GetLargestCumulativePayment(ctx context.Context, account return nil, nil } - payment, err := strconv.ParseUint(payments[0]["CumulativePayments"].(*types.AttributeValueMemberN).Value, 10, 64) - if err != nil { + var payment *big.Int + _, success := payment.SetString(payments[0]["CumulativePayments"].(*types.AttributeValueMemberN).Value, 10) + if !success { return nil, fmt.Errorf("failed to parse payment: %w", err) } - return new(big.Int).SetUint64(payment), nil + return payment, nil } func parseBinRecord(bin map[string]types.AttributeValue) (*pb.BinRecord, error) { From 1f84a216fc71b821b51a4e9bd54b63b97feedd12 Mon Sep 17 00:00:00 2001 From: Ian Shim <100327837+ian-shim@users.noreply.github.com> Date: Thu, 12 Dec 2024 09:21:33 -0800 Subject: [PATCH 3/8] [v2] Update blob header hasher (#962) --- core/chainio.go | 4 +- core/eth/reader.go | 8 +- core/mock/writer.go | 6 +- core/v2/serialization.go | 205 ++++++++++++++-------------- core/v2/serialization_test.go | 11 +- core/v2/types.go | 2 +- disperser/dataapi/server_v2_test.go | 2 +- relay/relay_test_utils.go | 2 +- 8 files changed, 124 insertions(+), 116 deletions(-) diff --git a/core/chainio.go b/core/chainio.go index 121aacdeeb..6ced6ed4c9 100644 --- a/core/chainio.go +++ b/core/chainio.go @@ -107,10 +107,10 @@ type Reader interface { GetNumBlobVersions(ctx context.Context) (uint16, error) // GetVersionedBlobParams returns the blob version parameters for the given block number and blob version. - GetVersionedBlobParams(ctx context.Context, blobVersion uint8) (*BlobVersionParameters, error) + GetVersionedBlobParams(ctx context.Context, blobVersion uint16) (*BlobVersionParameters, error) // GetAllVersionedBlobParams returns the blob version parameters for all blob versions at the given block number. - GetAllVersionedBlobParams(ctx context.Context) (map[uint8]*BlobVersionParameters, error) + GetAllVersionedBlobParams(ctx context.Context) (map[uint16]*BlobVersionParameters, error) // GetActiveReservations returns active reservations (end timestamp > current timestamp) GetActiveReservations(ctx context.Context, accountIDs []gethcommon.Address) (map[gethcommon.Address]*ActiveReservation, error) diff --git a/core/eth/reader.go b/core/eth/reader.go index af13ae4410..b2e59dfb4a 100644 --- a/core/eth/reader.go +++ b/core/eth/reader.go @@ -647,7 +647,7 @@ func (t *Reader) GetNumBlobVersions(ctx context.Context) (uint16, error) { }) } -func (t *Reader) GetVersionedBlobParams(ctx context.Context, blobVersion uint8) (*core.BlobVersionParameters, error) { +func (t *Reader) GetVersionedBlobParams(ctx context.Context, blobVersion uint16) (*core.BlobVersionParameters, error) { params, err := t.bindings.EigenDAServiceManager.GetBlobParams(&bind.CallOpts{ Context: ctx, }, uint16(blobVersion)) @@ -661,7 +661,7 @@ func (t *Reader) GetVersionedBlobParams(ctx context.Context, blobVersion uint8) }, nil } -func (t *Reader) GetAllVersionedBlobParams(ctx context.Context) (map[uint8]*core.BlobVersionParameters, error) { +func (t *Reader) GetAllVersionedBlobParams(ctx context.Context) (map[uint16]*core.BlobVersionParameters, error) { if t.bindings.ThresholdRegistry == nil { return nil, errors.New("threshold registry not deployed") } @@ -671,8 +671,8 @@ func (t *Reader) GetAllVersionedBlobParams(ctx context.Context) (map[uint8]*core return nil, err } - res := make(map[uint8]*core.BlobVersionParameters) - for version := uint8(0); version < uint8(numBlobVersions); version++ { + res := make(map[uint16]*core.BlobVersionParameters) + for version := uint16(0); version < uint16(numBlobVersions); version++ { params, err := t.GetVersionedBlobParams(ctx, version) if err != nil && strings.Contains(err.Error(), "execution reverted") { break diff --git a/core/mock/writer.go b/core/mock/writer.go index a6939554ed..87384401bf 100644 --- a/core/mock/writer.go +++ b/core/mock/writer.go @@ -203,7 +203,7 @@ func (t *MockWriter) GetNumBlobVersions(ctx context.Context) (uint16, error) { return result.(uint16), args.Error(1) } -func (t *MockWriter) GetVersionedBlobParams(ctx context.Context, blobVersion uint8) (*core.BlobVersionParameters, error) { +func (t *MockWriter) GetVersionedBlobParams(ctx context.Context, blobVersion uint16) (*core.BlobVersionParameters, error) { args := t.Called() if args.Get(0) == nil { return nil, args.Error(1) @@ -212,13 +212,13 @@ func (t *MockWriter) GetVersionedBlobParams(ctx context.Context, blobVersion uin return result.(*core.BlobVersionParameters), args.Error(1) } -func (t *MockWriter) GetAllVersionedBlobParams(ctx context.Context) (map[uint8]*core.BlobVersionParameters, error) { +func (t *MockWriter) GetAllVersionedBlobParams(ctx context.Context) (map[uint16]*core.BlobVersionParameters, error) { args := t.Called() result := args.Get(0) if result == nil { return nil, args.Error(1) } - return result.(map[uint8]*core.BlobVersionParameters), args.Error(1) + return result.(map[uint16]*core.BlobVersionParameters), args.Error(1) } func (t *MockWriter) PubkeyHashToOperator(ctx context.Context, operatorId core.OperatorID) (gethcommon.Address, error) { diff --git a/core/v2/serialization.go b/core/v2/serialization.go index 5acff91152..fa36850b1a 100644 --- a/core/v2/serialization.go +++ b/core/v2/serialization.go @@ -24,92 +24,85 @@ type abiBlobCommitments struct { Commitment abiG1Commit LengthCommitment abiG2Commit LengthProof abiG2Commit - Length uint32 -} -type abiBlobHeader struct { - BlobVersion uint8 - BlobCommitments abiBlobCommitments - QuorumNumbers []byte - PaymentMetadataHash [32]byte + DataLength uint32 } -func blobHeaderArgMarshaling() []abi.ArgumentMarshaling { - return []abi.ArgumentMarshaling{ +func (b *BlobHeader) BlobKey() (BlobKey, error) { + versionType, err := abi.NewType("uint16", "", nil) + if err != nil { + return [32]byte{}, err + } + quorumNumbersType, err := abi.NewType("bytes", "", nil) + if err != nil { + return [32]byte{}, err + } + commitmentType, err := abi.NewType("tuple", "", []abi.ArgumentMarshaling{ { - Name: "blobVersion", - Type: "uint8", + Name: "commitment", + Type: "tuple", + Components: []abi.ArgumentMarshaling{ + { + Name: "X", + Type: "uint256", + }, + { + Name: "Y", + Type: "uint256", + }, + }, }, { - Name: "blobCommitments", + Name: "lengthCommitment", Type: "tuple", Components: []abi.ArgumentMarshaling{ { - Name: "commitment", - Type: "tuple", - Components: []abi.ArgumentMarshaling{ - { - Name: "X", - Type: "uint256", - }, - { - Name: "Y", - Type: "uint256", - }, - }, + Name: "X", + Type: "uint256[2]", }, { - Name: "lengthCommitment", - Type: "tuple", - Components: []abi.ArgumentMarshaling{ - { - Name: "X", - Type: "uint256[2]", - }, - { - Name: "Y", - Type: "uint256[2]", - }, - }, + Name: "Y", + Type: "uint256[2]", }, + }, + }, + { + Name: "lengthProof", + Type: "tuple", + Components: []abi.ArgumentMarshaling{ { - Name: "lengthProof", - Type: "tuple", - Components: []abi.ArgumentMarshaling{ - { - Name: "X", - Type: "uint256[2]", - }, - { - Name: "Y", - Type: "uint256[2]", - }, - }, + Name: "X", + Type: "uint256[2]", }, { - Name: "length", - Type: "uint32", + Name: "Y", + Type: "uint256[2]", }, }, }, { - Name: "quorumNumbers", - Type: "bytes", + Name: "dataLength", + Type: "uint32", + }, + }) + if err != nil { + return [32]byte{}, err + } + arguments := abi.Arguments{ + { + Type: versionType, }, { - Name: "paymentMetadataHash", - Type: "bytes32", + Type: quorumNumbersType, + }, + { + Type: commitmentType, }, } -} -func (b *BlobHeader) toABIStruct() (abiBlobHeader, error) { - paymentHash, err := b.PaymentMetadata.Hash() - if err != nil { - return abiBlobHeader{}, err - } - return abiBlobHeader{ - BlobVersion: uint8(b.BlobVersion), - BlobCommitments: abiBlobCommitments{ + packedBytes, err := arguments.Pack( + b.BlobVersion, + b.QuorumNumbers, + abiBlobCommitments{ Commitment: abiG1Commit{ X: b.BlobCommitments.Commitment.X.BigInt(new(big.Int)), Y: b.BlobCommitments.Commitment.Y.BigInt(new(big.Int)), @@ -134,41 +127,62 @@ func (b *BlobHeader) toABIStruct() (abiBlobHeader, error) { b.BlobCommitments.LengthProof.Y.A1.BigInt(new(big.Int)), }, }, - Length: uint32(b.BlobCommitments.Length), + DataLength: uint32(b.BlobCommitments.Length), }, - QuorumNumbers: b.QuorumNumbers, - PaymentMetadataHash: paymentHash, - }, nil -} + ) + if err != nil { + return [32]byte{}, err + } -func (b *BlobHeader) BlobKey() (BlobKey, error) { - blobHeaderType, err := abi.NewType("tuple", "", blobHeaderArgMarshaling()) + var headerHash [32]byte + hasher := sha3.NewLegacyKeccak256() + hasher.Write(packedBytes) + copy(headerHash[:], hasher.Sum(nil)[:32]) + + blobKeyType, err := abi.NewType("tuple", "", []abi.ArgumentMarshaling{ + { + Name: "blobHeaderHash", + Type: "bytes32", + }, + { + Name: "paymentMetadataHash", + Type: "bytes32", + }, + }) if err != nil { return [32]byte{}, err } - arguments := abi.Arguments{ + arguments = abi.Arguments{ { - Type: blobHeaderType, + Type: blobKeyType, }, } - s, err := b.toABIStruct() + paymentMetadataHash, err := b.PaymentMetadata.Hash() if err != nil { return [32]byte{}, err } - bytes, err := arguments.Pack(s) + s2 := struct { + BlobHeaderHash [32]byte + PaymentMetadataHash [32]byte + }{ + BlobHeaderHash: headerHash, + PaymentMetadataHash: paymentMetadataHash, + } + + packedBytes, err = arguments.Pack(s2) if err != nil { return [32]byte{}, err } - var headerHash [32]byte - hasher := sha3.NewLegacyKeccak256() - hasher.Write(bytes) - copy(headerHash[:], hasher.Sum(nil)[:32]) + var blobKey [32]byte + hasher = sha3.NewLegacyKeccak256() + hasher.Write(packedBytes) + copy(blobKey[:], hasher.Sum(nil)[:32]) - return headerHash, nil + return blobKey, nil } func (c *BlobCertificate) Hash() ([32]byte, error) { @@ -176,40 +190,31 @@ func (c *BlobCertificate) Hash() ([32]byte, error) { return [32]byte{}, fmt.Errorf("blob header is nil") } - blobCertType, err := abi.NewType("tuple", "", []abi.ArgumentMarshaling{ - { - Name: "blobHeader", - Type: "tuple", - Components: blobHeaderArgMarshaling(), - }, - { - Name: "relayKeys", - Type: "uint32[]", - }, - }) + blobKeyType, err := abi.NewType("bytes32", "", nil) + if err != nil { + return [32]byte{}, err + } + + relayKeysType, err := abi.NewType("uint32[]", "", nil) if err != nil { return [32]byte{}, err } arguments := abi.Arguments{ { - Type: blobCertType, + Type: blobKeyType, + }, + { + Type: relayKeysType, }, } - bh, err := c.BlobHeader.toABIStruct() + blobKey, err := c.BlobHeader.BlobKey() if err != nil { return [32]byte{}, err } - s := struct { - BlobHeader abiBlobHeader - RelayKeys []RelayKey - }{ - BlobHeader: bh, - RelayKeys: c.RelayKeys, - } - bytes, err := arguments.Pack(s) + bytes, err := arguments.Pack(blobKey, c.RelayKeys) if err != nil { return [32]byte{}, err } diff --git a/core/v2/serialization_test.go b/core/v2/serialization_test.go index 78e97d1904..4f77438ade 100644 --- a/core/v2/serialization_test.go +++ b/core/v2/serialization_test.go @@ -48,13 +48,14 @@ func TestBlobKeyFromHeader(t *testing.T) { AccountID: "0x123", ReservationPeriod: 5, CumulativePayment: big.NewInt(100), + Salt: 42, }, Signature: []byte{1, 2, 3}, } blobKey, err := bh.BlobKey() assert.NoError(t, err) - // 0x1354b29d9dd9a332959795d17f456c219566417fdbf1a7b4f5d118f5c2a36bbd verified in solidity - assert.Equal(t, "1354b29d9dd9a332959795d17f456c219566417fdbf1a7b4f5d118f5c2a36bbd", blobKey.Hex()) + // 0x22c9e31c3d79c7c4085b564113f488019cbae18198c9a4fc4ecd70a5742e8638 verified in solidity + assert.Equal(t, "22c9e31c3d79c7c4085b564113f488019cbae18198c9a4fc4ecd70a5742e8638", blobKey.Hex()) } func TestBatchHeaderHash(t *testing.T) { @@ -102,6 +103,7 @@ func TestBlobCertHash(t *testing.T) { AccountID: "0x123", ReservationPeriod: 5, CumulativePayment: big.NewInt(100), + Salt: 42, }, Signature: []byte{1, 2, 3}, }, @@ -110,8 +112,8 @@ func TestBlobCertHash(t *testing.T) { hash, err := blobCert.Hash() assert.NoError(t, err) - // 0xad938e477d0bc1f9f4e8de7c5cd837560bdbb2dc7094207a7ad53e7442611a43 verified in solidity - assert.Equal(t, "ad938e477d0bc1f9f4e8de7c5cd837560bdbb2dc7094207a7ad53e7442611a43", hex.EncodeToString(hash[:])) + // 0x182087a394c8aab23e8da107c820679333c1efee66fd4380ba283c0e4c09efd6 verified in solidity + assert.Equal(t, "182087a394c8aab23e8da107c820679333c1efee66fd4380ba283c0e4c09efd6", hex.EncodeToString(hash[:])) } func TestBlobCertSerialization(t *testing.T) { @@ -130,6 +132,7 @@ func TestBlobCertSerialization(t *testing.T) { AccountID: "0x123", ReservationPeriod: 5, CumulativePayment: big.NewInt(100), + Salt: 42, }, Signature: []byte{1, 2, 3}, }, diff --git a/core/v2/types.go b/core/v2/types.go index 90cb5524a4..f1c129a8df 100644 --- a/core/v2/types.go +++ b/core/v2/types.go @@ -14,7 +14,7 @@ import ( gethcommon "github.com/ethereum/go-ethereum/common" ) -type BlobVersion = uint8 +type BlobVersion = uint16 // Assignment contains information about the set of chunks that a specific node will receive type Assignment struct { diff --git a/disperser/dataapi/server_v2_test.go b/disperser/dataapi/server_v2_test.go index 1e1da41146..2fe042058b 100644 --- a/disperser/dataapi/server_v2_test.go +++ b/disperser/dataapi/server_v2_test.go @@ -190,7 +190,7 @@ func TestFetchBlobHandlerV2(t *testing.T) { assert.Equal(t, http.StatusOK, res.StatusCode) assert.Equal(t, "Queued", response.Status) - assert.Equal(t, uint8(0), response.BlobHeader.BlobVersion) + assert.Equal(t, uint16(0), response.BlobHeader.BlobVersion) assert.Equal(t, blobHeader.Signature, response.BlobHeader.Signature) assert.Equal(t, blobHeader.PaymentMetadata.AccountID, response.BlobHeader.PaymentMetadata.AccountID) assert.Equal(t, blobHeader.PaymentMetadata.ReservationPeriod, response.BlobHeader.PaymentMetadata.ReservationPeriod) diff --git a/relay/relay_test_utils.go b/relay/relay_test_utils.go index 375e823b6a..efaca396c8 100644 --- a/relay/relay_test_utils.go +++ b/relay/relay_test_utils.go @@ -183,7 +183,7 @@ func newMockChainReader() *coremock.MockWriter { return w } -func mockBlobParamsMap() map[uint8]*core.BlobVersionParameters { +func mockBlobParamsMap() map[v2.BlobVersion]*core.BlobVersionParameters { blobParams := &core.BlobVersionParameters{ NumChunks: 8192, CodingRate: 8, From 67658ace52ed434c4d505022faf44a327e22136b Mon Sep 17 00:00:00 2001 From: leopardracer <136604165+leopardracer@users.noreply.github.com> Date: Thu, 12 Dec 2024 20:42:52 +0200 Subject: [PATCH 4/8] fix: typos in documentation files (#989) --- api/docs/disperser.md | 2 +- api/proto/disperser/disperser.proto | 2 +- disperser/dataapi/server_test.go | 2 +- test/synthetic-test/synthetic_client_test.go | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/api/docs/disperser.md b/api/docs/disperser.md index 39f01e439a..7a7975e80a 100644 --- a/api/docs/disperser.md +++ b/api/docs/disperser.md @@ -369,7 +369,7 @@ Disperser defines the public APIs for dispersing blobs. | DisperseBlob | [DisperseBlobRequest](#disperser-DisperseBlobRequest) | [DisperseBlobReply](#disperser-DisperseBlobReply) | DisperseBlob accepts a single blob to be dispersed. This executes the dispersal async, i.e. it returns once the request is accepted. The client should use GetBlobStatus() API to poll the processing status of the blob. If DisperseBlob returns the following error codes: INVALID_ARGUMENT (400): request is invalid for a reason specified in the error msg. RESOURCE_EXHAUSTED (429): request is rate limited for the quorum specified in the error msg. user should retry after the specified duration. INTERNAL (500): serious error, user should NOT retry. | -| DisperseBlobAuthenticated | [AuthenticatedRequest](#disperser-AuthenticatedRequest) stream | [AuthenticatedReply](#disperser-AuthenticatedReply) stream | DisperseBlobAuthenticated is similar to DisperseBlob, except that it requires the client to authenticate itself via the AuthenticationData message. The protoco is as follows: 1. The client sends a DisperseBlobAuthenticated request with the DisperseBlobRequest message 2. The Disperser sends back a BlobAuthHeader message containing information for the client to verify and sign. 3. The client verifies the BlobAuthHeader and sends back the signed BlobAuthHeader in an AuthenticationData message. 4. The Disperser verifies the signature and returns a DisperseBlobReply message. | +| DisperseBlobAuthenticated | [AuthenticatedRequest](#disperser-AuthenticatedRequest) stream | [AuthenticatedReply](#disperser-AuthenticatedReply) stream | DisperseBlobAuthenticated is similar to DisperseBlob, except that it requires the client to authenticate itself via the AuthenticationData message. The protocol is as follows: 1. The client sends a DisperseBlobAuthenticated request with the DisperseBlobRequest message 2. The Disperser sends back a BlobAuthHeader message containing information for the client to verify and sign. 3. The client verifies the BlobAuthHeader and sends back the signed BlobAuthHeader in an AuthenticationData message. 4. The Disperser verifies the signature and returns a DisperseBlobReply message. | | GetBlobStatus | [BlobStatusRequest](#disperser-BlobStatusRequest) | [BlobStatusReply](#disperser-BlobStatusReply) | This API is meant to be polled for the blob status. | | RetrieveBlob | [RetrieveBlobRequest](#disperser-RetrieveBlobRequest) | [RetrieveBlobReply](#disperser-RetrieveBlobReply) | This retrieves the requested blob from the Disperser's backend. This is a more efficient way to retrieve blobs than directly retrieving from the DA Nodes (see detail about this approach in api/proto/retriever/retriever.proto). The blob should have been initially dispersed via this Disperser service for this API to work. | diff --git a/api/proto/disperser/disperser.proto b/api/proto/disperser/disperser.proto index ee537ee156..d4ca1faee4 100644 --- a/api/proto/disperser/disperser.proto +++ b/api/proto/disperser/disperser.proto @@ -18,7 +18,7 @@ service Disperser { rpc DisperseBlob(DisperseBlobRequest) returns (DisperseBlobReply) {} // DisperseBlobAuthenticated is similar to DisperseBlob, except that it requires the - // client to authenticate itself via the AuthenticationData message. The protoco is as follows: + // client to authenticate itself via the AuthenticationData message. The protocol is as follows: // 1. The client sends a DisperseBlobAuthenticated request with the DisperseBlobRequest message // 2. The Disperser sends back a BlobAuthHeader message containing information for the client to // verify and sign. diff --git a/disperser/dataapi/server_test.go b/disperser/dataapi/server_test.go index 05c23efc3f..50c5723954 100644 --- a/disperser/dataapi/server_test.go +++ b/disperser/dataapi/server_test.go @@ -1240,7 +1240,7 @@ func TestFetchDeregisteredOperatorOnline(t *testing.T) { } func TestFetchDeregisteredOperatorsMultipleOfflineOnline(t *testing.T) { - // Skipping this test as repported being flaky but could not reproduce it locally + // Skipping this test as reported being flaky but could not reproduce it locally t.Skip("Skipping testing in CI environment") r := setUpRouter() diff --git a/test/synthetic-test/synthetic_client_test.go b/test/synthetic-test/synthetic_client_test.go index 3854a214e9..4f3d1a8e94 100644 --- a/test/synthetic-test/synthetic_client_test.go +++ b/test/synthetic-test/synthetic_client_test.go @@ -297,7 +297,7 @@ func TestDisperseBlobEndToEnd(t *testing.T) { // For now log....later we can define a baseline value for this logger.Printf("Time to Disperse Blob %s", disperseBlobStopTime.String()) - // Set Confirmation DeaLine For Confirmation of Dispersed Blob + // Set Confirmation Deadline For Confirmation of Dispersed Blob // Update this to a minute over Batcher_Pull_Interval confirmationDeadline, err := time.ParseDuration(testSuite.BatcherPullInterval) From 289747788a1205717c30078888d1e956dbb5edee Mon Sep 17 00:00:00 2001 From: Ian Shim <100327837+ian-shim@users.noreply.github.com> Date: Thu, 12 Dec 2024 11:01:57 -0800 Subject: [PATCH 5/8] Update proto docs (#993) --- api/docs/disperser.html | 2 +- api/docs/eigenda-protos.html | 2 +- api/docs/eigenda-protos.md | 2 +- api/grpc/disperser/disperser_grpc.pb.go | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/api/docs/disperser.html b/api/docs/disperser.html index fba274aa78..4ea56b620d 100644 --- a/api/docs/disperser.html +++ b/api/docs/disperser.html @@ -1009,7 +1009,7 @@

Disperser

AuthenticatedRequest stream AuthenticatedReply stream

DisperseBlobAuthenticated is similar to DisperseBlob, except that it requires the -client to authenticate itself via the AuthenticationData message. The protoco is as follows: +client to authenticate itself via the AuthenticationData message. The protocol is as follows: 1. The client sends a DisperseBlobAuthenticated request with the DisperseBlobRequest message 2. The Disperser sends back a BlobAuthHeader message containing information for the client to verify and sign. diff --git a/api/docs/eigenda-protos.html b/api/docs/eigenda-protos.html index 987f3250d2..4e1a8cb48d 100644 --- a/api/docs/eigenda-protos.html +++ b/api/docs/eigenda-protos.html @@ -1891,7 +1891,7 @@

Disperser

AuthenticatedRequest stream AuthenticatedReply stream

DisperseBlobAuthenticated is similar to DisperseBlob, except that it requires the -client to authenticate itself via the AuthenticationData message. The protoco is as follows: +client to authenticate itself via the AuthenticationData message. The protocol is as follows: 1. The client sends a DisperseBlobAuthenticated request with the DisperseBlobRequest message 2. The Disperser sends back a BlobAuthHeader message containing information for the client to verify and sign. diff --git a/api/docs/eigenda-protos.md b/api/docs/eigenda-protos.md index 732b7dc77c..4baff482a1 100644 --- a/api/docs/eigenda-protos.md +++ b/api/docs/eigenda-protos.md @@ -721,7 +721,7 @@ Disperser defines the public APIs for dispersing blobs. | DisperseBlob | [DisperseBlobRequest](#disperser-DisperseBlobRequest) | [DisperseBlobReply](#disperser-DisperseBlobReply) | DisperseBlob accepts a single blob to be dispersed. This executes the dispersal async, i.e. it returns once the request is accepted. The client should use GetBlobStatus() API to poll the processing status of the blob. If DisperseBlob returns the following error codes: INVALID_ARGUMENT (400): request is invalid for a reason specified in the error msg. RESOURCE_EXHAUSTED (429): request is rate limited for the quorum specified in the error msg. user should retry after the specified duration. INTERNAL (500): serious error, user should NOT retry. | -| DisperseBlobAuthenticated | [AuthenticatedRequest](#disperser-AuthenticatedRequest) stream | [AuthenticatedReply](#disperser-AuthenticatedReply) stream | DisperseBlobAuthenticated is similar to DisperseBlob, except that it requires the client to authenticate itself via the AuthenticationData message. The protoco is as follows: 1. The client sends a DisperseBlobAuthenticated request with the DisperseBlobRequest message 2. The Disperser sends back a BlobAuthHeader message containing information for the client to verify and sign. 3. The client verifies the BlobAuthHeader and sends back the signed BlobAuthHeader in an AuthenticationData message. 4. The Disperser verifies the signature and returns a DisperseBlobReply message. | +| DisperseBlobAuthenticated | [AuthenticatedRequest](#disperser-AuthenticatedRequest) stream | [AuthenticatedReply](#disperser-AuthenticatedReply) stream | DisperseBlobAuthenticated is similar to DisperseBlob, except that it requires the client to authenticate itself via the AuthenticationData message. The protocol is as follows: 1. The client sends a DisperseBlobAuthenticated request with the DisperseBlobRequest message 2. The Disperser sends back a BlobAuthHeader message containing information for the client to verify and sign. 3. The client verifies the BlobAuthHeader and sends back the signed BlobAuthHeader in an AuthenticationData message. 4. The Disperser verifies the signature and returns a DisperseBlobReply message. | | GetBlobStatus | [BlobStatusRequest](#disperser-BlobStatusRequest) | [BlobStatusReply](#disperser-BlobStatusReply) | This API is meant to be polled for the blob status. | | RetrieveBlob | [RetrieveBlobRequest](#disperser-RetrieveBlobRequest) | [RetrieveBlobReply](#disperser-RetrieveBlobReply) | This retrieves the requested blob from the Disperser's backend. This is a more efficient way to retrieve blobs than directly retrieving from the DA Nodes (see detail about this approach in api/proto/retriever/retriever.proto). The blob should have been initially dispersed via this Disperser service for this API to work. | diff --git a/api/grpc/disperser/disperser_grpc.pb.go b/api/grpc/disperser/disperser_grpc.pb.go index c6bb719a2b..412d28c866 100644 --- a/api/grpc/disperser/disperser_grpc.pb.go +++ b/api/grpc/disperser/disperser_grpc.pb.go @@ -43,7 +43,7 @@ type DisperserClient interface { // INTERNAL (500): serious error, user should NOT retry. DisperseBlob(ctx context.Context, in *DisperseBlobRequest, opts ...grpc.CallOption) (*DisperseBlobReply, error) // DisperseBlobAuthenticated is similar to DisperseBlob, except that it requires the - // client to authenticate itself via the AuthenticationData message. The protoco is as follows: + // client to authenticate itself via the AuthenticationData message. The protocol is as follows: // 1. The client sends a DisperseBlobAuthenticated request with the DisperseBlobRequest message // 2. The Disperser sends back a BlobAuthHeader message containing information for the client to // verify and sign. @@ -146,7 +146,7 @@ type DisperserServer interface { // INTERNAL (500): serious error, user should NOT retry. DisperseBlob(context.Context, *DisperseBlobRequest) (*DisperseBlobReply, error) // DisperseBlobAuthenticated is similar to DisperseBlob, except that it requires the - // client to authenticate itself via the AuthenticationData message. The protoco is as follows: + // client to authenticate itself via the AuthenticationData message. The protocol is as follows: // 1. The client sends a DisperseBlobAuthenticated request with the DisperseBlobRequest message // 2. The Disperser sends back a BlobAuthHeader message containing information for the client to // verify and sign. From 6588b33d67495bfe177f4a554c6af475a93dbec8 Mon Sep 17 00:00:00 2001 From: hopeyen <60078528+hopeyen@users.noreply.github.com> Date: Fri, 13 Dec 2024 02:25:48 +0700 Subject: [PATCH 6/8] payments - global reservation uses interval config (#980) --- core/eth/reader.go | 4 ++-- core/meterer/meterer.go | 18 +++--------------- core/meterer/meterer_test.go | 2 +- core/meterer/offchain_store.go | 4 ++-- core/meterer/onchain_state.go | 6 +++--- core/mock/payment_state.go | 4 ++-- disperser/apiserver/server_v2_test.go | 1 + 7 files changed, 14 insertions(+), 25 deletions(-) diff --git a/core/eth/reader.go b/core/eth/reader.go index b2e59dfb4a..73c2789420 100644 --- a/core/eth/reader.go +++ b/core/eth/reader.go @@ -786,7 +786,7 @@ func (t *Reader) GetGlobalSymbolsPerSecond(ctx context.Context) (uint64, error) return globalSymbolsPerSecond.Uint64(), nil } -func (t *Reader) GetGlobalRateBinInterval(ctx context.Context) (uint64, error) { +func (t *Reader) GetGlobalRateBinInterval(ctx context.Context) (uint32, error) { if t.bindings.PaymentVault == nil { return 0, errors.New("payment vault not deployed") } @@ -796,7 +796,7 @@ func (t *Reader) GetGlobalRateBinInterval(ctx context.Context) (uint64, error) { if err != nil { return 0, err } - return globalRateBinInterval.Uint64(), nil + return uint32(globalRateBinInterval.Uint64()), nil } func (t *Reader) GetMinNumSymbols(ctx context.Context) (uint32, error) { diff --git a/core/meterer/meterer.go b/core/meterer/meterer.go index a27e7853e0..1f0e1c5aeb 100644 --- a/core/meterer/meterer.go +++ b/core/meterer/meterer.go @@ -259,23 +259,11 @@ func (m *Meterer) SymbolsCharged(numSymbols uint) uint32 { return uint32(core.RoundUpDivide(uint(numSymbols), uint(m.ChainPaymentState.GetMinNumSymbols()))) * m.ChainPaymentState.GetMinNumSymbols() } -// ValidateReservationPeriod checks if the provided bin index is valid -func (m *Meterer) ValidateGlobalReservationPeriod(header core.PaymentMetadata) (uint32, error) { - // Deterministic function: local clock -> index (1second intervals) - currentReservationPeriod := uint32(time.Now().Unix()) - - // Valid bin indexes are either the current bin or the previous bin (allow this second or prev sec) - if header.ReservationPeriod != currentReservationPeriod && header.ReservationPeriod != (currentReservationPeriod-1) { - return 0, fmt.Errorf("invalid bin index for on-demand request") - } - return currentReservationPeriod, nil -} - // IncrementBinUsage increments the bin usage atomically and checks for overflow func (m *Meterer) IncrementGlobalBinUsage(ctx context.Context, symbolsCharged uint64) error { - //TODO: edit globalIndex based on bin interval in a subsequent PR - globalIndex := uint64(time.Now().Unix()) - newUsage, err := m.OffchainStore.UpdateGlobalBin(ctx, globalIndex, symbolsCharged) + globalPeriod := GetReservationPeriod(uint64(time.Now().Unix()), m.ChainPaymentState.GetGlobalRateBinInterval()) + + newUsage, err := m.OffchainStore.UpdateGlobalBin(ctx, globalPeriod, symbolsCharged) if err != nil { return fmt.Errorf("failed to increment global bin usage: %w", err) } diff --git a/core/meterer/meterer_test.go b/core/meterer/meterer_test.go index 38132596b9..30a03c9737 100644 --- a/core/meterer/meterer_test.go +++ b/core/meterer/meterer_test.go @@ -171,7 +171,7 @@ func TestMetererReservations(t *testing.T) { ctx := context.Background() paymentChainState.On("GetReservationWindow", testifymock.Anything).Return(uint32(1), nil) paymentChainState.On("GetGlobalSymbolsPerSecond", testifymock.Anything).Return(uint64(1009), nil) - paymentChainState.On("GetGlobalRateBinInterval", testifymock.Anything).Return(uint64(1), nil) + paymentChainState.On("GetGlobalRateBinInterval", testifymock.Anything).Return(uint32(1), nil) paymentChainState.On("GetMinNumSymbols", testifymock.Anything).Return(uint32(3), nil) reservationPeriod := meterer.GetReservationPeriod(uint64(time.Now().Unix()), mt.ChainPaymentState.GetReservationWindow()) diff --git a/core/meterer/offchain_store.go b/core/meterer/offchain_store.go index 3c3116f1b7..6b213a495e 100644 --- a/core/meterer/offchain_store.go +++ b/core/meterer/offchain_store.go @@ -93,9 +93,9 @@ func (s *OffchainStore) UpdateReservationBin(ctx context.Context, accountID stri return binUsageValue, nil } -func (s *OffchainStore) UpdateGlobalBin(ctx context.Context, reservationPeriod uint64, size uint64) (uint64, error) { +func (s *OffchainStore) UpdateGlobalBin(ctx context.Context, reservationPeriod uint32, size uint64) (uint64, error) { key := map[string]types.AttributeValue{ - "ReservationPeriod": &types.AttributeValueMemberN{Value: strconv.FormatUint(reservationPeriod, 10)}, + "ReservationPeriod": &types.AttributeValueMemberN{Value: strconv.FormatUint(uint64(reservationPeriod), 10)}, } res, err := s.dynamoClient.IncrementBy(ctx, s.globalBinTableName, key, "BinUsage", size) diff --git a/core/meterer/onchain_state.go b/core/meterer/onchain_state.go index cdfaef457b..3a9ba34f3d 100644 --- a/core/meterer/onchain_state.go +++ b/core/meterer/onchain_state.go @@ -20,7 +20,7 @@ type OnchainPayment interface { GetOnDemandPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.OnDemandPayment, error) GetOnDemandQuorumNumbers(ctx context.Context) ([]uint8, error) GetGlobalSymbolsPerSecond() uint64 - GetGlobalRateBinInterval() uint64 + GetGlobalRateBinInterval() uint32 GetMinNumSymbols() uint32 GetPricePerSymbol() uint32 GetReservationWindow() uint32 @@ -42,7 +42,7 @@ type OnchainPaymentState struct { type PaymentVaultParams struct { GlobalSymbolsPerSecond uint64 - GlobalRateBinInterval uint64 + GlobalRateBinInterval uint32 MinNumSymbols uint32 PricePerSymbol uint32 ReservationWindow uint32 @@ -211,7 +211,7 @@ func (pcs *OnchainPaymentState) GetGlobalSymbolsPerSecond() uint64 { return pcs.PaymentVaultParams.Load().GlobalSymbolsPerSecond } -func (pcs *OnchainPaymentState) GetGlobalRateBinInterval() uint64 { +func (pcs *OnchainPaymentState) GetGlobalRateBinInterval() uint32 { return pcs.PaymentVaultParams.Load().GlobalRateBinInterval } diff --git a/core/mock/payment_state.go b/core/mock/payment_state.go index e4c89784d3..00c34b326e 100644 --- a/core/mock/payment_state.go +++ b/core/mock/payment_state.go @@ -62,9 +62,9 @@ func (m *MockOnchainPaymentState) GetGlobalSymbolsPerSecond() uint64 { return args.Get(0).(uint64) } -func (m *MockOnchainPaymentState) GetGlobalRateBinInterval() uint64 { +func (m *MockOnchainPaymentState) GetGlobalRateBinInterval() uint32 { args := m.Called() - return args.Get(0).(uint64) + return args.Get(0).(uint32) } func (m *MockOnchainPaymentState) GetMinNumSymbols() uint32 { diff --git a/disperser/apiserver/server_v2_test.go b/disperser/apiserver/server_v2_test.go index 4fa233d3d4..0bd8b5997e 100644 --- a/disperser/apiserver/server_v2_test.go +++ b/disperser/apiserver/server_v2_test.go @@ -443,6 +443,7 @@ func newTestServerV2(t *testing.T) *testComponents { mockState.On("GetReservationWindow", tmock.Anything).Return(uint32(1), nil) mockState.On("GetPricePerSymbol", tmock.Anything).Return(uint32(2), nil) mockState.On("GetGlobalSymbolsPerSecond", tmock.Anything).Return(uint64(1009), nil) + mockState.On("GetGlobalRateBinInterval", tmock.Anything).Return(uint32(1), nil) mockState.On("GetMinNumSymbols", tmock.Anything).Return(uint32(3), nil) now := uint64(time.Now().Unix()) From 06e88b311ef27618b4b878a78cbbb24358104ab9 Mon Sep 17 00:00:00 2001 From: quaq <56312047+0x0aa0@users.noreply.github.com> Date: Thu, 12 Dec 2024 15:40:59 -0600 Subject: [PATCH 7/8] update blob verifier (#992) --- contracts/src/core/EigenDABlobVerifier.sol | 2 +- contracts/src/interfaces/IEigenDAStructs.sol | 1 - .../EigenDABlobVerificationUtils.sol | 26 +++++++++---------- 3 files changed, 14 insertions(+), 15 deletions(-) diff --git a/contracts/src/core/EigenDABlobVerifier.sol b/contracts/src/core/EigenDABlobVerifier.sol index 34c93e9d6d..d89198e391 100644 --- a/contracts/src/core/EigenDABlobVerifier.sol +++ b/contracts/src/core/EigenDABlobVerifier.sol @@ -134,7 +134,7 @@ contract EigenDABlobVerifier is IEigenDABlobVerifier { return EigenDABlobVerificationUtils._getNonSignerStakesAndSignature( operatorStateRetriever, registryCoordinator, - signedBatch.attestation + signedBatch ); } diff --git a/contracts/src/interfaces/IEigenDAStructs.sol b/contracts/src/interfaces/IEigenDAStructs.sol index faa5eab441..5413929dba 100644 --- a/contracts/src/interfaces/IEigenDAStructs.sol +++ b/contracts/src/interfaces/IEigenDAStructs.sol @@ -103,7 +103,6 @@ struct Attestation { BN254.G1Point sigma; BN254.G2Point apkG2; uint32[] quorumNumbers; - uint32 referenceBlockNumber; } ///////////////////////// SIGNATURE VERIFIER /////////////////////////////// diff --git a/contracts/src/libraries/EigenDABlobVerificationUtils.sol b/contracts/src/libraries/EigenDABlobVerificationUtils.sol index 694c01d5d4..0a08de8844 100644 --- a/contracts/src/libraries/EigenDABlobVerificationUtils.sol +++ b/contracts/src/libraries/EigenDABlobVerificationUtils.sol @@ -307,7 +307,7 @@ library EigenDABlobVerificationUtils { NonSignerStakesAndSignature memory nonSignerStakesAndSignature = _getNonSignerStakesAndSignature( operatorStateRetriever, registryCoordinator, - signedBatch.attestation + signedBatch ); _verifyBlobV2ForQuorums( @@ -336,7 +336,7 @@ library EigenDABlobVerificationUtils { NonSignerStakesAndSignature memory nonSignerStakesAndSignature = _getNonSignerStakesAndSignature( operatorStateRetriever, registryCoordinator, - signedBatch.attestation + signedBatch ); _verifyBlobV2ForQuorumsForThresholds( @@ -354,30 +354,30 @@ library EigenDABlobVerificationUtils { function _getNonSignerStakesAndSignature( OperatorStateRetriever operatorStateRetriever, IRegistryCoordinator registryCoordinator, - Attestation memory attestation + SignedBatch memory signedBatch ) internal view returns (NonSignerStakesAndSignature memory nonSignerStakesAndSignature) { - bytes32[] memory nonSignerOperatorIds = new bytes32[](attestation.nonSignerPubkeys.length); - for (uint i = 0; i < attestation.nonSignerPubkeys.length; ++i) { - nonSignerOperatorIds[i] = BN254.hashG1Point(attestation.nonSignerPubkeys[i]); + bytes32[] memory nonSignerOperatorIds = new bytes32[](signedBatch.attestation.nonSignerPubkeys.length); + for (uint i = 0; i < signedBatch.attestation.nonSignerPubkeys.length; ++i) { + nonSignerOperatorIds[i] = BN254.hashG1Point(signedBatch.attestation.nonSignerPubkeys[i]); } bytes memory quorumNumbers; - for (uint i = 0; i < attestation.quorumNumbers.length; ++i) { - quorumNumbers = abi.encodePacked(quorumNumbers, uint8(attestation.quorumNumbers[i])); + for (uint i = 0; i < signedBatch.attestation.quorumNumbers.length; ++i) { + quorumNumbers = abi.encodePacked(quorumNumbers, uint8(signedBatch.attestation.quorumNumbers[i])); } OperatorStateRetriever.CheckSignaturesIndices memory checkSignaturesIndices = operatorStateRetriever.getCheckSignaturesIndices( registryCoordinator, - attestation.referenceBlockNumber, + signedBatch.batchHeader.referenceBlockNumber, quorumNumbers, nonSignerOperatorIds ); nonSignerStakesAndSignature.nonSignerQuorumBitmapIndices = checkSignaturesIndices.nonSignerQuorumBitmapIndices; - nonSignerStakesAndSignature.nonSignerPubkeys = attestation.nonSignerPubkeys; - nonSignerStakesAndSignature.quorumApks = attestation.quorumApks; - nonSignerStakesAndSignature.apkG2 = attestation.apkG2; - nonSignerStakesAndSignature.sigma = attestation.sigma; + nonSignerStakesAndSignature.nonSignerPubkeys = signedBatch.attestation.nonSignerPubkeys; + nonSignerStakesAndSignature.quorumApks = signedBatch.attestation.quorumApks; + nonSignerStakesAndSignature.apkG2 = signedBatch.attestation.apkG2; + nonSignerStakesAndSignature.sigma = signedBatch.attestation.sigma; nonSignerStakesAndSignature.quorumApkIndices = checkSignaturesIndices.quorumApkIndices; nonSignerStakesAndSignature.totalStakeIndices = checkSignaturesIndices.totalStakeIndices; nonSignerStakesAndSignature.nonSignerStakeIndices = checkSignaturesIndices.nonSignerStakeIndices; From 7b12ebf8ce6fbdb90f0c12b3a975973e916811ce Mon Sep 17 00:00:00 2001 From: hopeyen <60078528+hopeyen@users.noreply.github.com> Date: Fri, 13 Dec 2024 05:23:48 +0700 Subject: [PATCH 8/8] reservation timestamp check (#990) --- api/clients/accountant.go | 6 +-- api/clients/accountant_test.go | 18 +++---- api/docs/disperser_v2.html | 2 +- api/docs/disperser_v2.md | 2 +- api/docs/eigenda-protos.html | 2 +- api/docs/eigenda-protos.md | 2 +- api/grpc/disperser/v2/disperser_v2.pb.go | 2 +- api/proto/disperser/v2/disperser_v2.proto | 2 +- core/chainio.go | 8 +-- core/data.go | 29 +++++++---- core/data_test.go | 62 +++++++++++++++++++++++ core/eth/reader.go | 10 ++-- core/eth/utils.go | 6 +-- core/meterer/meterer.go | 23 +++++---- core/meterer/meterer_test.go | 37 ++++++++++---- core/meterer/onchain_state.go | 43 +++++----------- core/meterer/onchain_state_test.go | 10 ++-- core/mock/payment_state.go | 6 +-- core/mock/writer.go | 8 +-- disperser/apiserver/server_test.go | 2 +- disperser/apiserver/server_v2.go | 5 +- disperser/apiserver/server_v2_test.go | 2 +- test/integration_test.go | 6 +-- 23 files changed, 184 insertions(+), 109 deletions(-) diff --git a/api/clients/accountant.go b/api/clients/accountant.go index 045877bd7e..6b923d51b4 100644 --- a/api/clients/accountant.go +++ b/api/clients/accountant.go @@ -18,7 +18,7 @@ var requiredQuorums = []uint8{0, 1} type Accountant struct { // on-chain states accountID string - reservation *core.ActiveReservation + reservation *core.ReservedPayment onDemand *core.OnDemandPayment reservationWindow uint32 pricePerSymbol uint32 @@ -39,7 +39,7 @@ type BinRecord struct { Usage uint64 } -func NewAccountant(accountID string, reservation *core.ActiveReservation, onDemand *core.OnDemandPayment, reservationWindow uint32, pricePerSymbol uint32, minNumSymbols uint32, numBins uint32) *Accountant { +func NewAccountant(accountID string, reservation *core.ReservedPayment, onDemand *core.OnDemandPayment, reservationWindow uint32, pricePerSymbol uint32, minNumSymbols uint32, numBins uint32) *Accountant { //TODO: client storage; currently every instance starts fresh but on-chain or a small store makes more sense // Also client is currently responsible for supplying network params, we need to add RPC in order to be automatic // There's a subsequent PR that handles populating the accountant with on-chain state from the disperser @@ -65,7 +65,7 @@ func NewAccountant(accountID string, reservation *core.ActiveReservation, onDema // BlobPaymentInfo calculates and records payment information. The accountant // will attempt to use the active reservation first and check for quorum settings, // then on-demand if the reservation is not available. The returned values are -// bin index for reservation payments and cumulative payment for on-demand payments, +// reservation period for reservation payments and cumulative payment for on-demand payments, // and both fields are used to create the payment header and signature func (a *Accountant) BlobPaymentInfo(ctx context.Context, numSymbols uint64, quorumNumbers []uint8) (uint32, *big.Int, error) { now := time.Now().Unix() diff --git a/api/clients/accountant_test.go b/api/clients/accountant_test.go index c6dc3fa692..d28dd9f16b 100644 --- a/api/clients/accountant_test.go +++ b/api/clients/accountant_test.go @@ -18,7 +18,7 @@ const numBins = uint32(3) const salt = uint32(0) func TestNewAccountant(t *testing.T) { - reservation := &core.ActiveReservation{ + reservation := &core.ReservedPayment{ SymbolsPerSecond: 100, StartTimestamp: 100, EndTimestamp: 200, @@ -48,7 +48,7 @@ func TestNewAccountant(t *testing.T) { } func TestAccountBlob_Reservation(t *testing.T) { - reservation := &core.ActiveReservation{ + reservation := &core.ReservedPayment{ SymbolsPerSecond: 200, StartTimestamp: 100, EndTimestamp: 200, @@ -96,7 +96,7 @@ func TestAccountBlob_Reservation(t *testing.T) { } func TestAccountBlob_OnDemand(t *testing.T) { - reservation := &core.ActiveReservation{ + reservation := &core.ReservedPayment{ SymbolsPerSecond: 200, StartTimestamp: 100, EndTimestamp: 200, @@ -130,7 +130,7 @@ func TestAccountBlob_OnDemand(t *testing.T) { } func TestAccountBlob_InsufficientOnDemand(t *testing.T) { - reservation := &core.ActiveReservation{} + reservation := &core.ReservedPayment{} onDemand := &core.OnDemandPayment{ CumulativePayment: big.NewInt(500), } @@ -152,7 +152,7 @@ func TestAccountBlob_InsufficientOnDemand(t *testing.T) { } func TestAccountBlobCallSeries(t *testing.T) { - reservation := &core.ActiveReservation{ + reservation := &core.ReservedPayment{ SymbolsPerSecond: 200, StartTimestamp: 100, EndTimestamp: 200, @@ -200,7 +200,7 @@ func TestAccountBlobCallSeries(t *testing.T) { } func TestAccountBlob_BinRotation(t *testing.T) { - reservation := &core.ActiveReservation{ + reservation := &core.ReservedPayment{ SymbolsPerSecond: 1000, StartTimestamp: 100, EndTimestamp: 200, @@ -242,7 +242,7 @@ func TestAccountBlob_BinRotation(t *testing.T) { } func TestConcurrentBinRotationAndAccountBlob(t *testing.T) { - reservation := &core.ActiveReservation{ + reservation := &core.ReservedPayment{ SymbolsPerSecond: 1000, StartTimestamp: 100, EndTimestamp: 200, @@ -284,7 +284,7 @@ func TestConcurrentBinRotationAndAccountBlob(t *testing.T) { } func TestAccountBlob_ReservationWithOneOverflow(t *testing.T) { - reservation := &core.ActiveReservation{ + reservation := &core.ReservedPayment{ SymbolsPerSecond: 200, StartTimestamp: 100, EndTimestamp: 200, @@ -332,7 +332,7 @@ func TestAccountBlob_ReservationWithOneOverflow(t *testing.T) { } func TestAccountBlob_ReservationOverflowReset(t *testing.T) { - reservation := &core.ActiveReservation{ + reservation := &core.ReservedPayment{ SymbolsPerSecond: 1000, StartTimestamp: 100, EndTimestamp: 200, diff --git a/api/docs/disperser_v2.html b/api/docs/disperser_v2.html index 2435d12a47..03a4c5533b 100644 --- a/api/docs/disperser_v2.html +++ b/api/docs/disperser_v2.html @@ -632,7 +632,7 @@

GetPaymentStateRequest

bytes

Signature over the account ID -TODO: sign over a bin index or a nonce to mitigate signature replay attacks

+TODO: sign over a reservation period or a nonce to mitigate signature replay attacks

diff --git a/api/docs/disperser_v2.md b/api/docs/disperser_v2.md index 4a52f3ccb8..bd1bc66acc 100644 --- a/api/docs/disperser_v2.md +++ b/api/docs/disperser_v2.md @@ -210,7 +210,7 @@ GetPaymentStateRequest contains parameters to query the payment state of an acco | Field | Type | Label | Description | | ----- | ---- | ----- | ----------- | | account_id | [string](#string) | | | -| signature | [bytes](#bytes) | | Signature over the account ID TODO: sign over a bin index or a nonce to mitigate signature replay attacks | +| signature | [bytes](#bytes) | | Signature over the account ID TODO: sign over a reservation period or a nonce to mitigate signature replay attacks | diff --git a/api/docs/eigenda-protos.html b/api/docs/eigenda-protos.html index 4e1a8cb48d..9a6b53af5b 100644 --- a/api/docs/eigenda-protos.html +++ b/api/docs/eigenda-protos.html @@ -2303,7 +2303,7 @@

GetPaymentStateRequest

bytes

Signature over the account ID -TODO: sign over a bin index or a nonce to mitigate signature replay attacks

+TODO: sign over a reservation period or a nonce to mitigate signature replay attacks

diff --git a/api/docs/eigenda-protos.md b/api/docs/eigenda-protos.md index 4baff482a1..7f90353af8 100644 --- a/api/docs/eigenda-protos.md +++ b/api/docs/eigenda-protos.md @@ -912,7 +912,7 @@ GetPaymentStateRequest contains parameters to query the payment state of an acco | Field | Type | Label | Description | | ----- | ---- | ----- | ----------- | | account_id | [string](#string) | | | -| signature | [bytes](#bytes) | | Signature over the account ID TODO: sign over a bin index or a nonce to mitigate signature replay attacks | +| signature | [bytes](#bytes) | | Signature over the account ID TODO: sign over a reservation period or a nonce to mitigate signature replay attacks | diff --git a/api/grpc/disperser/v2/disperser_v2.pb.go b/api/grpc/disperser/v2/disperser_v2.pb.go index ffbfb3c178..e420648ee5 100644 --- a/api/grpc/disperser/v2/disperser_v2.pb.go +++ b/api/grpc/disperser/v2/disperser_v2.pb.go @@ -424,7 +424,7 @@ type GetPaymentStateRequest struct { AccountId string `protobuf:"bytes,1,opt,name=account_id,json=accountId,proto3" json:"account_id,omitempty"` // Signature over the account ID - // TODO: sign over a bin index or a nonce to mitigate signature replay attacks + // TODO: sign over a reservation period or a nonce to mitigate signature replay attacks Signature []byte `protobuf:"bytes,2,opt,name=signature,proto3" json:"signature,omitempty"` } diff --git a/api/proto/disperser/v2/disperser_v2.proto b/api/proto/disperser/v2/disperser_v2.proto index 3038fef7e8..fb6386c724 100644 --- a/api/proto/disperser/v2/disperser_v2.proto +++ b/api/proto/disperser/v2/disperser_v2.proto @@ -70,7 +70,7 @@ message BlobCommitmentReply { message GetPaymentStateRequest { string account_id = 1; // Signature over the account ID - // TODO: sign over a bin index or a nonce to mitigate signature replay attacks + // TODO: sign over a reservation period or a nonce to mitigate signature replay attacks bytes signature = 2; } diff --git a/core/chainio.go b/core/chainio.go index 6ced6ed4c9..e28572832a 100644 --- a/core/chainio.go +++ b/core/chainio.go @@ -112,11 +112,11 @@ type Reader interface { // GetAllVersionedBlobParams returns the blob version parameters for all blob versions at the given block number. GetAllVersionedBlobParams(ctx context.Context) (map[uint16]*BlobVersionParameters, error) - // GetActiveReservations returns active reservations (end timestamp > current timestamp) - GetActiveReservations(ctx context.Context, accountIDs []gethcommon.Address) (map[gethcommon.Address]*ActiveReservation, error) + // GetReservedPayments returns active reservations (end timestamp > current timestamp) + GetReservedPayments(ctx context.Context, accountIDs []gethcommon.Address) (map[gethcommon.Address]*ReservedPayment, error) - // GetActiveReservationByAccount returns active reservation by account ID - GetActiveReservationByAccount(ctx context.Context, accountID gethcommon.Address) (*ActiveReservation, error) + // GetReservedPaymentByAccount returns active reservation by account ID + GetReservedPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*ReservedPayment, error) // GetOnDemandPayments returns all on-demand payments GetOnDemandPayments(ctx context.Context, accountIDs []gethcommon.Address) (map[gethcommon.Address]*OnDemandPayment, error) diff --git a/core/data.go b/core/data.go index 551260ec24..367faad32d 100644 --- a/core/data.go +++ b/core/data.go @@ -604,20 +604,24 @@ func ConvertToPaymentMetadata(ph *commonpb.PaymentHeader) *PaymentMetadata { } } -// OperatorInfo contains information about an operator which is stored on the blockchain state, -// corresponding to a particular quorum -type ActiveReservation struct { - SymbolsPerSecond uint64 // reserve number of symbols per second - //TODO: we are not using start and end timestamp, add check or remove - StartTimestamp uint64 // Unix timestamp that's valid for basically eternity - EndTimestamp uint64 +// ReservedPayment contains information the onchain state about a reserved payment +type ReservedPayment struct { + // reserve number of symbols per second + SymbolsPerSecond uint64 + // reservation activation timestamp + StartTimestamp uint64 + // reservation expiration timestamp + EndTimestamp uint64 - QuorumNumbers []uint8 // allowed quorums - QuorumSplits []byte // ordered mapping of quorum number to payment split; on-chain validation should ensure split <= 100 + // allowed quorums + QuorumNumbers []uint8 + // ordered mapping of quorum number to payment split; on-chain validation should ensure split <= 100 + QuorumSplits []byte } type OnDemandPayment struct { - CumulativePayment *big.Int // Total amount deposited by the user + // Total amount deposited by the user + CumulativePayment *big.Int } type BlobVersionParameters struct { @@ -625,3 +629,8 @@ type BlobVersionParameters struct { MaxNumOperators uint32 NumChunks uint32 } + +// IsActive returns true if the reservation is active at the given timestamp +func (ar *ReservedPayment) IsActive(currentTimestamp uint64) bool { + return ar.StartTimestamp <= currentTimestamp && ar.EndTimestamp >= currentTimestamp +} diff --git a/core/data_test.go b/core/data_test.go index 84cb5097e9..61ffd37241 100644 --- a/core/data_test.go +++ b/core/data_test.go @@ -217,3 +217,65 @@ func TestChunksData(t *testing.T) { assert.EqualError(t, err, "unsupported chunk encoding format: 3") } } + +func TestReservedPayment_IsActive(t *testing.T) { + tests := []struct { + name string + reservedPayment core.ReservedPayment + currentTimestamp uint64 + wantActive bool + }{ + { + name: "active - current time in middle of range", + reservedPayment: core.ReservedPayment{ + StartTimestamp: 100, + EndTimestamp: 200, + }, + currentTimestamp: 150, + wantActive: true, + }, + { + name: "active - current time at start", + reservedPayment: core.ReservedPayment{ + StartTimestamp: 100, + EndTimestamp: 200, + }, + currentTimestamp: 100, + wantActive: true, + }, + { + name: "active - current time at end", + reservedPayment: core.ReservedPayment{ + StartTimestamp: 100, + EndTimestamp: 200, + }, + currentTimestamp: 200, + wantActive: true, + }, + { + name: "inactive - current time before start", + reservedPayment: core.ReservedPayment{ + StartTimestamp: 100, + EndTimestamp: 200, + }, + currentTimestamp: 99, + wantActive: false, + }, + { + name: "inactive - current time after end", + reservedPayment: core.ReservedPayment{ + StartTimestamp: 100, + EndTimestamp: 200, + }, + currentTimestamp: 201, + wantActive: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + isActive := tt.reservedPayment.IsActive(tt.currentTimestamp) + assert.Equal(t, tt.wantActive, isActive) + }) + } +} diff --git a/core/eth/reader.go b/core/eth/reader.go index 73c2789420..6390f221f0 100644 --- a/core/eth/reader.go +++ b/core/eth/reader.go @@ -690,11 +690,11 @@ func (t *Reader) GetAllVersionedBlobParams(ctx context.Context) (map[uint16]*cor return res, nil } -func (t *Reader) GetActiveReservations(ctx context.Context, accountIDs []gethcommon.Address) (map[gethcommon.Address]*core.ActiveReservation, error) { +func (t *Reader) GetReservedPayments(ctx context.Context, accountIDs []gethcommon.Address) (map[gethcommon.Address]*core.ReservedPayment, error) { if t.bindings.PaymentVault == nil { return nil, errors.New("payment vault not deployed") } - reservationsMap := make(map[gethcommon.Address]*core.ActiveReservation) + reservationsMap := make(map[gethcommon.Address]*core.ReservedPayment) reservations, err := t.bindings.PaymentVault.GetReservations(&bind.CallOpts{ Context: ctx, }, accountIDs) @@ -704,7 +704,7 @@ func (t *Reader) GetActiveReservations(ctx context.Context, accountIDs []gethcom // since reservations are returned in the same order as the accountIDs, we can directly map them for i, reservation := range reservations { - res, err := ConvertToActiveReservation(reservation) + res, err := ConvertToReservedPayment(reservation) if err != nil { t.logger.Warn("failed to get active reservation", "account", accountIDs[i], "err", err) continue @@ -716,7 +716,7 @@ func (t *Reader) GetActiveReservations(ctx context.Context, accountIDs []gethcom return reservationsMap, nil } -func (t *Reader) GetActiveReservationByAccount(ctx context.Context, accountID gethcommon.Address) (*core.ActiveReservation, error) { +func (t *Reader) GetReservedPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.ReservedPayment, error) { if t.bindings.PaymentVault == nil { return nil, errors.New("payment vault not deployed") } @@ -726,7 +726,7 @@ func (t *Reader) GetActiveReservationByAccount(ctx context.Context, accountID ge if err != nil { return nil, err } - return ConvertToActiveReservation(reservation) + return ConvertToReservedPayment(reservation) } func (t *Reader) GetOnDemandPayments(ctx context.Context, accountIDs []gethcommon.Address) (map[gethcommon.Address]*core.OnDemandPayment, error) { diff --git a/core/eth/utils.go b/core/eth/utils.go index d98b6def2a..7334f62f98 100644 --- a/core/eth/utils.go +++ b/core/eth/utils.go @@ -137,14 +137,14 @@ func isZeroValuedReservation(reservation paymentvault.IPaymentVaultReservation) len(reservation.QuorumSplits) == 0 } -// ConvertToActiveReservation converts a upstream binding data structure to local definition. +// ConvertToReservedPayment converts a upstream binding data structure to local definition. // Returns an error if the input reservation is zero-valued. -func ConvertToActiveReservation(reservation paymentvault.IPaymentVaultReservation) (*core.ActiveReservation, error) { +func ConvertToReservedPayment(reservation paymentvault.IPaymentVaultReservation) (*core.ReservedPayment, error) { if isZeroValuedReservation(reservation) { return nil, fmt.Errorf("reservation is not a valid active reservation") } - return &core.ActiveReservation{ + return &core.ReservedPayment{ SymbolsPerSecond: reservation.SymbolsPerSecond, StartTimestamp: reservation.StartTimestamp, EndTimestamp: reservation.EndTimestamp, diff --git a/core/meterer/meterer.go b/core/meterer/meterer.go index 1f0e1c5aeb..681c6d8c6c 100644 --- a/core/meterer/meterer.go +++ b/core/meterer/meterer.go @@ -76,7 +76,7 @@ func (m *Meterer) MeterRequest(ctx context.Context, header core.PaymentMetadata, accountID := gethcommon.HexToAddress(header.AccountID) // Validate against the payment method if header.CumulativePayment.Sign() == 0 { - reservation, err := m.ChainPaymentState.GetActiveReservationByAccount(ctx, accountID) + reservation, err := m.ChainPaymentState.GetReservedPaymentByAccount(ctx, accountID) if err != nil { return fmt.Errorf("failed to get active reservation by account: %w", err) } @@ -97,12 +97,15 @@ func (m *Meterer) MeterRequest(ctx context.Context, header core.PaymentMetadata, } // ServeReservationRequest handles the rate limiting logic for incoming requests -func (m *Meterer) ServeReservationRequest(ctx context.Context, header core.PaymentMetadata, reservation *core.ActiveReservation, numSymbols uint, quorumNumbers []uint8) error { +func (m *Meterer) ServeReservationRequest(ctx context.Context, header core.PaymentMetadata, reservation *core.ReservedPayment, numSymbols uint, quorumNumbers []uint8) error { + if !reservation.IsActive(uint64(time.Now().Unix())) { + return fmt.Errorf("reservation not active") + } if err := m.ValidateQuorum(quorumNumbers, reservation.QuorumNumbers); err != nil { return fmt.Errorf("invalid quorum for reservation: %w", err) } if !m.ValidateReservationPeriod(header, reservation) { - return fmt.Errorf("invalid bin index for reservation") + return fmt.Errorf("invalid reservation period for reservation") } // Update bin usage atomically and check against reservation's data rate as the bin limit @@ -122,7 +125,7 @@ func (m *Meterer) ValidateQuorum(headerQuorums []uint8, allowedQuorums []uint8) return fmt.Errorf("no quorum params in blob header") } - // check that all the quorum ids are in ActiveReservation's + // check that all the quorum ids are in ReservedPayment's for _, q := range headerQuorums { if !slices.Contains(allowedQuorums, q) { // fail the entire request if there's a quorum number mismatch @@ -132,12 +135,12 @@ func (m *Meterer) ValidateQuorum(headerQuorums []uint8, allowedQuorums []uint8) return nil } -// ValidateReservationPeriod checks if the provided bin index is valid -func (m *Meterer) ValidateReservationPeriod(header core.PaymentMetadata, reservation *core.ActiveReservation) bool { +// ValidateReservationPeriod checks if the provided reservation period is valid +func (m *Meterer) ValidateReservationPeriod(header core.PaymentMetadata, reservation *core.ReservedPayment) bool { now := uint64(time.Now().Unix()) reservationWindow := m.ChainPaymentState.GetReservationWindow() currentReservationPeriod := GetReservationPeriod(now, reservationWindow) - // Valid bin indexes are either the current bin or the previous bin + // Valid reservation periodes are either the current bin or the previous bin if (header.ReservationPeriod != currentReservationPeriod && header.ReservationPeriod != (currentReservationPeriod-1)) || (GetReservationPeriod(reservation.StartTimestamp, reservationWindow) > header.ReservationPeriod || header.ReservationPeriod > GetReservationPeriod(reservation.EndTimestamp, reservationWindow)) { return false } @@ -145,7 +148,7 @@ func (m *Meterer) ValidateReservationPeriod(header core.PaymentMetadata, reserva } // IncrementBinUsage increments the bin usage atomically and checks for overflow -func (m *Meterer) IncrementBinUsage(ctx context.Context, header core.PaymentMetadata, reservation *core.ActiveReservation, numSymbols uint) error { +func (m *Meterer) IncrementBinUsage(ctx context.Context, header core.PaymentMetadata, reservation *core.ReservedPayment, numSymbols uint) error { symbolsCharged := m.SymbolsCharged(numSymbols) newUsage, err := m.OffchainStore.UpdateReservationBin(ctx, header.AccountID, uint64(header.ReservationPeriod), uint64(symbolsCharged)) if err != nil { @@ -170,7 +173,7 @@ func (m *Meterer) IncrementBinUsage(ctx context.Context, header core.PaymentMeta return fmt.Errorf("overflow usage exceeds bin limit") } -// GetReservationPeriod returns the current bin index by chunking time by the bin interval; +// GetReservationPeriod returns the current reservation period by chunking time by the bin interval; // bin interval used by the disperser should be public information func GetReservationPeriod(timestamp uint64, binInterval uint32) uint32 { return uint32(timestamp) / binInterval @@ -274,6 +277,6 @@ func (m *Meterer) IncrementGlobalBinUsage(ctx context.Context, symbolsCharged ui } // GetReservationBinLimit returns the bin limit for a given reservation -func (m *Meterer) GetReservationBinLimit(reservation *core.ActiveReservation) uint64 { +func (m *Meterer) GetReservationBinLimit(reservation *core.ReservedPayment) uint64 { return reservation.SymbolsPerSecond * uint64(m.ChainPaymentState.GetReservationWindow()) } diff --git a/core/meterer/meterer_test.go b/core/meterer/meterer_test.go index 30a03c9737..9ab35df545 100644 --- a/core/meterer/meterer_test.go +++ b/core/meterer/meterer_test.go @@ -32,11 +32,13 @@ var ( dynamoClient commondynamodb.Client clientConfig commonaws.ClientConfig accountID1 gethcommon.Address - account1Reservations *core.ActiveReservation + account1Reservations *core.ReservedPayment account1OnDemandPayments *core.OnDemandPayment accountID2 gethcommon.Address - account2Reservations *core.ActiveReservation + account2Reservations *core.ReservedPayment account2OnDemandPayments *core.OnDemandPayment + accountID3 gethcommon.Address + account3Reservations *core.ReservedPayment mt *meterer.Meterer deployLocalStack bool @@ -100,6 +102,11 @@ func setup(_ *testing.M) { teardown() panic("failed to generate private key") } + privateKey3, err := crypto.GenerateKey() + if err != nil { + teardown() + panic("failed to generate private key") + } logger = logging.NewNoopLogger() config := meterer.Config{ @@ -126,8 +133,10 @@ func setup(_ *testing.M) { now := uint64(time.Now().Unix()) accountID1 = crypto.PubkeyToAddress(privateKey1.PublicKey) accountID2 = crypto.PubkeyToAddress(privateKey2.PublicKey) - account1Reservations = &core.ActiveReservation{SymbolsPerSecond: 100, StartTimestamp: now + 1200, EndTimestamp: now + 1800, QuorumSplits: []byte{50, 50}, QuorumNumbers: []uint8{0, 1}} - account2Reservations = &core.ActiveReservation{SymbolsPerSecond: 200, StartTimestamp: now - 120, EndTimestamp: now + 180, QuorumSplits: []byte{30, 70}, QuorumNumbers: []uint8{0, 1}} + accountID3 = crypto.PubkeyToAddress(privateKey3.PublicKey) + account1Reservations = &core.ReservedPayment{SymbolsPerSecond: 100, StartTimestamp: now - 120, EndTimestamp: now + 180, QuorumSplits: []byte{50, 50}, QuorumNumbers: []uint8{0, 1}} + account2Reservations = &core.ReservedPayment{SymbolsPerSecond: 200, StartTimestamp: now - 120, EndTimestamp: now + 180, QuorumSplits: []byte{30, 70}, QuorumNumbers: []uint8{0, 1}} + account3Reservations = &core.ReservedPayment{SymbolsPerSecond: 200, StartTimestamp: now + 120, EndTimestamp: now + 180, QuorumSplits: []byte{30, 70}, QuorumNumbers: []uint8{0, 1}} account1OnDemandPayments = &core.OnDemandPayment{CumulativePayment: big.NewInt(3864)} account2OnDemandPayments = &core.OnDemandPayment{CumulativePayment: big.NewInt(2000)} @@ -177,13 +186,16 @@ func TestMetererReservations(t *testing.T) { reservationPeriod := meterer.GetReservationPeriod(uint64(time.Now().Unix()), mt.ChainPaymentState.GetReservationWindow()) quoromNumbers := []uint8{0, 1} - paymentChainState.On("GetActiveReservationByAccount", testifymock.Anything, testifymock.MatchedBy(func(account gethcommon.Address) bool { + paymentChainState.On("GetReservedPaymentByAccount", testifymock.Anything, testifymock.MatchedBy(func(account gethcommon.Address) bool { return account == accountID1 })).Return(account1Reservations, nil) - paymentChainState.On("GetActiveReservationByAccount", testifymock.Anything, testifymock.MatchedBy(func(account gethcommon.Address) bool { + paymentChainState.On("GetReservedPaymentByAccount", testifymock.Anything, testifymock.MatchedBy(func(account gethcommon.Address) bool { return account == accountID2 })).Return(account2Reservations, nil) - paymentChainState.On("GetActiveReservationByAccount", testifymock.Anything, testifymock.Anything).Return(&core.ActiveReservation{}, fmt.Errorf("reservation not found")) + paymentChainState.On("GetReservedPaymentByAccount", testifymock.Anything, testifymock.MatchedBy(func(account gethcommon.Address) bool { + return account == accountID3 + })).Return(account3Reservations, nil) + paymentChainState.On("GetReservedPaymentByAccount", testifymock.Anything, testifymock.Anything).Return(&core.ReservedPayment{}, fmt.Errorf("reservation not found")) // test invalid quorom ID header := createPaymentHeader(1, big.NewInt(0), accountID1) @@ -209,10 +221,15 @@ func TestMetererReservations(t *testing.T) { err = mt.MeterRequest(ctx, *header, 1000, []uint8{0, 1, 2}) assert.ErrorContains(t, err, "failed to get active reservation by account: reservation not found") - // test invalid bin index - header = createPaymentHeader(reservationPeriod, big.NewInt(0), accountID1) + // test inactive reservation + header = createPaymentHeader(reservationPeriod, big.NewInt(0), accountID3) + err = mt.MeterRequest(ctx, *header, 1000, []uint8{0}) + assert.ErrorContains(t, err, "reservation not active") + + // test invalid reservation period + header = createPaymentHeader(reservationPeriod-3, big.NewInt(0), accountID1) err = mt.MeterRequest(ctx, *header, 2000, quoromNumbers) - assert.ErrorContains(t, err, "invalid bin index for reservation") + assert.ErrorContains(t, err, "invalid reservation period for reservation") // test bin usage metering symbolLength := uint(20) diff --git a/core/meterer/onchain_state.go b/core/meterer/onchain_state.go index 3a9ba34f3d..85d6a9c932 100644 --- a/core/meterer/onchain_state.go +++ b/core/meterer/onchain_state.go @@ -2,7 +2,6 @@ package meterer import ( "context" - "fmt" "sync" "sync/atomic" @@ -16,7 +15,7 @@ import ( // OnchainPaymentState is an interface for getting information about the current chain state for payments. type OnchainPayment interface { RefreshOnchainPaymentState(ctx context.Context, tx *eth.Reader) error - GetActiveReservationByAccount(ctx context.Context, accountID gethcommon.Address) (*core.ActiveReservation, error) + GetReservedPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.ReservedPayment, error) GetOnDemandPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.OnDemandPayment, error) GetOnDemandQuorumNumbers(ctx context.Context) ([]uint8, error) GetGlobalSymbolsPerSecond() uint64 @@ -31,8 +30,8 @@ var _ OnchainPayment = (*OnchainPaymentState)(nil) type OnchainPaymentState struct { tx *eth.Reader - ActiveReservations map[gethcommon.Address]*core.ActiveReservation - OnDemandPayments map[gethcommon.Address]*core.OnDemandPayment + ReservedPayments map[gethcommon.Address]*core.ReservedPayment + OnDemandPayments map[gethcommon.Address]*core.OnDemandPayment ReservationsLock sync.RWMutex OnDemandLocks sync.RWMutex @@ -57,7 +56,7 @@ func NewOnchainPaymentState(ctx context.Context, tx *eth.Reader) (*OnchainPaymen state := OnchainPaymentState{ tx: tx, - ActiveReservations: make(map[gethcommon.Address]*core.ActiveReservation), + ReservedPayments: make(map[gethcommon.Address]*core.ReservedPayment), OnDemandPayments: make(map[gethcommon.Address]*core.OnDemandPayment), PaymentVaultParams: atomic.Pointer[PaymentVaultParams]{}, } @@ -116,16 +115,16 @@ func (pcs *OnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context, pcs.PaymentVaultParams.Store(paymentVaultParams) pcs.ReservationsLock.Lock() - accountIDs := make([]gethcommon.Address, 0, len(pcs.ActiveReservations)) - for accountID := range pcs.ActiveReservations { + accountIDs := make([]gethcommon.Address, 0, len(pcs.ReservedPayments)) + for accountID := range pcs.ReservedPayments { accountIDs = append(accountIDs, accountID) } - activeReservations, err := tx.GetActiveReservations(ctx, accountIDs) + reservedPayments, err := tx.GetReservedPayments(ctx, accountIDs) if err != nil { return err } - pcs.ActiveReservations = activeReservations + pcs.ReservedPayments = reservedPayments pcs.ReservationsLock.Unlock() pcs.OnDemandLocks.Lock() @@ -144,31 +143,23 @@ func (pcs *OnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context, return nil } -// GetActiveReservationByAccount returns a pointer to the active reservation for the given account ID; no writes will be made to the reservation -func (pcs *OnchainPaymentState) GetActiveReservationByAccount(ctx context.Context, accountID gethcommon.Address) (*core.ActiveReservation, error) { +// GetReservedPaymentByAccount returns a pointer to the active reservation for the given account ID; no writes will be made to the reservation +func (pcs *OnchainPaymentState) GetReservedPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.ReservedPayment, error) { pcs.ReservationsLock.RLock() defer pcs.ReservationsLock.RUnlock() - if reservation, ok := (pcs.ActiveReservations)[accountID]; ok { + if reservation, ok := (pcs.ReservedPayments)[accountID]; ok { return reservation, nil } // pulls the chain state - res, err := pcs.tx.GetActiveReservationByAccount(ctx, accountID) + res, err := pcs.tx.GetReservedPaymentByAccount(ctx, accountID) if err != nil { return nil, err } pcs.ReservationsLock.Lock() - (pcs.ActiveReservations)[accountID] = res + (pcs.ReservedPayments)[accountID] = res pcs.ReservationsLock.Unlock() - return res, nil -} -// GetActiveReservationByAccountOnChain returns on-chain reservation for the given account ID -func (pcs *OnchainPaymentState) GetActiveReservationByAccountOnChain(ctx context.Context, accountID gethcommon.Address) (*core.ActiveReservation, error) { - res, err := pcs.tx.GetActiveReservationByAccount(ctx, accountID) - if err != nil { - return nil, fmt.Errorf("reservation account not found on-chain: %w", err) - } return res, nil } @@ -191,14 +182,6 @@ func (pcs *OnchainPaymentState) GetOnDemandPaymentByAccount(ctx context.Context, return res, nil } -func (pcs *OnchainPaymentState) GetOnDemandPaymentByAccountOnChain(ctx context.Context, accountID gethcommon.Address) (*core.OnDemandPayment, error) { - res, err := pcs.tx.GetOnDemandPaymentByAccount(ctx, accountID) - if err != nil { - return nil, fmt.Errorf("on-demand not found on-chain: %w", err) - } - return res, nil -} - func (pcs *OnchainPaymentState) GetOnDemandQuorumNumbers(ctx context.Context) ([]uint8, error) { blockNumber, err := pcs.tx.GetCurrentBlockNumber(ctx) if err != nil { diff --git a/core/meterer/onchain_state_test.go b/core/meterer/onchain_state_test.go index d7fca84845..468296be87 100644 --- a/core/meterer/onchain_state_test.go +++ b/core/meterer/onchain_state_test.go @@ -14,7 +14,7 @@ import ( ) var ( - dummyActiveReservation = &core.ActiveReservation{ + dummyReservedPayment = &core.ReservedPayment{ SymbolsPerSecond: 100, StartTimestamp: 1000, EndTimestamp: 2000, @@ -43,14 +43,14 @@ func TestGetCurrentBlockNumber(t *testing.T) { assert.Equal(t, uint32(1000), blockNumber) } -func TestGetActiveReservationByAccount(t *testing.T) { +func TestGetReservedPaymentByAccount(t *testing.T) { mockState := &mock.MockOnchainPaymentState{} ctx := context.Background() - mockState.On("GetActiveReservationByAccount", testifymock.Anything, testifymock.Anything).Return(dummyActiveReservation, nil) + mockState.On("GetReservedPaymentByAccount", testifymock.Anything, testifymock.Anything).Return(dummyReservedPayment, nil) - reservation, err := mockState.GetActiveReservationByAccount(ctx, gethcommon.Address{}) + reservation, err := mockState.GetReservedPaymentByAccount(ctx, gethcommon.Address{}) assert.NoError(t, err) - assert.Equal(t, dummyActiveReservation, reservation) + assert.Equal(t, dummyReservedPayment, reservation) } func TestGetOnDemandPaymentByAccount(t *testing.T) { diff --git a/core/mock/payment_state.go b/core/mock/payment_state.go index 00c34b326e..8af76628f0 100644 --- a/core/mock/payment_state.go +++ b/core/mock/payment_state.go @@ -30,11 +30,11 @@ func (m *MockOnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context return args.Error(0) } -func (m *MockOnchainPaymentState) GetActiveReservationByAccount(ctx context.Context, accountID gethcommon.Address) (*core.ActiveReservation, error) { +func (m *MockOnchainPaymentState) GetReservedPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.ReservedPayment, error) { args := m.Called(ctx, accountID) - var value *core.ActiveReservation + var value *core.ReservedPayment if args.Get(0) != nil { - value = args.Get(0).(*core.ActiveReservation) + value = args.Get(0).(*core.ReservedPayment) } return value, args.Error(1) } diff --git a/core/mock/writer.go b/core/mock/writer.go index 87384401bf..b28f88b5f1 100644 --- a/core/mock/writer.go +++ b/core/mock/writer.go @@ -227,16 +227,16 @@ func (t *MockWriter) PubkeyHashToOperator(ctx context.Context, operatorId core.O return result.(gethcommon.Address), args.Error(1) } -func (t *MockWriter) GetActiveReservations(ctx context.Context, accountIDs []gethcommon.Address) (map[gethcommon.Address]*core.ActiveReservation, error) { +func (t *MockWriter) GetReservedPayments(ctx context.Context, accountIDs []gethcommon.Address) (map[gethcommon.Address]*core.ReservedPayment, error) { args := t.Called() result := args.Get(0) - return result.(map[gethcommon.Address]*core.ActiveReservation), args.Error(1) + return result.(map[gethcommon.Address]*core.ReservedPayment), args.Error(1) } -func (t *MockWriter) GetActiveReservationByAccount(ctx context.Context, accountID gethcommon.Address) (*core.ActiveReservation, error) { +func (t *MockWriter) GetReservedPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.ReservedPayment, error) { args := t.Called() result := args.Get(0) - return result.(*core.ActiveReservation), args.Error(1) + return result.(*core.ReservedPayment), args.Error(1) } func (t *MockWriter) GetOnDemandPayments(ctx context.Context, accountIDs []gethcommon.Address) (map[gethcommon.Address]*core.OnDemandPayment, error) { diff --git a/disperser/apiserver/server_test.go b/disperser/apiserver/server_test.go index 29f4f74a94..1bf6629d34 100644 --- a/disperser/apiserver/server_test.go +++ b/disperser/apiserver/server_test.go @@ -761,7 +761,7 @@ func newTestServer(transactor core.Writer, testName string) *apiserver.Dispersal mockState.On("GetOnDemandPaymentByAccount", tmock.Anything, tmock.Anything).Return(&core.OnDemandPayment{ CumulativePayment: big.NewInt(3000), }, nil) - mockState.On("GetActiveReservationByAccount", tmock.Anything, tmock.Anything).Return(&core.ActiveReservation{ + mockState.On("GetReservedPaymentByAccount", tmock.Anything, tmock.Anything).Return(&core.ReservedPayment{ SymbolsPerSecond: 2048, StartTimestamp: 0, EndTimestamp: math.MaxUint32, diff --git a/disperser/apiserver/server_v2.go b/disperser/apiserver/server_v2.go index e808a813c9..bd45248f1d 100644 --- a/disperser/apiserver/server_v2.go +++ b/disperser/apiserver/server_v2.go @@ -4,11 +4,12 @@ import ( "context" "errors" "fmt" - "github.com/prometheus/client_golang/prometheus" "net" "sync/atomic" "time" + "github.com/prometheus/client_golang/prometheus" + "github.com/Layr-Labs/eigenda/api" pbcommon "github.com/Layr-Labs/eigenda/api/grpc/common" pbv1 "github.com/Layr-Labs/eigenda/api/grpc/disperser" @@ -260,7 +261,7 @@ func (s *DispersalServerV2) GetPaymentState(ctx context.Context, req *pb.GetPaym return nil, api.NewErrorNotFound("failed to get largest cumulative payment") } // on-Chain account state - reservation, err := s.meterer.ChainPaymentState.GetActiveReservationByAccount(ctx, accountID) + reservation, err := s.meterer.ChainPaymentState.GetReservedPaymentByAccount(ctx, accountID) if err != nil { return nil, api.NewErrorNotFound("failed to get active reservation") } diff --git a/disperser/apiserver/server_v2_test.go b/disperser/apiserver/server_v2_test.go index 0bd8b5997e..41409bf552 100644 --- a/disperser/apiserver/server_v2_test.go +++ b/disperser/apiserver/server_v2_test.go @@ -447,7 +447,7 @@ func newTestServerV2(t *testing.T) *testComponents { mockState.On("GetMinNumSymbols", tmock.Anything).Return(uint32(3), nil) now := uint64(time.Now().Unix()) - mockState.On("GetActiveReservationByAccount", tmock.Anything, tmock.Anything).Return(&core.ActiveReservation{SymbolsPerSecond: 100, StartTimestamp: now + 1200, EndTimestamp: now + 1800, QuorumSplits: []byte{50, 50}, QuorumNumbers: []uint8{0, 1}}, nil) + mockState.On("GetReservedPaymentByAccount", tmock.Anything, tmock.Anything).Return(&core.ReservedPayment{SymbolsPerSecond: 100, StartTimestamp: now + 1200, EndTimestamp: now + 1800, QuorumSplits: []byte{50, 50}, QuorumNumbers: []uint8{0, 1}}, nil) mockState.On("GetOnDemandPaymentByAccount", tmock.Anything, tmock.Anything).Return(&core.OnDemandPayment{CumulativePayment: big.NewInt(3864)}, nil) mockState.On("GetOnDemandQuorumNumbers", tmock.Anything).Return([]uint8{0, 1}, nil) diff --git a/test/integration_test.go b/test/integration_test.go index 5016f3598f..cb83bdf621 100644 --- a/test/integration_test.go +++ b/test/integration_test.go @@ -221,10 +221,10 @@ func mustMakeDisperser(t *testing.T, cst core.IndexedChainState, store disperser mockState := &coremock.MockOnchainPaymentState{} reservationLimit := uint64(1024) paymentLimit := big.NewInt(512) - mockState.On("GetActiveReservationByAccount", mock.Anything, mock.MatchedBy(func(account gethcommon.Address) bool { + mockState.On("GetReservedPaymentByAccount", mock.Anything, mock.MatchedBy(func(account gethcommon.Address) bool { return account == publicKey - })).Return(&core.ActiveReservation{SymbolsPerSecond: reservationLimit, StartTimestamp: 0, EndTimestamp: math.MaxUint32, QuorumSplits: []byte{50, 50}, QuorumNumbers: []uint8{0, 1}}, nil) - mockState.On("GetActiveReservationByAccount", mock.Anything, mock.Anything).Return(&core.ActiveReservation{}, errors.New("reservation not found")) + })).Return(&core.ReservedPayment{SymbolsPerSecond: reservationLimit, StartTimestamp: 0, EndTimestamp: math.MaxUint32, QuorumSplits: []byte{50, 50}, QuorumNumbers: []uint8{0, 1}}, nil) + mockState.On("GetReservedPaymentByAccount", mock.Anything, mock.Anything).Return(&core.ReservedPayment{}, errors.New("reservation not found")) mockState.On("GetOnDemandPaymentByAccount", mock.Anything, mock.MatchedBy(func(account gethcommon.Address) bool { return account == publicKey