diff --git a/api/clients/accountant.go b/api/clients/accountant.go index 045877bd7e..6b923d51b4 100644 --- a/api/clients/accountant.go +++ b/api/clients/accountant.go @@ -18,7 +18,7 @@ var requiredQuorums = []uint8{0, 1} type Accountant struct { // on-chain states accountID string - reservation *core.ActiveReservation + reservation *core.ReservedPayment onDemand *core.OnDemandPayment reservationWindow uint32 pricePerSymbol uint32 @@ -39,7 +39,7 @@ type BinRecord struct { Usage uint64 } -func NewAccountant(accountID string, reservation *core.ActiveReservation, onDemand *core.OnDemandPayment, reservationWindow uint32, pricePerSymbol uint32, minNumSymbols uint32, numBins uint32) *Accountant { +func NewAccountant(accountID string, reservation *core.ReservedPayment, onDemand *core.OnDemandPayment, reservationWindow uint32, pricePerSymbol uint32, minNumSymbols uint32, numBins uint32) *Accountant { //TODO: client storage; currently every instance starts fresh but on-chain or a small store makes more sense // Also client is currently responsible for supplying network params, we need to add RPC in order to be automatic // There's a subsequent PR that handles populating the accountant with on-chain state from the disperser @@ -65,7 +65,7 @@ func NewAccountant(accountID string, reservation *core.ActiveReservation, onDema // BlobPaymentInfo calculates and records payment information. The accountant // will attempt to use the active reservation first and check for quorum settings, // then on-demand if the reservation is not available. The returned values are -// bin index for reservation payments and cumulative payment for on-demand payments, +// reservation period for reservation payments and cumulative payment for on-demand payments, // and both fields are used to create the payment header and signature func (a *Accountant) BlobPaymentInfo(ctx context.Context, numSymbols uint64, quorumNumbers []uint8) (uint32, *big.Int, error) { now := time.Now().Unix() diff --git a/api/clients/accountant_test.go b/api/clients/accountant_test.go index c6dc3fa692..d28dd9f16b 100644 --- a/api/clients/accountant_test.go +++ b/api/clients/accountant_test.go @@ -18,7 +18,7 @@ const numBins = uint32(3) const salt = uint32(0) func TestNewAccountant(t *testing.T) { - reservation := &core.ActiveReservation{ + reservation := &core.ReservedPayment{ SymbolsPerSecond: 100, StartTimestamp: 100, EndTimestamp: 200, @@ -48,7 +48,7 @@ func TestNewAccountant(t *testing.T) { } func TestAccountBlob_Reservation(t *testing.T) { - reservation := &core.ActiveReservation{ + reservation := &core.ReservedPayment{ SymbolsPerSecond: 200, StartTimestamp: 100, EndTimestamp: 200, @@ -96,7 +96,7 @@ func TestAccountBlob_Reservation(t *testing.T) { } func TestAccountBlob_OnDemand(t *testing.T) { - reservation := &core.ActiveReservation{ + reservation := &core.ReservedPayment{ SymbolsPerSecond: 200, StartTimestamp: 100, EndTimestamp: 200, @@ -130,7 +130,7 @@ func TestAccountBlob_OnDemand(t *testing.T) { } func TestAccountBlob_InsufficientOnDemand(t *testing.T) { - reservation := &core.ActiveReservation{} + reservation := &core.ReservedPayment{} onDemand := &core.OnDemandPayment{ CumulativePayment: big.NewInt(500), } @@ -152,7 +152,7 @@ func TestAccountBlob_InsufficientOnDemand(t *testing.T) { } func TestAccountBlobCallSeries(t *testing.T) { - reservation := &core.ActiveReservation{ + reservation := &core.ReservedPayment{ SymbolsPerSecond: 200, StartTimestamp: 100, EndTimestamp: 200, @@ -200,7 +200,7 @@ func TestAccountBlobCallSeries(t *testing.T) { } func TestAccountBlob_BinRotation(t *testing.T) { - reservation := &core.ActiveReservation{ + reservation := &core.ReservedPayment{ SymbolsPerSecond: 1000, StartTimestamp: 100, EndTimestamp: 200, @@ -242,7 +242,7 @@ func TestAccountBlob_BinRotation(t *testing.T) { } func TestConcurrentBinRotationAndAccountBlob(t *testing.T) { - reservation := &core.ActiveReservation{ + reservation := &core.ReservedPayment{ SymbolsPerSecond: 1000, StartTimestamp: 100, EndTimestamp: 200, @@ -284,7 +284,7 @@ func TestConcurrentBinRotationAndAccountBlob(t *testing.T) { } func TestAccountBlob_ReservationWithOneOverflow(t *testing.T) { - reservation := &core.ActiveReservation{ + reservation := &core.ReservedPayment{ SymbolsPerSecond: 200, StartTimestamp: 100, EndTimestamp: 200, @@ -332,7 +332,7 @@ func TestAccountBlob_ReservationWithOneOverflow(t *testing.T) { } func TestAccountBlob_ReservationOverflowReset(t *testing.T) { - reservation := &core.ActiveReservation{ + reservation := &core.ReservedPayment{ SymbolsPerSecond: 1000, StartTimestamp: 100, EndTimestamp: 200, diff --git a/api/docs/disperser_v2.html b/api/docs/disperser_v2.html index 2435d12a47..03a4c5533b 100644 --- a/api/docs/disperser_v2.html +++ b/api/docs/disperser_v2.html @@ -632,7 +632,7 @@

GetPaymentStateRequest

bytes

Signature over the account ID -TODO: sign over a bin index or a nonce to mitigate signature replay attacks

+TODO: sign over a reservation period or a nonce to mitigate signature replay attacks

diff --git a/api/docs/disperser_v2.md b/api/docs/disperser_v2.md index 4a52f3ccb8..bd1bc66acc 100644 --- a/api/docs/disperser_v2.md +++ b/api/docs/disperser_v2.md @@ -210,7 +210,7 @@ GetPaymentStateRequest contains parameters to query the payment state of an acco | Field | Type | Label | Description | | ----- | ---- | ----- | ----------- | | account_id | [string](#string) | | | -| signature | [bytes](#bytes) | | Signature over the account ID TODO: sign over a bin index or a nonce to mitigate signature replay attacks | +| signature | [bytes](#bytes) | | Signature over the account ID TODO: sign over a reservation period or a nonce to mitigate signature replay attacks | diff --git a/api/docs/eigenda-protos.html b/api/docs/eigenda-protos.html index 4e1a8cb48d..9a6b53af5b 100644 --- a/api/docs/eigenda-protos.html +++ b/api/docs/eigenda-protos.html @@ -2303,7 +2303,7 @@

GetPaymentStateRequest

bytes

Signature over the account ID -TODO: sign over a bin index or a nonce to mitigate signature replay attacks

+TODO: sign over a reservation period or a nonce to mitigate signature replay attacks

diff --git a/api/docs/eigenda-protos.md b/api/docs/eigenda-protos.md index 4baff482a1..7f90353af8 100644 --- a/api/docs/eigenda-protos.md +++ b/api/docs/eigenda-protos.md @@ -912,7 +912,7 @@ GetPaymentStateRequest contains parameters to query the payment state of an acco | Field | Type | Label | Description | | ----- | ---- | ----- | ----------- | | account_id | [string](#string) | | | -| signature | [bytes](#bytes) | | Signature over the account ID TODO: sign over a bin index or a nonce to mitigate signature replay attacks | +| signature | [bytes](#bytes) | | Signature over the account ID TODO: sign over a reservation period or a nonce to mitigate signature replay attacks | diff --git a/api/grpc/disperser/v2/disperser_v2.pb.go b/api/grpc/disperser/v2/disperser_v2.pb.go index ffbfb3c178..e420648ee5 100644 --- a/api/grpc/disperser/v2/disperser_v2.pb.go +++ b/api/grpc/disperser/v2/disperser_v2.pb.go @@ -424,7 +424,7 @@ type GetPaymentStateRequest struct { AccountId string `protobuf:"bytes,1,opt,name=account_id,json=accountId,proto3" json:"account_id,omitempty"` // Signature over the account ID - // TODO: sign over a bin index or a nonce to mitigate signature replay attacks + // TODO: sign over a reservation period or a nonce to mitigate signature replay attacks Signature []byte `protobuf:"bytes,2,opt,name=signature,proto3" json:"signature,omitempty"` } diff --git a/api/proto/disperser/v2/disperser_v2.proto b/api/proto/disperser/v2/disperser_v2.proto index 3038fef7e8..fb6386c724 100644 --- a/api/proto/disperser/v2/disperser_v2.proto +++ b/api/proto/disperser/v2/disperser_v2.proto @@ -70,7 +70,7 @@ message BlobCommitmentReply { message GetPaymentStateRequest { string account_id = 1; // Signature over the account ID - // TODO: sign over a bin index or a nonce to mitigate signature replay attacks + // TODO: sign over a reservation period or a nonce to mitigate signature replay attacks bytes signature = 2; } diff --git a/core/chainio.go b/core/chainio.go index 6ced6ed4c9..e28572832a 100644 --- a/core/chainio.go +++ b/core/chainio.go @@ -112,11 +112,11 @@ type Reader interface { // GetAllVersionedBlobParams returns the blob version parameters for all blob versions at the given block number. GetAllVersionedBlobParams(ctx context.Context) (map[uint16]*BlobVersionParameters, error) - // GetActiveReservations returns active reservations (end timestamp > current timestamp) - GetActiveReservations(ctx context.Context, accountIDs []gethcommon.Address) (map[gethcommon.Address]*ActiveReservation, error) + // GetReservedPayments returns active reservations (end timestamp > current timestamp) + GetReservedPayments(ctx context.Context, accountIDs []gethcommon.Address) (map[gethcommon.Address]*ReservedPayment, error) - // GetActiveReservationByAccount returns active reservation by account ID - GetActiveReservationByAccount(ctx context.Context, accountID gethcommon.Address) (*ActiveReservation, error) + // GetReservedPaymentByAccount returns active reservation by account ID + GetReservedPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*ReservedPayment, error) // GetOnDemandPayments returns all on-demand payments GetOnDemandPayments(ctx context.Context, accountIDs []gethcommon.Address) (map[gethcommon.Address]*OnDemandPayment, error) diff --git a/core/data.go b/core/data.go index 551260ec24..367faad32d 100644 --- a/core/data.go +++ b/core/data.go @@ -604,20 +604,24 @@ func ConvertToPaymentMetadata(ph *commonpb.PaymentHeader) *PaymentMetadata { } } -// OperatorInfo contains information about an operator which is stored on the blockchain state, -// corresponding to a particular quorum -type ActiveReservation struct { - SymbolsPerSecond uint64 // reserve number of symbols per second - //TODO: we are not using start and end timestamp, add check or remove - StartTimestamp uint64 // Unix timestamp that's valid for basically eternity - EndTimestamp uint64 +// ReservedPayment contains information the onchain state about a reserved payment +type ReservedPayment struct { + // reserve number of symbols per second + SymbolsPerSecond uint64 + // reservation activation timestamp + StartTimestamp uint64 + // reservation expiration timestamp + EndTimestamp uint64 - QuorumNumbers []uint8 // allowed quorums - QuorumSplits []byte // ordered mapping of quorum number to payment split; on-chain validation should ensure split <= 100 + // allowed quorums + QuorumNumbers []uint8 + // ordered mapping of quorum number to payment split; on-chain validation should ensure split <= 100 + QuorumSplits []byte } type OnDemandPayment struct { - CumulativePayment *big.Int // Total amount deposited by the user + // Total amount deposited by the user + CumulativePayment *big.Int } type BlobVersionParameters struct { @@ -625,3 +629,8 @@ type BlobVersionParameters struct { MaxNumOperators uint32 NumChunks uint32 } + +// IsActive returns true if the reservation is active at the given timestamp +func (ar *ReservedPayment) IsActive(currentTimestamp uint64) bool { + return ar.StartTimestamp <= currentTimestamp && ar.EndTimestamp >= currentTimestamp +} diff --git a/core/data_test.go b/core/data_test.go index 84cb5097e9..61ffd37241 100644 --- a/core/data_test.go +++ b/core/data_test.go @@ -217,3 +217,65 @@ func TestChunksData(t *testing.T) { assert.EqualError(t, err, "unsupported chunk encoding format: 3") } } + +func TestReservedPayment_IsActive(t *testing.T) { + tests := []struct { + name string + reservedPayment core.ReservedPayment + currentTimestamp uint64 + wantActive bool + }{ + { + name: "active - current time in middle of range", + reservedPayment: core.ReservedPayment{ + StartTimestamp: 100, + EndTimestamp: 200, + }, + currentTimestamp: 150, + wantActive: true, + }, + { + name: "active - current time at start", + reservedPayment: core.ReservedPayment{ + StartTimestamp: 100, + EndTimestamp: 200, + }, + currentTimestamp: 100, + wantActive: true, + }, + { + name: "active - current time at end", + reservedPayment: core.ReservedPayment{ + StartTimestamp: 100, + EndTimestamp: 200, + }, + currentTimestamp: 200, + wantActive: true, + }, + { + name: "inactive - current time before start", + reservedPayment: core.ReservedPayment{ + StartTimestamp: 100, + EndTimestamp: 200, + }, + currentTimestamp: 99, + wantActive: false, + }, + { + name: "inactive - current time after end", + reservedPayment: core.ReservedPayment{ + StartTimestamp: 100, + EndTimestamp: 200, + }, + currentTimestamp: 201, + wantActive: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + isActive := tt.reservedPayment.IsActive(tt.currentTimestamp) + assert.Equal(t, tt.wantActive, isActive) + }) + } +} diff --git a/core/eth/reader.go b/core/eth/reader.go index 73c2789420..6390f221f0 100644 --- a/core/eth/reader.go +++ b/core/eth/reader.go @@ -690,11 +690,11 @@ func (t *Reader) GetAllVersionedBlobParams(ctx context.Context) (map[uint16]*cor return res, nil } -func (t *Reader) GetActiveReservations(ctx context.Context, accountIDs []gethcommon.Address) (map[gethcommon.Address]*core.ActiveReservation, error) { +func (t *Reader) GetReservedPayments(ctx context.Context, accountIDs []gethcommon.Address) (map[gethcommon.Address]*core.ReservedPayment, error) { if t.bindings.PaymentVault == nil { return nil, errors.New("payment vault not deployed") } - reservationsMap := make(map[gethcommon.Address]*core.ActiveReservation) + reservationsMap := make(map[gethcommon.Address]*core.ReservedPayment) reservations, err := t.bindings.PaymentVault.GetReservations(&bind.CallOpts{ Context: ctx, }, accountIDs) @@ -704,7 +704,7 @@ func (t *Reader) GetActiveReservations(ctx context.Context, accountIDs []gethcom // since reservations are returned in the same order as the accountIDs, we can directly map them for i, reservation := range reservations { - res, err := ConvertToActiveReservation(reservation) + res, err := ConvertToReservedPayment(reservation) if err != nil { t.logger.Warn("failed to get active reservation", "account", accountIDs[i], "err", err) continue @@ -716,7 +716,7 @@ func (t *Reader) GetActiveReservations(ctx context.Context, accountIDs []gethcom return reservationsMap, nil } -func (t *Reader) GetActiveReservationByAccount(ctx context.Context, accountID gethcommon.Address) (*core.ActiveReservation, error) { +func (t *Reader) GetReservedPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.ReservedPayment, error) { if t.bindings.PaymentVault == nil { return nil, errors.New("payment vault not deployed") } @@ -726,7 +726,7 @@ func (t *Reader) GetActiveReservationByAccount(ctx context.Context, accountID ge if err != nil { return nil, err } - return ConvertToActiveReservation(reservation) + return ConvertToReservedPayment(reservation) } func (t *Reader) GetOnDemandPayments(ctx context.Context, accountIDs []gethcommon.Address) (map[gethcommon.Address]*core.OnDemandPayment, error) { diff --git a/core/eth/utils.go b/core/eth/utils.go index d98b6def2a..7334f62f98 100644 --- a/core/eth/utils.go +++ b/core/eth/utils.go @@ -137,14 +137,14 @@ func isZeroValuedReservation(reservation paymentvault.IPaymentVaultReservation) len(reservation.QuorumSplits) == 0 } -// ConvertToActiveReservation converts a upstream binding data structure to local definition. +// ConvertToReservedPayment converts a upstream binding data structure to local definition. // Returns an error if the input reservation is zero-valued. -func ConvertToActiveReservation(reservation paymentvault.IPaymentVaultReservation) (*core.ActiveReservation, error) { +func ConvertToReservedPayment(reservation paymentvault.IPaymentVaultReservation) (*core.ReservedPayment, error) { if isZeroValuedReservation(reservation) { return nil, fmt.Errorf("reservation is not a valid active reservation") } - return &core.ActiveReservation{ + return &core.ReservedPayment{ SymbolsPerSecond: reservation.SymbolsPerSecond, StartTimestamp: reservation.StartTimestamp, EndTimestamp: reservation.EndTimestamp, diff --git a/core/meterer/meterer.go b/core/meterer/meterer.go index 1f0e1c5aeb..681c6d8c6c 100644 --- a/core/meterer/meterer.go +++ b/core/meterer/meterer.go @@ -76,7 +76,7 @@ func (m *Meterer) MeterRequest(ctx context.Context, header core.PaymentMetadata, accountID := gethcommon.HexToAddress(header.AccountID) // Validate against the payment method if header.CumulativePayment.Sign() == 0 { - reservation, err := m.ChainPaymentState.GetActiveReservationByAccount(ctx, accountID) + reservation, err := m.ChainPaymentState.GetReservedPaymentByAccount(ctx, accountID) if err != nil { return fmt.Errorf("failed to get active reservation by account: %w", err) } @@ -97,12 +97,15 @@ func (m *Meterer) MeterRequest(ctx context.Context, header core.PaymentMetadata, } // ServeReservationRequest handles the rate limiting logic for incoming requests -func (m *Meterer) ServeReservationRequest(ctx context.Context, header core.PaymentMetadata, reservation *core.ActiveReservation, numSymbols uint, quorumNumbers []uint8) error { +func (m *Meterer) ServeReservationRequest(ctx context.Context, header core.PaymentMetadata, reservation *core.ReservedPayment, numSymbols uint, quorumNumbers []uint8) error { + if !reservation.IsActive(uint64(time.Now().Unix())) { + return fmt.Errorf("reservation not active") + } if err := m.ValidateQuorum(quorumNumbers, reservation.QuorumNumbers); err != nil { return fmt.Errorf("invalid quorum for reservation: %w", err) } if !m.ValidateReservationPeriod(header, reservation) { - return fmt.Errorf("invalid bin index for reservation") + return fmt.Errorf("invalid reservation period for reservation") } // Update bin usage atomically and check against reservation's data rate as the bin limit @@ -122,7 +125,7 @@ func (m *Meterer) ValidateQuorum(headerQuorums []uint8, allowedQuorums []uint8) return fmt.Errorf("no quorum params in blob header") } - // check that all the quorum ids are in ActiveReservation's + // check that all the quorum ids are in ReservedPayment's for _, q := range headerQuorums { if !slices.Contains(allowedQuorums, q) { // fail the entire request if there's a quorum number mismatch @@ -132,12 +135,12 @@ func (m *Meterer) ValidateQuorum(headerQuorums []uint8, allowedQuorums []uint8) return nil } -// ValidateReservationPeriod checks if the provided bin index is valid -func (m *Meterer) ValidateReservationPeriod(header core.PaymentMetadata, reservation *core.ActiveReservation) bool { +// ValidateReservationPeriod checks if the provided reservation period is valid +func (m *Meterer) ValidateReservationPeriod(header core.PaymentMetadata, reservation *core.ReservedPayment) bool { now := uint64(time.Now().Unix()) reservationWindow := m.ChainPaymentState.GetReservationWindow() currentReservationPeriod := GetReservationPeriod(now, reservationWindow) - // Valid bin indexes are either the current bin or the previous bin + // Valid reservation periodes are either the current bin or the previous bin if (header.ReservationPeriod != currentReservationPeriod && header.ReservationPeriod != (currentReservationPeriod-1)) || (GetReservationPeriod(reservation.StartTimestamp, reservationWindow) > header.ReservationPeriod || header.ReservationPeriod > GetReservationPeriod(reservation.EndTimestamp, reservationWindow)) { return false } @@ -145,7 +148,7 @@ func (m *Meterer) ValidateReservationPeriod(header core.PaymentMetadata, reserva } // IncrementBinUsage increments the bin usage atomically and checks for overflow -func (m *Meterer) IncrementBinUsage(ctx context.Context, header core.PaymentMetadata, reservation *core.ActiveReservation, numSymbols uint) error { +func (m *Meterer) IncrementBinUsage(ctx context.Context, header core.PaymentMetadata, reservation *core.ReservedPayment, numSymbols uint) error { symbolsCharged := m.SymbolsCharged(numSymbols) newUsage, err := m.OffchainStore.UpdateReservationBin(ctx, header.AccountID, uint64(header.ReservationPeriod), uint64(symbolsCharged)) if err != nil { @@ -170,7 +173,7 @@ func (m *Meterer) IncrementBinUsage(ctx context.Context, header core.PaymentMeta return fmt.Errorf("overflow usage exceeds bin limit") } -// GetReservationPeriod returns the current bin index by chunking time by the bin interval; +// GetReservationPeriod returns the current reservation period by chunking time by the bin interval; // bin interval used by the disperser should be public information func GetReservationPeriod(timestamp uint64, binInterval uint32) uint32 { return uint32(timestamp) / binInterval @@ -274,6 +277,6 @@ func (m *Meterer) IncrementGlobalBinUsage(ctx context.Context, symbolsCharged ui } // GetReservationBinLimit returns the bin limit for a given reservation -func (m *Meterer) GetReservationBinLimit(reservation *core.ActiveReservation) uint64 { +func (m *Meterer) GetReservationBinLimit(reservation *core.ReservedPayment) uint64 { return reservation.SymbolsPerSecond * uint64(m.ChainPaymentState.GetReservationWindow()) } diff --git a/core/meterer/meterer_test.go b/core/meterer/meterer_test.go index 30a03c9737..9ab35df545 100644 --- a/core/meterer/meterer_test.go +++ b/core/meterer/meterer_test.go @@ -32,11 +32,13 @@ var ( dynamoClient commondynamodb.Client clientConfig commonaws.ClientConfig accountID1 gethcommon.Address - account1Reservations *core.ActiveReservation + account1Reservations *core.ReservedPayment account1OnDemandPayments *core.OnDemandPayment accountID2 gethcommon.Address - account2Reservations *core.ActiveReservation + account2Reservations *core.ReservedPayment account2OnDemandPayments *core.OnDemandPayment + accountID3 gethcommon.Address + account3Reservations *core.ReservedPayment mt *meterer.Meterer deployLocalStack bool @@ -100,6 +102,11 @@ func setup(_ *testing.M) { teardown() panic("failed to generate private key") } + privateKey3, err := crypto.GenerateKey() + if err != nil { + teardown() + panic("failed to generate private key") + } logger = logging.NewNoopLogger() config := meterer.Config{ @@ -126,8 +133,10 @@ func setup(_ *testing.M) { now := uint64(time.Now().Unix()) accountID1 = crypto.PubkeyToAddress(privateKey1.PublicKey) accountID2 = crypto.PubkeyToAddress(privateKey2.PublicKey) - account1Reservations = &core.ActiveReservation{SymbolsPerSecond: 100, StartTimestamp: now + 1200, EndTimestamp: now + 1800, QuorumSplits: []byte{50, 50}, QuorumNumbers: []uint8{0, 1}} - account2Reservations = &core.ActiveReservation{SymbolsPerSecond: 200, StartTimestamp: now - 120, EndTimestamp: now + 180, QuorumSplits: []byte{30, 70}, QuorumNumbers: []uint8{0, 1}} + accountID3 = crypto.PubkeyToAddress(privateKey3.PublicKey) + account1Reservations = &core.ReservedPayment{SymbolsPerSecond: 100, StartTimestamp: now - 120, EndTimestamp: now + 180, QuorumSplits: []byte{50, 50}, QuorumNumbers: []uint8{0, 1}} + account2Reservations = &core.ReservedPayment{SymbolsPerSecond: 200, StartTimestamp: now - 120, EndTimestamp: now + 180, QuorumSplits: []byte{30, 70}, QuorumNumbers: []uint8{0, 1}} + account3Reservations = &core.ReservedPayment{SymbolsPerSecond: 200, StartTimestamp: now + 120, EndTimestamp: now + 180, QuorumSplits: []byte{30, 70}, QuorumNumbers: []uint8{0, 1}} account1OnDemandPayments = &core.OnDemandPayment{CumulativePayment: big.NewInt(3864)} account2OnDemandPayments = &core.OnDemandPayment{CumulativePayment: big.NewInt(2000)} @@ -177,13 +186,16 @@ func TestMetererReservations(t *testing.T) { reservationPeriod := meterer.GetReservationPeriod(uint64(time.Now().Unix()), mt.ChainPaymentState.GetReservationWindow()) quoromNumbers := []uint8{0, 1} - paymentChainState.On("GetActiveReservationByAccount", testifymock.Anything, testifymock.MatchedBy(func(account gethcommon.Address) bool { + paymentChainState.On("GetReservedPaymentByAccount", testifymock.Anything, testifymock.MatchedBy(func(account gethcommon.Address) bool { return account == accountID1 })).Return(account1Reservations, nil) - paymentChainState.On("GetActiveReservationByAccount", testifymock.Anything, testifymock.MatchedBy(func(account gethcommon.Address) bool { + paymentChainState.On("GetReservedPaymentByAccount", testifymock.Anything, testifymock.MatchedBy(func(account gethcommon.Address) bool { return account == accountID2 })).Return(account2Reservations, nil) - paymentChainState.On("GetActiveReservationByAccount", testifymock.Anything, testifymock.Anything).Return(&core.ActiveReservation{}, fmt.Errorf("reservation not found")) + paymentChainState.On("GetReservedPaymentByAccount", testifymock.Anything, testifymock.MatchedBy(func(account gethcommon.Address) bool { + return account == accountID3 + })).Return(account3Reservations, nil) + paymentChainState.On("GetReservedPaymentByAccount", testifymock.Anything, testifymock.Anything).Return(&core.ReservedPayment{}, fmt.Errorf("reservation not found")) // test invalid quorom ID header := createPaymentHeader(1, big.NewInt(0), accountID1) @@ -209,10 +221,15 @@ func TestMetererReservations(t *testing.T) { err = mt.MeterRequest(ctx, *header, 1000, []uint8{0, 1, 2}) assert.ErrorContains(t, err, "failed to get active reservation by account: reservation not found") - // test invalid bin index - header = createPaymentHeader(reservationPeriod, big.NewInt(0), accountID1) + // test inactive reservation + header = createPaymentHeader(reservationPeriod, big.NewInt(0), accountID3) + err = mt.MeterRequest(ctx, *header, 1000, []uint8{0}) + assert.ErrorContains(t, err, "reservation not active") + + // test invalid reservation period + header = createPaymentHeader(reservationPeriod-3, big.NewInt(0), accountID1) err = mt.MeterRequest(ctx, *header, 2000, quoromNumbers) - assert.ErrorContains(t, err, "invalid bin index for reservation") + assert.ErrorContains(t, err, "invalid reservation period for reservation") // test bin usage metering symbolLength := uint(20) diff --git a/core/meterer/onchain_state.go b/core/meterer/onchain_state.go index 3a9ba34f3d..85d6a9c932 100644 --- a/core/meterer/onchain_state.go +++ b/core/meterer/onchain_state.go @@ -2,7 +2,6 @@ package meterer import ( "context" - "fmt" "sync" "sync/atomic" @@ -16,7 +15,7 @@ import ( // OnchainPaymentState is an interface for getting information about the current chain state for payments. type OnchainPayment interface { RefreshOnchainPaymentState(ctx context.Context, tx *eth.Reader) error - GetActiveReservationByAccount(ctx context.Context, accountID gethcommon.Address) (*core.ActiveReservation, error) + GetReservedPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.ReservedPayment, error) GetOnDemandPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.OnDemandPayment, error) GetOnDemandQuorumNumbers(ctx context.Context) ([]uint8, error) GetGlobalSymbolsPerSecond() uint64 @@ -31,8 +30,8 @@ var _ OnchainPayment = (*OnchainPaymentState)(nil) type OnchainPaymentState struct { tx *eth.Reader - ActiveReservations map[gethcommon.Address]*core.ActiveReservation - OnDemandPayments map[gethcommon.Address]*core.OnDemandPayment + ReservedPayments map[gethcommon.Address]*core.ReservedPayment + OnDemandPayments map[gethcommon.Address]*core.OnDemandPayment ReservationsLock sync.RWMutex OnDemandLocks sync.RWMutex @@ -57,7 +56,7 @@ func NewOnchainPaymentState(ctx context.Context, tx *eth.Reader) (*OnchainPaymen state := OnchainPaymentState{ tx: tx, - ActiveReservations: make(map[gethcommon.Address]*core.ActiveReservation), + ReservedPayments: make(map[gethcommon.Address]*core.ReservedPayment), OnDemandPayments: make(map[gethcommon.Address]*core.OnDemandPayment), PaymentVaultParams: atomic.Pointer[PaymentVaultParams]{}, } @@ -116,16 +115,16 @@ func (pcs *OnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context, pcs.PaymentVaultParams.Store(paymentVaultParams) pcs.ReservationsLock.Lock() - accountIDs := make([]gethcommon.Address, 0, len(pcs.ActiveReservations)) - for accountID := range pcs.ActiveReservations { + accountIDs := make([]gethcommon.Address, 0, len(pcs.ReservedPayments)) + for accountID := range pcs.ReservedPayments { accountIDs = append(accountIDs, accountID) } - activeReservations, err := tx.GetActiveReservations(ctx, accountIDs) + reservedPayments, err := tx.GetReservedPayments(ctx, accountIDs) if err != nil { return err } - pcs.ActiveReservations = activeReservations + pcs.ReservedPayments = reservedPayments pcs.ReservationsLock.Unlock() pcs.OnDemandLocks.Lock() @@ -144,31 +143,23 @@ func (pcs *OnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context, return nil } -// GetActiveReservationByAccount returns a pointer to the active reservation for the given account ID; no writes will be made to the reservation -func (pcs *OnchainPaymentState) GetActiveReservationByAccount(ctx context.Context, accountID gethcommon.Address) (*core.ActiveReservation, error) { +// GetReservedPaymentByAccount returns a pointer to the active reservation for the given account ID; no writes will be made to the reservation +func (pcs *OnchainPaymentState) GetReservedPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.ReservedPayment, error) { pcs.ReservationsLock.RLock() defer pcs.ReservationsLock.RUnlock() - if reservation, ok := (pcs.ActiveReservations)[accountID]; ok { + if reservation, ok := (pcs.ReservedPayments)[accountID]; ok { return reservation, nil } // pulls the chain state - res, err := pcs.tx.GetActiveReservationByAccount(ctx, accountID) + res, err := pcs.tx.GetReservedPaymentByAccount(ctx, accountID) if err != nil { return nil, err } pcs.ReservationsLock.Lock() - (pcs.ActiveReservations)[accountID] = res + (pcs.ReservedPayments)[accountID] = res pcs.ReservationsLock.Unlock() - return res, nil -} -// GetActiveReservationByAccountOnChain returns on-chain reservation for the given account ID -func (pcs *OnchainPaymentState) GetActiveReservationByAccountOnChain(ctx context.Context, accountID gethcommon.Address) (*core.ActiveReservation, error) { - res, err := pcs.tx.GetActiveReservationByAccount(ctx, accountID) - if err != nil { - return nil, fmt.Errorf("reservation account not found on-chain: %w", err) - } return res, nil } @@ -191,14 +182,6 @@ func (pcs *OnchainPaymentState) GetOnDemandPaymentByAccount(ctx context.Context, return res, nil } -func (pcs *OnchainPaymentState) GetOnDemandPaymentByAccountOnChain(ctx context.Context, accountID gethcommon.Address) (*core.OnDemandPayment, error) { - res, err := pcs.tx.GetOnDemandPaymentByAccount(ctx, accountID) - if err != nil { - return nil, fmt.Errorf("on-demand not found on-chain: %w", err) - } - return res, nil -} - func (pcs *OnchainPaymentState) GetOnDemandQuorumNumbers(ctx context.Context) ([]uint8, error) { blockNumber, err := pcs.tx.GetCurrentBlockNumber(ctx) if err != nil { diff --git a/core/meterer/onchain_state_test.go b/core/meterer/onchain_state_test.go index d7fca84845..468296be87 100644 --- a/core/meterer/onchain_state_test.go +++ b/core/meterer/onchain_state_test.go @@ -14,7 +14,7 @@ import ( ) var ( - dummyActiveReservation = &core.ActiveReservation{ + dummyReservedPayment = &core.ReservedPayment{ SymbolsPerSecond: 100, StartTimestamp: 1000, EndTimestamp: 2000, @@ -43,14 +43,14 @@ func TestGetCurrentBlockNumber(t *testing.T) { assert.Equal(t, uint32(1000), blockNumber) } -func TestGetActiveReservationByAccount(t *testing.T) { +func TestGetReservedPaymentByAccount(t *testing.T) { mockState := &mock.MockOnchainPaymentState{} ctx := context.Background() - mockState.On("GetActiveReservationByAccount", testifymock.Anything, testifymock.Anything).Return(dummyActiveReservation, nil) + mockState.On("GetReservedPaymentByAccount", testifymock.Anything, testifymock.Anything).Return(dummyReservedPayment, nil) - reservation, err := mockState.GetActiveReservationByAccount(ctx, gethcommon.Address{}) + reservation, err := mockState.GetReservedPaymentByAccount(ctx, gethcommon.Address{}) assert.NoError(t, err) - assert.Equal(t, dummyActiveReservation, reservation) + assert.Equal(t, dummyReservedPayment, reservation) } func TestGetOnDemandPaymentByAccount(t *testing.T) { diff --git a/core/mock/payment_state.go b/core/mock/payment_state.go index 00c34b326e..8af76628f0 100644 --- a/core/mock/payment_state.go +++ b/core/mock/payment_state.go @@ -30,11 +30,11 @@ func (m *MockOnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context return args.Error(0) } -func (m *MockOnchainPaymentState) GetActiveReservationByAccount(ctx context.Context, accountID gethcommon.Address) (*core.ActiveReservation, error) { +func (m *MockOnchainPaymentState) GetReservedPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.ReservedPayment, error) { args := m.Called(ctx, accountID) - var value *core.ActiveReservation + var value *core.ReservedPayment if args.Get(0) != nil { - value = args.Get(0).(*core.ActiveReservation) + value = args.Get(0).(*core.ReservedPayment) } return value, args.Error(1) } diff --git a/core/mock/writer.go b/core/mock/writer.go index 87384401bf..b28f88b5f1 100644 --- a/core/mock/writer.go +++ b/core/mock/writer.go @@ -227,16 +227,16 @@ func (t *MockWriter) PubkeyHashToOperator(ctx context.Context, operatorId core.O return result.(gethcommon.Address), args.Error(1) } -func (t *MockWriter) GetActiveReservations(ctx context.Context, accountIDs []gethcommon.Address) (map[gethcommon.Address]*core.ActiveReservation, error) { +func (t *MockWriter) GetReservedPayments(ctx context.Context, accountIDs []gethcommon.Address) (map[gethcommon.Address]*core.ReservedPayment, error) { args := t.Called() result := args.Get(0) - return result.(map[gethcommon.Address]*core.ActiveReservation), args.Error(1) + return result.(map[gethcommon.Address]*core.ReservedPayment), args.Error(1) } -func (t *MockWriter) GetActiveReservationByAccount(ctx context.Context, accountID gethcommon.Address) (*core.ActiveReservation, error) { +func (t *MockWriter) GetReservedPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.ReservedPayment, error) { args := t.Called() result := args.Get(0) - return result.(*core.ActiveReservation), args.Error(1) + return result.(*core.ReservedPayment), args.Error(1) } func (t *MockWriter) GetOnDemandPayments(ctx context.Context, accountIDs []gethcommon.Address) (map[gethcommon.Address]*core.OnDemandPayment, error) { diff --git a/disperser/apiserver/server_test.go b/disperser/apiserver/server_test.go index 29f4f74a94..1bf6629d34 100644 --- a/disperser/apiserver/server_test.go +++ b/disperser/apiserver/server_test.go @@ -761,7 +761,7 @@ func newTestServer(transactor core.Writer, testName string) *apiserver.Dispersal mockState.On("GetOnDemandPaymentByAccount", tmock.Anything, tmock.Anything).Return(&core.OnDemandPayment{ CumulativePayment: big.NewInt(3000), }, nil) - mockState.On("GetActiveReservationByAccount", tmock.Anything, tmock.Anything).Return(&core.ActiveReservation{ + mockState.On("GetReservedPaymentByAccount", tmock.Anything, tmock.Anything).Return(&core.ReservedPayment{ SymbolsPerSecond: 2048, StartTimestamp: 0, EndTimestamp: math.MaxUint32, diff --git a/disperser/apiserver/server_v2.go b/disperser/apiserver/server_v2.go index e808a813c9..bd45248f1d 100644 --- a/disperser/apiserver/server_v2.go +++ b/disperser/apiserver/server_v2.go @@ -4,11 +4,12 @@ import ( "context" "errors" "fmt" - "github.com/prometheus/client_golang/prometheus" "net" "sync/atomic" "time" + "github.com/prometheus/client_golang/prometheus" + "github.com/Layr-Labs/eigenda/api" pbcommon "github.com/Layr-Labs/eigenda/api/grpc/common" pbv1 "github.com/Layr-Labs/eigenda/api/grpc/disperser" @@ -260,7 +261,7 @@ func (s *DispersalServerV2) GetPaymentState(ctx context.Context, req *pb.GetPaym return nil, api.NewErrorNotFound("failed to get largest cumulative payment") } // on-Chain account state - reservation, err := s.meterer.ChainPaymentState.GetActiveReservationByAccount(ctx, accountID) + reservation, err := s.meterer.ChainPaymentState.GetReservedPaymentByAccount(ctx, accountID) if err != nil { return nil, api.NewErrorNotFound("failed to get active reservation") } diff --git a/disperser/apiserver/server_v2_test.go b/disperser/apiserver/server_v2_test.go index 0bd8b5997e..41409bf552 100644 --- a/disperser/apiserver/server_v2_test.go +++ b/disperser/apiserver/server_v2_test.go @@ -447,7 +447,7 @@ func newTestServerV2(t *testing.T) *testComponents { mockState.On("GetMinNumSymbols", tmock.Anything).Return(uint32(3), nil) now := uint64(time.Now().Unix()) - mockState.On("GetActiveReservationByAccount", tmock.Anything, tmock.Anything).Return(&core.ActiveReservation{SymbolsPerSecond: 100, StartTimestamp: now + 1200, EndTimestamp: now + 1800, QuorumSplits: []byte{50, 50}, QuorumNumbers: []uint8{0, 1}}, nil) + mockState.On("GetReservedPaymentByAccount", tmock.Anything, tmock.Anything).Return(&core.ReservedPayment{SymbolsPerSecond: 100, StartTimestamp: now + 1200, EndTimestamp: now + 1800, QuorumSplits: []byte{50, 50}, QuorumNumbers: []uint8{0, 1}}, nil) mockState.On("GetOnDemandPaymentByAccount", tmock.Anything, tmock.Anything).Return(&core.OnDemandPayment{CumulativePayment: big.NewInt(3864)}, nil) mockState.On("GetOnDemandQuorumNumbers", tmock.Anything).Return([]uint8{0, 1}, nil) diff --git a/test/integration_test.go b/test/integration_test.go index 5016f3598f..cb83bdf621 100644 --- a/test/integration_test.go +++ b/test/integration_test.go @@ -221,10 +221,10 @@ func mustMakeDisperser(t *testing.T, cst core.IndexedChainState, store disperser mockState := &coremock.MockOnchainPaymentState{} reservationLimit := uint64(1024) paymentLimit := big.NewInt(512) - mockState.On("GetActiveReservationByAccount", mock.Anything, mock.MatchedBy(func(account gethcommon.Address) bool { + mockState.On("GetReservedPaymentByAccount", mock.Anything, mock.MatchedBy(func(account gethcommon.Address) bool { return account == publicKey - })).Return(&core.ActiveReservation{SymbolsPerSecond: reservationLimit, StartTimestamp: 0, EndTimestamp: math.MaxUint32, QuorumSplits: []byte{50, 50}, QuorumNumbers: []uint8{0, 1}}, nil) - mockState.On("GetActiveReservationByAccount", mock.Anything, mock.Anything).Return(&core.ActiveReservation{}, errors.New("reservation not found")) + })).Return(&core.ReservedPayment{SymbolsPerSecond: reservationLimit, StartTimestamp: 0, EndTimestamp: math.MaxUint32, QuorumSplits: []byte{50, 50}, QuorumNumbers: []uint8{0, 1}}, nil) + mockState.On("GetReservedPaymentByAccount", mock.Anything, mock.Anything).Return(&core.ReservedPayment{}, errors.New("reservation not found")) mockState.On("GetOnDemandPaymentByAccount", mock.Anything, mock.MatchedBy(func(account gethcommon.Address) bool { return account == publicKey