From f8807ea97d63317c278bfeb3ededb4fa028f2742 Mon Sep 17 00:00:00 2001 From: svenefftinge Date: Tue, 26 Sep 2023 14:22:35 +0000 Subject: [PATCH] [stripe] reconcile missing invoices --- components/gitpod-db/go/cost_center.go | 14 ++-- components/gitpod-db/go/cost_center_test.go | 19 +++++ .../public-api-server/pkg/webhooks/stripe.go | 23 +++--- components/usage/pkg/apiv1/billing.go | 70 ++++++++++++++++++- components/usage/pkg/apiv1/billing_test.go | 4 +- components/usage/pkg/stripe/stripe.go | 20 ++++++ 6 files changed, 132 insertions(+), 18 deletions(-) diff --git a/components/gitpod-db/go/cost_center.go b/components/gitpod-db/go/cost_center.go index 011d184fefbe8d..c0802348429efe 100644 --- a/components/gitpod-db/go/cost_center.go +++ b/components/gitpod-db/go/cost_center.go @@ -209,17 +209,23 @@ func (c *CostCenterManager) IncrementBillingCycle(ctx context.Context, attributi log.Infof("Cost center %s is not yet expired. Skipping increment.", attributionId) return cc, nil } - billingCycleStart := NewVarCharTime(now) + billingCycleStart := now if cc.NextBillingTime.IsSet() { - billingCycleStart = cc.NextBillingTime + billingCycleStart = cc.NextBillingTime.Time() + } + nextBillingTime := billingCycleStart.AddDate(0, 1, 0) + for nextBillingTime.Before(now) { + log.Warnf("Billing cycle for %s is lagging behind. Incrementing by one month.", attributionId) + billingCycleStart = billingCycleStart.AddDate(0, 1, 0) + nextBillingTime = billingCycleStart.AddDate(0, 1, 0) } // All fields on the new cost center remain the same, except for BillingCycleStart, NextBillingTime, and CreationTime newCostCenter := CostCenter{ ID: cc.ID, SpendingLimit: cc.SpendingLimit, BillingStrategy: cc.BillingStrategy, - BillingCycleStart: billingCycleStart, - NextBillingTime: NewVarCharTime(billingCycleStart.Time().AddDate(0, 1, 0)), + BillingCycleStart: NewVarCharTime(billingCycleStart), + NextBillingTime: NewVarCharTime(nextBillingTime), CreationTime: NewVarCharTime(now), } err = c.conn.Save(&newCostCenter).Error diff --git a/components/gitpod-db/go/cost_center_test.go b/components/gitpod-db/go/cost_center_test.go index d3598b18d14d5c..36921df16aab9d 100644 --- a/components/gitpod-db/go/cost_center_test.go +++ b/components/gitpod-db/go/cost_center_test.go @@ -198,6 +198,25 @@ func TestCostCenterManager_UpdateCostCenter(t *testing.T) { }) require.NoError(t, err) }) + + t.Run("increment billing cycle should always increment to now", func(t *testing.T) { + mnr := db.NewCostCenterManager(conn, limits) + teamAttributionID := db.NewTeamAttributionID(uuid.New().String()) + cleanUp(t, conn, teamAttributionID) + + res, err := mnr.GetOrCreateCostCenter(context.Background(), teamAttributionID) + require.NoError(t, err) + + // set res.nextBillingTime to two months ago + res.NextBillingTime = db.NewVarCharTime(time.Now().AddDate(0, -2, 0)) + conn.Save(res) + + cc, err := mnr.IncrementBillingCycle(context.Background(), teamAttributionID) + require.NoError(t, err) + + require.True(t, cc.NextBillingTime.Time().After(time.Now()), "The next billing time should be in the future") + require.True(t, cc.BillingCycleStart.Time().Before(time.Now()), "The next billing time should be in the future") + }) } func TestSaveCostCenterMovedToStripe(t *testing.T) { diff --git a/components/public-api-server/pkg/webhooks/stripe.go b/components/public-api-server/pkg/webhooks/stripe.go index a765ccf4377e74..172ce29770769d 100644 --- a/components/public-api-server/pkg/webhooks/stripe.go +++ b/components/public-api-server/pkg/webhooks/stripe.go @@ -10,6 +10,7 @@ import ( "github.com/gitpod-io/gitpod/common-go/log" "github.com/gitpod-io/gitpod/public-api-server/pkg/billingservice" + "github.com/sirupsen/logrus" "github.com/stripe/stripe-go/v72/webhook" ) @@ -56,57 +57,61 @@ func (h *webhookHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { } // https://stripe.com/docs/webhooks/signatures#verify-official-libraries - event, err := webhook.ConstructEvent(payload, req.Header.Get("Stripe-Signature"), h.stripeWebhookSignature) + event, err := webhook.ConstructEvent(payload, stripeSignature, h.stripeWebhookSignature) if err != nil { log.WithError(err).Error("Failed to verify webhook signature.") w.WriteHeader(http.StatusBadRequest) return } + logger := log.WithFields(logrus.Fields{"event": event.ID}) + switch event.Type { case InvoiceFinalizedEventType: + logger.Info("Handling invoice finalization") invoiceID, ok := event.Data.Object["id"].(string) if !ok { - log.Error("failed to find invoice id in Stripe event payload") + logger.Error("failed to find invoice id in Stripe event payload") w.WriteHeader(http.StatusBadRequest) return } err = h.billingService.FinalizeInvoice(req.Context(), invoiceID) if err != nil { - log.WithError(err).Error("Failed to finalize invoice") + logger.WithError(err).Error("Failed to finalize invoice") w.WriteHeader(http.StatusInternalServerError) return } case CustomerSubscriptionDeletedEventType: + logger.Info("Handling subscription cancellation") subscriptionID, ok := event.Data.Object["id"].(string) if !ok { - log.Error("failed to find subscriptionId id in Stripe event payload") + logger.Error("failed to find subscriptionId id in Stripe event payload") w.WriteHeader(http.StatusBadRequest) return } err = h.billingService.CancelSubscription(req.Context(), subscriptionID) if err != nil { - log.WithError(err).Error("Failed to cancel subscription") + logger.WithError(err).Error("Failed to cancel subscription") w.WriteHeader(http.StatusInternalServerError) return } case ChargeDisputeCreatedEventType: - log.Info("Handling charge dispute") + logger.Info("Handling charge dispute") disputeID, ok := event.Data.Object["id"].(string) if !ok { - log.Error("Failed to identify dispute ID from Stripe webhook.") + logger.Error("Failed to identify dispute ID from Stripe webhook.") w.WriteHeader(http.StatusBadRequest) return } if err := h.billingService.OnChargeDispute(req.Context(), disputeID); err != nil { - log.WithError(err).Errorf("Failed to handle charge dispute event for dispute ID: %s", disputeID) + logger.WithError(err).Errorf("Failed to handle charge dispute event for dispute ID: %s", disputeID) w.WriteHeader(http.StatusInternalServerError) return } default: - log.Errorf("Unexpected Stripe event type: %s", event.Type) + logger.Errorf("Unexpected Stripe event type: %s", event.Type) w.WriteHeader(http.StatusBadRequest) return } diff --git a/components/usage/pkg/apiv1/billing.go b/components/usage/pkg/apiv1/billing.go index 82fb88c1230370..fe046e959c95d8 100644 --- a/components/usage/pkg/apiv1/billing.go +++ b/components/usage/pkg/apiv1/billing.go @@ -326,12 +326,76 @@ func (s *BillingService) ReconcileInvoices(ctx context.Context, in *v1.Reconcile log.WithError(err).Errorf("Failed to udpate usage in stripe.") return nil, status.Errorf(codes.Internal, "Failed to update usage in stripe") } + err = s.ReconcileStripeCustomers(ctx) + if err != nil { + log.WithError(err).Errorf("Failed to reconcile stripe customers.") + } return &v1.ReconcileInvoicesResponse{}, nil } +func (s *BillingService) ReconcileStripeCustomers(ctx context.Context) error { + log.Info("Reconciling stripe customers") + var costCenters []db.CostCenter + result := s.conn.Raw("SELECT * from d_b_cost_center where creationTime in (SELECT max(creationTime) from d_b_cost_center group by id) and nextBillingTime < creationTime and billingStrategy='stripe'").Scan(&costCenters) + if result.Error != nil { + return result.Error + } + + log.Infof("Found %d cost centers to reconcile", len(costCenters)) + + for _, costCenter := range costCenters { + log.Infof("Reconciling stripe invoices for cost center %s", costCenter.ID) + err := s.reconcileStripeInvoices(ctx, costCenter.ID) + if err != nil { + return err + } + _, err = s.ccManager.IncrementBillingCycle(ctx, costCenter.ID) + if err != nil { + // we are just logging at this point, so that we don't see the event again as the usage has been recorded. + log.WithError(err).Errorf("Failed to increment billing cycle.") + } + } + return nil +} + +func (s *BillingService) reconcileStripeInvoices(ctx context.Context, id db.AttributionID) error { + cust, err := s.stripeClient.GetCustomerByAttributionID(ctx, string(id)) + if err != nil { + return err + } + invoices, err := s.stripeClient.ListInvoices(ctx, cust.ID) + if err != nil { + return err + } + for _, invoice := range invoices { + if invoice.Status == "paid" { + usage, err := InternalComputeInvoiceUsage(ctx, invoice, invoice.Customer) + if err != nil { + return err + } + // check if a usage entry exists for this invoice + var existingUsage db.Usage + result := s.conn.First(existingUsage, "description = ?", usage.Description) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + log.Infof("No usage entry found for invoice %s. Inserting one now.", invoice.ID) + err = db.InsertUsage(ctx, s.conn, usage) + if err != nil { + return err + } + } else { + return result.Error + } + } + } + } + return nil +} + func (s *BillingService) FinalizeInvoice(ctx context.Context, in *v1.FinalizeInvoiceRequest) (*v1.FinalizeInvoiceResponse, error) { logger := log.WithField("invoice_id", in.GetInvoiceId()) + logger.Info("Invoice finalized. Recording usage.") if in.GetInvoiceId() == "" { return nil, status.Errorf(codes.InvalidArgument, "Missing InvoiceID") @@ -342,7 +406,7 @@ func (s *BillingService) FinalizeInvoice(ctx context.Context, in *v1.FinalizeInv logger.WithError(err).Error("Failed to retrieve invoice from Stripe.") return nil, status.Errorf(codes.NotFound, "Failed to get invoice with ID %s: %s", in.GetInvoiceId(), err.Error()) } - usage, err := InternalComputeInvoiceUsage(ctx, invoice) + usage, err := InternalComputeInvoiceUsage(ctx, invoice, invoice.Customer) if err != nil { return nil, err } @@ -376,9 +440,9 @@ func (s *BillingService) FinalizeInvoice(ctx context.Context, in *v1.FinalizeInv return &v1.FinalizeInvoiceResponse{}, nil } -func InternalComputeInvoiceUsage(ctx context.Context, invoice *stripe_api.Invoice) (db.Usage, error) { +func InternalComputeInvoiceUsage(ctx context.Context, invoice *stripe_api.Invoice, customer *stripe_api.Customer) (db.Usage, error) { logger := log.WithField("invoice_id", invoice.ID) - attributionID, err := stripe.GetAttributionID(ctx, invoice.Customer) + attributionID, err := stripe.GetAttributionID(ctx, customer) if err != nil { return db.Usage{}, err } diff --git a/components/usage/pkg/apiv1/billing_test.go b/components/usage/pkg/apiv1/billing_test.go index e910efa73de4d5..a750c0f106087f 100644 --- a/components/usage/pkg/apiv1/billing_test.go +++ b/components/usage/pkg/apiv1/billing_test.go @@ -55,7 +55,7 @@ func NewStripeRecorder(t *testing.T, name string) *recorder.Recorder { require.NoError(t, err) t.Cleanup(func() { - r.Stop() + err = r.Stop() }) // Add a hook which removes Authorization headers from all requests @@ -191,7 +191,7 @@ func TestBalancesForStripeCostCenters(t *testing.T) { func TestFinalizeInvoiceForIndividual(t *testing.T) { invoice := stripe_api.Invoice{} require.NoError(t, json.Unmarshal([]byte(IndiInvoiceTestData), &invoice)) - usage, err := InternalComputeInvoiceUsage(context.Background(), &invoice) + usage, err := InternalComputeInvoiceUsage(context.Background(), &invoice, invoice.Customer) require.NoError(t, err) require.Equal(t, usage.CreditCents, db.CreditCents(-103100)) } diff --git a/components/usage/pkg/stripe/stripe.go b/components/usage/pkg/stripe/stripe.go index 5eaada760b9732..ab88319083c17f 100644 --- a/components/usage/pkg/stripe/stripe.go +++ b/components/usage/pkg/stripe/stripe.go @@ -335,6 +335,26 @@ func (c *Client) GetInvoiceWithCustomer(ctx context.Context, invoiceID string) ( return invoice, nil } +func (c *Client) ListInvoices(ctx context.Context, customerId string) (invoices []*stripe.Invoice, err error) { + if customerId == "" { + return nil, fmt.Errorf("no customer ID specified") + } + + now := time.Now() + reportStripeRequestStarted("invoice_list") + defer func() { + reportStripeRequestCompleted("invoice_list", err, time.Since(now)) + }() + + invoicesResponse := c.sc.Invoices.List(&stripe.InvoiceListParams{ + Customer: stripe.String(customerId), + }) + if invoicesResponse.Err() != nil { + return nil, fmt.Errorf("failed to get invoices for customer %s: %w", customerId, invoicesResponse.Err()) + } + return invoicesResponse.InvoiceList().Data, nil +} + func (c *Client) GetSubscriptionWithCustomer(ctx context.Context, subscriptionID string) (subscription *stripe.Subscription, err error) { if subscriptionID == "" { return nil, fmt.Errorf("no subscriptionID specified")