Skip to content

Commit

Permalink
use errgroup over a mess of channels
Browse files Browse the repository at this point in the history
  • Loading branch information
serprex committed Dec 19, 2024
1 parent 5783e9f commit b5b8c97
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 30 deletions.
63 changes: 35 additions & 28 deletions flow/activities/flowable.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"go.opentelemetry.io/otel/metric"
"go.temporal.io/sdk/activity"
"go.temporal.io/sdk/log"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/proto"

"github.com/PeerDB-io/peer-flow/alerting"
Expand Down Expand Up @@ -42,10 +43,10 @@ type NormalizeBatchRequest struct {
}

type CdcState struct {
connector connectors.CDCPullConnectorCore
syncDone chan struct{}
normalize chan NormalizeBatchRequest
normalizeDone chan struct{}
connector connectors.CDCPullConnectorCore
syncDone chan struct{}
normalize chan NormalizeBatchRequest
errGroup *errgroup.Group
}

type FlowableActivity struct {
Expand Down Expand Up @@ -252,50 +253,54 @@ func (a *FlowableActivity) CreateNormalizedTable(
func (a *FlowableActivity) maintainPull(
ctx context.Context,
config *protos.FlowConnectionConfigs,
) (CdcState, error) {
) (CdcState, context.Context, error) {
ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowJobName)
srcConn, err := connectors.GetByNameAs[connectors.CDCPullConnector](ctx, config.Env, a.CatalogPool, config.SourceName)
if err != nil {
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
return CdcState{}, err
return CdcState{}, nil, err
}

if err := srcConn.SetupReplConn(ctx); err != nil {
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
connectors.CloseConnector(ctx, srcConn)
return CdcState{}, err
return CdcState{}, nil, err
}

normalizeBufferSize, err := peerdbenv.PeerDBNormalizeChannelBufferSize(ctx, config.Env)
if err != nil {
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
connectors.CloseConnector(ctx, srcConn)
return CdcState{}, err
return CdcState{}, nil, err
}

// syncDone will be closed by SyncFlow,
// whereas normalizeDone will be closed by normalizing goroutine
// Wait on normalizeDone at end to not interrupt final normalize
syncDone := make(chan struct{})
normalize := make(chan NormalizeBatchRequest, normalizeBufferSize)
normalizeDone := make(chan struct{})

go a.normalizeLoop(ctx, config, syncDone, normalize, normalizeDone)
go func() {
defer connectors.CloseConnector(ctx, srcConn)
err := a.maintainReplConn(ctx, config.FlowJobName, srcConn, syncDone)
if err != nil {
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
group, groupCtx := errgroup.WithContext(ctx)
group.Go(func() error {
// returning error signals sync to stop, normalize can recover connections without interrupting sync, so never return error
a.normalizeLoop(groupCtx, config, syncDone, normalize)
return nil
})
group.Go(func() error {
defer connectors.CloseConnector(groupCtx, srcConn)
if err := a.maintainReplConn(groupCtx, config.FlowJobName, srcConn, syncDone); err != nil {
a.Alerter.LogFlowError(groupCtx, config.FlowJobName, err)
return err
}
// TODO propagate error
}()
return nil
})

return CdcState{
connector: srcConn,
syncDone: syncDone,
normalize: normalize,
normalizeDone: normalizeDone,
}, nil
connector: srcConn,
syncDone: syncDone,
normalize: normalize,
errGroup: group,
}, groupCtx, nil
}

func (a *FlowableActivity) SyncFlow(
Expand All @@ -306,7 +311,7 @@ func (a *FlowableActivity) SyncFlow(
ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowJobName)
logger := activity.GetLogger(ctx)

cdcState, err := a.maintainPull(ctx, config)
cdcState, groupCtx, err := a.maintainPull(ctx, config)
if err != nil {
logger.Error("MaintainPull failed", slog.Any("error", err))
return err
Expand All @@ -315,16 +320,16 @@ func (a *FlowableActivity) SyncFlow(
currentSyncFlowNum := int32(0)
totalRecordsSynced := int64(0)

for ctx.Err() == nil {
for groupCtx.Err() == nil {
currentSyncFlowNum += 1
logger.Info("executing sync flow", slog.Int("count", int(currentSyncFlowNum)))

var numRecordsSynced int64
var syncErr error
if config.System == protos.TypeSystem_Q {
numRecordsSynced, syncErr = a.SyncRecords(ctx, config, options, cdcState)
numRecordsSynced, syncErr = a.SyncRecords(groupCtx, config, options, cdcState)
} else {
numRecordsSynced, syncErr = a.SyncPg(ctx, config, options, cdcState)
numRecordsSynced, syncErr = a.SyncPg(groupCtx, config, options, cdcState)
}

if syncErr != nil {
Expand All @@ -343,11 +348,13 @@ func (a *FlowableActivity) SyncFlow(
}

close(cdcState.syncDone)
<-cdcState.normalizeDone

waitErr := cdcState.errGroup.Wait()
if err := ctx.Err(); err != nil {
logger.Info("sync canceled", slog.Any("error", err))
return err
} else if waitErr != nil {
logger.Error("sync failed", slog.Any("error", waitErr))
return waitErr
}
return nil
}
Expand Down
2 changes: 0 additions & 2 deletions flow/activities/flowable_core.go
Original file line number Diff line number Diff line change
Expand Up @@ -618,9 +618,7 @@ func (a *FlowableActivity) normalizeLoop(
config *protos.FlowConnectionConfigs,
syncDone <-chan struct{},
normalize <-chan NormalizeBatchRequest,
normalizeDone chan struct{},
) {
defer close(normalizeDone)
logger := activity.GetLogger(ctx)

for {
Expand Down

0 comments on commit b5b8c97

Please sign in to comment.