Skip to content

Commit

Permalink
Revised transformer conds implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
wwoytenko committed Oct 25, 2024
1 parent 66accc5 commit 3b3694a
Show file tree
Hide file tree
Showing 9 changed files with 169 additions and 127 deletions.
36 changes: 7 additions & 29 deletions internal/db/postgres/cmd/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,15 @@ func (v *Validate) Run(ctx context.Context) (int, error) {
return nonZeroExitCode, fmt.Errorf("unable to build runtime context: %w", err)
}

if err = v.printValidationWarnings(); err != nil {
err = toolkit.PrintValidationWarnings(
v.context.Warnings, v.config.Validate.ResolvedWarnings, v.config.Validate.Warnings,
)
if err != nil {
return nonZeroExitCode, err
}
if v.context.IsFatal() {
return nonZeroExitCode, fmt.Errorf("fatal validation error")
}

if err = v.diffWithPreviousSchema(ctx); err != nil {
return nonZeroExitCode, err
Expand Down Expand Up @@ -280,34 +286,6 @@ func (v *Validate) createDocument(ctx context.Context, t *entries.Table) (valida
return doc, nil
}

func (v *Validate) printValidationWarnings() error {
// TODO: Implement warnings hook, such as logging and HTTP sender
for _, w := range v.context.Warnings {
w.MakeHash()
if idx := slices.Index(v.config.Validate.ResolvedWarnings, w.Hash); idx != -1 {
log.Debug().Str("hash", w.Hash).Msg("resolved warning has been excluded")
if w.Severity == toolkit.ErrorValidationSeverity {
return fmt.Errorf("warning with hash %s cannot be excluded because it is an error", w.Hash)
}
continue
}

if w.Severity == toolkit.ErrorValidationSeverity {
// The warnings with error severity must be printed anyway
log.Error().Any("ValidationWarning", w).Msg("")
} else {
// Print warnings with severity level lower than ErrorValidationSeverity only if requested
if v.config.Validate.Warnings {
log.Warn().Any("ValidationWarning", w).Msg("")
}
}
}
if v.context.IsFatal() {
return fmt.Errorf("fatal validation error")
}
return nil
}

func (v *Validate) getTablesToValidate() ([]*domains.Table, error) {
var tablesToValidate []*domains.Table
for _, tv := range v.config.Validate.Tables {
Expand Down
2 changes: 2 additions & 0 deletions internal/db/postgres/context/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ func validateAndBuildTablesConfig(
func getTable(ctx context.Context, tx pgx.Tx, t *domains.Table) ([]*entries.Table, toolkit.ValidationWarnings, error) {
table := &entries.Table{
Table: &toolkit.Table{},
When: t.When,
}
var warnings toolkit.ValidationWarnings
var tables []*entries.Table
Expand Down Expand Up @@ -204,6 +205,7 @@ func getTable(ctx context.Context, tx pgx.Tx, t *domains.Table) ([]*entries.Tabl
RootPtSchema: table.Schema,
RootPtName: table.Name,
RootOid: table.Oid,
When: table.When,
}
if err = rows.Scan(&pt.Oid, &pt.Schema, &pt.Name); err != nil {
return nil, nil, fmt.Errorf("error scanning TableGetChildPatsQuery: %w", err)
Expand Down
11 changes: 5 additions & 6 deletions internal/db/postgres/context/transformers.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,17 @@ func initTransformer(
c *domains.TransformerConfig,
r *transformersUtils.TransformerRegistry,
) (*transformersUtils.TransformerContext, toolkit.ValidationWarnings, error) {
// TODO: Create
var totalWarnings toolkit.ValidationWarnings
td, ok := r.Get(c.Name)
if !ok {
totalWarnings = append(totalWarnings,
toolkit.NewValidationWarning().
SetMsg("transformer not found").
SetSeverity(toolkit.ErrorValidationSeverity).SetTrace(&toolkit.Trace{
SchemaName: d.Table.Schema,
TableName: d.Table.Name,
TransformerName: c.Name,
}))
AddMeta("SchemaName", d.Table.Schema).
AddMeta("TableName", d.Table.Name).
AddMeta("TransformerName", c.Name).
SetSeverity(toolkit.ErrorValidationSeverity),
)
return nil, totalWarnings, nil
}
transformer, warnings, err := td.Instance(ctx, d, c.Params, c.DynamicParams, c.When)
Expand Down
27 changes: 22 additions & 5 deletions internal/db/postgres/dumpers/transformation_pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ type TransformationPipeline struct {
Transform TransformationFunc
isAsync bool
record *toolkit.Record
// when - table level when condition
when *toolkit.WhenCond
}

func NewTransformationPipeline(ctx context.Context, eg *errgroup.Group, table *entries.Table, w io.Writer) (*TransformationPipeline, error) {
Expand Down Expand Up @@ -96,6 +98,15 @@ func NewTransformationPipeline(ctx context.Context, eg *errgroup.Group, table *e
}
tp.Transform = tf

whenCond, warnings := toolkit.NewWhenCond(table.When, table.Driver)
if err := toolkit.PrintValidationWarnings(warnings, nil, true); err != nil {
return nil, err
}
if warnings.IsFatal() {
return nil, fmt.Errorf("unable to compile when condition: fatal error")
}
tp.when = whenCond

return tp, nil
}

Expand Down Expand Up @@ -131,10 +142,8 @@ func (tp *TransformationPipeline) Init(ctx context.Context) error {
}

func (tp *TransformationPipeline) TransformSync(ctx context.Context, r *toolkit.Record) (*toolkit.Record, error) {
var err error
var needTransform bool
for _, t := range tp.table.TransformersContext {
needTransform, err = t.EvaluateWhen()
needTransform, err := t.EvaluateWhen(r)
if err != nil {
return nil, NewDumpError(tp.table.Schema, tp.table.Name, tp.line, fmt.Errorf("error evaluating when condition: %w", err))
}
Expand Down Expand Up @@ -167,10 +176,18 @@ func (tp *TransformationPipeline) Dump(ctx context.Context, data []byte) (err er
}
tp.record.SetRow(tp.row)

_, err = tp.Transform(ctx, tp.record)
needTransform, err := tp.when.Evaluate(tp.record)
if err != nil {
return NewDumpError(tp.table.Schema, tp.table.Name, tp.line, err)
return NewDumpError(tp.table.Schema, tp.table.Name, tp.line, fmt.Errorf("error evaluating when condition: %w", err))
}

if needTransform {
_, err = tp.Transform(ctx, tp.record)
if err != nil {
return NewDumpError(tp.table.Schema, tp.table.Name, tp.line, err)
}
}

rowDriver, err := tp.record.Encode()
if err != nil {
return NewDumpError(tp.table.Schema, tp.table.Name, tp.line, fmt.Errorf("error enocding to RowDriver: %w", err))
Expand Down
1 change: 1 addition & 0 deletions internal/db/postgres/entries/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ type Table struct {
Driver *toolkit.Driver
Scores int64
SubsetConds []string
When string
}

func (t *Table) HasCustomTransformer() bool {
Expand Down
6 changes: 4 additions & 2 deletions internal/db/postgres/transformers/random_person_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ import (
"strings"
"testing"

"github.com/rs/zerolog/log"
"github.com/stretchr/testify/require"

"github.com/greenmaskio/greenmask/internal/db/postgres/transformers/utils"
"github.com/greenmaskio/greenmask/internal/generators/transformers"
"github.com/greenmaskio/greenmask/pkg/toolkit"
"github.com/rs/zerolog/log"
"github.com/stretchr/testify/require"
)

func TestRandomPersonTransformer_Transform_static_fullname(t *testing.T) {
Expand Down Expand Up @@ -153,6 +154,7 @@ func TestRandomPersonTransformer_Transform_static_nullable(t *testing.T) {
driver,
params,
nil,
"",
)
require.NoError(t, err)
require.Empty(t, warnings)
Expand Down
79 changes: 9 additions & 70 deletions internal/db/postgres/transformers/utils/definition.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ import (
"context"
"fmt"

"github.com/expr-lang/expr"
"github.com/expr-lang/expr/vm"

"github.com/greenmaskio/greenmask/pkg/toolkit"
)

Expand Down Expand Up @@ -55,37 +52,6 @@ func (d *TransformerDefinition) SetSchemaValidator(v SchemaValidationFunc) *Tran
return d
}

type TransformerContext struct {
Transformer Transformer
StaticParameters map[string]*toolkit.StaticParameter
DynamicParameters map[string]*toolkit.DynamicParameter
whenCond *vm.Program
whenEnv expr.Option
rc *toolkit.RecordContext
}

func (tc *TransformerContext) SetRecord(r *toolkit.Record) {
tc.rc.SetRecord(r)
}

func (tc *TransformerContext) EvaluateWhen() (bool, error) {
if tc.whenCond == nil {
return true, nil
}

output, err := expr.Run(tc.whenCond, nil)
if err != nil {
return false, fmt.Errorf("unable to evaluate when condition: %w", err)
}

cond, ok := output.(bool)
if ok {
return cond, nil
}

return false, fmt.Errorf("when condition should return boolean, got (%T) and value %+v", cond, cond)
}

func (d *TransformerDefinition) Instance(
ctx context.Context, driver *toolkit.Driver, rawParams map[string]toolkit.ParamsValue, dynamicParameters map[string]*toolkit.DynamicParamValue,
whenCond string,
Expand Down Expand Up @@ -132,51 +98,24 @@ func (d *TransformerDefinition) Instance(
res = append(res, schemaWarnings...)
res = append(res, transformerWarnings...)

cond, rc, condWarns := compileCond(whenCond, driver)
when, condWarns := toolkit.NewWhenCond(whenCond, driver)
res = append(res, condWarns...)

return &TransformerContext{
Transformer: t,
StaticParameters: staticParams,
DynamicParameters: dynamicParams,
whenCond: cond,
rc: rc,
when: when,
}, res, nil
}

func compileCond(whenCond string, driver *toolkit.Driver) (*vm.Program, *toolkit.RecordContext, toolkit.ValidationWarnings) {
if whenCond == "" {
return nil, nil, nil
}
rc, funcs := newRecordContext(driver)

cond, err := expr.Compile(whenCond, funcs...)
if err != nil {
return nil, nil, toolkit.ValidationWarnings{
toolkit.NewValidationWarning().
SetSeverity(toolkit.ErrorValidationSeverity).
AddMeta("Error", err.Error()).
SetMsg("unable to compile when condition"),
}
}

return cond, rc, nil
type TransformerContext struct {
Transformer Transformer
StaticParameters map[string]*toolkit.StaticParameter
DynamicParameters map[string]*toolkit.DynamicParameter
when *toolkit.WhenCond
}

func newRecordContext(driver *toolkit.Driver) (*toolkit.RecordContext, []expr.Option) {
var funcs []expr.Option
rctx := toolkit.NewRecordContext()
for _, c := range driver.Table.Columns {

f := expr.Function(
c.Name,
func(name string) func(params ...any) (any, error) {
return func(params ...any) (any, error) {
return rctx.GetColumnRawValue(name)
}
}(c.Name),
)
funcs = append(funcs, f)
}
return rctx, funcs
func (tc *TransformerContext) EvaluateWhen(r *toolkit.Record) (bool, error) {
return tc.when.Evaluate(r)
}
92 changes: 92 additions & 0 deletions pkg/toolkit/expr.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package toolkit

import (
"fmt"

"github.com/expr-lang/expr"
"github.com/expr-lang/expr/vm"
)

type WhenCond struct {
rc *RecordContext
whenCond *vm.Program
when string
}

func NewWhenCond(when string, driver *Driver) (*WhenCond, ValidationWarnings) {
var (
rc *RecordContext
whenCond *vm.Program
warnings ValidationWarnings
)
if when != "" {
whenCond, rc, warnings = compileCond(when, driver)
if warnings.IsFatal() {
return nil, warnings
}
}
return &WhenCond{
rc: rc,
whenCond: whenCond,
when: when,
}, nil
}

func (wc *WhenCond) Evaluate(r *Record) (bool, error) {
if wc.whenCond == nil {
return true, nil
}
wc.rc.SetRecord(r)

output, err := expr.Run(wc.whenCond, nil)
if err != nil {
return false, fmt.Errorf("unable to evaluate when condition: %w", err)
}

cond, ok := output.(bool)
if ok {
return cond, nil
}

return false, fmt.Errorf("when condition should return boolean, got (%T) and value %+v", cond, cond)
}

// compileCond compiles the when condition and returns the compiled program and the record context
// with the functions for the columns. The functions represent the column names and return the column values.
func compileCond(whenCond string, driver *Driver) (*vm.Program, *RecordContext, ValidationWarnings) {
if whenCond == "" {
return nil, nil, nil
}
rc, funcs := newRecordContext(driver)

cond, err := expr.Compile(whenCond, funcs...)
if err != nil {
return nil, nil, ValidationWarnings{
NewValidationWarning().
SetSeverity(ErrorValidationSeverity).
AddMeta("Error", err.Error()).
SetMsg("unable to compile when condition"),
}
}

return cond, rc, nil
}

func newRecordContext(driver *Driver) (*RecordContext, []expr.Option) {
var funcs []expr.Option
rctx := NewRecordContext()
for _, c := range driver.Table.Columns {

f := expr.Function(
fmt.Sprintf("__%s", c.Name),
func(name string) func(params ...any) (any, error) {
return func(params ...any) (any, error) {
return rctx.GetColumnRawValue(name)
}
}(c.Name),
)
funcs = append(funcs, f)
}

return rctx, funcs
}
Loading

0 comments on commit 3b3694a

Please sign in to comment.