diff --git a/e2e/aws/databases_test.go b/e2e/aws/databases_test.go index fd985cdbc356b..e78d5aeeda9f1 100644 --- a/e2e/aws/databases_test.go +++ b/e2e/aws/databases_test.go @@ -22,17 +22,22 @@ import ( "context" "crypto/tls" "encoding/json" + "errors" "fmt" + "log/slog" "net" "os" "strconv" + "strings" "testing" "time" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/secretsmanager" mysqlclient "github.com/go-mysql-org/go-mysql/client" + "github.com/gravitational/trace" "github.com/jackc/pgconn" + "github.com/jackc/pgerrcode" "github.com/jackc/pgx/v4" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -40,6 +45,7 @@ import ( "github.com/gravitational/teleport" apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/utils/retryutils" "github.com/gravitational/teleport/integration/helpers" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/client" @@ -49,6 +55,7 @@ import ( "github.com/gravitational/teleport/lib/srv/db/common" "github.com/gravitational/teleport/lib/srv/db/postgres" "github.com/gravitational/teleport/lib/tlsca" + "github.com/gravitational/teleport/lib/utils" ) func TestDatabases(t *testing.T) { @@ -139,29 +146,14 @@ func postgresConnTest(t *testing.T, cluster *helpers.TeleInstance, user string, assert.NotNil(t, pgConn) }, waitForConnTimeout, connRetryTick, "connecting to postgres") - // dont wait forever on the exec or close. - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - // Execute a query. - results, err := pgConn.Exec(ctx, query).ReadAll() - require.NoError(t, err) - for i, r := range results { - require.NoError(t, r.Err, "error in result %v", i) - } - - // Disconnect. - err = pgConn.Close(ctx) - require.NoError(t, err) + execPGTestQuery(t, pgConn, query) } // postgresLocalProxyConnTest tests connection to a postgres database via // local proxy tunnel. func postgresLocalProxyConnTest(t *testing.T, cluster *helpers.TeleInstance, user string, route tlsca.RouteToDatabase, query string) { t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), 2*waitForConnTimeout) - defer cancel() - lp := startLocalALPNProxy(t, ctx, user, cluster, route) + lp := startLocalALPNProxy(t, user, cluster, route) pgconnConfig, err := pgconn.ParseConfig(fmt.Sprintf("postgres://%v/", lp.GetAddr())) require.NoError(t, err) @@ -179,30 +171,36 @@ func postgresLocalProxyConnTest(t *testing.T, cluster *helpers.TeleInstance, use assert.NotNil(t, pgConn) }, waitForConnTimeout, connRetryTick, "connecting to postgres") - // dont wait forever on the exec or close. - ctx, cancel = context.WithTimeout(context.Background(), 10*time.Second) + execPGTestQuery(t, pgConn, query) +} + +func execPGTestQuery(t *testing.T, conn *pgconn.PgConn, query string) { + t.Helper() + defer func() { + // dont wait forever to gracefully terminate. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + // Disconnect. + require.NoError(t, conn.Close(ctx)) + }() + + // dont wait forever on the exec. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() // Execute a query. - results, err := pgConn.Exec(ctx, query).ReadAll() + results, err := conn.Exec(ctx, query).ReadAll() require.NoError(t, err) for i, r := range results { require.NoError(t, r.Err, "error in result %v", i) } - - // Disconnect. - err = pgConn.Close(ctx) - require.NoError(t, err) } // mysqlLocalProxyConnTest tests connection to a MySQL database via // local proxy tunnel. func mysqlLocalProxyConnTest(t *testing.T, cluster *helpers.TeleInstance, user string, route tlsca.RouteToDatabase, query string) { t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), 2*waitForConnTimeout) - defer cancel() - - lp := startLocalALPNProxy(t, ctx, user, cluster, route) + lp := startLocalALPNProxy(t, user, cluster, route) var conn *mysqlclient.Conn // retry for a while, the database service might need time to give @@ -222,19 +220,22 @@ func mysqlLocalProxyConnTest(t *testing.T, cluster *helpers.TeleInstance, user s assert.NoError(t, err) assert.NotNil(t, conn) }, waitForConnTimeout, connRetryTick, "connecting to mysql") + defer func() { + // Disconnect. + require.NoError(t, conn.Close()) + }() // Execute a query. require.NoError(t, conn.SetDeadline(time.Now().Add(10*time.Second))) _, err := conn.Execute(query) require.NoError(t, err) - - // Disconnect. - require.NoError(t, conn.Close()) } // startLocalALPNProxy starts local ALPN proxy for the specified database. -func startLocalALPNProxy(t *testing.T, ctx context.Context, user string, cluster *helpers.TeleInstance, route tlsca.RouteToDatabase) *alpnproxy.LocalProxy { +func startLocalALPNProxy(t *testing.T, user string, cluster *helpers.TeleInstance, route tlsca.RouteToDatabase) *alpnproxy.LocalProxy { t.Helper() + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) proto, err := alpncommon.ToALPNProtocol(route.Protocol) require.NoError(t, err) @@ -333,7 +334,7 @@ type dbUserLogin struct { port int } -func connectPostgres(t *testing.T, ctx context.Context, info dbUserLogin, dbName string) *pgx.Conn { +func connectPostgres(t *testing.T, ctx context.Context, info dbUserLogin, dbName string) *pgConn { pgCfg, err := pgx.ParseConfig(fmt.Sprintf("postgres://%s:%d/?sslmode=verify-full", info.address, info.port)) require.NoError(t, err) pgCfg.User = info.username @@ -349,7 +350,10 @@ func connectPostgres(t *testing.T, ctx context.Context, info dbUserLogin, dbName t.Cleanup(func() { _ = conn.Close(ctx) }) - return conn + return &pgConn{ + logger: utils.NewSlogLoggerForTests(), + Conn: conn, + } } // secretPassword is used to unmarshal an AWS Secrets Manager @@ -391,3 +395,77 @@ func getSecretValue(t *testing.T, ctx context.Context, secretID string) secretsm require.NotNil(t, secretVal) return *secretVal } + +// pgConn wraps a [pgx.Conn] and adds retries to all Exec calls. +type pgConn struct { + logger *slog.Logger + *pgx.Conn +} + +func (c *pgConn) Exec(ctx context.Context, sql string, args ...interface{}) (pgconn.CommandTag, error) { + var out pgconn.CommandTag + err := withRetry(ctx, c.logger, func() error { + var err error + out, err = c.Conn.Exec(ctx, sql, args...) + return trace.Wrap(err) + }) + return out, trace.Wrap(err) +} + +// withRetry runs a given func a finite number of times until it returns nil +// error or the given context is done. +func withRetry(ctx context.Context, log *slog.Logger, f func() error) error { + linear, err := retryutils.NewLinear(retryutils.LinearConfig{ + First: 0, + Step: 500 * time.Millisecond, + Max: 5 * time.Second, + Jitter: retryutils.NewHalfJitter(), + }) + if err != nil { + return trace.Wrap(err) + } + + // retry a finite number of times before giving up. + const retries = 10 + for i := 0; i < retries; i++ { + err := f() + if err == nil { + return nil + } + + if isRetryable(err) { + log.DebugContext(ctx, "operation failed, retrying", "error", err) + } else { + return trace.Wrap(err) + } + + linear.Inc() + select { + case <-linear.After(): + case <-ctx.Done(): + return trace.Wrap(ctx.Err()) + } + } + return trace.Wrap(err, "too many retries") +} + +// isRetryable returns true if an error can be retried. +func isRetryable(err error) bool { + var pgErr *pgconn.PgError + err = trace.Unwrap(err) + if errors.As(err, &pgErr) { + // https://www.postgresql.org/docs/current/mvcc-serialization-failure-handling.html + switch pgErr.Code { + case pgerrcode.DeadlockDetected, pgerrcode.SerializationFailure, + pgerrcode.UniqueViolation, pgerrcode.ExclusionViolation: + return true + } + } + // Redshift reports this with a vague SQLSTATE XX000, which is the internal + // error code, but this is a serialization error that rolls back the + // transaction, so it should be retried. + if strings.Contains(err.Error(), "conflict with concurrent transaction") { + return true + } + return pgconn.SafeToRetry(err) +} diff --git a/e2e/aws/fixtures_test.go b/e2e/aws/fixtures_test.go index c196af60b663d..9831e1d502cf3 100644 --- a/e2e/aws/fixtures_test.go +++ b/e2e/aws/fixtures_test.go @@ -241,10 +241,6 @@ func withDiscoveryService(t *testing.T, discoveryGroup string, awsMatchers ...ty options.serviceConfigFuncs = append(options.serviceConfigFuncs, func(cfg *servicecfg.Config) { cfg.Discovery.Enabled = true cfg.Discovery.DiscoveryGroup = discoveryGroup - // Reduce the polling interval to speed up the test execution - // in the case of a failure of the first attempt. - // The default polling interval is 5 minutes. - cfg.Discovery.PollInterval = 1 * time.Minute cfg.Discovery.AWSMatchers = append(cfg.Discovery.AWSMatchers, awsMatchers...) }) } diff --git a/e2e/aws/rds_test.go b/e2e/aws/rds_test.go index a43ab478b47de..71e06d323342c 100644 --- a/e2e/aws/rds_test.go +++ b/e2e/aws/rds_test.go @@ -30,7 +30,6 @@ import ( "github.com/aws/aws-sdk-go-v2/service/rds" mysqlclient "github.com/go-mysql-org/go-mysql/client" "github.com/go-mysql-org/go-mysql/mysql" - "github.com/jackc/pgx/v4" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -440,7 +439,7 @@ func testRDS(t *testing.T) { }) } -func connectAsRDSPostgresAdmin(t *testing.T, ctx context.Context, instanceID string) *pgx.Conn { +func connectAsRDSPostgresAdmin(t *testing.T, ctx context.Context, instanceID string) *pgConn { t.Helper() info := getRDSAdminInfo(t, ctx, instanceID) const dbName = "postgres" @@ -508,7 +507,7 @@ func getRDSAdminInfo(t *testing.T, ctx context.Context, instanceID string) dbUse // provisionRDSPostgresAutoUsersAdmin provisions an admin user suitable for auto-user // provisioning. -func provisionRDSPostgresAutoUsersAdmin(t *testing.T, ctx context.Context, conn *pgx.Conn, adminUser string) { +func provisionRDSPostgresAutoUsersAdmin(t *testing.T, ctx context.Context, conn *pgConn, adminUser string) { t.Helper() // Create the admin user and grant rds_iam so Teleport can auth // with IAM as an existing user. @@ -599,7 +598,7 @@ const ( autoUserWaitStep = 10 * time.Second ) -func waitForPostgresAutoUserDeactivate(t *testing.T, ctx context.Context, conn *pgx.Conn, user string) { +func waitForPostgresAutoUserDeactivate(t *testing.T, ctx context.Context, conn *pgConn, user string) { t.Helper() require.EventuallyWithT(t, func(c *assert.CollectT) { // `Query` documents that it is always safe to attempt to read from the @@ -640,7 +639,7 @@ func waitForPostgresAutoUserDeactivate(t *testing.T, ctx context.Context, conn * }, autoUserWaitDur, autoUserWaitStep, "waiting for auto user %q to be deactivated", user) } -func waitForPostgresAutoUserDrop(t *testing.T, ctx context.Context, conn *pgx.Conn, user string) { +func waitForPostgresAutoUserDrop(t *testing.T, ctx context.Context, conn *pgConn, user string) { t.Helper() require.EventuallyWithT(t, func(c *assert.CollectT) { // `Query` documents that it is always safe to attempt to read from the diff --git a/e2e/aws/redshift_test.go b/e2e/aws/redshift_test.go index 6009e3c9df7af..c8e9bbf418c20 100644 --- a/e2e/aws/redshift_test.go +++ b/e2e/aws/redshift_test.go @@ -27,7 +27,6 @@ import ( "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/redshift" - "github.com/jackc/pgx/v4" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -96,7 +95,7 @@ func testRedshiftCluster(t *testing.T) { // eachother. labels := db.GetStaticLabels() labels[types.DatabaseAdminLabel] = "test_admin_" + randASCII(t, 6) - cluster.Process.GetAuthServer().UpdateDatabase(ctx, db) + err = cluster.Process.GetAuthServer().UpdateDatabase(ctx, db) require.NoError(t, err) adminUser := mustGetDBAdmin(t, db) @@ -213,7 +212,7 @@ func testRedshiftCluster(t *testing.T) { } } -func connectAsRedshiftClusterAdmin(t *testing.T, ctx context.Context, clusterID string) *pgx.Conn { +func connectAsRedshiftClusterAdmin(t *testing.T, ctx context.Context, clusterID string) *pgConn { t.Helper() info := getRedshiftAdminInfo(t, ctx, clusterID) const dbName = "dev" @@ -247,7 +246,7 @@ func getRedshiftAdminInfo(t *testing.T, ctx context.Context, clusterID string) d // provisionRedshiftAutoUsersAdmin provisions an admin user suitable for auto-user // provisioning. -func provisionRedshiftAutoUsersAdmin(t *testing.T, ctx context.Context, conn *pgx.Conn, adminUser string) { +func provisionRedshiftAutoUsersAdmin(t *testing.T, ctx context.Context, conn *pgConn, adminUser string) { t.Helper() // Don't cleanup the db admin after, because test runs would interfere // with each other. @@ -261,7 +260,7 @@ func provisionRedshiftAutoUsersAdmin(t *testing.T, ctx context.Context, conn *pg } } -func waitForRedshiftAutoUserDeactivate(t *testing.T, ctx context.Context, conn *pgx.Conn, user string) { +func waitForRedshiftAutoUserDeactivate(t *testing.T, ctx context.Context, conn *pgConn, user string) { t.Helper() require.EventuallyWithT(t, func(c *assert.CollectT) { // `Query` documents that it is always safe to attempt to read from the @@ -300,7 +299,7 @@ func waitForRedshiftAutoUserDeactivate(t *testing.T, ctx context.Context, conn * }, autoUserWaitDur, autoUserWaitStep, "waiting for auto user %q to be deactivated", user) } -func waitForRedshiftAutoUserDrop(t *testing.T, ctx context.Context, conn *pgx.Conn, user string) { +func waitForRedshiftAutoUserDrop(t *testing.T, ctx context.Context, conn *pgConn, user string) { t.Helper() require.EventuallyWithT(t, func(c *assert.CollectT) { // `Query` documents that it is always safe to attempt to read from the