Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor replica identity type and primary key column retrieval in Postgres #860

Merged
merged 3 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 65 additions & 22 deletions flow/connectors/postgres/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/jackc/pglogrepl"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/lib/pq/oid"
"golang.org/x/exp/maps"
)

Expand Down Expand Up @@ -77,6 +78,15 @@ const (
deleteJobMetadataSQL = "DELETE FROM %s.%s WHERE MIRROR_JOB_NAME=$1"
)

type ReplicaIdentityType rune

const (
ReplicaIdentityDefault ReplicaIdentityType = 'd'
ReplicaIdentityFull = 'f'
ReplicaIdentityIndex = 'i'
ReplicaIdentityNothing = 'n'
)

// getRelIDForTable returns the relation ID for a table.
func (c *PostgresConnector) getRelIDForTable(schemaTable *utils.SchemaTable) (uint32, error) {
var relID pgtype.Uint32
Expand All @@ -92,54 +102,87 @@ func (c *PostgresConnector) getRelIDForTable(schemaTable *utils.SchemaTable) (ui
}

// getReplicaIdentity returns the replica identity for a table.
func (c *PostgresConnector) isTableFullReplica(schemaTable *utils.SchemaTable) (bool, error) {
func (c *PostgresConnector) getReplicaIdentityType(schemaTable *utils.SchemaTable) (ReplicaIdentityType, error) {
relID, relIDErr := c.getRelIDForTable(schemaTable)
if relIDErr != nil {
return false, fmt.Errorf("failed to get relation id for table %s: %w", schemaTable, relIDErr)
return ReplicaIdentityDefault, fmt.Errorf("failed to get relation id for table %s: %w", schemaTable, relIDErr)
}

var replicaIdentity rune
err := c.pool.QueryRow(c.ctx,
`SELECT relreplident FROM pg_class WHERE oid = $1;`,
relID).Scan(&replicaIdentity)
if err != nil {
return false, fmt.Errorf("error getting replica identity for table %s: %w", schemaTable, err)
return ReplicaIdentityDefault, fmt.Errorf("error getting replica identity for table %s: %w", schemaTable, err)
}
return string(replicaIdentity) == "f", nil

return ReplicaIdentityType(replicaIdentity), nil
}

// getPrimaryKeyColumns for table returns the primary key column for a given table
// errors if there is no primary key column or if there is more than one primary key column.
func (c *PostgresConnector) getPrimaryKeyColumns(schemaTable *utils.SchemaTable) ([]string, error) {
// getPrimaryKeyColumns returns the primary key columns for a given table.
// Errors if there is no primary key column or if there is more than one primary key column.
func (c *PostgresConnector) getPrimaryKeyColumns(
replicaIdentity ReplicaIdentityType,
schemaTable *utils.SchemaTable,
) ([]string, error) {
relID, err := c.getRelIDForTable(schemaTable)
if err != nil {
return nil, fmt.Errorf("failed to get relation id for table %s: %w", schemaTable, err)
}

// Get the primary key column name
var pkCol pgtype.Text
pkCols := make([]string, 0)
if replicaIdentity == ReplicaIdentityIndex {
return c.getReplicaIdentityIndexColumns(relID, schemaTable)
}

// Find the primary key index OID
var pkIndexOID oid.Oid
err = c.pool.QueryRow(c.ctx,
`SELECT indexrelid FROM pg_index WHERE indrelid = $1 AND indisprimary`,
relID).Scan(&pkIndexOID)
if err != nil {
return nil, fmt.Errorf("error finding primary key index for table %s: %w", schemaTable, err)
}

return c.getColumnNamesForIndex(pkIndexOID)
}

// getReplicaIdentityIndexColumns returns the columns used in the replica identity index.
func (c *PostgresConnector) getReplicaIdentityIndexColumns(relID uint32, schemaTable *utils.SchemaTable) ([]string, error) {
var indexRelID oid.Oid
// Fetch the OID of the index used as the replica identity
err := c.pool.QueryRow(c.ctx,
`SELECT indexrelid FROM pg_index
WHERE indrelid = $1 AND indisreplident = true`,
relID).Scan(&indexRelID)
if err != nil {
return nil, fmt.Errorf("error finding replica identity index for table %s: %w", schemaTable, err)
}

return c.getColumnNamesForIndex(indexRelID)
}

// getColumnNamesForIndex returns the column names for a given index.
func (c *PostgresConnector) getColumnNamesForIndex(indexOID oid.Oid) ([]string, error) {
var col pgtype.Text
cols := make([]string, 0)
rows, err := c.pool.Query(c.ctx,
`SELECT a.attname FROM pg_index i
JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
WHERE i.indrelid = $1 AND i.indisprimary ORDER BY a.attname ASC`,
relID)
WHERE i.indexrelid = $1 ORDER BY a.attname ASC`,
indexOID)
if err != nil {
return nil, fmt.Errorf("error getting primary key column for table %s: %w", schemaTable, err)
return nil, fmt.Errorf("error getting columns for index %v: %w", indexOID, err)
}
defer rows.Close()
for {
if !rows.Next() {
break
}
err = rows.Scan(&pkCol)

for rows.Next() {
err = rows.Scan(&col)
if err != nil {
return nil, fmt.Errorf("error scanning primary key column for table %s: %w", schemaTable, err)
return nil, fmt.Errorf("error scanning column for index %v: %w", indexOID, err)
}
pkCols = append(pkCols, pkCol.String)
cols = append(cols, col.String)
}

return pkCols, nil
return cols, nil
}

func (c *PostgresConnector) tableExists(schemaTable *utils.SchemaTable) (bool, error) {
Expand Down
12 changes: 6 additions & 6 deletions flow/connectors/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -558,12 +558,12 @@ func (c *PostgresConnector) getTableSchemaForTable(
return nil, err
}

isFullReplica, replErr := c.isTableFullReplica(schemaTable)
replicaIdentityType, replErr := c.getReplicaIdentityType(schemaTable)
if replErr != nil {
return nil, fmt.Errorf("error getting replica identity for table %s: %w", schemaTable, replErr)
}

pKeyCols, err := c.getPrimaryKeyColumns(schemaTable)
pKeyCols, err := c.getPrimaryKeyColumns(replicaIdentityType, schemaTable)
if err != nil {
return nil, fmt.Errorf("error getting primary key column for table %s: %w", schemaTable, err)
}
Expand All @@ -581,7 +581,7 @@ func (c *PostgresConnector) getTableSchemaForTable(
TableIdentifier: tableName,
Columns: make(map[string]string),
PrimaryKeyColumns: pKeyCols,
IsReplicaIdentityFull: isFullReplica,
IsReplicaIdentityFull: replicaIdentityType == ReplicaIdentityFull,
}

for _, fieldDescription := range rows.FieldDescriptions() {
Expand Down Expand Up @@ -731,18 +731,18 @@ func (c *PostgresConnector) EnsurePullability(req *protos.EnsurePullabilityBatch
return nil, err
}

isFullReplica, replErr := c.isTableFullReplica(schemaTable)
replicaIdentity, replErr := c.getReplicaIdentityType(schemaTable)
if replErr != nil {
return nil, fmt.Errorf("error getting replica identity for table %s: %w", schemaTable, replErr)
}

pKeyCols, err := c.getPrimaryKeyColumns(schemaTable)
pKeyCols, err := c.getPrimaryKeyColumns(replicaIdentity, schemaTable)
if err != nil {
return nil, fmt.Errorf("error getting primary key column for table %s: %w", schemaTable, err)
}

// we only allow no primary key if the table has REPLICA IDENTITY FULL
if len(pKeyCols) == 0 && !isFullReplica {
if len(pKeyCols) == 0 && !(replicaIdentity == ReplicaIdentityFull) {
return nil, fmt.Errorf("table %s has no primary keys and does not have REPLICA IDENTITY FULL", schemaTable)
}

Expand Down
66 changes: 66 additions & 0 deletions flow/e2e/snowflake/peer_flow_sf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,72 @@ func (s PeerFlowE2ETestSuiteSF) Test_Complete_Simple_Flow_SF() {
env.AssertExpectations(s.t)
}

func (s PeerFlowE2ETestSuiteSF) Test_Flow_ReplicaIdentity_Index_No_Pkey() {
env := e2e.NewTemporalTestWorkflowEnvironment()
e2e.RegisterWorkflowsAndActivities(env, s.t)

srcTableName := s.attachSchemaSuffix("test_replica_identity_no_pkey")
dstTableName := fmt.Sprintf("%s.%s", s.sfHelper.testSchemaName, "test_replica_identity_no_pkey")

// Create a table without a primary key and create a named unique index
_, err := s.pool.Exec(context.Background(), fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (
id SERIAL,
key TEXT NOT NULL,
value TEXT NOT NULL
);
CREATE UNIQUE INDEX unique_idx_on_id_key ON %s (id, key);
ALTER TABLE %s REPLICA IDENTITY USING INDEX unique_idx_on_id_key;
`, srcTableName, srcTableName, srcTableName))
require.NoError(s.t, err)

connectionGen := e2e.FlowConnectionGenerationConfig{
FlowJobName: s.attachSuffix("test_simple_flow"),
TableNameMapping: map[string]string{srcTableName: dstTableName},
PostgresPort: e2e.PostgresPort,
Destination: s.sfHelper.Peer,
}

flowConnConfig, err := connectionGen.GenerateFlowConnectionConfigs()
require.NoError(s.t, err)

limits := peerflow.CDCFlowLimits{
ExitAfterRecords: 20,
MaxBatchSize: 100,
}

// in a separate goroutine, wait for PeerFlowStatusQuery to finish setup
// and then insert 20 rows into the source table
go func() {
e2e.SetupCDCFlowStatusQuery(env, connectionGen)
// insert 20 rows into the source table
for i := 0; i < 20; i++ {
testKey := fmt.Sprintf("test_key_%d", i)
testValue := fmt.Sprintf("test_value_%d", i)
_, err = s.pool.Exec(context.Background(), fmt.Sprintf(`
INSERT INTO %s (id, key, value) VALUES ($1, $2, $3)
`, srcTableName), i, testKey, testValue)
require.NoError(s.t, err)
}
fmt.Println("Inserted 20 rows into the source table")
}()

env.ExecuteWorkflow(peerflow.CDCFlowWorkflowWithConfig, flowConnConfig, &limits, nil)

// Verify workflow completes without error
s.True(env.IsWorkflowCompleted())
err = env.GetWorkflowError()

// allow only continue as new error
require.Contains(s.t, err.Error(), "continue as new")

count, err := s.sfHelper.CountRows("test_replica_identity_no_pkey")
require.NoError(s.t, err)
s.Equal(20, count)

env.AssertExpectations(s.t)
}

func (s PeerFlowE2ETestSuiteSF) Test_Invalid_Geo_SF_Avro_CDC() {
env := e2e.NewTemporalTestWorkflowEnvironment()
e2e.RegisterWorkflowsAndActivities(env, s.t)
Expand Down
Loading