Skip to content

Commit

Permalink
Handle duplicate domain verification for different app
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkRunWu committed Dec 14, 2023
1 parent fff095c commit 3072735
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 34 deletions.
13 changes: 13 additions & 0 deletions internal/cron/verify_domain_ownership.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,19 @@ func (v *VerifyDomainOwnership) Run(ctx context.Context, logger *zap.Logger) err
now, domainName, domainVerification.AppID, config.Site,
))
}
} else if domain.AppID != domainVerification.AppID {
domainName := domainVerification.Domain
app, err := c.GetApp(ctx, domainVerification.AppID)
if err != nil {
continue
}
config, ok := app.Config.ResolveDomain(domainName)
if ok {
c.DeleteDomain(ctx, domain.ID, now)
c.CreateDomain(ctx, models.NewDomain(
now, domainName, domainVerification.AppID, config.Site,
))
}
}
err = c.SetDomainIsVerified(ctx, domainVerification.ID, now, now.Add(v.RevalidatePeriod))
if err != nil {
Expand Down
67 changes: 61 additions & 6 deletions internal/cron/verify_domain_ownership_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@ import (
"go.uber.org/zap"
)

func setupDB(now time.Time, ctx context.Context, database db.DB) {
var userId = ""
type DBData struct {
userId string
}

func setupDB(now time.Time, ctx context.Context, database db.DB) DBData {
userId := ""
err := db.WithTx(ctx, database, func(tx db.Tx) error {
user := models.NewUser(now, "mock_user")
userId = user.ID
Expand Down Expand Up @@ -49,6 +53,9 @@ func setupDB(now time.Time, ctx context.Context, database db.DB) {
if err != nil {
panic(err)
}
return DBData{
userId: userId,
}
}

type RaiseErrorDNSResolver struct {
Expand All @@ -70,7 +77,7 @@ func TestVerifyDomainOwnership(t *testing.T) {
t.Run("Should verify valid domain", func(t *testing.T) {
testutil.WithTestDB(func(database db.DB) {
setupDB(now, ctx, database)
domainVerification, err := database.GetDomainVerificationByName(ctx, "test.com")
domainVerification, err := database.GetDomainVerificationByName(ctx, "test.com", "test")
if assert.NoError(t, err) {
assert.Nil(t, domainVerification.VerifiedAt)
}
Expand All @@ -91,7 +98,7 @@ func TestVerifyDomainOwnership(t *testing.T) {
}
job.Run(ctx, logger)

domainVerification, err = database.GetDomainVerificationByName(ctx, "test.com")
domainVerification, err = database.GetDomainVerificationByName(ctx, "test.com", "test")
if assert.NoError(t, err) {
assert.NotNil(t, domainVerification.VerifiedAt)
assert.True(t, domainVerification.VerifiedAt.After(now))
Expand All @@ -108,7 +115,7 @@ func TestVerifyDomainOwnership(t *testing.T) {
testutil.WithTestDB(func(database db.DB) {

setupDB(now, ctx, database)
domainVerification, err := database.GetDomainVerificationByName(ctx, "test.com")
domainVerification, err := database.GetDomainVerificationByName(ctx, "test.com", "test")
if assert.NoError(t, err) {
assert.Nil(t, domainVerification.VerifiedAt)
}
Expand Down Expand Up @@ -147,7 +154,7 @@ func TestVerifyDomainOwnership(t *testing.T) {
testutil.WithTestDB(func(database db.DB) {

setupDB(now, ctx, database)
domainVerification, err := database.GetDomainVerificationByName(ctx, "test.com")
domainVerification, err := database.GetDomainVerificationByName(ctx, "test.com", "test")
if assert.NoError(t, err) {
assert.Nil(t, domainVerification.VerifiedAt)
}
Expand Down Expand Up @@ -178,4 +185,52 @@ func TestVerifyDomainOwnership(t *testing.T) {
}
})
})
t.Run("Should replace the conflict domain", func(t *testing.T) {
testutil.WithTestDB(func(database db.DB) {
data := setupDB(now, ctx, database)
err := db.WithTx(ctx, database, func(tx db.Tx) error {
return tx.CreateApp(ctx, models.NewApp(
now,
"test2",
data.userId,
))
})
assert.NoError(t, err)
err = db.WithTx(ctx, database, func(tx db.Tx) error {
return tx.CreateDomain(ctx, models.NewDomain(
now,
"test.com",
"test2",
"main",
))
})
assert.NoError(t, err)
domainVerification, err := database.GetDomainVerificationByName(ctx, "test.com", "test")
txtDomain, value := domainVerification.GetTxtRecord()
r := mockdns.Resolver{
Zones: map[string]mockdns.Zone{
txtDomain + ".": {
TXT: []string{
value,
},
},
},
}
job := cron.VerifyDomainOwnership{
DB: database,
Resolver: &r,
MaxConsumeActiveDomainCount: 1,
MaxConsumePendingDomainCount: 1,
RevalidatePeriod: time.Hour,
}
err = job.Run(ctx, logger)
assert.NoError(t, err)

domain, err := database.GetDomainByName(ctx, "test.com")
if assert.NoError(t, err) {
assert.Equal(t, "test.com", domain.Domain)
assert.Equal(t, "test", domain.AppID)
}
})
})
}
2 changes: 1 addition & 1 deletion internal/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ type DomainsDB interface {

type DomainVerificationDB interface {
CreateDomainVerification(ctx context.Context, domainVerification *models.DomainVerification) error
GetDomainVerificationByName(ctx context.Context, domain string) (*models.DomainVerification, error)
GetDomainVerificationByName(ctx context.Context, domain string, appID string) (*models.DomainVerification, error)
DeleteDomainVerification(ctx context.Context, id string, now time.Time) error
ListDomainVerifications(ctx context.Context, appID string) ([]*models.DomainVerification, error)
ListLeastRecentlyCheckedDomain(ctx context.Context, now time.Time, isVerified bool, count uint) ([]*models.DomainVerification, error)
Expand Down
16 changes: 10 additions & 6 deletions internal/db/postgres/domain_verification.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,16 @@ func (q query[T]) CreateDomainVerification(ctx context.Context, domainVerificati
return nil
}

func (q query[T]) GetDomainVerificationByName(ctx context.Context, domainName string) (*models.DomainVerification, error) {
func (q query[T]) GetDomainVerificationByName(ctx context.Context, domainName string, appId string) (*models.DomainVerification, error) {
var domainVerification models.DomainVerification

err := sqlx.GetContext(ctx, q.ext, &domainVerification, `
SELECT d.id, d.created_at, d.updated_at, d.deleted_at, d.domain, d.app_id, d.value, d.verified_at, d.domain_prefix FROM domain_verification d
SELECT d.id, d.created_at, d.updated_at, d.deleted_at, d.domain, d.app_id, d.value, d.verified_at, d.domain_prefix,
d.verified_at, d.domain_prefix, d.will_check_at, d.last_checked_at
FROM domain_verification d
JOIN app a ON (a.id = d.app_id AND a.deleted_at IS NULL)
WHERE d.domain = $1 AND d.deleted_at IS NULL
`, domainName)
WHERE d.domain = $1 AND d.deleted_at IS NULL AND d.app_id = $2
`, domainName, appId)
if errors.Is(err, sql.ErrNoRows) {
return nil, models.ErrDomainNotFound
} else if err != nil {
Expand All @@ -64,7 +66,8 @@ func (q query[T]) DeleteDomainVerification(ctx context.Context, id string, now t
func (q query[T]) ListDomainVerifications(ctx context.Context, appID string) ([]*models.DomainVerification, error) {
var domainVerifications []*models.DomainVerification
stmt := `
SELECT d.id, d.created_at, d.updated_at, d.deleted_at, d.domain, d.app_id, d.value, d.verified_at, d.domain_prefix
SELECT d.id, d.created_at, d.updated_at, d.deleted_at, d.domain, d.app_id, d.value, d.verified_at, d.domain_prefix,
d.verified_at, d.domain_prefix, d.will_check_at, d.last_checked_at
FROM domain_verification d
WHERE d.deleted_at IS NULL AND d.app_id = $1
ORDER BY d.domain, d.created_at
Expand Down Expand Up @@ -103,7 +106,8 @@ func (q query[T]) SetDomainIsInvalid(ctx context.Context, id string, now time.Ti
func (q query[T]) ListLeastRecentlyCheckedDomain(ctx context.Context, time time.Time, isVerified bool, count uint) ([]*models.DomainVerification, error) {
var domainVerifications []*models.DomainVerification
stmt := `
SELECT d.id, d.created_at, d.updated_at, d.deleted_at, d.verified_at, d.last_checked_at, d.will_check_at, d.domain, d.domain_prefix, d.app_id, d.value
SELECT d.id, d.created_at, d.updated_at, d.deleted_at, d.verified_at, d.last_checked_at, d.will_check_at, d.domain, d.domain_prefix, d.app_id, d.value,
d.verified_at, d.domain_prefix, d.will_check_at, d.last_checked_at
FROM domain_verification d
WHERE %s
ORDER BY d.will_check_at, d.last_checked_at NULLS FIRST
Expand Down
15 changes: 9 additions & 6 deletions internal/db/sqlite/domain_verification.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,16 @@ func (q query[T]) CreateDomainVerification(ctx context.Context, domainVerificati
return nil
}

func (q query[T]) GetDomainVerificationByName(ctx context.Context, domainName string) (*models.DomainVerification, error) {
func (q query[T]) GetDomainVerificationByName(ctx context.Context, domainName string, appId string) (*models.DomainVerification, error) {
var domainVerification models.DomainVerification

err := sqlx.GetContext(ctx, q.ext, &domainVerification, `
SELECT d.id, d.created_at, d.updated_at, d.deleted_at, d.verified_at, d.last_checked_at, d.will_check_at, d.domain, d.domain_prefix, d.app_id, d.value
SELECT d.id, d.created_at, d.updated_at, d.deleted_at, d.verified_at, d.last_checked_at, d.will_check_at, d.domain, d.domain_prefix, d.app_id, d.value,
d.verified_at, d.domain_prefix, d.will_check_at, d.last_checked_at
FROM domain_verification d
JOIN app a ON (a.id = d.app_id AND a.deleted_at IS NULL)
WHERE d.domain = ? AND d.deleted_at IS NULL
`, domainName)
WHERE d.domain = ? AND d.deleted_at IS NULL AND d.app_id = ?
`, domainName, appId)
if errors.Is(err, sql.ErrNoRows) {
return nil, models.ErrDomainNotFound
} else if err != nil {
Expand All @@ -65,7 +66,8 @@ func (q query[T]) DeleteDomainVerification(ctx context.Context, id string, now t
func (q query[T]) ListDomainVerifications(ctx context.Context, appID string) ([]*models.DomainVerification, error) {
var domainVerifications []*models.DomainVerification
stmt := `
SELECT d.id, d.created_at, d.updated_at, d.deleted_at, d.verified_at, d.last_checked_at, d.will_check_at, d.domain, d.domain_prefix, d.app_id, d.value
SELECT d.id, d.created_at, d.updated_at, d.deleted_at, d.verified_at, d.last_checked_at, d.will_check_at, d.domain, d.domain_prefix, d.app_id, d.value,
d.verified_at, d.domain_prefix, d.will_check_at, d.last_checked_at
FROM domain_verification d
WHERE d.deleted_at IS NULL AND d.app_id = ?
ORDER BY d.domain, d.created_at
Expand Down Expand Up @@ -104,7 +106,8 @@ func (q query[T]) SetDomainIsInvalid(ctx context.Context, id string, now time.Ti
func (q query[T]) ListLeastRecentlyCheckedDomain(ctx context.Context, time time.Time, isVerified bool, count uint) ([]*models.DomainVerification, error) {
var domainVerifications []*models.DomainVerification
stmt := `
SELECT d.id, d.created_at, d.updated_at, d.deleted_at, d.verified_at, d.last_checked_at, d.will_check_at, d.domain, d.domain_prefix, d.app_id, d.value
SELECT d.id, d.created_at, d.updated_at, d.deleted_at, d.verified_at, d.last_checked_at, d.will_check_at, d.domain, d.domain_prefix, d.app_id, d.value,
d.verified_at, d.domain_prefix, d.will_check_at, d.last_checked_at
FROM domain_verification d
WHERE %s
ORDER BY d.will_check_at, d.last_checked_at NULLS FIRST
Expand Down
6 changes: 3 additions & 3 deletions internal/handler/controller/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (c *Controller) handleDomainList(w http.ResponseWriter, r *http.Request) {
if err != nil && !errors.Is(err, models.ErrDomainNotFound) {
return nil, err
}
domainVerification, err := c.DB.GetDomainVerificationByName(r.Context(), dconf.Domain)
domainVerification, err := c.DB.GetDomainVerificationByName(r.Context(), dconf.Domain, app.ID)
if err != nil && !errors.Is(err, models.ErrDomainNotFound) {
return nil, err
}
Expand All @@ -60,8 +60,7 @@ func (c *Controller) handleDomainVerification(w http.ResponseWriter, r *http.Req
respond(w, withTx(r.Context(), c.DB, func(tx db.Tx) (any, error) {
var domainVerification *models.DomainVerification
domain, _ := tx.GetDomainByName(r.Context(), domainName)
domainVerification, _ = tx.GetDomainVerificationByName(r.Context(), domainName)
if domain == nil && domainVerification == nil {
domainVerification, _ = tx.GetDomainVerificationByName(r.Context(), domainName, app.ID)
domainVerification = models.NewDomainVerification(c.Clock.Now().UTC(), domainName, app.ID)
err := tx.CreateDomainVerification(r.Context(), domainVerification)
if err != nil {
Expand All @@ -87,6 +86,7 @@ func (c *Controller) handleDomainCreate(w http.ResponseWriter, r *http.Request)
return nil, models.ErrUndefinedDomain
}
domain, err := tx.GetDomainByName(r.Context(), domainName)
domainVerification, _ := tx.GetDomainVerificationByName(r.Context(), domainName, app.ID)
if errors.Is(err, models.ErrDomainNotFound) {
// Continue create new domain.
} else if err != nil {
Expand Down
40 changes: 28 additions & 12 deletions internal/handler/controller/domain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,36 @@ func TestDomainVerification(t *testing.T) {
}
})
})
t.Run("Should add a pending activating domain with app A", func(t *testing.T) {
t.Run("Should add a pending activating domain for same domain with different Apps", func(t *testing.T) {
testutil.WithTestController(func(c *testutil.TestController) {
token := setupDomainVerification(c)

t.Run("Should add a pending activating domain with app B", func(t *testing.T) {
req := httptest.NewRequest("POST", "http://localtest.me/api/v1/apps/test/domains/test.com", nil)
req.Header.Add("Authorization", "bearer "+token)
w := httptest.NewRecorder()
c.ServeHTTP(w, req)
domain, err := testutil.DecodeJSONResponse[*api.APIDomain](w.Result())
if assert.NoError(t, err) {
assert.Nil(t, domain.Domain)
assert.Equal(t, "test.com", domain.DomainVerification.Domain)
assert.Equal(t, "test", domain.DomainVerification.AppID)
}
})
t.Run("Should add a pending activating domain with app B", func(t *testing.T) {
req := httptest.NewRequest("POST", "http://localtest.me/api/v1/apps/test2/domains/test.com", nil)
req.Header.Add("Authorization", "bearer "+token)
w := httptest.NewRecorder()
c.ServeHTTP(w, req)
domain, err := testutil.DecodeJSONResponse[*api.APIDomain](w.Result())
if assert.NoError(t, err) {
assert.Nil(t, domain.Domain)
assert.Equal(t, "test.com", domain.DomainVerification.Domain)
assert.Equal(t, "test2", domain.DomainVerification.AppID)
}
})
})
})
req := httptest.NewRequest("POST", "http://localtest.me/api/v1/apps/test/domains/test.com", nil)
req.Header.Add("Authorization", "bearer "+token)
w := httptest.NewRecorder()
Expand All @@ -337,9 +364,6 @@ func TestDomainVerification(t *testing.T) {
}
})
})
t.Run("Should add a pending activating domain with app B", func(t *testing.T) {
testutil.WithTestController(func(c *testutil.TestController) {
token := setupDomainVerification(c)
req := httptest.NewRequest("POST", "http://localtest.me/api/v1/apps/test2/domains/test.com", nil)
req.Header.Add("Authorization", "bearer "+token)
w := httptest.NewRecorder()
Expand All @@ -350,14 +374,6 @@ func TestDomainVerification(t *testing.T) {
assert.Equal(t, "test.com", domain.DomainVerification.Domain)
}
})
})
t.Run("Should not add a pending domain for existing domain", func(t *testing.T) {
testutil.WithTestController(func(c *testutil.TestController) {
token := setupDomainVerification(c)
db.WithTx(c.Context, c.DB, func(tx db.Tx) error {
return tx.CreateDomain(c.Context, models.NewDomain(time.Now(), "test.com", "test", "main"))
})
req := httptest.NewRequest("POST", "http://localtest.me/api/v1/apps/test/domains/test.com", nil)
req.Header.Add("Authorization", "bearer "+token)
w := httptest.NewRecorder()
c.ServeHTTP(w, req)
Expand Down

0 comments on commit 3072735

Please sign in to comment.