diff --git a/internal/db/postgres/dumpers/transformation_pipeline.go b/internal/db/postgres/dumpers/transformation_pipeline.go index e2f8758a..ece19c13 100644 --- a/internal/db/postgres/dumpers/transformation_pipeline.go +++ b/internal/db/postgres/dumpers/transformation_pipeline.go @@ -30,11 +30,9 @@ import ( "github.com/greenmaskio/greenmask/pkg/toolkit" ) -const tmpFilePath = "/tmp" - var endOfLineSeq = []byte("\n") -type TransformationFunc func(ctx context.Context, r *toolkit.Record) (*toolkit.Record, error) +type transformationFunc func(ctx context.Context, r *toolkit.Record) (*toolkit.Record, error) type TransformationPipeline struct { table *entries.Table @@ -42,8 +40,8 @@ type TransformationPipeline struct { w io.Writer line uint64 row *pgcopy.Row - transformationWindows []*TransformationWindow - Transform TransformationFunc + transformationWindows []*transformationWindow + Transform transformationFunc isAsync bool record *toolkit.Record // when - table level when condition @@ -52,7 +50,7 @@ type TransformationPipeline struct { func NewTransformationPipeline(ctx context.Context, eg *errgroup.Group, table *entries.Table, w io.Writer) (*TransformationPipeline, error) { - var tws []*TransformationWindow + var tws []*transformationWindow var isAsync bool // TODO: Fix this hint. Async execution cannot be performed with template record because it is unsafe. @@ -64,13 +62,13 @@ func NewTransformationPipeline(ctx context.Context, eg *errgroup.Group, table *e if !hasTemplateRecordTransformer && table.HasCustomTransformer() && len(table.TransformersContext) > 1 { isAsync = true - tw := NewTransformationWindow(ctx, eg) + tw := newTransformationWindow(ctx, eg) tws = append(tws, tw) for _, t := range table.TransformersContext { - if !tw.TryAdd(table, t.Transformer) { - tw = NewTransformationWindow(ctx, eg) + if !tw.tryAdd(table, t) { + tw = newTransformationWindow(ctx, eg) tws = append(tws, tw) - tw.TryAdd(table, t.Transformer) + tw.tryAdd(table, t) } } } @@ -92,7 +90,7 @@ func NewTransformationPipeline(ctx context.Context, eg *errgroup.Group, table *e record: record, } - var tf TransformationFunc = tp.TransformSync + var tf transformationFunc = tp.TransformSync if isAsync { tf = tp.TransformAsync } @@ -139,7 +137,7 @@ func (tp *TransformationPipeline) Init(ctx context.Context) error { } if tp.isAsync { for _, w := range tp.transformationWindows { - w.Init() + w.init() } } @@ -195,7 +193,7 @@ func (tp *TransformationPipeline) Dump(ctx context.Context, data []byte) (err er rowDriver, err := tp.record.Encode() if err != nil { - return NewDumpError(tp.table.Schema, tp.table.Name, tp.line, fmt.Errorf("error enocding to RowDriver: %w", err)) + return NewDumpError(tp.table.Schema, tp.table.Name, tp.line, fmt.Errorf("error enocding Record to RowDriver: %w", err)) } res, err := rowDriver.Encode() if err != nil { @@ -234,7 +232,7 @@ func (tp *TransformationPipeline) Done(ctx context.Context) error { } if tp.isAsync { for _, w := range tp.transformationWindows { - w.Done() + w.close() } } diff --git a/internal/db/postgres/dumpers/transformation_window.go b/internal/db/postgres/dumpers/transformation_window.go index eb1a8a4e..6424c201 100644 --- a/internal/db/postgres/dumpers/transformation_window.go +++ b/internal/db/postgres/dumpers/transformation_window.go @@ -16,6 +16,7 @@ package dumpers import ( "context" + "fmt" "sync" "golang.org/x/sync/errgroup" @@ -25,20 +26,23 @@ import ( "github.com/greenmaskio/greenmask/pkg/toolkit" ) -type TransformationWindow struct { +type asyncContext struct { + tc *utils.TransformerContext + ch chan struct{} +} + +type transformationWindow struct { affectedColumns map[string]struct{} - transformers []utils.Transformer - chs []chan struct{} + window []*asyncContext done chan struct{} wg *sync.WaitGroup eg *errgroup.Group r *toolkit.Record ctx context.Context - size int } -func NewTransformationWindow(ctx context.Context, eg *errgroup.Group) *TransformationWindow { - return &TransformationWindow{ +func newTransformationWindow(ctx context.Context, eg *errgroup.Group) *transformationWindow { + return &transformationWindow{ affectedColumns: map[string]struct{}{}, done: make(chan struct{}, 1), wg: &sync.WaitGroup{}, @@ -47,12 +51,11 @@ func NewTransformationWindow(ctx context.Context, eg *errgroup.Group) *Transform } } -func (tw *TransformationWindow) TryAdd(table *entries.Table, t utils.Transformer) bool { +func (tw *transformationWindow) tryAdd(table *entries.Table, t *utils.TransformerContext) bool { - affectedColumn := t.GetAffectedColumns() + affectedColumn := t.Transformer.GetAffectedColumns() if len(affectedColumn) == 0 { - if len(tw.transformers) == 0 { - tw.transformers = append(tw.transformers, t) + if len(tw.window) == 0 { for _, c := range table.Columns { tw.affectedColumns[c.Name] = struct{}{} } @@ -68,20 +71,19 @@ func (tw *TransformationWindow) TryAdd(table *entries.Table, t utils.Transformer for _, name := range affectedColumn { tw.affectedColumns[name] = struct{}{} } - tw.transformers = append(tw.transformers, t) } - ch := make(chan struct{}, 1) - tw.chs = append(tw.chs, ch) - tw.size++ + tw.window = append(tw.window, &asyncContext{ + tc: t, + ch: make(chan struct{}, 1), + }) return true } -func (tw *TransformationWindow) Init() { - for idx, t := range tw.transformers { - ch := tw.chs[idx] - func(t utils.Transformer, ch chan struct{}) { +func (tw *transformationWindow) init() { + for _, ac := range tw.window { + func(ac *asyncContext) { tw.eg.Go(func() error { for { select { @@ -89,9 +91,9 @@ func (tw *TransformationWindow) Init() { return tw.ctx.Err() case <-tw.done: return nil - case <-ch: + case <-ac.ch: } - _, err := t.Transform(tw.ctx, tw.r) + _, err := ac.tc.Transformer.Transform(tw.ctx, tw.r) if err != nil { tw.wg.Done() return err @@ -99,26 +101,33 @@ func (tw *TransformationWindow) Init() { tw.wg.Done() } }) - }(t, ch) + }(ac) } } -func (tw *TransformationWindow) Done() { +func (tw *transformationWindow) close() { close(tw.done) } -func (tw *TransformationWindow) Transform(ctx context.Context, r *toolkit.Record) (*toolkit.Record, error) { - - tw.wg.Add(tw.size) +func (tw *transformationWindow) Transform(ctx context.Context, r *toolkit.Record) (*toolkit.Record, error) { tw.r = r - for _, ch := range tw.chs { + for _, ac := range tw.window { + needTransform, err := ac.tc.EvaluateWhen(r) + if err != nil { + return nil, fmt.Errorf("error evaluating when condition: %w", err) + } + if !needTransform { + continue + } + + tw.wg.Add(1) select { case <-ctx.Done(): return nil, ctx.Err() case <-tw.ctx.Done(): return nil, tw.ctx.Err() - case ch <- struct{}{}: + case ac.ch <- struct{}{}: } }