Skip to content

Commit

Permalink
[stripe] reconcile missing invoices (#18810)
Browse files Browse the repository at this point in the history
  • Loading branch information
svenefftinge authored Sep 26, 2023
1 parent ba516e7 commit c6e90d4
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 18 deletions.
14 changes: 10 additions & 4 deletions components/gitpod-db/go/cost_center.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,17 +209,23 @@ func (c *CostCenterManager) IncrementBillingCycle(ctx context.Context, attributi
log.Infof("Cost center %s is not yet expired. Skipping increment.", attributionId)
return cc, nil
}
billingCycleStart := NewVarCharTime(now)
billingCycleStart := now
if cc.NextBillingTime.IsSet() {
billingCycleStart = cc.NextBillingTime
billingCycleStart = cc.NextBillingTime.Time()
}
nextBillingTime := billingCycleStart.AddDate(0, 1, 0)
for nextBillingTime.Before(now) {
log.Warnf("Billing cycle for %s is lagging behind. Incrementing by one month.", attributionId)
billingCycleStart = billingCycleStart.AddDate(0, 1, 0)
nextBillingTime = billingCycleStart.AddDate(0, 1, 0)
}
// All fields on the new cost center remain the same, except for BillingCycleStart, NextBillingTime, and CreationTime
newCostCenter := CostCenter{
ID: cc.ID,
SpendingLimit: cc.SpendingLimit,
BillingStrategy: cc.BillingStrategy,
BillingCycleStart: billingCycleStart,
NextBillingTime: NewVarCharTime(billingCycleStart.Time().AddDate(0, 1, 0)),
BillingCycleStart: NewVarCharTime(billingCycleStart),
NextBillingTime: NewVarCharTime(nextBillingTime),
CreationTime: NewVarCharTime(now),
}
err = c.conn.Save(&newCostCenter).Error
Expand Down
19 changes: 19 additions & 0 deletions components/gitpod-db/go/cost_center_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,25 @@ func TestCostCenterManager_UpdateCostCenter(t *testing.T) {
})
require.NoError(t, err)
})

t.Run("increment billing cycle should always increment to now", func(t *testing.T) {
mnr := db.NewCostCenterManager(conn, limits)
teamAttributionID := db.NewTeamAttributionID(uuid.New().String())
cleanUp(t, conn, teamAttributionID)

res, err := mnr.GetOrCreateCostCenter(context.Background(), teamAttributionID)
require.NoError(t, err)

// set res.nextBillingTime to two months ago
res.NextBillingTime = db.NewVarCharTime(time.Now().AddDate(0, -2, 0))
conn.Save(res)

cc, err := mnr.IncrementBillingCycle(context.Background(), teamAttributionID)
require.NoError(t, err)

require.True(t, cc.NextBillingTime.Time().After(time.Now()), "The next billing time should be in the future")
require.True(t, cc.BillingCycleStart.Time().Before(time.Now()), "The next billing time should be in the future")
})
}

func TestSaveCostCenterMovedToStripe(t *testing.T) {
Expand Down
23 changes: 14 additions & 9 deletions components/public-api-server/pkg/webhooks/stripe.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"github.com/gitpod-io/gitpod/common-go/log"
"github.com/gitpod-io/gitpod/public-api-server/pkg/billingservice"
"github.com/sirupsen/logrus"
"github.com/stripe/stripe-go/v72/webhook"
)

Expand Down Expand Up @@ -56,57 +57,61 @@ func (h *webhookHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
}

// https://stripe.com/docs/webhooks/signatures#verify-official-libraries
event, err := webhook.ConstructEvent(payload, req.Header.Get("Stripe-Signature"), h.stripeWebhookSignature)
event, err := webhook.ConstructEvent(payload, stripeSignature, h.stripeWebhookSignature)
if err != nil {
log.WithError(err).Error("Failed to verify webhook signature.")
w.WriteHeader(http.StatusBadRequest)
return
}

logger := log.WithFields(logrus.Fields{"event": event.ID})

switch event.Type {
case InvoiceFinalizedEventType:
logger.Info("Handling invoice finalization")
invoiceID, ok := event.Data.Object["id"].(string)
if !ok {
log.Error("failed to find invoice id in Stripe event payload")
logger.Error("failed to find invoice id in Stripe event payload")
w.WriteHeader(http.StatusBadRequest)
return
}

err = h.billingService.FinalizeInvoice(req.Context(), invoiceID)
if err != nil {
log.WithError(err).Error("Failed to finalize invoice")
logger.WithError(err).Error("Failed to finalize invoice")
w.WriteHeader(http.StatusInternalServerError)
return
}
case CustomerSubscriptionDeletedEventType:
logger.Info("Handling subscription cancellation")
subscriptionID, ok := event.Data.Object["id"].(string)
if !ok {
log.Error("failed to find subscriptionId id in Stripe event payload")
logger.Error("failed to find subscriptionId id in Stripe event payload")
w.WriteHeader(http.StatusBadRequest)
return
}
err = h.billingService.CancelSubscription(req.Context(), subscriptionID)
if err != nil {
log.WithError(err).Error("Failed to cancel subscription")
logger.WithError(err).Error("Failed to cancel subscription")
w.WriteHeader(http.StatusInternalServerError)
return
}
case ChargeDisputeCreatedEventType:
log.Info("Handling charge dispute")
logger.Info("Handling charge dispute")
disputeID, ok := event.Data.Object["id"].(string)
if !ok {
log.Error("Failed to identify dispute ID from Stripe webhook.")
logger.Error("Failed to identify dispute ID from Stripe webhook.")
w.WriteHeader(http.StatusBadRequest)
return
}

if err := h.billingService.OnChargeDispute(req.Context(), disputeID); err != nil {
log.WithError(err).Errorf("Failed to handle charge dispute event for dispute ID: %s", disputeID)
logger.WithError(err).Errorf("Failed to handle charge dispute event for dispute ID: %s", disputeID)
w.WriteHeader(http.StatusInternalServerError)
return
}
default:
log.Errorf("Unexpected Stripe event type: %s", event.Type)
logger.Errorf("Unexpected Stripe event type: %s", event.Type)
w.WriteHeader(http.StatusBadRequest)
return
}
Expand Down
70 changes: 67 additions & 3 deletions components/usage/pkg/apiv1/billing.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,12 +326,76 @@ func (s *BillingService) ReconcileInvoices(ctx context.Context, in *v1.Reconcile
log.WithError(err).Errorf("Failed to udpate usage in stripe.")
return nil, status.Errorf(codes.Internal, "Failed to update usage in stripe")
}
err = s.ReconcileStripeCustomers(ctx)
if err != nil {
log.WithError(err).Errorf("Failed to reconcile stripe customers.")
}

return &v1.ReconcileInvoicesResponse{}, nil
}

func (s *BillingService) ReconcileStripeCustomers(ctx context.Context) error {
log.Info("Reconciling stripe customers")
var costCenters []db.CostCenter
result := s.conn.Raw("SELECT * from d_b_cost_center where creationTime in (SELECT max(creationTime) from d_b_cost_center group by id) and nextBillingTime < creationTime and billingStrategy='stripe'").Scan(&costCenters)
if result.Error != nil {
return result.Error
}

log.Infof("Found %d cost centers to reconcile", len(costCenters))

for _, costCenter := range costCenters {
log.Infof("Reconciling stripe invoices for cost center %s", costCenter.ID)
err := s.reconcileStripeInvoices(ctx, costCenter.ID)
if err != nil {
return err
}
_, err = s.ccManager.IncrementBillingCycle(ctx, costCenter.ID)
if err != nil {
// we are just logging at this point, so that we don't see the event again as the usage has been recorded.
log.WithError(err).Errorf("Failed to increment billing cycle.")
}
}
return nil
}

func (s *BillingService) reconcileStripeInvoices(ctx context.Context, id db.AttributionID) error {
cust, err := s.stripeClient.GetCustomerByAttributionID(ctx, string(id))
if err != nil {
return err
}
invoices, err := s.stripeClient.ListInvoices(ctx, cust.ID)
if err != nil {
return err
}
for _, invoice := range invoices {
if invoice.Status == "paid" {
usage, err := InternalComputeInvoiceUsage(ctx, invoice, invoice.Customer)
if err != nil {
return err
}
// check if a usage entry exists for this invoice
var existingUsage db.Usage
result := s.conn.First(existingUsage, "description = ?", usage.Description)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
log.Infof("No usage entry found for invoice %s. Inserting one now.", invoice.ID)
err = db.InsertUsage(ctx, s.conn, usage)
if err != nil {
return err
}
} else {
return result.Error
}
}
}
}
return nil
}

func (s *BillingService) FinalizeInvoice(ctx context.Context, in *v1.FinalizeInvoiceRequest) (*v1.FinalizeInvoiceResponse, error) {
logger := log.WithField("invoice_id", in.GetInvoiceId())
logger.Info("Invoice finalized. Recording usage.")

if in.GetInvoiceId() == "" {
return nil, status.Errorf(codes.InvalidArgument, "Missing InvoiceID")
Expand All @@ -342,7 +406,7 @@ func (s *BillingService) FinalizeInvoice(ctx context.Context, in *v1.FinalizeInv
logger.WithError(err).Error("Failed to retrieve invoice from Stripe.")
return nil, status.Errorf(codes.NotFound, "Failed to get invoice with ID %s: %s", in.GetInvoiceId(), err.Error())
}
usage, err := InternalComputeInvoiceUsage(ctx, invoice)
usage, err := InternalComputeInvoiceUsage(ctx, invoice, invoice.Customer)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -376,9 +440,9 @@ func (s *BillingService) FinalizeInvoice(ctx context.Context, in *v1.FinalizeInv
return &v1.FinalizeInvoiceResponse{}, nil
}

func InternalComputeInvoiceUsage(ctx context.Context, invoice *stripe_api.Invoice) (db.Usage, error) {
func InternalComputeInvoiceUsage(ctx context.Context, invoice *stripe_api.Invoice, customer *stripe_api.Customer) (db.Usage, error) {
logger := log.WithField("invoice_id", invoice.ID)
attributionID, err := stripe.GetAttributionID(ctx, invoice.Customer)
attributionID, err := stripe.GetAttributionID(ctx, customer)
if err != nil {
return db.Usage{}, err
}
Expand Down
4 changes: 2 additions & 2 deletions components/usage/pkg/apiv1/billing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func NewStripeRecorder(t *testing.T, name string) *recorder.Recorder {
require.NoError(t, err)

t.Cleanup(func() {
r.Stop()
err = r.Stop()
})

// Add a hook which removes Authorization headers from all requests
Expand Down Expand Up @@ -191,7 +191,7 @@ func TestBalancesForStripeCostCenters(t *testing.T) {
func TestFinalizeInvoiceForIndividual(t *testing.T) {
invoice := stripe_api.Invoice{}
require.NoError(t, json.Unmarshal([]byte(IndiInvoiceTestData), &invoice))
usage, err := InternalComputeInvoiceUsage(context.Background(), &invoice)
usage, err := InternalComputeInvoiceUsage(context.Background(), &invoice, invoice.Customer)
require.NoError(t, err)
require.Equal(t, usage.CreditCents, db.CreditCents(-103100))
}
Expand Down
20 changes: 20 additions & 0 deletions components/usage/pkg/stripe/stripe.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,26 @@ func (c *Client) GetInvoiceWithCustomer(ctx context.Context, invoiceID string) (
return invoice, nil
}

func (c *Client) ListInvoices(ctx context.Context, customerId string) (invoices []*stripe.Invoice, err error) {
if customerId == "" {
return nil, fmt.Errorf("no customer ID specified")
}

now := time.Now()
reportStripeRequestStarted("invoice_list")
defer func() {
reportStripeRequestCompleted("invoice_list", err, time.Since(now))
}()

invoicesResponse := c.sc.Invoices.List(&stripe.InvoiceListParams{
Customer: stripe.String(customerId),
})
if invoicesResponse.Err() != nil {
return nil, fmt.Errorf("failed to get invoices for customer %s: %w", customerId, invoicesResponse.Err())
}
return invoicesResponse.InvoiceList().Data, nil
}

func (c *Client) GetSubscriptionWithCustomer(ctx context.Context, subscriptionID string) (subscription *stripe.Subscription, err error) {
if subscriptionID == "" {
return nil, fmt.Errorf("no subscriptionID specified")
Expand Down

0 comments on commit c6e90d4

Please sign in to comment.