From 0adb9f2021df9d01b4ecb5eee9a5d7f69a129a4b Mon Sep 17 00:00:00 2001 From: Vadim Voitenko Date: Sat, 26 Oct 2024 16:25:57 +0300 Subject: [PATCH] feat: implemented expression patcher * users can access simply to the record using record.column_name to get the driver encoded value or raw_record.column_name to get a raw string record --- pkg/toolkit/expr.go | 90 ++++++++++++++++++++++------------------ pkg/toolkit/expt_test.go | 54 ++++++++++++++++++++++++ 2 files changed, 104 insertions(+), 40 deletions(-) create mode 100644 pkg/toolkit/expt_test.go diff --git a/pkg/toolkit/expr.go b/pkg/toolkit/expr.go index 8c356764..8582686f 100644 --- a/pkg/toolkit/expr.go +++ b/pkg/toolkit/expr.go @@ -1,3 +1,6 @@ +// Ane expression handler for the toolkit package. It is used to evaluate the when condition of the record. +// Might be used in transformation conditions and other places where the record is used. + package toolkit import ( @@ -8,12 +11,16 @@ import ( "github.com/expr-lang/expr/vm" ) +// WhenCond - A condition that should be evaluated to determine if the record should be processed. type WhenCond struct { rc *RecordContext whenCond *vm.Program when string } +// NewWhenCond - creates a new WhenCond object. It compiles the when condition and returns the compiled program +// and the record context with the functions for the columns. The functions represent the column names and return the +// column values. If the when condition is empty, the WhenCond object will always return true. func NewWhenCond(when string, driver *Driver) (*WhenCond, ValidationWarnings) { var ( rc *RecordContext @@ -33,6 +40,7 @@ func NewWhenCond(when string, driver *Driver) (*WhenCond, ValidationWarnings) { }, nil } +// Evaluate - evaluates the when condition. If the when condition is empty, it will always return true. func (wc *WhenCond) Evaluate(r *Record) (bool, error) { if wc.whenCond == nil { return true, nil @@ -74,65 +82,50 @@ func compileCond(whenCond string, driver *Driver) (*vm.Program, *RecordContext, return cond, rc, nil } +// newRecordContext creates a new record context and create kind of column descriptors for the record to access the +// column values by the column name. For instance if the column name is "name", the function __name will return +// the value func newRecordContext(driver *Driver) (*RecordContext, []expr.Option) { var funcs []expr.Option rctx := NewRecordContext() for _, c := range driver.Table.Columns { - f := expr.Function( + // create a function that returns the column value by the column name. The returned value is encoded using + // pgx driver + typedFunc := expr.Function( fmt.Sprintf("__%s", c.Name), + func(name string) func(params ...any) (any, error) { + return func(params ...any) (any, error) { + return rctx.GetColumnValue(name) + } + }(c.Name), + ) + funcs = append(funcs, typedFunc) + + rawFunc := expr.Function( + fmt.Sprintf("__raw__%s", c.Name), func(name string) func(params ...any) (any, error) { return func(params ...any) (any, error) { return rctx.GetColumnRawValue(name) } }(c.Name), ) - funcs = append(funcs, f) + funcs = append(funcs, rawFunc) } return rctx, funcs } -func valIsNotNull(params ...any) (any, error) { - return !valueIsNull(params[0]), nil -} - -func valIsNull(params ...any) (any, error) { - vv, ok := params[0].(NullType) - if !ok { - return false, nil - } - return vv == NullValue, nil -} - // exprPatcher - patcher for the expression compiler. It patches the expression tree by some identifiers to // function calls. For instance is null, is not null, records address type exprPatcher struct{} func (exprPatcher) Visit(node *ast.Node) { - switch { - case isNullOp(node): - case isNotNullOp(node): - case isRecordOp(node): + if isRecordOp(node) { patchRecordOp(node) } } -func isNullOp(node *ast.Node) bool { - _, ok := (*node).(*ast.IdentifierNode) - if !ok { - return false - } - return false -} - -func isNotNullOp(node *ast.Node) bool { - _, ok := (*node).(*ast.IdentifierNode) - if !ok { - return false - } - return false -} - +// isRecordOp checks if the node is a record operation func isRecordOp(node *ast.Node) bool { mn, ok := (*node).(*ast.MemberNode) if !ok { @@ -146,21 +139,38 @@ func isRecordOp(node *ast.Node) bool { if !ok { return false } - return owner.Value == "record" + return owner.Value == "record" || owner.Value == "raw_record" } +// patchRecordOp patches the record access operation +// 1. record.id -> __id() function call for decoding the column value into type using pgx driver +// 2. raw_record.id -> __raw_id() function call getting a raw value as a string func patchRecordOp(node *ast.Node) { mn, ok := (*node).(*ast.MemberNode) if !ok { return } + owner, ok := (mn.Node).(*ast.IdentifierNode) + if !ok { + return + } attr, ok := (mn.Property).(*ast.StringNode) if !ok { return } - ast.Patch(node, &ast.CallNode{ - Callee: &ast.IdentifierNode{ - Value: fmt.Sprintf("__%s", attr.Value), - }, - }) + switch owner.Value { + case "record": + ast.Patch(node, &ast.CallNode{ + Callee: &ast.IdentifierNode{ + Value: fmt.Sprintf("__%s", attr.Value), + }, + }) + case "raw_record": + ast.Patch(node, &ast.CallNode{ + Callee: &ast.IdentifierNode{ + Value: fmt.Sprintf("__raw__%s", attr.Value), + }, + }) + } + } diff --git a/pkg/toolkit/expt_test.go b/pkg/toolkit/expt_test.go new file mode 100644 index 00000000..4a456d18 --- /dev/null +++ b/pkg/toolkit/expt_test.go @@ -0,0 +1,54 @@ +package toolkit + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestWhenCond_Evaluate(t *testing.T) { + driver := getDriver() + record := NewRecord(driver) + row := &TestRowDriver{ + row: []string{"1", "2023-08-27 00:00:00.000000", "\\N"}, + } + record.SetRow(row) + + type test struct { + name string + when string + expected bool + } + tests := []test{ + { + name: "int value equal", + when: "record.id == 1", + expected: true, + }, + { + name: "raw int value equal", + when: "raw_record.id == \"1\"", + expected: true, + }, + { + name: "is null value check", + when: "record.title == null", + expected: true, + }, + { + name: "test date cmp", + when: "record.created_at > now()", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + whenCond, warns := NewWhenCond(tt.when, driver) + require.Empty(t, warns) + res, err := whenCond.Evaluate(record) + require.NoError(t, err) + require.Equal(t, tt.expected, res) + }) + } +}