diff --git a/internal/db/postgres/cmd/validate.go b/internal/db/postgres/cmd/validate.go index 252866c3..7a7a61da 100644 --- a/internal/db/postgres/cmd/validate.go +++ b/internal/db/postgres/cmd/validate.go @@ -130,9 +130,15 @@ func (v *Validate) Run(ctx context.Context) (int, error) { return nonZeroExitCode, fmt.Errorf("unable to build runtime context: %w", err) } - if err = v.printValidationWarnings(); err != nil { + err = toolkit.PrintValidationWarnings( + v.context.Warnings, v.config.Validate.ResolvedWarnings, v.config.Validate.Warnings, + ) + if err != nil { return nonZeroExitCode, err } + if v.context.IsFatal() { + return nonZeroExitCode, fmt.Errorf("fatal validation error") + } if err = v.diffWithPreviousSchema(ctx); err != nil { return nonZeroExitCode, err @@ -280,34 +286,6 @@ func (v *Validate) createDocument(ctx context.Context, t *entries.Table) (valida return doc, nil } -func (v *Validate) printValidationWarnings() error { - // TODO: Implement warnings hook, such as logging and HTTP sender - for _, w := range v.context.Warnings { - w.MakeHash() - if idx := slices.Index(v.config.Validate.ResolvedWarnings, w.Hash); idx != -1 { - log.Debug().Str("hash", w.Hash).Msg("resolved warning has been excluded") - if w.Severity == toolkit.ErrorValidationSeverity { - return fmt.Errorf("warning with hash %s cannot be excluded because it is an error", w.Hash) - } - continue - } - - if w.Severity == toolkit.ErrorValidationSeverity { - // The warnings with error severity must be printed anyway - log.Error().Any("ValidationWarning", w).Msg("") - } else { - // Print warnings with severity level lower than ErrorValidationSeverity only if requested - if v.config.Validate.Warnings { - log.Warn().Any("ValidationWarning", w).Msg("") - } - } - } - if v.context.IsFatal() { - return fmt.Errorf("fatal validation error") - } - return nil -} - func (v *Validate) getTablesToValidate() ([]*domains.Table, error) { var tablesToValidate []*domains.Table for _, tv := range v.config.Validate.Tables { diff --git a/internal/db/postgres/context/table.go b/internal/db/postgres/context/table.go index 808b810d..f61c312f 100644 --- a/internal/db/postgres/context/table.go +++ b/internal/db/postgres/context/table.go @@ -155,6 +155,7 @@ func validateAndBuildTablesConfig( func getTable(ctx context.Context, tx pgx.Tx, t *domains.Table) ([]*entries.Table, toolkit.ValidationWarnings, error) { table := &entries.Table{ Table: &toolkit.Table{}, + When: t.When, } var warnings toolkit.ValidationWarnings var tables []*entries.Table @@ -204,6 +205,7 @@ func getTable(ctx context.Context, tx pgx.Tx, t *domains.Table) ([]*entries.Tabl RootPtSchema: table.Schema, RootPtName: table.Name, RootOid: table.Oid, + When: table.When, } if err = rows.Scan(&pt.Oid, &pt.Schema, &pt.Name); err != nil { return nil, nil, fmt.Errorf("error scanning TableGetChildPatsQuery: %w", err) diff --git a/internal/db/postgres/context/transformers.go b/internal/db/postgres/context/transformers.go index b32d53d0..dc6739cd 100644 --- a/internal/db/postgres/context/transformers.go +++ b/internal/db/postgres/context/transformers.go @@ -28,18 +28,17 @@ func initTransformer( c *domains.TransformerConfig, r *transformersUtils.TransformerRegistry, ) (*transformersUtils.TransformerContext, toolkit.ValidationWarnings, error) { - // TODO: Create var totalWarnings toolkit.ValidationWarnings td, ok := r.Get(c.Name) if !ok { totalWarnings = append(totalWarnings, toolkit.NewValidationWarning(). SetMsg("transformer not found"). - SetSeverity(toolkit.ErrorValidationSeverity).SetTrace(&toolkit.Trace{ - SchemaName: d.Table.Schema, - TableName: d.Table.Name, - TransformerName: c.Name, - })) + AddMeta("SchemaName", d.Table.Schema). + AddMeta("TableName", d.Table.Name). + AddMeta("TransformerName", c.Name). + SetSeverity(toolkit.ErrorValidationSeverity), + ) return nil, totalWarnings, nil } transformer, warnings, err := td.Instance(ctx, d, c.Params, c.DynamicParams, c.When) diff --git a/internal/db/postgres/dumpers/transformation_pipeline.go b/internal/db/postgres/dumpers/transformation_pipeline.go index 0d185018..7c660dc4 100644 --- a/internal/db/postgres/dumpers/transformation_pipeline.go +++ b/internal/db/postgres/dumpers/transformation_pipeline.go @@ -46,6 +46,8 @@ type TransformationPipeline struct { Transform TransformationFunc isAsync bool record *toolkit.Record + // when - table level when condition + when *toolkit.WhenCond } func NewTransformationPipeline(ctx context.Context, eg *errgroup.Group, table *entries.Table, w io.Writer) (*TransformationPipeline, error) { @@ -96,6 +98,15 @@ func NewTransformationPipeline(ctx context.Context, eg *errgroup.Group, table *e } tp.Transform = tf + whenCond, warnings := toolkit.NewWhenCond(table.When, table.Driver) + if err := toolkit.PrintValidationWarnings(warnings, nil, true); err != nil { + return nil, err + } + if warnings.IsFatal() { + return nil, fmt.Errorf("unable to compile when condition: fatal error") + } + tp.when = whenCond + return tp, nil } @@ -131,10 +142,8 @@ func (tp *TransformationPipeline) Init(ctx context.Context) error { } func (tp *TransformationPipeline) TransformSync(ctx context.Context, r *toolkit.Record) (*toolkit.Record, error) { - var err error - var needTransform bool for _, t := range tp.table.TransformersContext { - needTransform, err = t.EvaluateWhen() + needTransform, err := t.EvaluateWhen(r) if err != nil { return nil, NewDumpError(tp.table.Schema, tp.table.Name, tp.line, fmt.Errorf("error evaluating when condition: %w", err)) } @@ -167,10 +176,18 @@ func (tp *TransformationPipeline) Dump(ctx context.Context, data []byte) (err er } tp.record.SetRow(tp.row) - _, err = tp.Transform(ctx, tp.record) + needTransform, err := tp.when.Evaluate(tp.record) if err != nil { - return NewDumpError(tp.table.Schema, tp.table.Name, tp.line, err) + return NewDumpError(tp.table.Schema, tp.table.Name, tp.line, fmt.Errorf("error evaluating when condition: %w", err)) } + + if needTransform { + _, err = tp.Transform(ctx, tp.record) + if err != nil { + return NewDumpError(tp.table.Schema, tp.table.Name, tp.line, err) + } + } + rowDriver, err := tp.record.Encode() if err != nil { return NewDumpError(tp.table.Schema, tp.table.Name, tp.line, fmt.Errorf("error enocding to RowDriver: %w", err)) diff --git a/internal/db/postgres/entries/table.go b/internal/db/postgres/entries/table.go index 629fb4a9..ef5d9db9 100644 --- a/internal/db/postgres/entries/table.go +++ b/internal/db/postgres/entries/table.go @@ -46,6 +46,7 @@ type Table struct { Driver *toolkit.Driver Scores int64 SubsetConds []string + When string } func (t *Table) HasCustomTransformer() bool { diff --git a/internal/db/postgres/transformers/random_person_test.go b/internal/db/postgres/transformers/random_person_test.go index 8165c107..cf40040d 100644 --- a/internal/db/postgres/transformers/random_person_test.go +++ b/internal/db/postgres/transformers/random_person_test.go @@ -6,11 +6,12 @@ import ( "strings" "testing" + "github.com/rs/zerolog/log" + "github.com/stretchr/testify/require" + "github.com/greenmaskio/greenmask/internal/db/postgres/transformers/utils" "github.com/greenmaskio/greenmask/internal/generators/transformers" "github.com/greenmaskio/greenmask/pkg/toolkit" - "github.com/rs/zerolog/log" - "github.com/stretchr/testify/require" ) func TestRandomPersonTransformer_Transform_static_fullname(t *testing.T) { @@ -153,6 +154,7 @@ func TestRandomPersonTransformer_Transform_static_nullable(t *testing.T) { driver, params, nil, + "", ) require.NoError(t, err) require.Empty(t, warnings) diff --git a/internal/db/postgres/transformers/utils/definition.go b/internal/db/postgres/transformers/utils/definition.go index 6db05e50..bb0f6c8c 100644 --- a/internal/db/postgres/transformers/utils/definition.go +++ b/internal/db/postgres/transformers/utils/definition.go @@ -18,9 +18,6 @@ import ( "context" "fmt" - "github.com/expr-lang/expr" - "github.com/expr-lang/expr/vm" - "github.com/greenmaskio/greenmask/pkg/toolkit" ) @@ -55,37 +52,6 @@ func (d *TransformerDefinition) SetSchemaValidator(v SchemaValidationFunc) *Tran return d } -type TransformerContext struct { - Transformer Transformer - StaticParameters map[string]*toolkit.StaticParameter - DynamicParameters map[string]*toolkit.DynamicParameter - whenCond *vm.Program - whenEnv expr.Option - rc *toolkit.RecordContext -} - -func (tc *TransformerContext) SetRecord(r *toolkit.Record) { - tc.rc.SetRecord(r) -} - -func (tc *TransformerContext) EvaluateWhen() (bool, error) { - if tc.whenCond == nil { - return true, nil - } - - output, err := expr.Run(tc.whenCond, nil) - if err != nil { - return false, fmt.Errorf("unable to evaluate when condition: %w", err) - } - - cond, ok := output.(bool) - if ok { - return cond, nil - } - - return false, fmt.Errorf("when condition should return boolean, got (%T) and value %+v", cond, cond) -} - func (d *TransformerDefinition) Instance( ctx context.Context, driver *toolkit.Driver, rawParams map[string]toolkit.ParamsValue, dynamicParameters map[string]*toolkit.DynamicParamValue, whenCond string, @@ -132,51 +98,24 @@ func (d *TransformerDefinition) Instance( res = append(res, schemaWarnings...) res = append(res, transformerWarnings...) - cond, rc, condWarns := compileCond(whenCond, driver) + when, condWarns := toolkit.NewWhenCond(whenCond, driver) res = append(res, condWarns...) return &TransformerContext{ Transformer: t, StaticParameters: staticParams, DynamicParameters: dynamicParams, - whenCond: cond, - rc: rc, + when: when, }, res, nil } -func compileCond(whenCond string, driver *toolkit.Driver) (*vm.Program, *toolkit.RecordContext, toolkit.ValidationWarnings) { - if whenCond == "" { - return nil, nil, nil - } - rc, funcs := newRecordContext(driver) - - cond, err := expr.Compile(whenCond, funcs...) - if err != nil { - return nil, nil, toolkit.ValidationWarnings{ - toolkit.NewValidationWarning(). - SetSeverity(toolkit.ErrorValidationSeverity). - AddMeta("Error", err.Error()). - SetMsg("unable to compile when condition"), - } - } - - return cond, rc, nil +type TransformerContext struct { + Transformer Transformer + StaticParameters map[string]*toolkit.StaticParameter + DynamicParameters map[string]*toolkit.DynamicParameter + when *toolkit.WhenCond } -func newRecordContext(driver *toolkit.Driver) (*toolkit.RecordContext, []expr.Option) { - var funcs []expr.Option - rctx := toolkit.NewRecordContext() - for _, c := range driver.Table.Columns { - - f := expr.Function( - c.Name, - func(name string) func(params ...any) (any, error) { - return func(params ...any) (any, error) { - return rctx.GetColumnRawValue(name) - } - }(c.Name), - ) - funcs = append(funcs, f) - } - return rctx, funcs +func (tc *TransformerContext) EvaluateWhen(r *toolkit.Record) (bool, error) { + return tc.when.Evaluate(r) } diff --git a/pkg/toolkit/expr.go b/pkg/toolkit/expr.go new file mode 100644 index 00000000..09d7f0bb --- /dev/null +++ b/pkg/toolkit/expr.go @@ -0,0 +1,92 @@ +package toolkit + +import ( + "fmt" + + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/vm" +) + +type WhenCond struct { + rc *RecordContext + whenCond *vm.Program + when string +} + +func NewWhenCond(when string, driver *Driver) (*WhenCond, ValidationWarnings) { + var ( + rc *RecordContext + whenCond *vm.Program + warnings ValidationWarnings + ) + if when != "" { + whenCond, rc, warnings = compileCond(when, driver) + if warnings.IsFatal() { + return nil, warnings + } + } + return &WhenCond{ + rc: rc, + whenCond: whenCond, + when: when, + }, nil +} + +func (wc *WhenCond) Evaluate(r *Record) (bool, error) { + if wc.whenCond == nil { + return true, nil + } + wc.rc.SetRecord(r) + + output, err := expr.Run(wc.whenCond, nil) + if err != nil { + return false, fmt.Errorf("unable to evaluate when condition: %w", err) + } + + cond, ok := output.(bool) + if ok { + return cond, nil + } + + return false, fmt.Errorf("when condition should return boolean, got (%T) and value %+v", cond, cond) +} + +// compileCond compiles the when condition and returns the compiled program and the record context +// with the functions for the columns. The functions represent the column names and return the column values. +func compileCond(whenCond string, driver *Driver) (*vm.Program, *RecordContext, ValidationWarnings) { + if whenCond == "" { + return nil, nil, nil + } + rc, funcs := newRecordContext(driver) + + cond, err := expr.Compile(whenCond, funcs...) + if err != nil { + return nil, nil, ValidationWarnings{ + NewValidationWarning(). + SetSeverity(ErrorValidationSeverity). + AddMeta("Error", err.Error()). + SetMsg("unable to compile when condition"), + } + } + + return cond, rc, nil +} + +func newRecordContext(driver *Driver) (*RecordContext, []expr.Option) { + var funcs []expr.Option + rctx := NewRecordContext() + for _, c := range driver.Table.Columns { + + f := expr.Function( + fmt.Sprintf("__%s", c.Name), + func(name string) func(params ...any) (any, error) { + return func(params ...any) (any, error) { + return rctx.GetColumnRawValue(name) + } + }(c.Name), + ) + funcs = append(funcs, f) + } + + return rctx, funcs +} diff --git a/pkg/toolkit/validation_warning.go b/pkg/toolkit/validation_warning.go index e35db8e9..d48cc547 100644 --- a/pkg/toolkit/validation_warning.go +++ b/pkg/toolkit/validation_warning.go @@ -19,6 +19,8 @@ import ( "encoding/hex" "fmt" "slices" + + "github.com/rs/zerolog/log" ) const ( @@ -28,15 +30,6 @@ const ( DebugValidationSeverity = "debug" ) -// deprecated -type Trace struct { - SchemaName string `json:"schemaName,omitempty"` - TableName string `json:"tableName,omitempty"` - TransformerName string `json:"transformerName,omitempty"` - ParameterName string `json:"parameterName,omitempty"` - Msg string `json:"msg,omitempty"` -} - type ValidationWarnings []*ValidationWarning func (re ValidationWarnings) IsFatal() bool { @@ -48,7 +41,6 @@ func (re ValidationWarnings) IsFatal() bool { type ValidationWarning struct { Msg string `json:"msg,omitempty"` Severity string `json:"severity,omitempty"` - Trace *Trace `json:"trace,omitempty"` Meta map[string]any `json:"meta,omitempty"` Hash string `json:"hash"` } @@ -80,11 +72,6 @@ func (re *ValidationWarning) AddMeta(key string, value any) *ValidationWarning { return re } -func (re *ValidationWarning) SetTrace(value *Trace) *ValidationWarning { - re.Trace = value - return re -} - func (re *ValidationWarning) MakeHash() { var meta string keys := make([]string, 0, len(re.Meta)) @@ -102,3 +89,28 @@ func (re *ValidationWarning) MakeHash() { hash := md5.Sum([]byte(signature)) re.Hash = hex.EncodeToString(hash[:]) } + +func PrintValidationWarnings(warns ValidationWarnings, resolvedWarnings []string, printAll bool) error { + // TODO: Implement warnings hook, such as logging and HTTP sender + for _, w := range warns { + w.MakeHash() + if idx := slices.Index(resolvedWarnings, w.Hash); idx != -1 { + log.Debug().Str("hash", w.Hash).Msg("resolved warning has been excluded") + if w.Severity == ErrorValidationSeverity { + return fmt.Errorf("warning with hash %s cannot be excluded because it is an error", w.Hash) + } + continue + } + + if w.Severity == ErrorValidationSeverity { + // The warnings with error severity must be printed anyway + log.Error().Any("ValidationWarning", w).Msg("") + } else { + // Print warnings with severity level lower than ErrorValidationSeverity only if requested + if printAll { + log.Warn().Any("ValidationWarning", w).Msg("") + } + } + } + return nil +}