Skip to content

Commit

Permalink
Re-trigger domain verification
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkRunWu committed Dec 14, 2023
1 parent 4e72cba commit fae9f73
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 5 deletions.
1 change: 1 addition & 0 deletions internal/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ type DomainsDB interface {

type DomainVerificationDB interface {
CreateDomainVerification(ctx context.Context, domainVerification *models.DomainVerification) error
ScheduleDomainVerificationAt(ctx context.Context, id string, time time.Time) error
GetDomainVerificationByName(ctx context.Context, domain string, appID string) (*models.DomainVerification, error)
DeleteDomainVerification(ctx context.Context, id string, now time.Time) error
ListDomainVerifications(ctx context.Context, appID string) ([]*models.DomainVerification, error)
Expand Down
9 changes: 9 additions & 0 deletions internal/db/postgres/domain_verification.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,12 @@ func (q query[T]) ListLeastRecentlyCheckedDomain(ctx context.Context, time time.
}
return domainVerifications, err
}

func (q query[T]) ScheduleDomainVerificationAt(ctx context.Context, id string, time time.Time) error {
_, err := q.ext.ExecContext(ctx, `
UPDATE domain_verification SET
will_check_at = $1
WHERE id = $2
`, time, id)
return err
}
9 changes: 9 additions & 0 deletions internal/db/sqlite/domain_verification.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,12 @@ func (q query[T]) ListLeastRecentlyCheckedDomain(ctx context.Context, time time.
}
return domainVerifications, err
}

func (q query[T]) ScheduleDomainVerificationAt(ctx context.Context, id string, time time.Time) error {
_, err := q.ext.ExecContext(ctx, `
UPDATE domain_verification SET
will_check_at = ?
WHERE id = ?
`, time, id)
return err
}
15 changes: 12 additions & 3 deletions internal/handler/controller/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ func (c *Controller) handleDomainVerification(w http.ResponseWriter, r *http.Req
}
respond(w, withTx(r.Context(), c.DB, func(tx db.Tx) (any, error) {
var domainVerification *models.DomainVerification
domain, _ := tx.GetDomainByName(r.Context(), domainName)
domainVerification, _ = tx.GetDomainVerificationByName(r.Context(), domainName, app.ID)
if domainVerification == nil {
domainVerification = models.NewDomainVerification(c.Clock.Now().UTC(), domainName, app.ID)
err := tx.CreateDomainVerification(r.Context(), domainVerification)
if err != nil {
Expand All @@ -75,7 +75,16 @@ func (c *Controller) handleDomainVerification(w http.ResponseWriter, r *http.Req
log(r).Info("creating domain verification",
zap.String("domain", domainName),
zap.String("site", config.Site))
} else if domainVerification.WillCheckAt == nil {
err := tx.ScheduleDomainVerificationAt(r.Context(), domainVerification.ID, c.Clock.Now().UTC())
if err != nil {
return nil, err
}
log(r).Info("triggering domain verification",
zap.String("domain", domainName),
zap.String("site", config.Site))
}
domainVerification, _ = tx.GetDomainVerificationByName(r.Context(), domainName, app.ID)
return c.makeAPIDomain(domain, domainVerification), nil
}))
}
Expand All @@ -99,7 +108,7 @@ func (c *Controller) handleDomainCreate(w http.ResponseWriter, r *http.Request)
return nil, err
} else {
if domain.AppID == app.ID {
return c.makeAPIDomain(domain, nil), nil
return c.makeAPIDomain(domain, domainVerification), nil
} else if replaceApp != domain.AppID {
return nil, models.ErrDomainUsedName
}
Expand All @@ -120,7 +129,7 @@ func (c *Controller) handleDomainCreate(w http.ResponseWriter, r *http.Request)
zap.String("domain", domain.Domain),
zap.String("site", domain.SiteName))

return c.makeAPIDomain(domain, nil), nil
return c.makeAPIDomain(domain, domainVerification), nil
}))
}

Expand Down
40 changes: 38 additions & 2 deletions internal/handler/controller/domain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,14 +353,50 @@ func TestDomainVerification(t *testing.T) {
})
})
})
t.Run("Should not add a pending domain for existing domain", func(t *testing.T) {
testutil.WithTestController(func(c *testutil.TestController) {
token := setupDomainVerification(c)
db.WithTx(c.Context, c.DB, func(tx db.Tx) error {
return tx.CreateDomain(c.Context, models.NewDomain(time.Now(), "test.com", "test", "main"))
})
req := httptest.NewRequest("POST", "http://localtest.me/api/v1/apps/test/domains/test.com", nil)
req.Header.Add("Authorization", "bearer "+token)
w := httptest.NewRecorder()
c.ServeHTTP(w, req)
domain, err := testutil.DecodeJSONResponse[*api.APIDomain](w.Result())
if assert.NoError(t, err) {
assert.Nil(t, domain.DomainVerification)
assert.NotNil(t, domain.Domain)
assert.Equal(t, "test.com", domain.Domain.Domain)
}
})
})
t.Run("Should retrigger domain verification", func(t *testing.T) {
testutil.WithTestController(func(c *testutil.TestController) {
token := setupDomainVerification(c)
req := httptest.NewRequest("POST", "http://localtest.me/api/v1/apps/test/domains/test.com", nil)
req.Header.Add("Authorization", "bearer "+token)
w := httptest.NewRecorder()
c.ServeHTTP(w, req)
domain, err := testutil.DecodeJSONResponse[*api.APIDomain](w.Result())
now := time.Now()
if assert.NoError(t, err) {
assert.NotNil(t, domain.DomainVerification)
assert.NotNil(t, domain.DomainVerification.WillCheckAt)
assert.True(t, domain.DomainVerification.WillCheckAt.Before(now))
}
db.WithTx(c.Context, c.DB, func(tx db.Tx) error {
return tx.SetDomainIsInvalid(c.Context, domain.DomainVerification.ID, now)
})
req = httptest.NewRequest("POST", "http://localtest.me/api/v1/apps/test/domains/test.com", nil)
req.Header.Add("Authorization", "bearer "+token)
w = httptest.NewRecorder()
c.ServeHTTP(w, req)
domain, err = testutil.DecodeJSONResponse[*api.APIDomain](w.Result())
if assert.NoError(t, err) {
assert.Nil(t, domain.Domain)
assert.Equal(t, "test.com", domain.DomainVerification.Domain)
assert.NotNil(t, domain.DomainVerification)
assert.NotNil(t, domain.DomainVerification.WillCheckAt)
assert.True(t, domain.DomainVerification.WillCheckAt.After(now))
}
})
})
Expand Down

0 comments on commit fae9f73

Please sign in to comment.