diff --git a/pkg/toolkit/expr.go b/pkg/toolkit/expr.go index 09d7f0bb..8c356764 100644 --- a/pkg/toolkit/expr.go +++ b/pkg/toolkit/expr.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/expr-lang/expr" + "github.com/expr-lang/expr/ast" "github.com/expr-lang/expr/vm" ) @@ -38,7 +39,7 @@ func (wc *WhenCond) Evaluate(r *Record) (bool, error) { } wc.rc.SetRecord(r) - output, err := expr.Run(wc.whenCond, nil) + output, err := expr.Run(wc.whenCond, map[string]any{"null": NullValue}) if err != nil { return false, fmt.Errorf("unable to evaluate when condition: %w", err) } @@ -57,9 +58,10 @@ func compileCond(whenCond string, driver *Driver) (*vm.Program, *RecordContext, if whenCond == "" { return nil, nil, nil } - rc, funcs := newRecordContext(driver) + rc, ops := newRecordContext(driver) + ops = append(ops, expr.Patch(exprPatcher{})) - cond, err := expr.Compile(whenCond, funcs...) + cond, err := expr.Compile(whenCond, ops...) if err != nil { return nil, nil, ValidationWarnings{ NewValidationWarning(). @@ -87,6 +89,78 @@ func newRecordContext(driver *Driver) (*RecordContext, []expr.Option) { ) funcs = append(funcs, f) } - return rctx, funcs } + +func valIsNotNull(params ...any) (any, error) { + return !valueIsNull(params[0]), nil +} + +func valIsNull(params ...any) (any, error) { + vv, ok := params[0].(NullType) + if !ok { + return false, nil + } + return vv == NullValue, nil +} + +// exprPatcher - patcher for the expression compiler. It patches the expression tree by some identifiers to +// function calls. For instance is null, is not null, records address +type exprPatcher struct{} + +func (exprPatcher) Visit(node *ast.Node) { + switch { + case isNullOp(node): + case isNotNullOp(node): + case isRecordOp(node): + patchRecordOp(node) + } +} + +func isNullOp(node *ast.Node) bool { + _, ok := (*node).(*ast.IdentifierNode) + if !ok { + return false + } + return false +} + +func isNotNullOp(node *ast.Node) bool { + _, ok := (*node).(*ast.IdentifierNode) + if !ok { + return false + } + return false +} + +func isRecordOp(node *ast.Node) bool { + mn, ok := (*node).(*ast.MemberNode) + if !ok { + return false + } + owner, ok := (mn.Node).(*ast.IdentifierNode) + if !ok { + return false + } + _, ok = (mn.Property).(*ast.StringNode) + if !ok { + return false + } + return owner.Value == "record" +} + +func patchRecordOp(node *ast.Node) { + mn, ok := (*node).(*ast.MemberNode) + if !ok { + return + } + attr, ok := (mn.Property).(*ast.StringNode) + if !ok { + return + } + ast.Patch(node, &ast.CallNode{ + Callee: &ast.IdentifierNode{ + Value: fmt.Sprintf("__%s", attr.Value), + }, + }) +}