From 0c551fb09136797d63e10cbccac67f4c4d857c40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Thu, 19 Dec 2024 16:40:22 +0000 Subject: [PATCH] use errgroup over a mess of channels --- flow/activities/flowable.go | 63 ++++++++++++++++++-------------- flow/activities/flowable_core.go | 2 - 2 files changed, 35 insertions(+), 30 deletions(-) diff --git a/flow/activities/flowable.go b/flow/activities/flowable.go index 1203452d67..ea3be2bce8 100644 --- a/flow/activities/flowable.go +++ b/flow/activities/flowable.go @@ -15,6 +15,7 @@ import ( "go.opentelemetry.io/otel/metric" "go.temporal.io/sdk/activity" "go.temporal.io/sdk/log" + "golang.org/x/sync/errgroup" "google.golang.org/protobuf/proto" "github.com/PeerDB-io/peer-flow/alerting" @@ -42,10 +43,10 @@ type NormalizeBatchRequest struct { } type CdcState struct { - connector connectors.CDCPullConnectorCore - syncDone chan struct{} - normalize chan NormalizeBatchRequest - normalizeDone chan struct{} + connector connectors.CDCPullConnectorCore + syncDone chan struct{} + normalize chan NormalizeBatchRequest + errGroup *errgroup.Group } type FlowableActivity struct { @@ -252,25 +253,25 @@ func (a *FlowableActivity) CreateNormalizedTable( func (a *FlowableActivity) maintainPull( ctx context.Context, config *protos.FlowConnectionConfigs, -) (CdcState, error) { +) (CdcState, context.Context, error) { ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowJobName) srcConn, err := connectors.GetByNameAs[connectors.CDCPullConnector](ctx, config.Env, a.CatalogPool, config.SourceName) if err != nil { a.Alerter.LogFlowError(ctx, config.FlowJobName, err) - return CdcState{}, err + return CdcState{}, nil, err } if err := srcConn.SetupReplConn(ctx); err != nil { a.Alerter.LogFlowError(ctx, config.FlowJobName, err) connectors.CloseConnector(ctx, srcConn) - return CdcState{}, err + return CdcState{}, nil, err } normalizeBufferSize, err := peerdbenv.PeerDBNormalizeChannelBufferSize(ctx, config.Env) if err != nil { a.Alerter.LogFlowError(ctx, config.FlowJobName, err) connectors.CloseConnector(ctx, srcConn) - return CdcState{}, err + return CdcState{}, nil, err } // syncDone will be closed by SyncFlow, @@ -278,24 +279,28 @@ func (a *FlowableActivity) maintainPull( // Wait on normalizeDone at end to not interrupt final normalize syncDone := make(chan struct{}) normalize := make(chan NormalizeBatchRequest, normalizeBufferSize) - normalizeDone := make(chan struct{}) - go a.normalizeLoop(ctx, config, syncDone, normalize, normalizeDone) - go func() { - defer connectors.CloseConnector(ctx, srcConn) - err := a.maintainReplConn(ctx, config.FlowJobName, srcConn, syncDone) - if err != nil { - a.Alerter.LogFlowError(ctx, config.FlowJobName, err) + group, groupCtx := errgroup.WithContext(ctx) + group.Go(func() error { + // returning error signals sync to stop, normalize can recover connections without interrupting sync, so never return error + a.normalizeLoop(groupCtx, config, syncDone, normalize) + return nil + }) + group.Go(func() error { + defer connectors.CloseConnector(groupCtx, srcConn) + if err := a.maintainReplConn(groupCtx, config.FlowJobName, srcConn, syncDone); err != nil { + a.Alerter.LogFlowError(groupCtx, config.FlowJobName, err) + return err } - // TODO propagate error - }() + return nil + }) return CdcState{ - connector: srcConn, - syncDone: syncDone, - normalize: normalize, - normalizeDone: normalizeDone, - }, nil + connector: srcConn, + syncDone: syncDone, + normalize: normalize, + errGroup: group, + }, groupCtx, nil } func (a *FlowableActivity) SyncFlow( @@ -306,7 +311,7 @@ func (a *FlowableActivity) SyncFlow( ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowJobName) logger := activity.GetLogger(ctx) - cdcState, err := a.maintainPull(ctx, config) + cdcState, groupCtx, err := a.maintainPull(ctx, config) if err != nil { logger.Error("MaintainPull failed", slog.Any("error", err)) return err @@ -315,16 +320,16 @@ func (a *FlowableActivity) SyncFlow( currentSyncFlowNum := int32(0) totalRecordsSynced := int64(0) - for ctx.Err() == nil { + for groupCtx.Err() == nil { currentSyncFlowNum += 1 logger.Info("executing sync flow", slog.Int("count", int(currentSyncFlowNum))) var numRecordsSynced int64 var syncErr error if config.System == protos.TypeSystem_Q { - numRecordsSynced, syncErr = a.SyncRecords(ctx, config, options, cdcState) + numRecordsSynced, syncErr = a.SyncRecords(groupCtx, config, options, cdcState) } else { - numRecordsSynced, syncErr = a.SyncPg(ctx, config, options, cdcState) + numRecordsSynced, syncErr = a.SyncPg(groupCtx, config, options, cdcState) } if syncErr != nil { @@ -343,11 +348,13 @@ func (a *FlowableActivity) SyncFlow( } close(cdcState.syncDone) - <-cdcState.normalizeDone - + waitErr := cdcState.errGroup.Wait() if err := ctx.Err(); err != nil { logger.Info("sync canceled", slog.Any("error", err)) return err + } else if waitErr != nil { + logger.Error("sync failed", slog.Any("error", waitErr)) + return waitErr } return nil } diff --git a/flow/activities/flowable_core.go b/flow/activities/flowable_core.go index 1423d5f9b8..b3717c5559 100644 --- a/flow/activities/flowable_core.go +++ b/flow/activities/flowable_core.go @@ -618,9 +618,7 @@ func (a *FlowableActivity) normalizeLoop( config *protos.FlowConnectionConfigs, syncDone <-chan struct{}, normalize <-chan NormalizeBatchRequest, - normalizeDone chan struct{}, ) { - defer close(normalizeDone) logger := activity.GetLogger(ctx) for {