From 30727355cebd4db293c61a5bf0c1dcdc440dbe61 Mon Sep 17 00:00:00 2001 From: mark wu Date: Thu, 14 Dec 2023 15:35:10 +0800 Subject: [PATCH] Handle duplicate domain verification for different app --- internal/cron/verify_domain_ownership.go | 13 ++++ internal/cron/verify_domain_ownership_test.go | 67 +++++++++++++++++-- internal/db/db.go | 2 +- internal/db/postgres/domain_verification.go | 16 +++-- internal/db/sqlite/domain_verification.go | 15 +++-- internal/handler/controller/domain.go | 6 +- internal/handler/controller/domain_test.go | 40 +++++++---- 7 files changed, 125 insertions(+), 34 deletions(-) diff --git a/internal/cron/verify_domain_ownership.go b/internal/cron/verify_domain_ownership.go index 87c239f..9d1ed5b 100644 --- a/internal/cron/verify_domain_ownership.go +++ b/internal/cron/verify_domain_ownership.go @@ -90,6 +90,19 @@ func (v *VerifyDomainOwnership) Run(ctx context.Context, logger *zap.Logger) err now, domainName, domainVerification.AppID, config.Site, )) } + } else if domain.AppID != domainVerification.AppID { + domainName := domainVerification.Domain + app, err := c.GetApp(ctx, domainVerification.AppID) + if err != nil { + continue + } + config, ok := app.Config.ResolveDomain(domainName) + if ok { + c.DeleteDomain(ctx, domain.ID, now) + c.CreateDomain(ctx, models.NewDomain( + now, domainName, domainVerification.AppID, config.Site, + )) + } } err = c.SetDomainIsVerified(ctx, domainVerification.ID, now, now.Add(v.RevalidatePeriod)) if err != nil { diff --git a/internal/cron/verify_domain_ownership_test.go b/internal/cron/verify_domain_ownership_test.go index 04735d5..2dc419d 100644 --- a/internal/cron/verify_domain_ownership_test.go +++ b/internal/cron/verify_domain_ownership_test.go @@ -16,8 +16,12 @@ import ( "go.uber.org/zap" ) -func setupDB(now time.Time, ctx context.Context, database db.DB) { - var userId = "" +type DBData struct { + userId string +} + +func setupDB(now time.Time, ctx context.Context, database db.DB) DBData { + userId := "" err := db.WithTx(ctx, database, func(tx db.Tx) error { user := models.NewUser(now, "mock_user") userId = user.ID @@ -49,6 +53,9 @@ func setupDB(now time.Time, ctx context.Context, database db.DB) { if err != nil { panic(err) } + return DBData{ + userId: userId, + } } type RaiseErrorDNSResolver struct { @@ -70,7 +77,7 @@ func TestVerifyDomainOwnership(t *testing.T) { t.Run("Should verify valid domain", func(t *testing.T) { testutil.WithTestDB(func(database db.DB) { setupDB(now, ctx, database) - domainVerification, err := database.GetDomainVerificationByName(ctx, "test.com") + domainVerification, err := database.GetDomainVerificationByName(ctx, "test.com", "test") if assert.NoError(t, err) { assert.Nil(t, domainVerification.VerifiedAt) } @@ -91,7 +98,7 @@ func TestVerifyDomainOwnership(t *testing.T) { } job.Run(ctx, logger) - domainVerification, err = database.GetDomainVerificationByName(ctx, "test.com") + domainVerification, err = database.GetDomainVerificationByName(ctx, "test.com", "test") if assert.NoError(t, err) { assert.NotNil(t, domainVerification.VerifiedAt) assert.True(t, domainVerification.VerifiedAt.After(now)) @@ -108,7 +115,7 @@ func TestVerifyDomainOwnership(t *testing.T) { testutil.WithTestDB(func(database db.DB) { setupDB(now, ctx, database) - domainVerification, err := database.GetDomainVerificationByName(ctx, "test.com") + domainVerification, err := database.GetDomainVerificationByName(ctx, "test.com", "test") if assert.NoError(t, err) { assert.Nil(t, domainVerification.VerifiedAt) } @@ -147,7 +154,7 @@ func TestVerifyDomainOwnership(t *testing.T) { testutil.WithTestDB(func(database db.DB) { setupDB(now, ctx, database) - domainVerification, err := database.GetDomainVerificationByName(ctx, "test.com") + domainVerification, err := database.GetDomainVerificationByName(ctx, "test.com", "test") if assert.NoError(t, err) { assert.Nil(t, domainVerification.VerifiedAt) } @@ -178,4 +185,52 @@ func TestVerifyDomainOwnership(t *testing.T) { } }) }) + t.Run("Should replace the conflict domain", func(t *testing.T) { + testutil.WithTestDB(func(database db.DB) { + data := setupDB(now, ctx, database) + err := db.WithTx(ctx, database, func(tx db.Tx) error { + return tx.CreateApp(ctx, models.NewApp( + now, + "test2", + data.userId, + )) + }) + assert.NoError(t, err) + err = db.WithTx(ctx, database, func(tx db.Tx) error { + return tx.CreateDomain(ctx, models.NewDomain( + now, + "test.com", + "test2", + "main", + )) + }) + assert.NoError(t, err) + domainVerification, err := database.GetDomainVerificationByName(ctx, "test.com", "test") + txtDomain, value := domainVerification.GetTxtRecord() + r := mockdns.Resolver{ + Zones: map[string]mockdns.Zone{ + txtDomain + ".": { + TXT: []string{ + value, + }, + }, + }, + } + job := cron.VerifyDomainOwnership{ + DB: database, + Resolver: &r, + MaxConsumeActiveDomainCount: 1, + MaxConsumePendingDomainCount: 1, + RevalidatePeriod: time.Hour, + } + err = job.Run(ctx, logger) + assert.NoError(t, err) + + domain, err := database.GetDomainByName(ctx, "test.com") + if assert.NoError(t, err) { + assert.Equal(t, "test.com", domain.Domain) + assert.Equal(t, "test", domain.AppID) + } + }) + }) } diff --git a/internal/db/db.go b/internal/db/db.go index 847b9cb..d2dccf0 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -86,7 +86,7 @@ type DomainsDB interface { type DomainVerificationDB interface { CreateDomainVerification(ctx context.Context, domainVerification *models.DomainVerification) error - GetDomainVerificationByName(ctx context.Context, domain string) (*models.DomainVerification, error) + GetDomainVerificationByName(ctx context.Context, domain string, appID string) (*models.DomainVerification, error) DeleteDomainVerification(ctx context.Context, id string, now time.Time) error ListDomainVerifications(ctx context.Context, appID string) ([]*models.DomainVerification, error) ListLeastRecentlyCheckedDomain(ctx context.Context, now time.Time, isVerified bool, count uint) ([]*models.DomainVerification, error) diff --git a/internal/db/postgres/domain_verification.go b/internal/db/postgres/domain_verification.go index dde082e..bcec5c2 100644 --- a/internal/db/postgres/domain_verification.go +++ b/internal/db/postgres/domain_verification.go @@ -32,14 +32,16 @@ func (q query[T]) CreateDomainVerification(ctx context.Context, domainVerificati return nil } -func (q query[T]) GetDomainVerificationByName(ctx context.Context, domainName string) (*models.DomainVerification, error) { +func (q query[T]) GetDomainVerificationByName(ctx context.Context, domainName string, appId string) (*models.DomainVerification, error) { var domainVerification models.DomainVerification err := sqlx.GetContext(ctx, q.ext, &domainVerification, ` - SELECT d.id, d.created_at, d.updated_at, d.deleted_at, d.domain, d.app_id, d.value, d.verified_at, d.domain_prefix FROM domain_verification d + SELECT d.id, d.created_at, d.updated_at, d.deleted_at, d.domain, d.app_id, d.value, d.verified_at, d.domain_prefix, + d.verified_at, d.domain_prefix, d.will_check_at, d.last_checked_at + FROM domain_verification d JOIN app a ON (a.id = d.app_id AND a.deleted_at IS NULL) - WHERE d.domain = $1 AND d.deleted_at IS NULL - `, domainName) + WHERE d.domain = $1 AND d.deleted_at IS NULL AND d.app_id = $2 + `, domainName, appId) if errors.Is(err, sql.ErrNoRows) { return nil, models.ErrDomainNotFound } else if err != nil { @@ -64,7 +66,8 @@ func (q query[T]) DeleteDomainVerification(ctx context.Context, id string, now t func (q query[T]) ListDomainVerifications(ctx context.Context, appID string) ([]*models.DomainVerification, error) { var domainVerifications []*models.DomainVerification stmt := ` - SELECT d.id, d.created_at, d.updated_at, d.deleted_at, d.domain, d.app_id, d.value, d.verified_at, d.domain_prefix + SELECT d.id, d.created_at, d.updated_at, d.deleted_at, d.domain, d.app_id, d.value, d.verified_at, d.domain_prefix, + d.verified_at, d.domain_prefix, d.will_check_at, d.last_checked_at FROM domain_verification d WHERE d.deleted_at IS NULL AND d.app_id = $1 ORDER BY d.domain, d.created_at @@ -103,7 +106,8 @@ func (q query[T]) SetDomainIsInvalid(ctx context.Context, id string, now time.Ti func (q query[T]) ListLeastRecentlyCheckedDomain(ctx context.Context, time time.Time, isVerified bool, count uint) ([]*models.DomainVerification, error) { var domainVerifications []*models.DomainVerification stmt := ` - SELECT d.id, d.created_at, d.updated_at, d.deleted_at, d.verified_at, d.last_checked_at, d.will_check_at, d.domain, d.domain_prefix, d.app_id, d.value + SELECT d.id, d.created_at, d.updated_at, d.deleted_at, d.verified_at, d.last_checked_at, d.will_check_at, d.domain, d.domain_prefix, d.app_id, d.value, + d.verified_at, d.domain_prefix, d.will_check_at, d.last_checked_at FROM domain_verification d WHERE %s ORDER BY d.will_check_at, d.last_checked_at NULLS FIRST diff --git a/internal/db/sqlite/domain_verification.go b/internal/db/sqlite/domain_verification.go index d593dd5..535e555 100644 --- a/internal/db/sqlite/domain_verification.go +++ b/internal/db/sqlite/domain_verification.go @@ -32,15 +32,16 @@ func (q query[T]) CreateDomainVerification(ctx context.Context, domainVerificati return nil } -func (q query[T]) GetDomainVerificationByName(ctx context.Context, domainName string) (*models.DomainVerification, error) { +func (q query[T]) GetDomainVerificationByName(ctx context.Context, domainName string, appId string) (*models.DomainVerification, error) { var domainVerification models.DomainVerification err := sqlx.GetContext(ctx, q.ext, &domainVerification, ` - SELECT d.id, d.created_at, d.updated_at, d.deleted_at, d.verified_at, d.last_checked_at, d.will_check_at, d.domain, d.domain_prefix, d.app_id, d.value + SELECT d.id, d.created_at, d.updated_at, d.deleted_at, d.verified_at, d.last_checked_at, d.will_check_at, d.domain, d.domain_prefix, d.app_id, d.value, + d.verified_at, d.domain_prefix, d.will_check_at, d.last_checked_at FROM domain_verification d JOIN app a ON (a.id = d.app_id AND a.deleted_at IS NULL) - WHERE d.domain = ? AND d.deleted_at IS NULL - `, domainName) + WHERE d.domain = ? AND d.deleted_at IS NULL AND d.app_id = ? + `, domainName, appId) if errors.Is(err, sql.ErrNoRows) { return nil, models.ErrDomainNotFound } else if err != nil { @@ -65,7 +66,8 @@ func (q query[T]) DeleteDomainVerification(ctx context.Context, id string, now t func (q query[T]) ListDomainVerifications(ctx context.Context, appID string) ([]*models.DomainVerification, error) { var domainVerifications []*models.DomainVerification stmt := ` - SELECT d.id, d.created_at, d.updated_at, d.deleted_at, d.verified_at, d.last_checked_at, d.will_check_at, d.domain, d.domain_prefix, d.app_id, d.value + SELECT d.id, d.created_at, d.updated_at, d.deleted_at, d.verified_at, d.last_checked_at, d.will_check_at, d.domain, d.domain_prefix, d.app_id, d.value, + d.verified_at, d.domain_prefix, d.will_check_at, d.last_checked_at FROM domain_verification d WHERE d.deleted_at IS NULL AND d.app_id = ? ORDER BY d.domain, d.created_at @@ -104,7 +106,8 @@ func (q query[T]) SetDomainIsInvalid(ctx context.Context, id string, now time.Ti func (q query[T]) ListLeastRecentlyCheckedDomain(ctx context.Context, time time.Time, isVerified bool, count uint) ([]*models.DomainVerification, error) { var domainVerifications []*models.DomainVerification stmt := ` - SELECT d.id, d.created_at, d.updated_at, d.deleted_at, d.verified_at, d.last_checked_at, d.will_check_at, d.domain, d.domain_prefix, d.app_id, d.value + SELECT d.id, d.created_at, d.updated_at, d.deleted_at, d.verified_at, d.last_checked_at, d.will_check_at, d.domain, d.domain_prefix, d.app_id, d.value, + d.verified_at, d.domain_prefix, d.will_check_at, d.last_checked_at FROM domain_verification d WHERE %s ORDER BY d.will_check_at, d.last_checked_at NULLS FIRST diff --git a/internal/handler/controller/domain.go b/internal/handler/controller/domain.go index 8ea70ff..25d3c96 100644 --- a/internal/handler/controller/domain.go +++ b/internal/handler/controller/domain.go @@ -34,7 +34,7 @@ func (c *Controller) handleDomainList(w http.ResponseWriter, r *http.Request) { if err != nil && !errors.Is(err, models.ErrDomainNotFound) { return nil, err } - domainVerification, err := c.DB.GetDomainVerificationByName(r.Context(), dconf.Domain) + domainVerification, err := c.DB.GetDomainVerificationByName(r.Context(), dconf.Domain, app.ID) if err != nil && !errors.Is(err, models.ErrDomainNotFound) { return nil, err } @@ -60,8 +60,7 @@ func (c *Controller) handleDomainVerification(w http.ResponseWriter, r *http.Req respond(w, withTx(r.Context(), c.DB, func(tx db.Tx) (any, error) { var domainVerification *models.DomainVerification domain, _ := tx.GetDomainByName(r.Context(), domainName) - domainVerification, _ = tx.GetDomainVerificationByName(r.Context(), domainName) - if domain == nil && domainVerification == nil { + domainVerification, _ = tx.GetDomainVerificationByName(r.Context(), domainName, app.ID) domainVerification = models.NewDomainVerification(c.Clock.Now().UTC(), domainName, app.ID) err := tx.CreateDomainVerification(r.Context(), domainVerification) if err != nil { @@ -87,6 +86,7 @@ func (c *Controller) handleDomainCreate(w http.ResponseWriter, r *http.Request) return nil, models.ErrUndefinedDomain } domain, err := tx.GetDomainByName(r.Context(), domainName) + domainVerification, _ := tx.GetDomainVerificationByName(r.Context(), domainName, app.ID) if errors.Is(err, models.ErrDomainNotFound) { // Continue create new domain. } else if err != nil { diff --git a/internal/handler/controller/domain_test.go b/internal/handler/controller/domain_test.go index 7aa34d9..1cfa0ac 100644 --- a/internal/handler/controller/domain_test.go +++ b/internal/handler/controller/domain_test.go @@ -323,9 +323,36 @@ func TestDomainVerification(t *testing.T) { } }) }) - t.Run("Should add a pending activating domain with app A", func(t *testing.T) { + t.Run("Should add a pending activating domain for same domain with different Apps", func(t *testing.T) { testutil.WithTestController(func(c *testutil.TestController) { token := setupDomainVerification(c) + + t.Run("Should add a pending activating domain with app B", func(t *testing.T) { + req := httptest.NewRequest("POST", "http://localtest.me/api/v1/apps/test/domains/test.com", nil) + req.Header.Add("Authorization", "bearer "+token) + w := httptest.NewRecorder() + c.ServeHTTP(w, req) + domain, err := testutil.DecodeJSONResponse[*api.APIDomain](w.Result()) + if assert.NoError(t, err) { + assert.Nil(t, domain.Domain) + assert.Equal(t, "test.com", domain.DomainVerification.Domain) + assert.Equal(t, "test", domain.DomainVerification.AppID) + } + }) + t.Run("Should add a pending activating domain with app B", func(t *testing.T) { + req := httptest.NewRequest("POST", "http://localtest.me/api/v1/apps/test2/domains/test.com", nil) + req.Header.Add("Authorization", "bearer "+token) + w := httptest.NewRecorder() + c.ServeHTTP(w, req) + domain, err := testutil.DecodeJSONResponse[*api.APIDomain](w.Result()) + if assert.NoError(t, err) { + assert.Nil(t, domain.Domain) + assert.Equal(t, "test.com", domain.DomainVerification.Domain) + assert.Equal(t, "test2", domain.DomainVerification.AppID) + } + }) + }) + }) req := httptest.NewRequest("POST", "http://localtest.me/api/v1/apps/test/domains/test.com", nil) req.Header.Add("Authorization", "bearer "+token) w := httptest.NewRecorder() @@ -337,9 +364,6 @@ func TestDomainVerification(t *testing.T) { } }) }) - t.Run("Should add a pending activating domain with app B", func(t *testing.T) { - testutil.WithTestController(func(c *testutil.TestController) { - token := setupDomainVerification(c) req := httptest.NewRequest("POST", "http://localtest.me/api/v1/apps/test2/domains/test.com", nil) req.Header.Add("Authorization", "bearer "+token) w := httptest.NewRecorder() @@ -350,14 +374,6 @@ func TestDomainVerification(t *testing.T) { assert.Equal(t, "test.com", domain.DomainVerification.Domain) } }) - }) - t.Run("Should not add a pending domain for existing domain", func(t *testing.T) { - testutil.WithTestController(func(c *testutil.TestController) { - token := setupDomainVerification(c) - db.WithTx(c.Context, c.DB, func(tx db.Tx) error { - return tx.CreateDomain(c.Context, models.NewDomain(time.Now(), "test.com", "test", "main")) - }) - req := httptest.NewRequest("POST", "http://localtest.me/api/v1/apps/test/domains/test.com", nil) req.Header.Add("Authorization", "bearer "+token) w := httptest.NewRecorder() c.ServeHTTP(w, req)