From 592acc8a20cf455a321942d923e584980eb3a05c Mon Sep 17 00:00:00 2001 From: Kaushik Iska Date: Wed, 20 Dec 2023 08:31:44 -0500 Subject: [PATCH 1/2] Refactor replica identity type and primary key column retrieval in PostgreSQL connector This is to support `USING INDEX` replica identity types in postgres. Treating the index columns as primary key columns for now as it is the fastest way for us to support replica identity index. --- flow/connectors/postgres/client.go | 99 +++++++++++++++++++------ flow/connectors/postgres/postgres.go | 12 +-- flow/e2e/snowflake/peer_flow_sf_test.go | 66 +++++++++++++++++ 3 files changed, 149 insertions(+), 28 deletions(-) diff --git a/flow/connectors/postgres/client.go b/flow/connectors/postgres/client.go index 9aa05131c7..e9d8ad4910 100644 --- a/flow/connectors/postgres/client.go +++ b/flow/connectors/postgres/client.go @@ -15,6 +15,7 @@ import ( "github.com/jackc/pglogrepl" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" + "github.com/lib/pq/oid" "golang.org/x/exp/maps" ) @@ -77,6 +78,15 @@ const ( deleteJobMetadataSQL = "DELETE FROM %s.%s WHERE MIRROR_JOB_NAME=$1" ) +type ReplicaIdentityType string + +const ( + ReplicaIdentityDefault ReplicaIdentityType = "d" + ReplicaIdentityFull = "f" + ReplicaIdentityIndex = "i" + ReplicaIdentityNothing = "n" +) + // getRelIDForTable returns the relation ID for a table. func (c *PostgresConnector) getRelIDForTable(schemaTable *utils.SchemaTable) (uint32, error) { var relID pgtype.Uint32 @@ -92,10 +102,10 @@ func (c *PostgresConnector) getRelIDForTable(schemaTable *utils.SchemaTable) (ui } // getReplicaIdentity returns the replica identity for a table. -func (c *PostgresConnector) isTableFullReplica(schemaTable *utils.SchemaTable) (bool, error) { +func (c *PostgresConnector) getReplicaIdentityType(schemaTable *utils.SchemaTable) (ReplicaIdentityType, error) { relID, relIDErr := c.getRelIDForTable(schemaTable) if relIDErr != nil { - return false, fmt.Errorf("failed to get relation id for table %s: %w", schemaTable, relIDErr) + return ReplicaIdentityDefault, fmt.Errorf("failed to get relation id for table %s: %w", schemaTable, relIDErr) } var replicaIdentity rune @@ -103,43 +113,88 @@ func (c *PostgresConnector) isTableFullReplica(schemaTable *utils.SchemaTable) ( `SELECT relreplident FROM pg_class WHERE oid = $1;`, relID).Scan(&replicaIdentity) if err != nil { - return false, fmt.Errorf("error getting replica identity for table %s: %w", schemaTable, err) + return ReplicaIdentityDefault, fmt.Errorf("error getting replica identity for table %s: %w", schemaTable, err) + } + + var replicaIdentityMap = map[rune]ReplicaIdentityType{ + 'd': ReplicaIdentityDefault, + 'f': ReplicaIdentityFull, + 'i': ReplicaIdentityIndex, + 'n': ReplicaIdentityNothing, + } + + replicaType, ok := replicaIdentityMap[replicaIdentity] + if !ok { + return ReplicaIdentityDefault, fmt.Errorf("unknown replica identity type %c for table %s", replicaIdentity, schemaTable) } - return string(replicaIdentity) == "f", nil + + return replicaType, nil } -// getPrimaryKeyColumns for table returns the primary key column for a given table -// errors if there is no primary key column or if there is more than one primary key column. -func (c *PostgresConnector) getPrimaryKeyColumns(schemaTable *utils.SchemaTable) ([]string, error) { +// getPrimaryKeyColumns returns the primary key columns for a given table. +// Errors if there is no primary key column or if there is more than one primary key column. +func (c *PostgresConnector) getPrimaryKeyColumns( + replicaIdentity ReplicaIdentityType, + schemaTable *utils.SchemaTable, +) ([]string, error) { relID, err := c.getRelIDForTable(schemaTable) if err != nil { return nil, fmt.Errorf("failed to get relation id for table %s: %w", schemaTable, err) } - // Get the primary key column name - var pkCol pgtype.Text - pkCols := make([]string, 0) + if replicaIdentity == ReplicaIdentityIndex { + return c.getReplicaIdentityIndexColumns(relID, schemaTable) + } + + // Find the primary key index OID + var pkIndexOID oid.Oid + err = c.pool.QueryRow(c.ctx, + `SELECT indexrelid FROM pg_index WHERE indrelid = $1 AND indisprimary`, + relID).Scan(&pkIndexOID) + if err != nil { + return nil, fmt.Errorf("error finding primary key index for table %s: %w", schemaTable, err) + } + + return c.getColumnNamesForIndex(pkIndexOID) +} + +// getReplicaIdentityIndexColumns returns the columns used in the replica identity index. +func (c *PostgresConnector) getReplicaIdentityIndexColumns(relID uint32, schemaTable *utils.SchemaTable) ([]string, error) { + var indexRelID oid.Oid + // Fetch the OID of the index used as the replica identity + err := c.pool.QueryRow(c.ctx, + `SELECT indexrelid FROM pg_index + WHERE indrelid = $1 AND indisreplident = true`, + relID).Scan(&indexRelID) + if err != nil { + return nil, fmt.Errorf("error finding replica identity index for table %s: %w", schemaTable, err) + } + + return c.getColumnNamesForIndex(indexRelID) +} + +// getColumnNamesForIndex returns the column names for a given index. +func (c *PostgresConnector) getColumnNamesForIndex(indexOID oid.Oid) ([]string, error) { + var col pgtype.Text + cols := make([]string, 0) rows, err := c.pool.Query(c.ctx, `SELECT a.attname FROM pg_index i JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey) - WHERE i.indrelid = $1 AND i.indisprimary ORDER BY a.attname ASC`, - relID) + WHERE i.indexrelid = $1 ORDER BY a.attname ASC`, + indexOID) if err != nil { - return nil, fmt.Errorf("error getting primary key column for table %s: %w", schemaTable, err) + return nil, fmt.Errorf("error getting columns for index %v: %w", indexOID, err) } defer rows.Close() - for { - if !rows.Next() { - break - } - err = rows.Scan(&pkCol) + + for rows.Next() { + err = rows.Scan(&col) if err != nil { - return nil, fmt.Errorf("error scanning primary key column for table %s: %w", schemaTable, err) + return nil, fmt.Errorf("error scanning column for index %v: %w", indexOID, err) } - pkCols = append(pkCols, pkCol.String) + cols = append(cols, col.String) } - - return pkCols, nil + return cols, nil } func (c *PostgresConnector) tableExists(schemaTable *utils.SchemaTable) (bool, error) { diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index 82426b3e3f..cad914706f 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -558,12 +558,12 @@ func (c *PostgresConnector) getTableSchemaForTable( return nil, err } - isFullReplica, replErr := c.isTableFullReplica(schemaTable) + replicaIdentityType, replErr := c.getReplicaIdentityType(schemaTable) if replErr != nil { return nil, fmt.Errorf("error getting replica identity for table %s: %w", schemaTable, replErr) } - pKeyCols, err := c.getPrimaryKeyColumns(schemaTable) + pKeyCols, err := c.getPrimaryKeyColumns(replicaIdentityType, schemaTable) if err != nil { return nil, fmt.Errorf("error getting primary key column for table %s: %w", schemaTable, err) } @@ -581,7 +581,7 @@ func (c *PostgresConnector) getTableSchemaForTable( TableIdentifier: tableName, Columns: make(map[string]string), PrimaryKeyColumns: pKeyCols, - IsReplicaIdentityFull: isFullReplica, + IsReplicaIdentityFull: replicaIdentityType == ReplicaIdentityFull, } for _, fieldDescription := range rows.FieldDescriptions() { @@ -731,18 +731,18 @@ func (c *PostgresConnector) EnsurePullability(req *protos.EnsurePullabilityBatch return nil, err } - isFullReplica, replErr := c.isTableFullReplica(schemaTable) + replicaIdentity, replErr := c.getReplicaIdentityType(schemaTable) if replErr != nil { return nil, fmt.Errorf("error getting replica identity for table %s: %w", schemaTable, replErr) } - pKeyCols, err := c.getPrimaryKeyColumns(schemaTable) + pKeyCols, err := c.getPrimaryKeyColumns(replicaIdentity, schemaTable) if err != nil { return nil, fmt.Errorf("error getting primary key column for table %s: %w", schemaTable, err) } // we only allow no primary key if the table has REPLICA IDENTITY FULL - if len(pKeyCols) == 0 && !isFullReplica { + if len(pKeyCols) == 0 && !(replicaIdentity == ReplicaIdentityFull) { return nil, fmt.Errorf("table %s has no primary keys and does not have REPLICA IDENTITY FULL", schemaTable) } diff --git a/flow/e2e/snowflake/peer_flow_sf_test.go b/flow/e2e/snowflake/peer_flow_sf_test.go index d4ff50751f..8d521dbb72 100644 --- a/flow/e2e/snowflake/peer_flow_sf_test.go +++ b/flow/e2e/snowflake/peer_flow_sf_test.go @@ -198,6 +198,72 @@ func (s PeerFlowE2ETestSuiteSF) Test_Complete_Simple_Flow_SF() { env.AssertExpectations(s.t) } +func (s PeerFlowE2ETestSuiteSF) Test_Flow_ReplicaIdentity_Index_No_Pkey() { + env := e2e.NewTemporalTestWorkflowEnvironment() + e2e.RegisterWorkflowsAndActivities(env, s.t) + + srcTableName := s.attachSchemaSuffix("test_replica_identity_no_pkey") + dstTableName := fmt.Sprintf("%s.%s", s.sfHelper.testSchemaName, "test_replica_identity_no_pkey") + + // Create a table without a primary key and create a named unique index + _, err := s.pool.Exec(context.Background(), fmt.Sprintf(` + CREATE TABLE IF NOT EXISTS %s ( + id SERIAL, + key TEXT NOT NULL, + value TEXT NOT NULL + ); + CREATE UNIQUE INDEX unique_idx_on_id_key ON %s (id, key); + ALTER TABLE %s REPLICA IDENTITY USING INDEX unique_idx_on_id_key; + `, srcTableName, srcTableName, srcTableName)) + require.NoError(s.t, err) + + connectionGen := e2e.FlowConnectionGenerationConfig{ + FlowJobName: s.attachSuffix("test_simple_flow"), + TableNameMapping: map[string]string{srcTableName: dstTableName}, + PostgresPort: e2e.PostgresPort, + Destination: s.sfHelper.Peer, + } + + flowConnConfig, err := connectionGen.GenerateFlowConnectionConfigs() + require.NoError(s.t, err) + + limits := peerflow.CDCFlowLimits{ + ExitAfterRecords: 20, + MaxBatchSize: 100, + } + + // in a separate goroutine, wait for PeerFlowStatusQuery to finish setup + // and then insert 20 rows into the source table + go func() { + e2e.SetupCDCFlowStatusQuery(env, connectionGen) + // insert 20 rows into the source table + for i := 0; i < 20; i++ { + testKey := fmt.Sprintf("test_key_%d", i) + testValue := fmt.Sprintf("test_value_%d", i) + _, err = s.pool.Exec(context.Background(), fmt.Sprintf(` + INSERT INTO %s (id, key, value) VALUES ($1, $2, $3) + `, srcTableName), i, testKey, testValue) + require.NoError(s.t, err) + } + fmt.Println("Inserted 20 rows into the source table") + }() + + env.ExecuteWorkflow(peerflow.CDCFlowWorkflowWithConfig, flowConnConfig, &limits, nil) + + // Verify workflow completes without error + s.True(env.IsWorkflowCompleted()) + err = env.GetWorkflowError() + + // allow only continue as new error + require.Contains(s.t, err.Error(), "continue as new") + + count, err := s.sfHelper.CountRows("test_replica_identity_no_pkey") + require.NoError(s.t, err) + s.Equal(20, count) + + env.AssertExpectations(s.t) +} + func (s PeerFlowE2ETestSuiteSF) Test_Invalid_Geo_SF_Avro_CDC() { env := e2e.NewTemporalTestWorkflowEnvironment() e2e.RegisterWorkflowsAndActivities(env, s.t) From 4edc10645715e47e975ab5eec60209cad916dc5f Mon Sep 17 00:00:00 2001 From: Kaushik Iska Date: Wed, 20 Dec 2023 09:24:24 -0500 Subject: [PATCH 2/2] make replica identity type a rune --- flow/connectors/postgres/client.go | 24 ++++++------------------ 1 file changed, 6 insertions(+), 18 deletions(-) diff --git a/flow/connectors/postgres/client.go b/flow/connectors/postgres/client.go index e9d8ad4910..5e00ca8a7f 100644 --- a/flow/connectors/postgres/client.go +++ b/flow/connectors/postgres/client.go @@ -78,13 +78,13 @@ const ( deleteJobMetadataSQL = "DELETE FROM %s.%s WHERE MIRROR_JOB_NAME=$1" ) -type ReplicaIdentityType string +type ReplicaIdentityType rune const ( - ReplicaIdentityDefault ReplicaIdentityType = "d" - ReplicaIdentityFull = "f" - ReplicaIdentityIndex = "i" - ReplicaIdentityNothing = "n" + ReplicaIdentityDefault ReplicaIdentityType = 'd' + ReplicaIdentityFull = 'f' + ReplicaIdentityIndex = 'i' + ReplicaIdentityNothing = 'n' ) // getRelIDForTable returns the relation ID for a table. @@ -116,19 +116,7 @@ func (c *PostgresConnector) getReplicaIdentityType(schemaTable *utils.SchemaTabl return ReplicaIdentityDefault, fmt.Errorf("error getting replica identity for table %s: %w", schemaTable, err) } - var replicaIdentityMap = map[rune]ReplicaIdentityType{ - 'd': ReplicaIdentityDefault, - 'f': ReplicaIdentityFull, - 'i': ReplicaIdentityIndex, - 'n': ReplicaIdentityNothing, - } - - replicaType, ok := replicaIdentityMap[replicaIdentity] - if !ok { - return ReplicaIdentityDefault, fmt.Errorf("unknown replica identity type %c for table %s", replicaIdentity, schemaTable) - } - - return replicaType, nil + return ReplicaIdentityType(replicaIdentity), nil } // getPrimaryKeyColumns returns the primary key columns for a given table.