From df4c330dca83e372c95efaaf9ad2d91cb82c2457 Mon Sep 17 00:00:00 2001 From: hopeyen Date: Thu, 12 Dec 2024 09:34:33 -0800 Subject: [PATCH] feat: global reservation period --- core/eth/reader.go | 4 ++-- core/meterer/meterer.go | 18 +++--------------- core/meterer/meterer_test.go | 2 +- core/meterer/offchain_store.go | 4 ++-- core/meterer/onchain_state.go | 6 +++--- core/mock/payment_state.go | 4 ++-- disperser/apiserver/server_v2_test.go | 1 + 7 files changed, 14 insertions(+), 25 deletions(-) diff --git a/core/eth/reader.go b/core/eth/reader.go index af13ae441..ac38950fd 100644 --- a/core/eth/reader.go +++ b/core/eth/reader.go @@ -786,7 +786,7 @@ func (t *Reader) GetGlobalSymbolsPerSecond(ctx context.Context) (uint64, error) return globalSymbolsPerSecond.Uint64(), nil } -func (t *Reader) GetGlobalRateBinInterval(ctx context.Context) (uint64, error) { +func (t *Reader) GetGlobalRateBinInterval(ctx context.Context) (uint32, error) { if t.bindings.PaymentVault == nil { return 0, errors.New("payment vault not deployed") } @@ -796,7 +796,7 @@ func (t *Reader) GetGlobalRateBinInterval(ctx context.Context) (uint64, error) { if err != nil { return 0, err } - return globalRateBinInterval.Uint64(), nil + return uint32(globalRateBinInterval.Uint64()), nil } func (t *Reader) GetMinNumSymbols(ctx context.Context) (uint32, error) { diff --git a/core/meterer/meterer.go b/core/meterer/meterer.go index a27e7853e..1f0e1c5ae 100644 --- a/core/meterer/meterer.go +++ b/core/meterer/meterer.go @@ -259,23 +259,11 @@ func (m *Meterer) SymbolsCharged(numSymbols uint) uint32 { return uint32(core.RoundUpDivide(uint(numSymbols), uint(m.ChainPaymentState.GetMinNumSymbols()))) * m.ChainPaymentState.GetMinNumSymbols() } -// ValidateReservationPeriod checks if the provided bin index is valid -func (m *Meterer) ValidateGlobalReservationPeriod(header core.PaymentMetadata) (uint32, error) { - // Deterministic function: local clock -> index (1second intervals) - currentReservationPeriod := uint32(time.Now().Unix()) - - // Valid bin indexes are either the current bin or the previous bin (allow this second or prev sec) - if header.ReservationPeriod != currentReservationPeriod && header.ReservationPeriod != (currentReservationPeriod-1) { - return 0, fmt.Errorf("invalid bin index for on-demand request") - } - return currentReservationPeriod, nil -} - // IncrementBinUsage increments the bin usage atomically and checks for overflow func (m *Meterer) IncrementGlobalBinUsage(ctx context.Context, symbolsCharged uint64) error { - //TODO: edit globalIndex based on bin interval in a subsequent PR - globalIndex := uint64(time.Now().Unix()) - newUsage, err := m.OffchainStore.UpdateGlobalBin(ctx, globalIndex, symbolsCharged) + globalPeriod := GetReservationPeriod(uint64(time.Now().Unix()), m.ChainPaymentState.GetGlobalRateBinInterval()) + + newUsage, err := m.OffchainStore.UpdateGlobalBin(ctx, globalPeriod, symbolsCharged) if err != nil { return fmt.Errorf("failed to increment global bin usage: %w", err) } diff --git a/core/meterer/meterer_test.go b/core/meterer/meterer_test.go index 38132596b..30a03c973 100644 --- a/core/meterer/meterer_test.go +++ b/core/meterer/meterer_test.go @@ -171,7 +171,7 @@ func TestMetererReservations(t *testing.T) { ctx := context.Background() paymentChainState.On("GetReservationWindow", testifymock.Anything).Return(uint32(1), nil) paymentChainState.On("GetGlobalSymbolsPerSecond", testifymock.Anything).Return(uint64(1009), nil) - paymentChainState.On("GetGlobalRateBinInterval", testifymock.Anything).Return(uint64(1), nil) + paymentChainState.On("GetGlobalRateBinInterval", testifymock.Anything).Return(uint32(1), nil) paymentChainState.On("GetMinNumSymbols", testifymock.Anything).Return(uint32(3), nil) reservationPeriod := meterer.GetReservationPeriod(uint64(time.Now().Unix()), mt.ChainPaymentState.GetReservationWindow()) diff --git a/core/meterer/offchain_store.go b/core/meterer/offchain_store.go index 3c3116f1b..6b213a495 100644 --- a/core/meterer/offchain_store.go +++ b/core/meterer/offchain_store.go @@ -93,9 +93,9 @@ func (s *OffchainStore) UpdateReservationBin(ctx context.Context, accountID stri return binUsageValue, nil } -func (s *OffchainStore) UpdateGlobalBin(ctx context.Context, reservationPeriod uint64, size uint64) (uint64, error) { +func (s *OffchainStore) UpdateGlobalBin(ctx context.Context, reservationPeriod uint32, size uint64) (uint64, error) { key := map[string]types.AttributeValue{ - "ReservationPeriod": &types.AttributeValueMemberN{Value: strconv.FormatUint(reservationPeriod, 10)}, + "ReservationPeriod": &types.AttributeValueMemberN{Value: strconv.FormatUint(uint64(reservationPeriod), 10)}, } res, err := s.dynamoClient.IncrementBy(ctx, s.globalBinTableName, key, "BinUsage", size) diff --git a/core/meterer/onchain_state.go b/core/meterer/onchain_state.go index cdfaef457..3a9ba34f3 100644 --- a/core/meterer/onchain_state.go +++ b/core/meterer/onchain_state.go @@ -20,7 +20,7 @@ type OnchainPayment interface { GetOnDemandPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.OnDemandPayment, error) GetOnDemandQuorumNumbers(ctx context.Context) ([]uint8, error) GetGlobalSymbolsPerSecond() uint64 - GetGlobalRateBinInterval() uint64 + GetGlobalRateBinInterval() uint32 GetMinNumSymbols() uint32 GetPricePerSymbol() uint32 GetReservationWindow() uint32 @@ -42,7 +42,7 @@ type OnchainPaymentState struct { type PaymentVaultParams struct { GlobalSymbolsPerSecond uint64 - GlobalRateBinInterval uint64 + GlobalRateBinInterval uint32 MinNumSymbols uint32 PricePerSymbol uint32 ReservationWindow uint32 @@ -211,7 +211,7 @@ func (pcs *OnchainPaymentState) GetGlobalSymbolsPerSecond() uint64 { return pcs.PaymentVaultParams.Load().GlobalSymbolsPerSecond } -func (pcs *OnchainPaymentState) GetGlobalRateBinInterval() uint64 { +func (pcs *OnchainPaymentState) GetGlobalRateBinInterval() uint32 { return pcs.PaymentVaultParams.Load().GlobalRateBinInterval } diff --git a/core/mock/payment_state.go b/core/mock/payment_state.go index e4c89784d..00c34b326 100644 --- a/core/mock/payment_state.go +++ b/core/mock/payment_state.go @@ -62,9 +62,9 @@ func (m *MockOnchainPaymentState) GetGlobalSymbolsPerSecond() uint64 { return args.Get(0).(uint64) } -func (m *MockOnchainPaymentState) GetGlobalRateBinInterval() uint64 { +func (m *MockOnchainPaymentState) GetGlobalRateBinInterval() uint32 { args := m.Called() - return args.Get(0).(uint64) + return args.Get(0).(uint32) } func (m *MockOnchainPaymentState) GetMinNumSymbols() uint32 { diff --git a/disperser/apiserver/server_v2_test.go b/disperser/apiserver/server_v2_test.go index 4fa233d3d..0bd8b5997 100644 --- a/disperser/apiserver/server_v2_test.go +++ b/disperser/apiserver/server_v2_test.go @@ -443,6 +443,7 @@ func newTestServerV2(t *testing.T) *testComponents { mockState.On("GetReservationWindow", tmock.Anything).Return(uint32(1), nil) mockState.On("GetPricePerSymbol", tmock.Anything).Return(uint32(2), nil) mockState.On("GetGlobalSymbolsPerSecond", tmock.Anything).Return(uint64(1009), nil) + mockState.On("GetGlobalRateBinInterval", tmock.Anything).Return(uint32(1), nil) mockState.On("GetMinNumSymbols", tmock.Anything).Return(uint32(3), nil) now := uint64(time.Now().Unix())