Skip to content

Commit

Permalink
refactor: rename ActiveReservation to ReservedPayment
Browse files Browse the repository at this point in the history
  • Loading branch information
hopeyen committed Dec 12, 2024
1 parent 6588b33 commit b41ff1d
Show file tree
Hide file tree
Showing 16 changed files with 70 additions and 87 deletions.
4 changes: 2 additions & 2 deletions api/clients/accountant.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ var requiredQuorums = []uint8{0, 1}
type Accountant struct {
// on-chain states
accountID string
reservation *core.ActiveReservation
reservation *core.ReservedPayment
onDemand *core.OnDemandPayment
reservationWindow uint32
pricePerSymbol uint32
Expand All @@ -39,7 +39,7 @@ type BinRecord struct {
Usage uint64
}

func NewAccountant(accountID string, reservation *core.ActiveReservation, onDemand *core.OnDemandPayment, reservationWindow uint32, pricePerSymbol uint32, minNumSymbols uint32, numBins uint32) *Accountant {
func NewAccountant(accountID string, reservation *core.ReservedPayment, onDemand *core.OnDemandPayment, reservationWindow uint32, pricePerSymbol uint32, minNumSymbols uint32, numBins uint32) *Accountant {
//TODO: client storage; currently every instance starts fresh but on-chain or a small store makes more sense
// Also client is currently responsible for supplying network params, we need to add RPC in order to be automatic
// There's a subsequent PR that handles populating the accountant with on-chain state from the disperser
Expand Down
18 changes: 9 additions & 9 deletions api/clients/accountant_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ const numBins = uint32(3)
const salt = uint32(0)

func TestNewAccountant(t *testing.T) {
reservation := &core.ActiveReservation{
reservation := &core.ReservedPayment{
SymbolsPerSecond: 100,
StartTimestamp: 100,
EndTimestamp: 200,
Expand Down Expand Up @@ -48,7 +48,7 @@ func TestNewAccountant(t *testing.T) {
}

func TestAccountBlob_Reservation(t *testing.T) {
reservation := &core.ActiveReservation{
reservation := &core.ReservedPayment{
SymbolsPerSecond: 200,
StartTimestamp: 100,
EndTimestamp: 200,
Expand Down Expand Up @@ -96,7 +96,7 @@ func TestAccountBlob_Reservation(t *testing.T) {
}

func TestAccountBlob_OnDemand(t *testing.T) {
reservation := &core.ActiveReservation{
reservation := &core.ReservedPayment{
SymbolsPerSecond: 200,
StartTimestamp: 100,
EndTimestamp: 200,
Expand Down Expand Up @@ -130,7 +130,7 @@ func TestAccountBlob_OnDemand(t *testing.T) {
}

func TestAccountBlob_InsufficientOnDemand(t *testing.T) {
reservation := &core.ActiveReservation{}
reservation := &core.ReservedPayment{}
onDemand := &core.OnDemandPayment{
CumulativePayment: big.NewInt(500),
}
Expand All @@ -152,7 +152,7 @@ func TestAccountBlob_InsufficientOnDemand(t *testing.T) {
}

func TestAccountBlobCallSeries(t *testing.T) {
reservation := &core.ActiveReservation{
reservation := &core.ReservedPayment{
SymbolsPerSecond: 200,
StartTimestamp: 100,
EndTimestamp: 200,
Expand Down Expand Up @@ -200,7 +200,7 @@ func TestAccountBlobCallSeries(t *testing.T) {
}

func TestAccountBlob_BinRotation(t *testing.T) {
reservation := &core.ActiveReservation{
reservation := &core.ReservedPayment{
SymbolsPerSecond: 1000,
StartTimestamp: 100,
EndTimestamp: 200,
Expand Down Expand Up @@ -242,7 +242,7 @@ func TestAccountBlob_BinRotation(t *testing.T) {
}

func TestConcurrentBinRotationAndAccountBlob(t *testing.T) {
reservation := &core.ActiveReservation{
reservation := &core.ReservedPayment{
SymbolsPerSecond: 1000,
StartTimestamp: 100,
EndTimestamp: 200,
Expand Down Expand Up @@ -284,7 +284,7 @@ func TestConcurrentBinRotationAndAccountBlob(t *testing.T) {
}

func TestAccountBlob_ReservationWithOneOverflow(t *testing.T) {
reservation := &core.ActiveReservation{
reservation := &core.ReservedPayment{
SymbolsPerSecond: 200,
StartTimestamp: 100,
EndTimestamp: 200,
Expand Down Expand Up @@ -332,7 +332,7 @@ func TestAccountBlob_ReservationWithOneOverflow(t *testing.T) {
}

func TestAccountBlob_ReservationOverflowReset(t *testing.T) {
reservation := &core.ActiveReservation{
reservation := &core.ReservedPayment{
SymbolsPerSecond: 1000,
StartTimestamp: 100,
EndTimestamp: 200,
Expand Down
8 changes: 4 additions & 4 deletions core/chainio.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,11 @@ type Reader interface {
// GetAllVersionedBlobParams returns the blob version parameters for all blob versions at the given block number.
GetAllVersionedBlobParams(ctx context.Context) (map[uint16]*BlobVersionParameters, error)

// GetActiveReservations returns active reservations (end timestamp > current timestamp)
GetActiveReservations(ctx context.Context, accountIDs []gethcommon.Address) (map[gethcommon.Address]*ActiveReservation, error)
// GetReservedPayments returns active reservations (end timestamp > current timestamp)
GetReservedPayments(ctx context.Context, accountIDs []gethcommon.Address) (map[gethcommon.Address]*ReservedPayment, error)

// GetActiveReservationByAccount returns active reservation by account ID
GetActiveReservationByAccount(ctx context.Context, accountID gethcommon.Address) (*ActiveReservation, error)
// GetReservedPaymentByAccount returns active reservation by account ID
GetReservedPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*ReservedPayment, error)

// GetOnDemandPayments returns all on-demand payments
GetOnDemandPayments(ctx context.Context, accountIDs []gethcommon.Address) (map[gethcommon.Address]*OnDemandPayment, error)
Expand Down
2 changes: 1 addition & 1 deletion core/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ func ConvertToPaymentMetadata(ph *commonpb.PaymentHeader) *PaymentMetadata {

// OperatorInfo contains information about an operator which is stored on the blockchain state,
// corresponding to a particular quorum
type ActiveReservation struct {
type ReservedPayment struct {
SymbolsPerSecond uint64 // reserve number of symbols per second
//TODO: we are not using start and end timestamp, add check or remove
StartTimestamp uint64 // Unix timestamp that's valid for basically eternity
Expand Down
10 changes: 5 additions & 5 deletions core/eth/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -690,11 +690,11 @@ func (t *Reader) GetAllVersionedBlobParams(ctx context.Context) (map[uint16]*cor
return res, nil
}

func (t *Reader) GetActiveReservations(ctx context.Context, accountIDs []gethcommon.Address) (map[gethcommon.Address]*core.ActiveReservation, error) {
func (t *Reader) GetReservedPayments(ctx context.Context, accountIDs []gethcommon.Address) (map[gethcommon.Address]*core.ReservedPayment, error) {
if t.bindings.PaymentVault == nil {
return nil, errors.New("payment vault not deployed")
}
reservationsMap := make(map[gethcommon.Address]*core.ActiveReservation)
reservationsMap := make(map[gethcommon.Address]*core.ReservedPayment)
reservations, err := t.bindings.PaymentVault.GetReservations(&bind.CallOpts{
Context: ctx,
}, accountIDs)
Expand All @@ -704,7 +704,7 @@ func (t *Reader) GetActiveReservations(ctx context.Context, accountIDs []gethcom

// since reservations are returned in the same order as the accountIDs, we can directly map them
for i, reservation := range reservations {
res, err := ConvertToActiveReservation(reservation)
res, err := ConvertToReservedPayment(reservation)
if err != nil {
t.logger.Warn("failed to get active reservation", "account", accountIDs[i], "err", err)
continue
Expand All @@ -716,7 +716,7 @@ func (t *Reader) GetActiveReservations(ctx context.Context, accountIDs []gethcom
return reservationsMap, nil
}

func (t *Reader) GetActiveReservationByAccount(ctx context.Context, accountID gethcommon.Address) (*core.ActiveReservation, error) {
func (t *Reader) GetReservedPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.ReservedPayment, error) {
if t.bindings.PaymentVault == nil {
return nil, errors.New("payment vault not deployed")
}
Expand All @@ -726,7 +726,7 @@ func (t *Reader) GetActiveReservationByAccount(ctx context.Context, accountID ge
if err != nil {
return nil, err
}
return ConvertToActiveReservation(reservation)
return ConvertToReservedPayment(reservation)
}

func (t *Reader) GetOnDemandPayments(ctx context.Context, accountIDs []gethcommon.Address) (map[gethcommon.Address]*core.OnDemandPayment, error) {
Expand Down
6 changes: 3 additions & 3 deletions core/eth/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,14 @@ func isZeroValuedReservation(reservation paymentvault.IPaymentVaultReservation)
len(reservation.QuorumSplits) == 0
}

// ConvertToActiveReservation converts a upstream binding data structure to local definition.
// ConvertToReservedPayment converts a upstream binding data structure to local definition.
// Returns an error if the input reservation is zero-valued.
func ConvertToActiveReservation(reservation paymentvault.IPaymentVaultReservation) (*core.ActiveReservation, error) {
func ConvertToReservedPayment(reservation paymentvault.IPaymentVaultReservation) (*core.ReservedPayment, error) {
if isZeroValuedReservation(reservation) {
return nil, fmt.Errorf("reservation is not a valid active reservation")
}

return &core.ActiveReservation{
return &core.ReservedPayment{
SymbolsPerSecond: reservation.SymbolsPerSecond,
StartTimestamp: reservation.StartTimestamp,
EndTimestamp: reservation.EndTimestamp,
Expand Down
12 changes: 6 additions & 6 deletions core/meterer/meterer.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func (m *Meterer) MeterRequest(ctx context.Context, header core.PaymentMetadata,
accountID := gethcommon.HexToAddress(header.AccountID)
// Validate against the payment method
if header.CumulativePayment.Sign() == 0 {
reservation, err := m.ChainPaymentState.GetActiveReservationByAccount(ctx, accountID)
reservation, err := m.ChainPaymentState.GetReservedPaymentByAccount(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to get active reservation by account: %w", err)
}
Expand All @@ -97,7 +97,7 @@ func (m *Meterer) MeterRequest(ctx context.Context, header core.PaymentMetadata,
}

// ServeReservationRequest handles the rate limiting logic for incoming requests
func (m *Meterer) ServeReservationRequest(ctx context.Context, header core.PaymentMetadata, reservation *core.ActiveReservation, numSymbols uint, quorumNumbers []uint8) error {
func (m *Meterer) ServeReservationRequest(ctx context.Context, header core.PaymentMetadata, reservation *core.ReservedPayment, numSymbols uint, quorumNumbers []uint8) error {
if err := m.ValidateQuorum(quorumNumbers, reservation.QuorumNumbers); err != nil {
return fmt.Errorf("invalid quorum for reservation: %w", err)
}
Expand All @@ -122,7 +122,7 @@ func (m *Meterer) ValidateQuorum(headerQuorums []uint8, allowedQuorums []uint8)
return fmt.Errorf("no quorum params in blob header")
}

// check that all the quorum ids are in ActiveReservation's
// check that all the quorum ids are in ReservedPayment's
for _, q := range headerQuorums {
if !slices.Contains(allowedQuorums, q) {
// fail the entire request if there's a quorum number mismatch
Expand All @@ -133,7 +133,7 @@ func (m *Meterer) ValidateQuorum(headerQuorums []uint8, allowedQuorums []uint8)
}

// ValidateReservationPeriod checks if the provided bin index is valid
func (m *Meterer) ValidateReservationPeriod(header core.PaymentMetadata, reservation *core.ActiveReservation) bool {
func (m *Meterer) ValidateReservationPeriod(header core.PaymentMetadata, reservation *core.ReservedPayment) bool {
now := uint64(time.Now().Unix())
reservationWindow := m.ChainPaymentState.GetReservationWindow()
currentReservationPeriod := GetReservationPeriod(now, reservationWindow)
Expand All @@ -145,7 +145,7 @@ func (m *Meterer) ValidateReservationPeriod(header core.PaymentMetadata, reserva
}

// IncrementBinUsage increments the bin usage atomically and checks for overflow
func (m *Meterer) IncrementBinUsage(ctx context.Context, header core.PaymentMetadata, reservation *core.ActiveReservation, numSymbols uint) error {
func (m *Meterer) IncrementBinUsage(ctx context.Context, header core.PaymentMetadata, reservation *core.ReservedPayment, numSymbols uint) error {
symbolsCharged := m.SymbolsCharged(numSymbols)
newUsage, err := m.OffchainStore.UpdateReservationBin(ctx, header.AccountID, uint64(header.ReservationPeriod), uint64(symbolsCharged))
if err != nil {
Expand Down Expand Up @@ -274,6 +274,6 @@ func (m *Meterer) IncrementGlobalBinUsage(ctx context.Context, symbolsCharged ui
}

// GetReservationBinLimit returns the bin limit for a given reservation
func (m *Meterer) GetReservationBinLimit(reservation *core.ActiveReservation) uint64 {
func (m *Meterer) GetReservationBinLimit(reservation *core.ReservedPayment) uint64 {
return reservation.SymbolsPerSecond * uint64(m.ChainPaymentState.GetReservationWindow())
}
14 changes: 7 additions & 7 deletions core/meterer/meterer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ var (
dynamoClient commondynamodb.Client
clientConfig commonaws.ClientConfig
accountID1 gethcommon.Address
account1Reservations *core.ActiveReservation
account1Reservations *core.ReservedPayment
account1OnDemandPayments *core.OnDemandPayment
accountID2 gethcommon.Address
account2Reservations *core.ActiveReservation
account2Reservations *core.ReservedPayment
account2OnDemandPayments *core.OnDemandPayment
mt *meterer.Meterer

Expand Down Expand Up @@ -126,8 +126,8 @@ func setup(_ *testing.M) {
now := uint64(time.Now().Unix())
accountID1 = crypto.PubkeyToAddress(privateKey1.PublicKey)
accountID2 = crypto.PubkeyToAddress(privateKey2.PublicKey)
account1Reservations = &core.ActiveReservation{SymbolsPerSecond: 100, StartTimestamp: now + 1200, EndTimestamp: now + 1800, QuorumSplits: []byte{50, 50}, QuorumNumbers: []uint8{0, 1}}
account2Reservations = &core.ActiveReservation{SymbolsPerSecond: 200, StartTimestamp: now - 120, EndTimestamp: now + 180, QuorumSplits: []byte{30, 70}, QuorumNumbers: []uint8{0, 1}}
account1Reservations = &core.ReservedPayment{SymbolsPerSecond: 100, StartTimestamp: now + 1200, EndTimestamp: now + 1800, QuorumSplits: []byte{50, 50}, QuorumNumbers: []uint8{0, 1}}
account2Reservations = &core.ReservedPayment{SymbolsPerSecond: 200, StartTimestamp: now - 120, EndTimestamp: now + 180, QuorumSplits: []byte{30, 70}, QuorumNumbers: []uint8{0, 1}}
account1OnDemandPayments = &core.OnDemandPayment{CumulativePayment: big.NewInt(3864)}
account2OnDemandPayments = &core.OnDemandPayment{CumulativePayment: big.NewInt(2000)}

Expand Down Expand Up @@ -177,13 +177,13 @@ func TestMetererReservations(t *testing.T) {
reservationPeriod := meterer.GetReservationPeriod(uint64(time.Now().Unix()), mt.ChainPaymentState.GetReservationWindow())
quoromNumbers := []uint8{0, 1}

paymentChainState.On("GetActiveReservationByAccount", testifymock.Anything, testifymock.MatchedBy(func(account gethcommon.Address) bool {
paymentChainState.On("GetReservedPaymentByAccount", testifymock.Anything, testifymock.MatchedBy(func(account gethcommon.Address) bool {
return account == accountID1
})).Return(account1Reservations, nil)
paymentChainState.On("GetActiveReservationByAccount", testifymock.Anything, testifymock.MatchedBy(func(account gethcommon.Address) bool {
paymentChainState.On("GetReservedPaymentByAccount", testifymock.Anything, testifymock.MatchedBy(func(account gethcommon.Address) bool {
return account == accountID2
})).Return(account2Reservations, nil)
paymentChainState.On("GetActiveReservationByAccount", testifymock.Anything, testifymock.Anything).Return(&core.ActiveReservation{}, fmt.Errorf("reservation not found"))
paymentChainState.On("GetReservedPaymentByAccount", testifymock.Anything, testifymock.Anything).Return(&core.ReservedPayment{}, fmt.Errorf("reservation not found"))

// test invalid quorom ID
header := createPaymentHeader(1, big.NewInt(0), accountID1)
Expand Down
44 changes: 13 additions & 31 deletions core/meterer/onchain_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package meterer

import (
"context"
"fmt"
"sync"
"sync/atomic"

Expand All @@ -16,7 +15,7 @@ import (
// OnchainPaymentState is an interface for getting information about the current chain state for payments.
type OnchainPayment interface {
RefreshOnchainPaymentState(ctx context.Context, tx *eth.Reader) error
GetActiveReservationByAccount(ctx context.Context, accountID gethcommon.Address) (*core.ActiveReservation, error)
GetReservedPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.ReservedPayment, error)
GetOnDemandPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.OnDemandPayment, error)
GetOnDemandQuorumNumbers(ctx context.Context) ([]uint8, error)
GetGlobalSymbolsPerSecond() uint64
Expand All @@ -31,8 +30,8 @@ var _ OnchainPayment = (*OnchainPaymentState)(nil)
type OnchainPaymentState struct {
tx *eth.Reader

ActiveReservations map[gethcommon.Address]*core.ActiveReservation
OnDemandPayments map[gethcommon.Address]*core.OnDemandPayment
ReservedPayments map[gethcommon.Address]*core.ReservedPayment
OnDemandPayments map[gethcommon.Address]*core.OnDemandPayment

ReservationsLock sync.RWMutex
OnDemandLocks sync.RWMutex
Expand All @@ -57,7 +56,7 @@ func NewOnchainPaymentState(ctx context.Context, tx *eth.Reader) (*OnchainPaymen

state := OnchainPaymentState{
tx: tx,
ActiveReservations: make(map[gethcommon.Address]*core.ActiveReservation),
ReservedPayments: make(map[gethcommon.Address]*core.ReservedPayment),
OnDemandPayments: make(map[gethcommon.Address]*core.OnDemandPayment),
PaymentVaultParams: atomic.Pointer[PaymentVaultParams]{},
}
Expand Down Expand Up @@ -116,16 +115,16 @@ func (pcs *OnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context,
pcs.PaymentVaultParams.Store(paymentVaultParams)

pcs.ReservationsLock.Lock()
accountIDs := make([]gethcommon.Address, 0, len(pcs.ActiveReservations))
for accountID := range pcs.ActiveReservations {
accountIDs := make([]gethcommon.Address, 0, len(pcs.ReservedPayments))
for accountID := range pcs.ReservedPayments {
accountIDs = append(accountIDs, accountID)
}

activeReservations, err := tx.GetActiveReservations(ctx, accountIDs)
reservedPayments, err := tx.GetReservedPayments(ctx, accountIDs)
if err != nil {
return err
}
pcs.ActiveReservations = activeReservations
pcs.ReservedPayments = reservedPayments
pcs.ReservationsLock.Unlock()

pcs.OnDemandLocks.Lock()
Expand All @@ -144,34 +143,25 @@ func (pcs *OnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context,
return nil
}

// GetActiveReservationByAccount returns a pointer to the active reservation for the given account ID; no writes will be made to the reservation
func (pcs *OnchainPaymentState) GetActiveReservationByAccount(ctx context.Context, accountID gethcommon.Address) (*core.ActiveReservation, error) {
// GetReservedPaymentByAccount returns a pointer to the active reservation for the given account ID; no writes will be made to the reservation
func (pcs *OnchainPaymentState) GetReservedPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.ReservedPayment, error) {
pcs.ReservationsLock.RLock()
defer pcs.ReservationsLock.RUnlock()
if reservation, ok := (pcs.ActiveReservations)[accountID]; ok {
if reservation, ok := (pcs.ReservedPayments)[accountID]; ok {
return reservation, nil
}

// pulls the chain state
res, err := pcs.tx.GetActiveReservationByAccount(ctx, accountID)
res, err := pcs.tx.GetReservedPaymentByAccount(ctx, accountID)
if err != nil {
return nil, err
}
pcs.ReservationsLock.Lock()
(pcs.ActiveReservations)[accountID] = res
(pcs.ReservedPayments)[accountID] = res
pcs.ReservationsLock.Unlock()
return res, nil
}

// GetActiveReservationByAccountOnChain returns on-chain reservation for the given account ID
func (pcs *OnchainPaymentState) GetActiveReservationByAccountOnChain(ctx context.Context, accountID gethcommon.Address) (*core.ActiveReservation, error) {
res, err := pcs.tx.GetActiveReservationByAccount(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("reservation account not found on-chain: %w", err)
}
return res, nil
}

// GetOnDemandPaymentByAccount returns a pointer to the on-demand payment for the given account ID; no writes will be made to the payment
func (pcs *OnchainPaymentState) GetOnDemandPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.OnDemandPayment, error) {
pcs.OnDemandLocks.RLock()
Expand All @@ -191,14 +181,6 @@ func (pcs *OnchainPaymentState) GetOnDemandPaymentByAccount(ctx context.Context,
return res, nil
}

func (pcs *OnchainPaymentState) GetOnDemandPaymentByAccountOnChain(ctx context.Context, accountID gethcommon.Address) (*core.OnDemandPayment, error) {
res, err := pcs.tx.GetOnDemandPaymentByAccount(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("on-demand not found on-chain: %w", err)
}
return res, nil
}

func (pcs *OnchainPaymentState) GetOnDemandQuorumNumbers(ctx context.Context) ([]uint8, error) {
blockNumber, err := pcs.tx.GetCurrentBlockNumber(ctx)
if err != nil {
Expand Down
Loading

0 comments on commit b41ff1d

Please sign in to comment.