Skip to content

Commit

Permalink
Merge branch 'main' into stable
Browse files Browse the repository at this point in the history
  • Loading branch information
Amogh-Bharadwaj committed Sep 30, 2024
2 parents 3a8e4a8 + 1faa86c commit de6cf0e
Show file tree
Hide file tree
Showing 41 changed files with 873 additions and 557 deletions.
174 changes: 156 additions & 18 deletions flow/activities/flowable.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,23 +130,85 @@ func (a *FlowableActivity) CreateRawTable(
return res, nil
}

// GetTableSchema returns the schema of a table.
func (a *FlowableActivity) GetTableSchema(
func (a *FlowableActivity) MigrateTableSchema(
ctx context.Context,
config *protos.GetTableSchemaBatchInput,
) (*protos.GetTableSchemaBatchOutput, error) {
flowName string,
schemas map[string]*protos.TableSchema,
) error {
logger := activity.GetLogger(ctx)
tx, err := a.CatalogPool.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
return err
}
defer shared.RollbackTx(tx, logger)

for k, v := range schemas {
processedBytes, err := proto.Marshal(v)
if err != nil {
return err
}
if _, err := tx.Exec(
ctx,
"insert into table_schema_mapping(flow_name, table_name, table_schema) values ($1, $2, $3) "+
"on conflict (flow_name, table_name) do update set table_schema = $3",
flowName,
k,
processedBytes,
); err != nil {
return err
}
}

return tx.Commit(ctx)
}

// SetupTableSchema populates table_schema_mapping
func (a *FlowableActivity) SetupTableSchema(
ctx context.Context,
config *protos.SetupTableSchemaBatchInput,
) error {
logger := activity.GetLogger(ctx)
ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowName)
srcConn, err := connectors.GetByNameAs[connectors.GetTableSchemaConnector](ctx, config.Env, a.CatalogPool, config.PeerName)
if err != nil {
return nil, fmt.Errorf("failed to get GetTableSchemaConnector: %w", err)
return fmt.Errorf("failed to get GetTableSchemaConnector: %w", err)
}
defer connectors.CloseConnector(ctx, srcConn)

heartbeatRoutine(ctx, func() string {
return "getting table schema"
})

return srcConn.GetTableSchema(ctx, config)
tableNameSchemaMapping, err := srcConn.GetTableSchema(ctx, config.Env, config.System, config.TableIdentifiers)
if err != nil {
return fmt.Errorf("failed to get GetTableSchemaConnector: %w", err)
}
processed := shared.BuildProcessedSchemaMapping(config.TableMappings, tableNameSchemaMapping, logger)

tx, err := a.CatalogPool.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
return err
}
defer shared.RollbackTx(tx, logger)

for k, v := range processed {
processedBytes, err := proto.Marshal(v)
if err != nil {
return err
}
if _, err := tx.Exec(
ctx,
"insert into table_schema_mapping(flow_name, table_name, table_schema) values ($1, $2, $3) "+
"on conflict (flow_name, table_name) do update set table_schema = $3",
config.FlowName,
k,
processedBytes,
); err != nil {
return err
}
}

return tx.Commit(ctx)
}

// CreateNormalizedTable creates normalized tables in destination.
Expand All @@ -172,21 +234,26 @@ func (a *FlowableActivity) CreateNormalizedTable(
}
defer conn.CleanupSetupNormalizedTables(ctx, tx)

tableNameSchemaMapping, err := a.getTableNameSchemaMapping(ctx, config.FlowName)
if err != nil {
return nil, err
}

numTablesSetup := atomic.Uint32{}
totalTables := uint32(len(config.TableNameSchemaMapping))
shutdown := heartbeatRoutine(ctx, func() string {
return fmt.Sprintf("setting up normalized tables - %d of %d done",
numTablesSetup.Load(), totalTables)
numTablesSetup.Load(), len(tableNameSchemaMapping))
})
defer shutdown()

tableExistsMapping := make(map[string]bool, len(config.TableNameSchemaMapping))
for tableIdentifier := range config.TableNameSchemaMapping {
tableExistsMapping := make(map[string]bool, len(tableNameSchemaMapping))
for tableIdentifier, tableSchema := range tableNameSchemaMapping {
existing, err := conn.SetupNormalizedTable(
ctx,
tx,
config,
tableIdentifier,
tableSchema,
)
if err != nil {
a.Alerter.LogFlowError(ctx, config.FlowName, err)
Expand Down Expand Up @@ -348,10 +415,15 @@ func (a *FlowableActivity) StartNormalize(
})
defer shutdown()

tableNameSchemaMapping, err := a.getTableNameSchemaMapping(ctx, input.FlowConnectionConfigs.FlowJobName)
if err != nil {
return nil, err
}

res, err := dstConn.NormalizeRecords(ctx, &model.NormalizeRecordsRequest{
FlowJobName: input.FlowConnectionConfigs.FlowJobName,
Env: input.FlowConnectionConfigs.Env,
TableNameSchemaMapping: input.TableNameSchemaMapping,
TableNameSchemaMapping: tableNameSchemaMapping,
TableMappings: input.FlowConnectionConfigs.TableMappings,
SyncBatchID: input.SyncBatchID,
SoftDeleteColName: input.FlowConnectionConfigs.SoftDeleteColName,
Expand All @@ -361,6 +433,17 @@ func (a *FlowableActivity) StartNormalize(
a.Alerter.LogFlowError(ctx, input.FlowConnectionConfigs.FlowJobName, err)
return nil, fmt.Errorf("failed to normalized records: %w", err)
}
dstType, err := connectors.LoadPeerType(ctx, a.CatalogPool, input.FlowConnectionConfigs.DestinationName)
if err != nil {
return nil, err
}
if dstType == protos.DBType_POSTGRES {
err = monitoring.UpdateEndTimeForCDCBatch(ctx, a.CatalogPool, input.FlowConnectionConfigs.FlowJobName,
input.SyncBatchID)
if err != nil {
return nil, err
}
}

// log the number of batches normalized
logger.Info(fmt.Sprintf("normalized records from batch %d to batch %d",
Expand Down Expand Up @@ -735,9 +818,7 @@ func (a *FlowableActivity) QRepHasNewRows(ctx context.Context,
return result, nil
}

func (a *FlowableActivity) RenameTables(ctx context.Context, config *protos.RenameTablesInput) (
*protos.RenameTablesOutput, error,
) {
func (a *FlowableActivity) RenameTables(ctx context.Context, config *protos.RenameTablesInput) (*protos.RenameTablesOutput, error) {
ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowJobName)
conn, err := connectors.GetByNameAs[connectors.RenameTablesConnector](ctx, nil, a.CatalogPool, config.PeerName)
if err != nil {
Expand All @@ -751,13 +832,46 @@ func (a *FlowableActivity) RenameTables(ctx context.Context, config *protos.Rena
})
defer shutdown()

renameOutput, err := conn.RenameTables(ctx, config)
tableNameSchemaMapping := make(map[string]*protos.TableSchema, len(config.RenameTableOptions))
for _, option := range config.RenameTableOptions {
schema, err := shared.LoadTableSchemaFromCatalog(
ctx,
a.CatalogPool,
config.FlowJobName,
option.CurrentName,
)
if err != nil {
return nil, fmt.Errorf("failed to load schema to rename tables: %w", err)
}
tableNameSchemaMapping[option.CurrentName] = schema
}

renameOutput, err := conn.RenameTables(ctx, config, tableNameSchemaMapping)
if err != nil {
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
return nil, fmt.Errorf("failed to rename tables: %w", err)
}

return renameOutput, nil
tx, err := a.CatalogPool.Begin(ctx)
if err != nil {
return nil, fmt.Errorf("failed to begin updating table_schema_mapping: %w", err)
}
logger := log.With(activity.GetLogger(ctx), slog.String(string(shared.FlowNameKey), config.FlowJobName))
defer shared.RollbackTx(tx, logger)

for _, option := range config.RenameTableOptions {
if _, err := tx.Exec(
ctx,
"update table_schema_mapping set table_name = $3 where flow_name = $1 and table_name = $2",
config.FlowJobName,
option.CurrentName,
option.NewName,
); err != nil {
return nil, fmt.Errorf("failed to update table_schema_mapping: %w", err)
}
}

return renameOutput, tx.Commit(ctx)
}

func (a *FlowableActivity) DeleteMirrorStats(ctx context.Context, flowName string) error {
Expand Down Expand Up @@ -833,7 +947,9 @@ func (a *FlowableActivity) AddTablesToPublication(ctx context.Context, cfg *prot
return err
}

func (a *FlowableActivity) RemoveTablesFromPublication(ctx context.Context, cfg *protos.FlowConnectionConfigs,
func (a *FlowableActivity) RemoveTablesFromPublication(
ctx context.Context,
cfg *protos.FlowConnectionConfigs,
removedTablesMapping []*protos.TableMapping,
) error {
ctx = context.WithValue(ctx, shared.FlowNameKey, cfg.FlowJobName)
Expand All @@ -854,7 +970,9 @@ func (a *FlowableActivity) RemoveTablesFromPublication(ctx context.Context, cfg
return err
}

func (a *FlowableActivity) RemoveTablesFromRawTable(ctx context.Context, cfg *protos.FlowConnectionConfigs,
func (a *FlowableActivity) RemoveTablesFromRawTable(
ctx context.Context,
cfg *protos.FlowConnectionConfigs,
tablesToRemove []*protos.TableMapping,
) error {
ctx = context.WithValue(ctx, shared.FlowNameKey, cfg.FlowJobName)
Expand Down Expand Up @@ -898,3 +1016,23 @@ func (a *FlowableActivity) RemoveTablesFromRawTable(ctx context.Context, cfg *pr
}
return err
}

func (a *FlowableActivity) RemoveTablesFromCatalog(
ctx context.Context,
cfg *protos.FlowConnectionConfigs,
tablesToRemove []*protos.TableMapping,
) error {
removedTables := make([]string, 0, len(tablesToRemove))
for _, tm := range tablesToRemove {
removedTables = append(removedTables, tm.DestinationTableIdentifier)
}

_, err := a.CatalogPool.Exec(
ctx,
"delete from table_schema_mapping where flow_name = $1 and table_name = ANY($2)",
cfg.FlowJobName,
removedTables,
)

return err
}
33 changes: 30 additions & 3 deletions flow/activities/flowable_core.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,28 @@ func waitForCdcCache[TPull connectors.CDCPullConnectorCore](ctx context.Context,
}
}

func (a *FlowableActivity) getTableNameSchemaMapping(ctx context.Context, flowName string) (map[string]*protos.TableSchema, error) {
rows, err := a.CatalogPool.Query(ctx, "select table_name, table_schema from table_schema_mapping where flow_name = $1", flowName)
if err != nil {
return nil, err
}

var tableName string
var tableSchemaBytes []byte
tableNameSchemaMapping := make(map[string]*protos.TableSchema)
if _, err := pgx.ForEachRow(rows, []any{&tableName, &tableSchemaBytes}, func() error {
tableSchema := &protos.TableSchema{}
if err := proto.Unmarshal(tableSchemaBytes, tableSchema); err != nil {
return err
}
tableNameSchemaMapping[tableName] = tableSchema
return nil
}); err != nil {
return nil, fmt.Errorf("failed to deserialize table schema proto: %w", err)
}
return tableNameSchemaMapping, nil
}

func syncCore[TPull connectors.CDCPullConnectorCore, TSync connectors.CDCSyncConnectorCore, Items model.Items](
ctx context.Context,
a *FlowableActivity,
Expand Down Expand Up @@ -141,8 +163,13 @@ func syncCore[TPull connectors.CDCPullConnectorCore, TSync connectors.CDCSyncCon
return nil, err
}
}
startTime := time.Now()

tableNameSchemaMapping, err := a.getTableNameSchemaMapping(ctx, flowName)
if err != nil {
return nil, err
}

startTime := time.Now()
errGroup, errCtx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
return pull(srcConn, errCtx, a.CatalogPool, &model.PullRecordsRequest[Items]{
Expand All @@ -155,7 +182,7 @@ func syncCore[TPull connectors.CDCPullConnectorCore, TSync connectors.CDCSyncCon
IdleTimeout: peerdbenv.PeerDBCDCIdleTimeoutSeconds(
int(options.IdleTimeoutSeconds),
),
TableNameSchemaMapping: options.TableNameSchemaMapping,
TableNameSchemaMapping: tableNameSchemaMapping,
OverridePublicationName: config.PublicationName,
OverrideReplicationSlotName: config.ReplicationSlotName,
RecordStream: recordBatchPull,
Expand Down Expand Up @@ -237,7 +264,7 @@ func syncCore[TPull connectors.CDCPullConnectorCore, TSync connectors.CDCSyncCon
TableMappings: options.TableMappings,
StagingPath: config.CdcStagingPath,
Script: config.Script,
TableNameSchemaMapping: options.TableNameSchemaMapping,
TableNameSchemaMapping: tableNameSchemaMapping,
})
if err != nil {
a.Alerter.LogFlowError(ctx, flowName, err)
Expand Down
8 changes: 8 additions & 0 deletions flow/activities/snapshot_activity.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,11 @@ func (a *SnapshotActivity) WaitForExportSnapshot(ctx context.Context, sessionID
time.Sleep(time.Second)
}
}

func (a *SnapshotActivity) LoadTableSchema(
ctx context.Context,
flowName string,
tableName string,
) (*protos.TableSchema, error) {
return shared.LoadTableSchemaFromCatalog(ctx, a.CatalogPool, flowName, tableName)
}
15 changes: 14 additions & 1 deletion flow/cmd/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,24 @@ func (h *FlowRequestHandler) removeFlowEntryInCatalog(
ctx context.Context,
flowName string,
) error {
_, err := h.pool.Exec(ctx, "DELETE FROM flows WHERE name=$1", flowName)
tx, err := h.pool.Begin(ctx)
if err != nil {
return fmt.Errorf("unable to begin tx to remove flow entry in catalog: %w", err)
}
defer shared.RollbackTx(tx, slog.Default())

if _, err := tx.Exec(ctx, "DELETE FROM table_schema_mapping WHERE flow_name=$1", flowName); err != nil {
return fmt.Errorf("unable to clear table_schema_mapping to remove flow entry in catalog: %w", err)
}

if _, err := tx.Exec(ctx, "DELETE FROM flows WHERE name=$1", flowName); err != nil {
return fmt.Errorf("unable to remove flow entry in catalog: %w", err)
}

if err := tx.Commit(ctx); err != nil {
return fmt.Errorf("unable to commit remove flow entry in catalog: %w", err)
}

return nil
}

Expand Down
7 changes: 2 additions & 5 deletions flow/cmd/validate_mirror.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,7 @@ func (h *FlowRequestHandler) ValidateCDCMirror(
}
defer chPeer.Close()

res, err := pgPeer.GetTableSchema(ctx, &protos.GetTableSchemaBatchInput{
TableIdentifiers: srcTableNames,
System: protos.TypeSystem_PG,
})
res, err := pgPeer.GetTableSchema(ctx, nil, protos.TypeSystem_PG, srcTableNames)
if err != nil {
displayErr := fmt.Errorf("failed to get source table schema: %v", err)
h.alerter.LogNonFlowWarning(ctx, telemetry.CreateMirror, req.ConnectionConfigs.FlowJobName,
Expand All @@ -204,7 +201,7 @@ func (h *FlowRequestHandler) ValidateCDCMirror(
}, displayErr
}

err = chPeer.CheckDestinationTables(ctx, req.ConnectionConfigs, res.TableNameSchemaMapping)
err = chPeer.CheckDestinationTables(ctx, req.ConnectionConfigs, res)
if err != nil {
h.alerter.LogNonFlowWarning(ctx, telemetry.CreateMirror, req.ConnectionConfigs.FlowJobName,
fmt.Sprint(err),
Expand Down
Loading

0 comments on commit de6cf0e

Please sign in to comment.