Skip to content

Commit

Permalink
feat: global reservation period
Browse files Browse the repository at this point in the history
  • Loading branch information
hopeyen committed Dec 12, 2024
1 parent be47a6c commit df4c330
Show file tree
Hide file tree
Showing 7 changed files with 14 additions and 25 deletions.
4 changes: 2 additions & 2 deletions core/eth/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,7 @@ func (t *Reader) GetGlobalSymbolsPerSecond(ctx context.Context) (uint64, error)
return globalSymbolsPerSecond.Uint64(), nil
}

func (t *Reader) GetGlobalRateBinInterval(ctx context.Context) (uint64, error) {
func (t *Reader) GetGlobalRateBinInterval(ctx context.Context) (uint32, error) {
if t.bindings.PaymentVault == nil {
return 0, errors.New("payment vault not deployed")
}
Expand All @@ -796,7 +796,7 @@ func (t *Reader) GetGlobalRateBinInterval(ctx context.Context) (uint64, error) {
if err != nil {
return 0, err
}
return globalRateBinInterval.Uint64(), nil
return uint32(globalRateBinInterval.Uint64()), nil
}

func (t *Reader) GetMinNumSymbols(ctx context.Context) (uint32, error) {
Expand Down
18 changes: 3 additions & 15 deletions core/meterer/meterer.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,23 +259,11 @@ func (m *Meterer) SymbolsCharged(numSymbols uint) uint32 {
return uint32(core.RoundUpDivide(uint(numSymbols), uint(m.ChainPaymentState.GetMinNumSymbols()))) * m.ChainPaymentState.GetMinNumSymbols()
}

// ValidateReservationPeriod checks if the provided bin index is valid
func (m *Meterer) ValidateGlobalReservationPeriod(header core.PaymentMetadata) (uint32, error) {
// Deterministic function: local clock -> index (1second intervals)
currentReservationPeriod := uint32(time.Now().Unix())

// Valid bin indexes are either the current bin or the previous bin (allow this second or prev sec)
if header.ReservationPeriod != currentReservationPeriod && header.ReservationPeriod != (currentReservationPeriod-1) {
return 0, fmt.Errorf("invalid bin index for on-demand request")
}
return currentReservationPeriod, nil
}

// IncrementBinUsage increments the bin usage atomically and checks for overflow
func (m *Meterer) IncrementGlobalBinUsage(ctx context.Context, symbolsCharged uint64) error {
//TODO: edit globalIndex based on bin interval in a subsequent PR
globalIndex := uint64(time.Now().Unix())
newUsage, err := m.OffchainStore.UpdateGlobalBin(ctx, globalIndex, symbolsCharged)
globalPeriod := GetReservationPeriod(uint64(time.Now().Unix()), m.ChainPaymentState.GetGlobalRateBinInterval())

newUsage, err := m.OffchainStore.UpdateGlobalBin(ctx, globalPeriod, symbolsCharged)
if err != nil {
return fmt.Errorf("failed to increment global bin usage: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion core/meterer/meterer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ func TestMetererReservations(t *testing.T) {
ctx := context.Background()
paymentChainState.On("GetReservationWindow", testifymock.Anything).Return(uint32(1), nil)
paymentChainState.On("GetGlobalSymbolsPerSecond", testifymock.Anything).Return(uint64(1009), nil)
paymentChainState.On("GetGlobalRateBinInterval", testifymock.Anything).Return(uint64(1), nil)
paymentChainState.On("GetGlobalRateBinInterval", testifymock.Anything).Return(uint32(1), nil)
paymentChainState.On("GetMinNumSymbols", testifymock.Anything).Return(uint32(3), nil)

reservationPeriod := meterer.GetReservationPeriod(uint64(time.Now().Unix()), mt.ChainPaymentState.GetReservationWindow())
Expand Down
4 changes: 2 additions & 2 deletions core/meterer/offchain_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ func (s *OffchainStore) UpdateReservationBin(ctx context.Context, accountID stri
return binUsageValue, nil
}

func (s *OffchainStore) UpdateGlobalBin(ctx context.Context, reservationPeriod uint64, size uint64) (uint64, error) {
func (s *OffchainStore) UpdateGlobalBin(ctx context.Context, reservationPeriod uint32, size uint64) (uint64, error) {
key := map[string]types.AttributeValue{
"ReservationPeriod": &types.AttributeValueMemberN{Value: strconv.FormatUint(reservationPeriod, 10)},
"ReservationPeriod": &types.AttributeValueMemberN{Value: strconv.FormatUint(uint64(reservationPeriod), 10)},
}

res, err := s.dynamoClient.IncrementBy(ctx, s.globalBinTableName, key, "BinUsage", size)
Expand Down
6 changes: 3 additions & 3 deletions core/meterer/onchain_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type OnchainPayment interface {
GetOnDemandPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.OnDemandPayment, error)
GetOnDemandQuorumNumbers(ctx context.Context) ([]uint8, error)
GetGlobalSymbolsPerSecond() uint64
GetGlobalRateBinInterval() uint64
GetGlobalRateBinInterval() uint32
GetMinNumSymbols() uint32
GetPricePerSymbol() uint32
GetReservationWindow() uint32
Expand All @@ -42,7 +42,7 @@ type OnchainPaymentState struct {

type PaymentVaultParams struct {
GlobalSymbolsPerSecond uint64
GlobalRateBinInterval uint64
GlobalRateBinInterval uint32
MinNumSymbols uint32
PricePerSymbol uint32
ReservationWindow uint32
Expand Down Expand Up @@ -211,7 +211,7 @@ func (pcs *OnchainPaymentState) GetGlobalSymbolsPerSecond() uint64 {
return pcs.PaymentVaultParams.Load().GlobalSymbolsPerSecond
}

func (pcs *OnchainPaymentState) GetGlobalRateBinInterval() uint64 {
func (pcs *OnchainPaymentState) GetGlobalRateBinInterval() uint32 {
return pcs.PaymentVaultParams.Load().GlobalRateBinInterval
}

Expand Down
4 changes: 2 additions & 2 deletions core/mock/payment_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ func (m *MockOnchainPaymentState) GetGlobalSymbolsPerSecond() uint64 {
return args.Get(0).(uint64)
}

func (m *MockOnchainPaymentState) GetGlobalRateBinInterval() uint64 {
func (m *MockOnchainPaymentState) GetGlobalRateBinInterval() uint32 {
args := m.Called()
return args.Get(0).(uint64)
return args.Get(0).(uint32)
}

func (m *MockOnchainPaymentState) GetMinNumSymbols() uint32 {
Expand Down
1 change: 1 addition & 0 deletions disperser/apiserver/server_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ func newTestServerV2(t *testing.T) *testComponents {
mockState.On("GetReservationWindow", tmock.Anything).Return(uint32(1), nil)
mockState.On("GetPricePerSymbol", tmock.Anything).Return(uint32(2), nil)
mockState.On("GetGlobalSymbolsPerSecond", tmock.Anything).Return(uint64(1009), nil)
mockState.On("GetGlobalRateBinInterval", tmock.Anything).Return(uint32(1), nil)
mockState.On("GetMinNumSymbols", tmock.Anything).Return(uint32(3), nil)

now := uint64(time.Now().Unix())
Expand Down

0 comments on commit df4c330

Please sign in to comment.