diff --git a/flow/activities/flowable.go b/flow/activities/flowable.go index 3d8e5083f9..9bb4f06fc2 100644 --- a/flow/activities/flowable.go +++ b/flow/activities/flowable.go @@ -489,7 +489,17 @@ func (a *FlowableActivity) GetQRepPartitions(ctx context.Context, return "getting partitions for job" }) defer shutdown() - partitions, err := srcConn.GetQRepPartitions(ctx, config, last) + + snapshotName := "" + if config.ParentMirrorName != "" { + _, snapshotName, _, err = shared.LoadSnapshotNameFromCatalog(ctx, a.CatalogPool, config.ParentMirrorName) + if err != nil { + a.Alerter.LogFlowError(ctx, "[GetQRepPartitions] "+config.FlowJobName, err) + return nil, fmt.Errorf("[GetQRepPartitions] failed to LoadSnapshotNameFromCatalog: %w", err) + } + } + + partitions, err := srcConn.GetQRepPartitions(ctx, config, last, snapshotName) if err != nil { a.Alerter.LogFlowError(ctx, config.FlowJobName, err) return nil, fmt.Errorf("failed to get partitions from source: %w", err) diff --git a/flow/activities/flowable_core.go b/flow/activities/flowable_core.go index bb75daa6da..2e6842b5b1 100644 --- a/flow/activities/flowable_core.go +++ b/flow/activities/flowable_core.go @@ -381,8 +381,10 @@ func replicateQRepPartition[TRead any, TWrite any, TSync connectors.QRepSyncConn outstream TRead, pullRecords func( TPull, - context.Context, *protos.QRepConfig, + context.Context, + *protos.QRepConfig, *protos.QRepPartition, + string, TWrite, ) (int, error), syncRecords func(TSync, context.Context, *protos.QRepConfig, *protos.QRepPartition, TRead) (int, error), @@ -422,6 +424,15 @@ func replicateQRepPartition[TRead any, TWrite any, TSync connectors.QRepSyncConn var rowsSynced int errGroup, errCtx := errgroup.WithContext(ctx) errGroup.Go(func() error { + snapshotName := "" + if config.ParentMirrorName != "" { + _, snapshotName, _, err = shared.LoadSnapshotNameFromCatalog(ctx, a.CatalogPool, config.ParentMirrorName) + if err != nil { + a.Alerter.LogFlowError(ctx, "[replicateQRepPartition] "+config.FlowJobName, err) + return fmt.Errorf("[replicateQRepPartition] failed to LoadSnapshotNameFromCatalog: %w", err) + } + } + srcConn, err := connectors.GetByNameAs[TPull](ctx, config.Env, a.CatalogPool, config.SourceName) if err != nil { a.Alerter.LogFlowError(ctx, config.FlowJobName, err) @@ -429,7 +440,7 @@ func replicateQRepPartition[TRead any, TWrite any, TSync connectors.QRepSyncConn } defer connectors.CloseConnector(ctx, srcConn) - tmp, err := pullRecords(srcConn, errCtx, config, partition, stream) + tmp, err := pullRecords(srcConn, errCtx, config, partition, snapshotName, stream) if err != nil { a.Alerter.LogFlowError(ctx, config.FlowJobName, err) return fmt.Errorf("failed to pull records: %w", err) @@ -479,8 +490,10 @@ func replicateXminPartition[TRead any, TWrite any, TSync connectors.QRepSyncConn outstream TRead, pullRecords func( *connpostgres.PostgresConnector, - context.Context, *protos.QRepConfig, + context.Context, + *protos.QRepConfig, *protos.QRepPartition, + string, TWrite, ) (int, int64, error), syncRecords func(TSync, context.Context, *protos.QRepConfig, *protos.QRepPartition, TRead) (int, error), @@ -501,6 +514,16 @@ func replicateXminPartition[TRead any, TWrite any, TSync connectors.QRepSyncConn var currentSnapshotXmin int64 var rowsSynced int errGroup.Go(func() error { + snapshotName := "" + if config.ParentMirrorName != "" { + var err error + _, snapshotName, _, err = shared.LoadSnapshotNameFromCatalog(ctx, a.CatalogPool, config.ParentMirrorName) + if err != nil { + a.Alerter.LogFlowError(ctx, "[replicateXminPartition] "+config.FlowJobName, err) + return fmt.Errorf("[replicateXminPartition] failed to LoadSnapshotNameFromCatalog: %w", err) + } + } + srcConn, err := connectors.GetByNameAs[*connpostgres.PostgresConnector](ctx, config.Env, a.CatalogPool, config.SourceName) if err != nil { return fmt.Errorf("failed to get qrep source connector: %w", err) @@ -509,10 +532,10 @@ func replicateXminPartition[TRead any, TWrite any, TSync connectors.QRepSyncConn var pullErr error var numRecords int - numRecords, currentSnapshotXmin, pullErr = pullRecords(srcConn, ctx, config, partition, stream) + numRecords, currentSnapshotXmin, pullErr = pullRecords(srcConn, ctx, config, partition, snapshotName, stream) if pullErr != nil { a.Alerter.LogFlowError(ctx, config.FlowJobName, pullErr) - logger.Warn(fmt.Sprintf("[xmin] failed to pull recordS: %v", pullErr)) + logger.Warn(fmt.Sprintf("[xmin] failed to pull records: %v", pullErr)) return pullErr } diff --git a/flow/activities/snapshot_activity.go b/flow/activities/snapshot_activity.go index 01c9e748e6..a789c3e9ae 100644 --- a/flow/activities/snapshot_activity.go +++ b/flow/activities/snapshot_activity.go @@ -8,6 +8,7 @@ import ( "sync" "time" + "github.com/jackc/pgerrcode" "github.com/jackc/pgx/v5/pgxpool" "go.temporal.io/sdk/activity" @@ -19,21 +20,14 @@ import ( ) type SlotSnapshotState struct { - connector connectors.CDCPullConnector - signal connpostgres.SlotSignal - snapshotName string -} - -type TxSnapshotState struct { - SnapshotName string - SupportsTIDScans bool + connector connectors.CDCPullConnector + signal connpostgres.SlotSignal } type SnapshotActivity struct { Alerter *alerting.Alerter CatalogPool *pgxpool.Pool SlotSnapshotStates map[string]SlotSnapshotState - TxSnapshotStates map[string]TxSnapshotState SnapshotStatesMutex sync.Mutex } @@ -55,7 +49,7 @@ func (a *SnapshotActivity) CloseSlotKeepAlive(ctx context.Context, flowJobName s func (a *SnapshotActivity) SetupReplication( ctx context.Context, config *protos.SetupReplicationInput, -) (*protos.SetupReplicationOutput, error) { +) error { ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowJobName) logger := activity.GetLogger(ctx) @@ -65,9 +59,9 @@ func (a *SnapshotActivity) SetupReplication( if err != nil { if errors.Is(err, errors.ErrUnsupported) { logger.Info("setup replication is no-op for non-postgres source") - return nil, nil + return nil } - return nil, fmt.Errorf("failed to get connector: %w", err) + return fmt.Errorf("failed to get connector: %w", err) } closeConnectionForError := func(err error) { @@ -84,46 +78,64 @@ func (a *SnapshotActivity) SetupReplication( if slotInfo.Err != nil { closeConnectionForError(slotInfo.Err) - return nil, fmt.Errorf("slot error: %w", slotInfo.Err) + return fmt.Errorf("slot error: %w", slotInfo.Err) } else { logger.Info("slot created", slog.String("SlotName", slotInfo.SlotName)) } - a.SnapshotStatesMutex.Lock() - defer a.SnapshotStatesMutex.Unlock() + for { + var slotName string + if err := a.CatalogPool.QueryRow( + ctx, + "select slot_name from snapshot_names where flow_name = $1", + config.FlowJobName, + ).Scan(&slotName); err == nil && slotName != "" { + if err := conn.ExecuteCommand( + ctx, + "select pg_drop_replication_slot($1)", + slotName, + ); err != nil && !shared.IsSQLStateError(err, pgerrcode.UndefinedObject) { + if shared.IsSQLStateError(err, pgerrcode.ObjectInUse) { + a.Alerter.LogFlowError(ctx, "[SetupReplication] "+config.FlowJobName, err) + time.Sleep(time.Second * 15) + continue + } + return fmt.Errorf("failed to drop slot from previous run: %w", err) + } + } + break + } + + if _, err := a.CatalogPool.Exec(ctx, + `insert into snapshot_names (flow_name, slot_name, snapshot_name, supports_tid_scan) values ($1, $2, $3, $4) + on conflict (flow_name) do update set slot_name = $2, snapshot_name = $3, supports_tid_scan = $4`, + config.FlowJobName, slotInfo.SlotName, slotInfo.SnapshotName, slotInfo.SupportsTIDScans, + ); err != nil { + return err + } + a.SnapshotStatesMutex.Lock() a.SlotSnapshotStates[config.FlowJobName] = SlotSnapshotState{ - signal: slotSignal, - snapshotName: slotInfo.SnapshotName, - connector: conn, + signal: slotSignal, + connector: conn, } + a.SnapshotStatesMutex.Unlock() - return &protos.SetupReplicationOutput{ - SlotName: slotInfo.SlotName, - SnapshotName: slotInfo.SnapshotName, - SupportsTidScans: slotInfo.SupportsTIDScans, - }, nil + return nil } -func (a *SnapshotActivity) MaintainTx(ctx context.Context, sessionID string, peer string) error { +func (a *SnapshotActivity) MaintainTx(ctx context.Context, sessionID string, flowJobName string, peer string) error { conn, err := connectors.GetByNameAs[connectors.CDCPullConnector](ctx, nil, a.CatalogPool, peer) if err != nil { return err } defer connectors.CloseConnector(ctx, conn) - exportSnapshotOutput, tx, err := conn.ExportTxSnapshot(ctx) + tx, err := conn.ExportTxSnapshot(ctx, flowJobName) if err != nil { return err } - a.SnapshotStatesMutex.Lock() - a.TxSnapshotStates[sessionID] = TxSnapshotState{ - SnapshotName: exportSnapshotOutput.SnapshotName, - SupportsTIDScans: exportSnapshotOutput.SupportsTidScans, - } - a.SnapshotStatesMutex.Unlock() - logger := activity.GetLogger(ctx) start := time.Now() for { @@ -131,38 +143,24 @@ func (a *SnapshotActivity) MaintainTx(ctx context.Context, sessionID string, pee logger.Info(msg) // this function relies on context cancellation to exit // context is not explicitly cancelled, but workflow exit triggers an implicit cancel - // from activity.RecordBeat + // from activity.RecordHeartBeat activity.RecordHeartbeat(ctx, msg) if ctx.Err() != nil { - a.SnapshotStatesMutex.Lock() - delete(a.TxSnapshotStates, sessionID) - a.SnapshotStatesMutex.Unlock() return conn.FinishExport(tx) } time.Sleep(time.Minute) } } -func (a *SnapshotActivity) WaitForExportSnapshot(ctx context.Context, sessionID string) (*TxSnapshotState, error) { - logger := activity.GetLogger(ctx) - attempt := 0 - for { - a.SnapshotStatesMutex.Lock() - tsc, ok := a.TxSnapshotStates[sessionID] - a.SnapshotStatesMutex.Unlock() - if ok { - return &tsc, nil - } - activity.RecordHeartbeat(ctx, "wait another second for snapshot export") - attempt += 1 - if attempt > 2 { - logger.Info("waiting on snapshot export", slog.Int("attempt", attempt)) - } - if err := ctx.Err(); err != nil { - return nil, err - } - time.Sleep(time.Second) +func (a *SnapshotActivity) LoadSupportsTidScan( + ctx context.Context, + flowJobName string, +) (bool, error) { + _, _, supportsTidScan, err := shared.LoadSnapshotNameFromCatalog(ctx, a.CatalogPool, flowJobName) + if err != nil { + a.Alerter.LogFlowError(ctx, "[LoadSupportsTidScan] "+flowJobName, err) } + return supportsTidScan, err } func (a *SnapshotActivity) LoadTableSchema( diff --git a/flow/cmd/snapshot_worker.go b/flow/cmd/snapshot_worker.go index f32a39cda8..4cd38aacd7 100644 --- a/flow/cmd/snapshot_worker.go +++ b/flow/cmd/snapshot_worker.go @@ -68,7 +68,6 @@ func SnapshotWorkerMain(opts *SnapshotWorkerOptions) (client.Client, worker.Work // explicitly not initializing mutex, in line with design w.RegisterActivity(&activities.SnapshotActivity{ SlotSnapshotStates: make(map[string]activities.SlotSnapshotState), - TxSnapshotStates: make(map[string]activities.TxSnapshotState), Alerter: alerting.NewAlerter(context.Background(), conn), CatalogPool: conn, }) diff --git a/flow/connectors/core.go b/flow/connectors/core.go index 75d52506b5..9bc6c5c9d1 100644 --- a/flow/connectors/core.go +++ b/flow/connectors/core.go @@ -62,7 +62,7 @@ type CDCPullConnectorCore interface { // For InitialSnapshotOnly correctness without replication slot // `any` is for returning transaction if necessary - ExportTxSnapshot(context.Context) (*protos.ExportTxSnapshotOutput, any, error) + ExportTxSnapshot(ctx context.Context, flowJobName string) (any, error) // `any` from ExportSnapshot passed here when done, allowing transaction to commit FinishExport(any) error @@ -201,21 +201,22 @@ type QRepPullConnectorCore interface { Connector // GetQRepPartitions returns the partitions for a given table that haven't been synced yet. - GetQRepPartitions(ctx context.Context, config *protos.QRepConfig, last *protos.QRepPartition) ([]*protos.QRepPartition, error) + GetQRepPartitions(ctx context.Context, config *protos.QRepConfig, last *protos.QRepPartition, + snapshotName string) ([]*protos.QRepPartition, error) } type QRepPullConnector interface { QRepPullConnectorCore // PullQRepRecords returns the records for a given partition. - PullQRepRecords(context.Context, *protos.QRepConfig, *protos.QRepPartition, *model.QRecordStream) (int, error) + PullQRepRecords(context.Context, *protos.QRepConfig, *protos.QRepPartition, string, *model.QRecordStream) (int, error) } type QRepPullPgConnector interface { QRepPullConnectorCore // PullPgQRepRecords returns the records for a given partition. - PullPgQRepRecords(context.Context, *protos.QRepConfig, *protos.QRepPartition, connpostgres.PgCopyWriter) (int, error) + PullPgQRepRecords(context.Context, *protos.QRepConfig, *protos.QRepPartition, string, connpostgres.PgCopyWriter) (int, error) } type QRepSyncConnectorCore interface { diff --git a/flow/connectors/postgres/client.go b/flow/connectors/postgres/client.go index 0d163f173f..c164b5a62e 100644 --- a/flow/connectors/postgres/client.go +++ b/flow/connectors/postgres/client.go @@ -668,8 +668,8 @@ func (c *PostgresConnector) checkIfTableExistsWithTx( return result.Bool, nil } -func (c *PostgresConnector) ExecuteCommand(ctx context.Context, command string) error { - _, err := c.conn.Exec(ctx, command) +func (c *PostgresConnector) ExecuteCommand(ctx context.Context, command string, args ...any) error { + _, err := c.conn.Exec(ctx, command, args...) return err } diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index 45730a87de..c150c939d1 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -1017,11 +1017,11 @@ func (c *PostgresConnector) EnsurePullability( return &protos.EnsurePullabilityBatchOutput{TableIdentifierMapping: tableIdentifierMapping}, nil } -func (c *PostgresConnector) ExportTxSnapshot(ctx context.Context) (*protos.ExportTxSnapshotOutput, any, error) { +func (c *PostgresConnector) ExportTxSnapshot(ctx context.Context, flowJobName string) (any, error) { var snapshotName string tx, err := c.conn.Begin(ctx) if err != nil { - return nil, nil, err + return nil, err } txNeedsRollback := true defer func() { @@ -1037,32 +1037,44 @@ func (c *PostgresConnector) ExportTxSnapshot(ctx context.Context) (*protos.Expor _, err = tx.Exec(ctx, "SET LOCAL idle_in_transaction_session_timeout=0") if err != nil { - return nil, nil, fmt.Errorf("[export-snapshot] error setting idle_in_transaction_session_timeout: %w", err) + return nil, fmt.Errorf("[export-snapshot] error setting idle_in_transaction_session_timeout: %w", err) } _, err = tx.Exec(ctx, "SET LOCAL lock_timeout=0") if err != nil { - return nil, nil, fmt.Errorf("[export-snapshot] error setting lock_timeout: %w", err) + return nil, fmt.Errorf("[export-snapshot] error setting lock_timeout: %w", err) } pgversion, err := c.MajorVersion(ctx) if err != nil { - return nil, nil, fmt.Errorf("[export-snapshot] error getting PG version: %w", err) + return nil, fmt.Errorf("[export-snapshot] error getting PG version: %w", err) } err = tx.QueryRow(ctx, "SELECT pg_export_snapshot()").Scan(&snapshotName) if err != nil { - return nil, nil, err + return nil, err } - txNeedsRollback = false - return &protos.ExportTxSnapshotOutput{ - SnapshotName: snapshotName, - SupportsTidScans: pgversion >= shared.POSTGRES_13, - }, tx, err + pool, err := peerdbenv.GetCatalogConnectionPoolFromEnv(ctx) + if err != nil { + return nil, err + } + + if _, err := pool.Exec(ctx, + `insert into snapshot_names (flow_name, slot_name, snapshot_name, supports_tid_scan) values ($1, '', $2, $3) + on conflict (flow_name) do update set slot_name = '', snapshot_name = $2, supports_tid_scan = $3`, + flowJobName, snapshotName, pgversion >= shared.POSTGRES_13, + ); err != nil { + return nil, err + } + + txNeedsRollback = false + return tx, nil } func (c *PostgresConnector) FinishExport(tx any) error { + // could delete snapshot_names row here, + // but has racy potential if FinishExport is called after row overwritten by retry pgtx := tx.(pgx.Tx) timeout, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() diff --git a/flow/connectors/postgres/qrep.go b/flow/connectors/postgres/qrep.go index 48864b0c38..a5361cee34 100644 --- a/flow/connectors/postgres/qrep.go +++ b/flow/connectors/postgres/qrep.go @@ -41,6 +41,7 @@ func (c *PostgresConnector) GetQRepPartitions( ctx context.Context, config *protos.QRepConfig, last *protos.QRepPartition, + snapshotName string, ) ([]*protos.QRepPartition, error) { if config.WatermarkColumn == "" { // if no watermark column is specified, return a single partition @@ -62,7 +63,7 @@ func (c *PostgresConnector) GetQRepPartitions( } defer shared.RollbackTx(getPartitionsTx, c.logger) - if err := c.setTransactionSnapshot(ctx, getPartitionsTx, config.SnapshotName); err != nil { + if err := c.setTransactionSnapshot(ctx, getPartitionsTx, snapshotName); err != nil { return nil, fmt.Errorf("failed to set transaction snapshot: %w", err) } @@ -310,9 +311,10 @@ func (c *PostgresConnector) PullQRepRecords( ctx context.Context, config *protos.QRepConfig, partition *protos.QRepPartition, + snapshotName string, stream *model.QRecordStream, ) (int, error) { - return corePullQRepRecords(c, ctx, config, partition, &RecordStreamSink{ + return corePullQRepRecords(c, ctx, config, partition, snapshotName, &RecordStreamSink{ QRecordStream: stream, }) } @@ -321,9 +323,10 @@ func (c *PostgresConnector) PullPgQRepRecords( ctx context.Context, config *protos.QRepConfig, partition *protos.QRepPartition, + snapshotName string, stream PgCopyWriter, ) (int, error) { - return corePullQRepRecords(c, ctx, config, partition, stream) + return corePullQRepRecords(c, ctx, config, partition, snapshotName, stream) } func corePullQRepRecords( @@ -331,12 +334,13 @@ func corePullQRepRecords( ctx context.Context, config *protos.QRepConfig, partition *protos.QRepPartition, + snapshotName string, sink QRepPullSink, ) (int, error) { partitionIdLog := slog.String(string(shared.PartitionIDKey), partition.PartitionId) if partition.FullTablePartition { c.logger.Info("pulling full table partition", partitionIdLog) - executor := c.NewQRepQueryExecutorSnapshot(config.SnapshotName, config.FlowJobName, partition.PartitionId) + executor := c.NewQRepQueryExecutorSnapshot(snapshotName, config.FlowJobName, partition.PartitionId) _, err := executor.ExecuteQueryIntoSink(ctx, sink, config.Query) return 0, err } @@ -375,7 +379,7 @@ func corePullQRepRecords( return 0, err } - executor := c.NewQRepQueryExecutorSnapshot(config.SnapshotName, config.FlowJobName, partition.PartitionId) + executor := c.NewQRepQueryExecutorSnapshot(snapshotName, config.FlowJobName, partition.PartitionId) numRecords, err := executor.ExecuteQueryIntoSink(ctx, sink, query, rangeStart, rangeEnd) if err != nil { @@ -645,9 +649,10 @@ func (c *PostgresConnector) PullXminRecordStream( ctx context.Context, config *protos.QRepConfig, partition *protos.QRepPartition, + snapshotName string, stream *model.QRecordStream, ) (int, int64, error) { - return pullXminRecordStream(c, ctx, config, partition, RecordStreamSink{ + return pullXminRecordStream(c, ctx, config, partition, snapshotName, RecordStreamSink{ QRecordStream: stream, }) } @@ -656,9 +661,10 @@ func (c *PostgresConnector) PullXminPgRecordStream( ctx context.Context, config *protos.QRepConfig, partition *protos.QRepPartition, + snapshotName string, pipe PgCopyWriter, ) (int, int64, error) { - return pullXminRecordStream(c, ctx, config, partition, pipe) + return pullXminRecordStream(c, ctx, config, partition, snapshotName, pipe) } func pullXminRecordStream( @@ -666,6 +672,7 @@ func pullXminRecordStream( ctx context.Context, config *protos.QRepConfig, partition *protos.QRepPartition, + snapshotName string, sink QRepPullSink, ) (int, int64, error) { query := config.Query @@ -675,7 +682,7 @@ func pullXminRecordStream( queryArgs = []interface{}{strconv.FormatInt(partition.Range.Range.(*protos.PartitionRange_IntRange).IntRange.Start&0xffffffff, 10)} } - executor := c.NewQRepQueryExecutorSnapshot(config.SnapshotName, config.FlowJobName, partition.PartitionId) + executor := c.NewQRepQueryExecutorSnapshot(snapshotName, config.FlowJobName, partition.PartitionId) numRecords, currentSnapshotXmin, err := executor.ExecuteQueryIntoSinkGettingCurrentSnapshotXmin( ctx, diff --git a/flow/connectors/postgres/qrep_partition_test.go b/flow/connectors/postgres/qrep_partition_test.go index 0249b75fc1..d4b6c94344 100644 --- a/flow/connectors/postgres/qrep_partition_test.go +++ b/flow/connectors/postgres/qrep_partition_test.go @@ -176,7 +176,7 @@ func TestGetQRepPartitions(t *testing.T) { logger: log.NewStructuredLogger(slog.With(slog.String(string(shared.FlowNameKey), "testGetQRepPartitions"))), } - got, err := c.GetQRepPartitions(context.Background(), tc.config, tc.last) + got, err := c.GetQRepPartitions(context.Background(), tc.config, tc.last, "") if (err != nil) != tc.wantErr { t.Fatalf("GetQRepPartitions() error = %v, wantErr %v", err, tc.wantErr) } diff --git a/flow/connectors/sqlserver/qrep.go b/flow/connectors/sqlserver/qrep.go index 2dcc7ac1c5..7eeffca6a0 100644 --- a/flow/connectors/sqlserver/qrep.go +++ b/flow/connectors/sqlserver/qrep.go @@ -18,7 +18,7 @@ import ( ) func (c *SQLServerConnector) GetQRepPartitions( - ctx context.Context, config *protos.QRepConfig, last *protos.QRepPartition, + ctx context.Context, config *protos.QRepConfig, last *protos.QRepPartition, snapshotName string, ) ([]*protos.QRepPartition, error) { if config.WatermarkTable == "" { c.logger.Info("watermark table is empty, doing full table refresh") @@ -162,6 +162,7 @@ func (c *SQLServerConnector) PullQRepRecords( ctx context.Context, config *protos.QRepConfig, partition *protos.QRepPartition, + snapshotName string, stream *model.QRecordStream, ) (int, error) { // Build the query to pull records within the range from the source table diff --git a/flow/shared/postgres.go b/flow/shared/postgres.go index be3cf7d07d..c957e921a5 100644 --- a/flow/shared/postgres.go +++ b/flow/shared/postgres.go @@ -127,6 +127,30 @@ func UpdateCDCConfigInCatalog(ctx context.Context, pool *pgxpool.Pool, return nil } +func LoadSnapshotNameFromCatalog( + ctx context.Context, + pool *pgxpool.Pool, + flowName string, +) (string, string, bool, error) { + for { + var slotName string + var snapshotName string + var supportsTidScan bool + if err := pool.QueryRow( + ctx, + "select slot_name, snapshot_name, supports_tid_scan from snapshot_names where flow_name = $1", + flowName, + ).Scan(&slotName, &snapshotName, &supportsTidScan); err != nil { + if err == pgx.ErrNoRows { + time.Sleep(time.Second) + continue + } + return "", "", false, err + } + return slotName, snapshotName, supportsTidScan, nil + } +} + func LoadTableSchemaFromCatalog( ctx context.Context, pool *pgxpool.Pool, diff --git a/flow/workflows/snapshot_flow.go b/flow/workflows/snapshot_flow.go index c8b6a3fd29..0dd1032b95 100644 --- a/flow/workflows/snapshot_flow.go +++ b/flow/workflows/snapshot_flow.go @@ -11,7 +11,6 @@ import ( "go.temporal.io/sdk/temporal" "go.temporal.io/sdk/workflow" - "github.com/PeerDB-io/peer-flow/activities" connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres" "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/generated/protos" @@ -19,23 +18,13 @@ import ( "github.com/PeerDB-io/peer-flow/shared" ) -type snapshotType int8 - -const ( - SNAPSHOT_TYPE_UNKNOWN snapshotType = iota - SNAPSHOT_TYPE_SLOT - SNAPSHOT_TYPE_TX -) - type SnapshotFlowExecution struct { config *protos.FlowConnectionConfigs logger log.Logger } // ensurePullability ensures that the source peer is pullable. -func (s *SnapshotFlowExecution) setupReplication( - ctx workflow.Context, -) (*protos.SetupReplicationOutput, error) { +func (s *SnapshotFlowExecution) setupReplication(ctx workflow.Context) error { flowName := s.config.FlowJobName s.logger.Info("setting up replication on source for peer flow") @@ -60,20 +49,17 @@ func (s *SnapshotFlowExecution) setupReplication( ExistingReplicationSlotName: s.config.ReplicationSlotName, } - res := &protos.SetupReplicationOutput{} setupReplicationFuture := workflow.ExecuteActivity(ctx, snapshot.SetupReplication, setupReplicationInput) - if err := setupReplicationFuture.Get(ctx, &res); err != nil { - return nil, fmt.Errorf("failed to setup replication on source peer: %w", err) + if err := setupReplicationFuture.Get(ctx, nil); err != nil { + return fmt.Errorf("failed to setup replication on source peer: %w", err) } s.logger.Info("replication slot live on source for peer flow") - return res, nil + return nil } -func (s *SnapshotFlowExecution) closeSlotKeepAlive( - ctx workflow.Context, -) error { +func (s *SnapshotFlowExecution) closeSlotKeepAlive(ctx workflow.Context) error { s.logger.Info("closing slot keep alive for peer flow") ctx = workflow.WithActivityOptions(ctx, workflow.ActivityOptions{ @@ -92,13 +78,10 @@ func (s *SnapshotFlowExecution) closeSlotKeepAlive( func (s *SnapshotFlowExecution) cloneTable( ctx workflow.Context, boundSelector *shared.BoundSelector, - snapshotName string, mapping *protos.TableMapping, ) error { flowName := s.config.FlowJobName - cloneLog := slog.Group("clone-log", - slog.String(string(shared.FlowNameKey), flowName), - slog.String("snapshotName", snapshotName)) + cloneLog := slog.String(string(shared.FlowNameKey), flowName) srcName := mapping.SourceTableIdentifier dstName := mapping.DestinationTableIdentifier @@ -198,7 +181,6 @@ func (s *SnapshotFlowExecution) cloneTable( WatermarkColumn: mapping.PartitionKey, WatermarkTable: srcName, InitialCopyOnly: true, - SnapshotName: snapshotName, DestinationTableIdentifier: dstName, NumRowsPerPartition: numRowsPerPartition, MaxParallelWorkers: numWorkers, @@ -215,24 +197,22 @@ func (s *SnapshotFlowExecution) cloneTable( return nil } -func (s *SnapshotFlowExecution) cloneTables( - ctx workflow.Context, - snapshotType snapshotType, - slotName string, - snapshotName string, - supportsTIDScans bool, - maxParallelClones int, -) error { - if snapshotType == SNAPSHOT_TYPE_SLOT { - s.logger.Info(fmt.Sprintf("cloning tables for slot name %s and snapshotName %s", - slotName, snapshotName)) - } else if snapshotType == SNAPSHOT_TYPE_TX { - s.logger.Info("cloning tables in txn snapshot mode with snapshotName " + - snapshotName) - } - +func (s *SnapshotFlowExecution) cloneTables(ctx workflow.Context, maxParallelClones int) error { boundSelector := shared.NewBoundSelector(ctx, "CloneTablesSelector", maxParallelClones) + supportsCtx := workflow.WithActivityOptions(ctx, workflow.ActivityOptions{ + StartToCloseTimeout: 5 * time.Minute, + }) + supportsFuture := workflow.ExecuteActivity( + supportsCtx, + snapshot.LoadSupportsTidScan, + s.config.FlowJobName, + ) + var supportsTIDScans bool + if err := supportsFuture.Get(supportsCtx, &supportsTIDScans); err != nil { + return err + } + defaultPartitionCol := "ctid" if !supportsTIDScans { s.logger.Info("Postgres version too old for TID scans, might use full table partitions!") @@ -244,13 +224,11 @@ func (s *SnapshotFlowExecution) cloneTables( destination := v.DestinationTableIdentifier s.logger.Info( fmt.Sprintf("Cloning table with source table %s and destination table name %s", source, destination), - slog.String("snapshotName", snapshotName), ) if v.PartitionKey == "" { v.PartitionKey = defaultPartitionCol } - err := s.cloneTable(ctx, boundSelector, snapshotName, v) - if err != nil { + if err := s.cloneTable(ctx, boundSelector, v); err != nil { s.logger.Error("failed to start clone child workflow", slog.Any("error", err)) continue } @@ -270,19 +248,12 @@ func (s *SnapshotFlowExecution) cloneTablesWithSlot( sessionCtx workflow.Context, numTablesInParallel int, ) error { - slotInfo, err := s.setupReplication(sessionCtx) - if err != nil { + if err := s.setupReplication(sessionCtx); err != nil { return fmt.Errorf("failed to setup replication: %w", err) } s.logger.Info(fmt.Sprintf("cloning %d tables in parallel", numTablesInParallel)) - if err := s.cloneTables(ctx, - SNAPSHOT_TYPE_SLOT, - slotInfo.SlotName, - slotInfo.SnapshotName, - slotInfo.SupportsTidScans, - numTablesInParallel, - ); err != nil { + if err := s.cloneTables(ctx, numTablesInParallel); err != nil { return fmt.Errorf("failed to clone tables: %w", err) } @@ -307,8 +278,7 @@ func SnapshotFlowWorkflow( numTablesInParallel := int(max(config.SnapshotNumTablesInParallel, 1)) if !config.DoInitialSnapshot { - _, err := se.setupReplication(ctx) - if err != nil { + if err := se.setupReplication(ctx); err != nil { return fmt.Errorf("failed to setup replication: %w", err) } @@ -344,43 +314,24 @@ func SnapshotFlowWorkflow( exportCtx, snapshot.MaintainTx, sessionInfo.SessionID, + config.FlowJobName, config.SourceName, ) - fExportSnapshot := workflow.ExecuteActivity( - exportCtx, - snapshot.WaitForExportSnapshot, - sessionInfo.SessionID, - ) - var sessionError error - var txnSnapshotState *activities.TxSnapshotState - sessionSelector := workflow.NewNamedSelector(ctx, "ExportSnapshotSetup") - sessionSelector.AddFuture(fMaintain, func(f workflow.Future) { + cancelCtx, cancel := workflow.WithCancel(ctx) + workflow.GoNamed(ctx, "ExportSnapshotGoroutine", func(ctx workflow.Context) { // MaintainTx should never exit without an error before this point - sessionError = f.Get(exportCtx, nil) - }) - sessionSelector.AddFuture(fExportSnapshot, func(f workflow.Future) { - // Happy path is waiting for this to return without error - sessionError = f.Get(exportCtx, &txnSnapshotState) + sessionError = fMaintain.Get(ctx, nil) + cancel() }) - sessionSelector.AddReceive(ctx.Done(), func(_ workflow.ReceiveChannel, _ bool) { - sessionError = ctx.Err() - }) - sessionSelector.Select(ctx) - if sessionError != nil { - return sessionError - } - if err := se.cloneTables(ctx, - SNAPSHOT_TYPE_TX, - "", - txnSnapshotState.SnapshotName, - txnSnapshotState.SupportsTIDScans, - numTablesInParallel, - ); err != nil { + if err := se.cloneTables(cancelCtx, numTablesInParallel); err != nil { return fmt.Errorf("failed to clone tables: %w", err) } + if sessionError != nil { + return sessionError + } } else if err := se.cloneTablesWithSlot(ctx, sessionCtx, numTablesInParallel); err != nil { return fmt.Errorf("failed to clone slots and create replication slot: %w", err) } diff --git a/nexus/catalog/migrations/V40__snapshot_names.sql b/nexus/catalog/migrations/V40__snapshot_names.sql new file mode 100644 index 0000000000..041f1eedd9 --- /dev/null +++ b/nexus/catalog/migrations/V40__snapshot_names.sql @@ -0,0 +1,7 @@ +CREATE TABLE IF NOT EXISTS snapshot_names ( + flow_name varchar(255) primary key, + slot_name text not null, + snapshot_name text not null, + supports_tid_scan bool not null +); + diff --git a/protos/flow.proto b/protos/flow.proto index 7e24cfc528..4d859d0cfd 100644 --- a/protos/flow.proto +++ b/protos/flow.proto @@ -149,12 +149,6 @@ message SetupReplicationInput { string destination_name = 9; } -message SetupReplicationOutput { - string slot_name = 1; - string snapshot_name = 2; - bool supports_tid_scans = 3; -} - message CreateRawTableInput { string flow_job_name = 2; map table_name_mapping = 3; @@ -310,7 +304,6 @@ message QRepConfig { string source_name = 20; string destination_name = 21; - string snapshot_name = 23; map env = 24; @@ -428,11 +421,6 @@ message IsQRepPartitionSyncedInput { string partition_id = 2; } -message ExportTxSnapshotOutput { - string snapshot_name = 1; - bool supports_tid_scans = 2; -} - enum DynconfValueType { UNKNOWN = 0; STRING = 1; diff --git a/ui/app/mirrors/create/helpers/common.ts b/ui/app/mirrors/create/helpers/common.ts index d4ba5747ad..af4653a763 100644 --- a/ui/app/mirrors/create/helpers/common.ts +++ b/ui/app/mirrors/create/helpers/common.ts @@ -62,7 +62,6 @@ export const blankQRepSetting: QRepConfig = { numRowsPerPartition: 100000, setupWatermarkTableOnDestination: false, dstTableFullResync: false, - snapshotName: '', softDeleteColName: '_PEERDB_IS_DELETED', syncedAtColName: '', script: '',