Skip to content

Commit

Permalink
feat: added test for transformation pipelines. Revised some code
Browse files Browse the repository at this point in the history
  • Loading branch information
wwoytenko committed Oct 27, 2024
1 parent d9fc636 commit 378a420
Show file tree
Hide file tree
Showing 11 changed files with 371 additions and 39 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,3 @@ venv
.cache
# Binaries
cmd/greenmask/greenmask
pkg/toolkit/test/test
40 changes: 20 additions & 20 deletions internal/db/postgres/cmd/validate_utils/json_document_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,6 @@ import (
"github.com/greenmaskio/greenmask/pkg/toolkit"
)

type testTransformer struct{}

func (tt *testTransformer) Init(ctx context.Context) error {
return nil
}

func (tt *testTransformer) Done(ctx context.Context) error {
return nil
}

func (tt *testTransformer) Transform(ctx context.Context, r *toolkit.Record) (*toolkit.Record, error) {
return nil, nil
}

func (tt *testTransformer) GetAffectedColumns() map[int]string {
return map[int]string{
1: "name",
}
}

func TestJsonDocument_GetAffectedColumns(t *testing.T) {
tab, _, _ := getTableAndRows()
jd := NewJsonDocument(tab, true, true)
Expand Down Expand Up @@ -87,6 +67,26 @@ func TestJsonDocument_GetRecords(t *testing.T) {
//r.SetRow(row)
}

type testTransformer struct{}

func (tt *testTransformer) Init(ctx context.Context) error {
return nil
}

func (tt *testTransformer) Done(ctx context.Context) error {
return nil
}

func (tt *testTransformer) Transform(ctx context.Context, r *toolkit.Record) (*toolkit.Record, error) {
return nil, nil
}

func (tt *testTransformer) GetAffectedColumns() map[int]string {
return map[int]string{
1: "name",
}
}

func getTableAndRows() (table *entries.Table, original, transformed [][]byte) {

tableDef := `
Expand Down
108 changes: 108 additions & 0 deletions internal/db/postgres/dumpers/transformation_pipeline_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package dumpers

import (
"bytes"
"context"
"testing"
"time"

"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"

"github.com/greenmaskio/greenmask/internal/db/postgres/transformers/utils"
"github.com/greenmaskio/greenmask/pkg/toolkit"
)

func TestTransformationPipeline_Dump(t *testing.T) {
termCtx, termCancel := context.WithTimeout(context.Background(), 10*time.Second)
defer termCancel()
table := getTable()
ctx := context.Background()
eg, gtx := errgroup.WithContext(ctx)
driver := getDriver(table.Table)
table.Driver = driver
when, warns := toolkit.NewWhenCond("", driver, nil)
require.Empty(t, warns)
tt := &testTransformer{}
tc := &utils.TransformerContext{
Transformer: tt,
When: when,
}
table.TransformersContext = []*utils.TransformerContext{tc}

buf := bytes.NewBuffer(nil)

pipeline, err := NewTransformationPipeline(gtx, eg, table, buf)
require.NoError(t, err)
require.NoError(t, pipeline.Init(termCtx))
data := []byte("1\t2023-08-27 00:00:00.000000")
err = pipeline.Dump(ctx, data)
require.NoError(t, err)
require.NoError(t, pipeline.Done(termCtx))
require.NoError(t, pipeline.CompleteDump())
require.Equal(t, tt.callsCount, 1)
require.Equal(t, buf.String(), "2\t2023-08-27 00:00:00.00000\n\\.\n\n")
}

func TestTransformationPipeline_Dump_with_transformer_cond(t *testing.T) {
termCtx, termCancel := context.WithTimeout(context.Background(), 10*time.Second)
defer termCancel()
table := getTable()
ctx := context.Background()
eg, gtx := errgroup.WithContext(ctx)
driver := getDriver(table.Table)
table.Driver = driver
when, warns := toolkit.NewWhenCond("record.id != 1", driver, make(map[string]any))
require.Empty(t, warns)
tt := &testTransformer{}
tc := &utils.TransformerContext{
Transformer: tt,
When: when,
}
table.TransformersContext = []*utils.TransformerContext{tc}

buf := bytes.NewBuffer(nil)

pipeline, err := NewTransformationPipeline(gtx, eg, table, buf)
require.NoError(t, err)
require.NoError(t, pipeline.Init(termCtx))
data := []byte("1\t2023-08-27 00:00:00.000000")
err = pipeline.Dump(ctx, data)
require.NoError(t, err)
require.NoError(t, pipeline.Done(termCtx))
require.NoError(t, pipeline.CompleteDump())
require.Equal(t, tt.callsCount, 0)
require.Equal(t, buf.String(), "1\t2023-08-27 00:00:00.00000\n\\.\n\n")
}

func TestTransformationPipeline_Dump_with_table_cond(t *testing.T) {
termCtx, termCancel := context.WithTimeout(context.Background(), 10*time.Second)
defer termCancel()
table := getTable()
ctx := context.Background()
eg, gtx := errgroup.WithContext(ctx)
driver := getDriver(table.Table)
table.Driver = driver
when, warns := toolkit.NewWhenCond("", driver, make(map[string]any))
require.Empty(t, warns)
tt := &testTransformer{}
tc := &utils.TransformerContext{
Transformer: tt,
When: when,
}
table.TransformersContext = []*utils.TransformerContext{tc}
table.When = "record.id != 1"

buf := bytes.NewBuffer(nil)

pipeline, err := NewTransformationPipeline(gtx, eg, table, buf)
require.NoError(t, err)
require.NoError(t, pipeline.Init(termCtx))
data := []byte("1\t2023-08-27 00:00:00.000000")
err = pipeline.Dump(ctx, data)
require.NoError(t, err)
require.NoError(t, pipeline.Done(termCtx))
require.NoError(t, pipeline.CompleteDump())
require.Equal(t, tt.callsCount, 0)
require.Equal(t, buf.String(), "1\t2023-08-27 00:00:00.00000\n\\.\n\n")
}
4 changes: 4 additions & 0 deletions internal/db/postgres/dumpers/transformation_window.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ func (tw *transformationWindow) tryAdd(table *entries.Table, t *utils.Transforme
return true
}

// init - runs all transformers in the goroutines and waits for the ac.ch signal to run the transformer
func (tw *transformationWindow) init() {
for _, ac := range tw.window {
func(ac *asyncContext) {
Expand All @@ -105,10 +106,13 @@ func (tw *transformationWindow) init() {
}
}

// close - closes the done channel to stop the transformers goroutines
func (tw *transformationWindow) close() {
close(tw.done)
}

// Transform - runs the transformation for the record in the window. This function checks the when
// condition of the transformer and if true sends a signal to the transformer goroutine to run the transformation
func (tw *transformationWindow) Transform(ctx context.Context, r *toolkit.Record) (*toolkit.Record, error) {
tw.r = r
for _, ac := range tw.window {
Expand Down
158 changes: 158 additions & 0 deletions internal/db/postgres/dumpers/transformation_window_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
package dumpers

import (
"context"
"testing"
"time"

"github.com/jackc/pgx/v5/pgtype"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"

"github.com/greenmaskio/greenmask/internal/db/postgres/entries"
"github.com/greenmaskio/greenmask/internal/db/postgres/transformers/utils"
"github.com/greenmaskio/greenmask/pkg/toolkit"
"github.com/greenmaskio/greenmask/pkg/toolkit/testutils"
)

func TestTransformationWindow_tryAdd(t *testing.T) {
ctx := context.Background()
eg, gtx := errgroup.WithContext(ctx)
tw := newTransformationWindow(gtx, eg)
tc := utils.TransformerContext{
Transformer: &testTransformer{},
}
table := getTable()
require.True(t, tw.tryAdd(table, &tc))
require.False(t, tw.tryAdd(table, &tc))
}

func TestTransformationWindow_Transform(t *testing.T) {
mainCtx, mainCancel := context.WithTimeout(context.Background(), 10*time.Second)
defer mainCancel()
eg, gtx := errgroup.WithContext(mainCtx)
tw := newTransformationWindow(gtx, eg)
when, warns := toolkit.NewWhenCond("", nil, nil)
require.Empty(t, warns)
tc := utils.TransformerContext{
Transformer: &testTransformer{},
When: when,
}
table := getTable()
require.True(t, tw.tryAdd(table, &tc))

driver := getDriver(table.Table)
record := toolkit.NewRecord(driver)
row := testutils.NewTestRowDriver([]string{"1", "2023-08-27 00:00:00.000000"})
record.SetRow(row)
tw.init()

ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
_, err := tw.Transform(ctx, record)
require.NoError(t, err)
v, err := record.GetRawColumnValueByName("id")
require.NoError(t, err)
require.False(t, v.IsNull)
require.Equal(t, []byte("2"), v.Data)
tw.close()
require.NoError(t, eg.Wait())
}

func TestTransformationWindow_Transform_with_cond(t *testing.T) {
table := getTable()
driver := getDriver(table.Table)
record := toolkit.NewRecord(driver)
when, warns := toolkit.NewWhenCond("record.id != 1", driver, make(map[string]any))
require.Empty(t, warns)
mainCtx, mainCancel := context.WithTimeout(context.Background(), 10*time.Second)
defer mainCancel()
eg, gtx := errgroup.WithContext(mainCtx)
tw := newTransformationWindow(gtx, eg)
tt := &testTransformer{}
tc := utils.TransformerContext{
Transformer: tt,
When: when,
}
require.True(t, tw.tryAdd(table, &tc))

row := testutils.NewTestRowDriver([]string{"1", "2023-08-27 00:00:00.000000"})
record.SetRow(row)
tw.init()

ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
_, err := tw.Transform(ctx, record)
require.NoError(t, err)
require.Equal(t, 0, tt.callsCount)
v, err := record.GetRawColumnValueByName("id")
require.NoError(t, err)
require.False(t, v.IsNull)
require.Equal(t, []byte("1"), v.Data)
tw.close()
require.NoError(t, eg.Wait())
}

type testTransformer struct {
callsCount int
}

func (tt *testTransformer) Init(ctx context.Context) error {
return nil
}

func (tt *testTransformer) Done(ctx context.Context) error {
return nil
}

func (tt *testTransformer) Transform(ctx context.Context, r *toolkit.Record) (*toolkit.Record, error) {
tt.callsCount++
err := r.SetColumnValueByName("id", 2)
if err != nil {
return nil, err
}
return r, nil
}

func (tt *testTransformer) GetAffectedColumns() map[int]string {
return map[int]string{
1: "name",
}
}

func getDriver(table *toolkit.Table) *toolkit.Driver {
driver, _, err := toolkit.NewDriver(table, nil)
if err != nil {
panic(err.Error())
}
return driver
}

func getTable() *entries.Table {
return &entries.Table{
Table: &toolkit.Table{
Schema: "public",
Name: "test",
Oid: 1224,
Columns: []*toolkit.Column{
{
Name: "id",
TypeName: "int2",
TypeOid: pgtype.Int2OID,
Num: 1,
NotNull: true,
Length: -1,
},
{
Name: "created_at",
TypeName: "timestamp",
TypeOid: pgtype.TimestampOID,
Num: 1,
NotNull: true,
Length: -1,
},
},
Constraints: []toolkit.Constraint{},
},
}
}
6 changes: 3 additions & 3 deletions internal/db/postgres/transformers/utils/definition.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,17 @@ func (d *TransformerDefinition) Instance(
Transformer: t,
StaticParameters: staticParams,
DynamicParameters: dynamicParams,
when: when,
When: when,
}, res, nil
}

type TransformerContext struct {
Transformer Transformer
StaticParameters map[string]*toolkit.StaticParameter
DynamicParameters map[string]*toolkit.DynamicParameter
when *toolkit.WhenCond
When *toolkit.WhenCond
}

func (tc *TransformerContext) EvaluateWhen(r *toolkit.Record) (bool, error) {
return tc.when.Evaluate(r)
return tc.When.Evaluate(r)
}
4 changes: 1 addition & 3 deletions pkg/toolkit/expt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@ import (
func TestWhenCond_Evaluate(t *testing.T) {
driver := getDriver()
record := NewRecord(driver)
row := &TestRowDriver{
row: []string{"1", "2023-08-27 00:00:00.000000", testNullSeq, `{"a": 1}`, "123.0"},
}
row := newTestRowDriver([]string{"1", "2023-08-27 00:00:00.000000", testNullSeq, `{"a": 1}`, "123.0"})
record.SetRow(row)

type test struct {
Expand Down
Loading

0 comments on commit 378a420

Please sign in to comment.