diff --git a/internal/db/db.go b/internal/db/db.go index d2dccf0..20ff5f2 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -86,6 +86,7 @@ type DomainsDB interface { type DomainVerificationDB interface { CreateDomainVerification(ctx context.Context, domainVerification *models.DomainVerification) error + ScheduleDomainVerificationAt(ctx context.Context, id string, time time.Time) error GetDomainVerificationByName(ctx context.Context, domain string, appID string) (*models.DomainVerification, error) DeleteDomainVerification(ctx context.Context, id string, now time.Time) error ListDomainVerifications(ctx context.Context, appID string) ([]*models.DomainVerification, error) diff --git a/internal/db/postgres/domain_verification.go b/internal/db/postgres/domain_verification.go index bcec5c2..62ca0e6 100644 --- a/internal/db/postgres/domain_verification.go +++ b/internal/db/postgres/domain_verification.go @@ -126,3 +126,12 @@ func (q query[T]) ListLeastRecentlyCheckedDomain(ctx context.Context, time time. } return domainVerifications, err } + +func (q query[T]) ScheduleDomainVerificationAt(ctx context.Context, id string, time time.Time) error { + _, err := q.ext.ExecContext(ctx, ` + UPDATE domain_verification SET + will_check_at = $1 + WHERE id = $2 + `, time, id) + return err +} diff --git a/internal/db/sqlite/domain_verification.go b/internal/db/sqlite/domain_verification.go index 535e555..a4ef77a 100644 --- a/internal/db/sqlite/domain_verification.go +++ b/internal/db/sqlite/domain_verification.go @@ -126,3 +126,12 @@ func (q query[T]) ListLeastRecentlyCheckedDomain(ctx context.Context, time time. } return domainVerifications, err } + +func (q query[T]) ScheduleDomainVerificationAt(ctx context.Context, id string, time time.Time) error { + _, err := q.ext.ExecContext(ctx, ` + UPDATE domain_verification SET + will_check_at = ? + WHERE id = ? + `, time, id) + return err +} diff --git a/internal/handler/controller/domain.go b/internal/handler/controller/domain.go index 5a8584c..b706799 100644 --- a/internal/handler/controller/domain.go +++ b/internal/handler/controller/domain.go @@ -65,8 +65,8 @@ func (c *Controller) handleDomainVerification(w http.ResponseWriter, r *http.Req } respond(w, withTx(r.Context(), c.DB, func(tx db.Tx) (any, error) { var domainVerification *models.DomainVerification - domain, _ := tx.GetDomainByName(r.Context(), domainName) domainVerification, _ = tx.GetDomainVerificationByName(r.Context(), domainName, app.ID) + if domainVerification == nil { domainVerification = models.NewDomainVerification(c.Clock.Now().UTC(), domainName, app.ID) err := tx.CreateDomainVerification(r.Context(), domainVerification) if err != nil { @@ -75,7 +75,16 @@ func (c *Controller) handleDomainVerification(w http.ResponseWriter, r *http.Req log(r).Info("creating domain verification", zap.String("domain", domainName), zap.String("site", config.Site)) + } else if domainVerification.WillCheckAt == nil { + err := tx.ScheduleDomainVerificationAt(r.Context(), domainVerification.ID, c.Clock.Now().UTC()) + if err != nil { + return nil, err + } + log(r).Info("triggering domain verification", + zap.String("domain", domainName), + zap.String("site", config.Site)) } + domainVerification, _ = tx.GetDomainVerificationByName(r.Context(), domainName, app.ID) return c.makeAPIDomain(domain, domainVerification), nil })) } @@ -99,7 +108,7 @@ func (c *Controller) handleDomainCreate(w http.ResponseWriter, r *http.Request) return nil, err } else { if domain.AppID == app.ID { - return c.makeAPIDomain(domain, nil), nil + return c.makeAPIDomain(domain, domainVerification), nil } else if replaceApp != domain.AppID { return nil, models.ErrDomainUsedName } @@ -120,7 +129,7 @@ func (c *Controller) handleDomainCreate(w http.ResponseWriter, r *http.Request) zap.String("domain", domain.Domain), zap.String("site", domain.SiteName)) - return c.makeAPIDomain(domain, nil), nil + return c.makeAPIDomain(domain, domainVerification), nil })) } diff --git a/internal/handler/controller/domain_test.go b/internal/handler/controller/domain_test.go index 1be6b14..6c573ac 100644 --- a/internal/handler/controller/domain_test.go +++ b/internal/handler/controller/domain_test.go @@ -353,14 +353,50 @@ func TestDomainVerification(t *testing.T) { }) }) }) + t.Run("Should not add a pending domain for existing domain", func(t *testing.T) { + testutil.WithTestController(func(c *testutil.TestController) { + token := setupDomainVerification(c) + db.WithTx(c.Context, c.DB, func(tx db.Tx) error { + return tx.CreateDomain(c.Context, models.NewDomain(time.Now(), "test.com", "test", "main")) + }) + req := httptest.NewRequest("POST", "http://localtest.me/api/v1/apps/test/domains/test.com", nil) + req.Header.Add("Authorization", "bearer "+token) + w := httptest.NewRecorder() + c.ServeHTTP(w, req) + domain, err := testutil.DecodeJSONResponse[*api.APIDomain](w.Result()) + if assert.NoError(t, err) { + assert.Nil(t, domain.DomainVerification) + assert.NotNil(t, domain.Domain) + assert.Equal(t, "test.com", domain.Domain.Domain) + } + }) + }) + t.Run("Should retrigger domain verification", func(t *testing.T) { + testutil.WithTestController(func(c *testutil.TestController) { + token := setupDomainVerification(c) req := httptest.NewRequest("POST", "http://localtest.me/api/v1/apps/test/domains/test.com", nil) req.Header.Add("Authorization", "bearer "+token) w := httptest.NewRecorder() c.ServeHTTP(w, req) domain, err := testutil.DecodeJSONResponse[*api.APIDomain](w.Result()) + now := time.Now() + if assert.NoError(t, err) { + assert.NotNil(t, domain.DomainVerification) + assert.NotNil(t, domain.DomainVerification.WillCheckAt) + assert.True(t, domain.DomainVerification.WillCheckAt.Before(now)) + } + db.WithTx(c.Context, c.DB, func(tx db.Tx) error { + return tx.SetDomainIsInvalid(c.Context, domain.DomainVerification.ID, now) + }) + req = httptest.NewRequest("POST", "http://localtest.me/api/v1/apps/test/domains/test.com", nil) + req.Header.Add("Authorization", "bearer "+token) + w = httptest.NewRecorder() + c.ServeHTTP(w, req) + domain, err = testutil.DecodeJSONResponse[*api.APIDomain](w.Result()) if assert.NoError(t, err) { - assert.Nil(t, domain.Domain) - assert.Equal(t, "test.com", domain.DomainVerification.Domain) + assert.NotNil(t, domain.DomainVerification) + assert.NotNil(t, domain.DomainVerification.WillCheckAt) + assert.True(t, domain.DomainVerification.WillCheckAt.After(now)) } }) })