From ddce74a03f54e88d1609e3ff54fc7a51ee35516c Mon Sep 17 00:00:00 2001 From: Brennan Lamey Date: Sun, 13 Oct 2024 18:32:09 -0500 Subject: [PATCH] updated for v0.10 --- core/types/schema.go | 33 +- core/types/transactions/payload_schema.go | 2 +- parse/antlr.go | 3 +- parse/ast.go | 2 + parse/common/datetime.go | 85 --- parse/common/datetime_test.go | 108 ---- parse/common/errors.go | 7 +- parse/common/functions.go | 425 --------------- parse/common/functions_test.go | 446 ---------------- parse/interpreter/README.md | 12 + parse/interpreter/benchmark_test.go | 3 +- parse/interpreter/interpreter_test.go | 9 +- .../{interpreter.go => planner.go} | 411 ++++++++------- parse/interpreter/types.go | 103 ++-- parse/{common => interpreter}/values.go | 498 ++++++++++++++---- 15 files changed, 696 insertions(+), 1451 deletions(-) delete mode 100644 parse/common/datetime.go delete mode 100644 parse/common/datetime_test.go create mode 100644 parse/interpreter/README.md rename parse/interpreter/{interpreter.go => planner.go} (71%) rename parse/{common => interpreter}/values.go (75%) diff --git a/core/types/schema.go b/core/types/schema.go index bb4033c54..7691b542d 100644 --- a/core/types/schema.go +++ b/core/types/schema.go @@ -1099,7 +1099,7 @@ type DataType struct { // IsArray is true if the type is an array. IsArray bool `json:"is_array"` // Metadata is the metadata of the type. - Metadata [2]uint16 `json:"metadata"` + Metadata *[2]uint16 `json:"metadata"` } // String returns the string representation of the type. @@ -1121,8 +1121,6 @@ func (c *DataType) String() string { return str.String() } -var ZeroMetadata = [2]uint16{} - // PGString returns the string representation of the type in Postgres. func (c *DataType) PGString() (string, error) { var scalar string @@ -1140,11 +1138,11 @@ func (c *DataType) PGString() (string, error) { case uint256Str: scalar = "UINT256" case DecimalStr: - if c.Metadata == ZeroMetadata { - return "", fmt.Errorf("decimal type must have metadata") + if c.Metadata == nil { + scalar = "NUMERIC" + } else { + scalar = fmt.Sprintf("NUMERIC(%d,%d)", c.Metadata[0], c.Metadata[1]) } - - scalar = fmt.Sprintf("NUMERIC(%d,%d)", c.Metadata[0], c.Metadata[1]) case nullStr: return "", fmt.Errorf("cannot have null column type") case unknownStr: @@ -1164,14 +1162,15 @@ func (c *DataType) Clean() error { c.Name = strings.ToLower(c.Name) switch c.Name { case intStr, textStr, boolStr, blobStr, uuidStr, uint256Str: // ok - if c.Metadata != ZeroMetadata { + if c.Metadata != nil { return fmt.Errorf("type %s cannot have metadata", c.Name) } return nil case DecimalStr: - if c.Metadata == ZeroMetadata { - return fmt.Errorf("decimal type must have metadata") + if c.Metadata == nil { + // numeric can have unspecified precision and scale + return nil } err := decimal.CheckPrecisionAndScale(c.Metadata[0], c.Metadata[1]) @@ -1185,7 +1184,7 @@ func (c *DataType) Clean() error { return fmt.Errorf("type %s cannot be an array", c.Name) } - if c.Metadata != ZeroMetadata { + if c.Metadata != nil { return fmt.Errorf("type %s cannot have metadata", c.Name) } @@ -1219,10 +1218,10 @@ func (c *DataType) EqualsStrict(other *DataType) bool { return false } - if (c.Metadata == ZeroMetadata) != (other.Metadata == ZeroMetadata) { + if (c.Metadata == nil) != (other.Metadata == nil) { return false } - if c.Metadata != ZeroMetadata { + if c.Metadata != nil { if c.Metadata[0] != other.Metadata[0] || c.Metadata[1] != other.Metadata[1] { return false } @@ -1275,8 +1274,8 @@ var ( // For type detection, users should prefer compare a datatype // name with the DecimalStr constant. DecimalType = &DataType{ - Name: DecimalStr, - Metadata: [2]uint16{1, 0}, // the minimum precision and scale + Name: DecimalStr, + // unspecified precision and scale } DecimalArrayType = ArrayType(DecimalType) Uint256Type = &DataType{ @@ -1332,8 +1331,10 @@ func NewDecimalType(precision, scale uint16) (*DataType, error) { return nil, err } + met := [2]uint16{precision, scale} + return &DataType{ Name: DecimalStr, - Metadata: [2]uint16{precision, scale}, + Metadata: &met, }, nil } diff --git a/core/types/transactions/payload_schema.go b/core/types/transactions/payload_schema.go index 2a199bfc6..685e758bb 100644 --- a/core/types/transactions/payload_schema.go +++ b/core/types/transactions/payload_schema.go @@ -179,7 +179,7 @@ type DataType struct { // IsArray is true if the type is an array. IsArray bool // Metadata is the metadata of the type. - Metadata [2]uint16 `rlp:"optional"` + Metadata *[2]uint16 `rlp:"optional"` } // ForeignProcedure is a foreign procedure call in a database diff --git a/parse/antlr.go b/parse/antlr.go index e6afe41a1..ac923ecfd 100644 --- a/parse/antlr.go +++ b/parse/antlr.go @@ -271,7 +271,8 @@ func (s *schemaVisitor) VisitType(ctx *gen.TypeContext) any { return types.UnknownType } - dt.Metadata = [2]uint16{uint16(prec), uint16(scale)} + met := [2]uint16{uint16(prec), uint16(scale)} + dt.Metadata = &met } if ctx.LBRACKET() != nil { diff --git a/parse/ast.go b/parse/ast.go index 5a3078398..840dbc490 100644 --- a/parse/ast.go +++ b/parse/ast.go @@ -100,6 +100,8 @@ func literalToString(value any) (string, error) { return str.String(), nil } +// TODO: we can remove this interface since in v0.10, we are getting rid of foreign calls +// this means that there will be only 1 implementation of ExpressionCall type ExpressionCall interface { Expression Cast(*types.DataType) diff --git a/parse/common/datetime.go b/parse/common/datetime.go deleted file mode 100644 index 69aefbfac..000000000 --- a/parse/common/datetime.go +++ /dev/null @@ -1,85 +0,0 @@ -package common - -import ( - "fmt" - "strings" - "time" -) - -/* - Kwil's formatting specifiers: - - YYYY: 4-digit year - - YY: 2-digit year - - MM: 2-digit month - - DD: 2-digit day - - HH: 2-digit hour (24-hour clock) - - HH12: 2-digit hour (12-hour clock) - - MI: 2-digit minute - - SS: 2-digit second - - MS: 3-digit millisecond - - US: 6-digit microsecond - - A.M.: AM/PM indicator (upper case) - - a.m.: AM/PM indicator (lower case) - - P.M.: AM/PM indicator (upper case) - - p.m.: AM/PM indicator (lower case) -*/ - -var strftimeReplacer = strings.NewReplacer( - "YYYY", "2006", - "YY", "06", - "MM", "01", - "DD", "02", - "HH12", "03", - "HH", "15", - "MI", "04", - "SS", "05", - "MS", "000", - "US", "000000", - "A.M.", "PM", - "a.m.", "pm", - "P.M.", "PM", - "p.m.", "pm", -) - -// parseTimestamp converts a timestamp to a microsecond Unix timestamp. -func parseTimestamp(format, value string) (int64, error) { - layout := strftimeReplacer.Replace(format) - - isPM := false - if strings.Contains(value, "P.M.") { - value = strings.ReplaceAll(value, "P.M.", "PM") - isPM = true - } - if strings.Contains(value, "p.m.") { - value = strings.ReplaceAll(value, "p.m.", "pm") - isPM = true - } - if strings.Contains(value, "A.M.") { - value = strings.ReplaceAll(value, "A.M.", "AM") - } - if strings.Contains(value, "a.m.") { - value = strings.ReplaceAll(value, "a.m.", "am") - } - - if !isPM && (strings.Contains(value, "PM") || strings.Contains(value, "pm")) { - isPM = true - } - - t, err := time.Parse(layout, value) - if err != nil { - return -1, fmt.Errorf("failed to parse timestamp: %w", err) - } - - if isPM && t.Hour() < 12 { - t = t.Add(time.Hour * 12) - } - - return t.UnixMicro(), nil -} - -// formatUnixMicro converts a Unix timestamp in microseconds to a formatted string -func formatUnixMicro(unixMicro int64, format string) string { - t := time.UnixMicro(unixMicro) - layout := strftimeReplacer.Replace(format) - return t.UTC().Format(layout) -} diff --git a/parse/common/datetime_test.go b/parse/common/datetime_test.go deleted file mode 100644 index 5b4d463f7..000000000 --- a/parse/common/datetime_test.go +++ /dev/null @@ -1,108 +0,0 @@ -package common - -import ( - "testing" - "time" -) - -func TestParseTimestamp(t *testing.T) { - tests := []struct { - name string - format string - value string - expected int64 - }{ - { - name: "Basic date", - format: "YYYY-MM-DD", - value: "2023-05-15", - expected: time.Date(2023, 5, 15, 0, 0, 0, 0, time.UTC).UnixMicro(), - }, - { - name: "Date and time", - format: "YYYY-MM-DD HH:MI:SS", - value: "2023-05-15 14:30:45", - expected: time.Date(2023, 5, 15, 14, 30, 45, 0, time.UTC).UnixMicro(), - }, - { - name: "Date and time with microseconds", - format: "YYYY-MM-DD HH:MI:SS.US", - value: "2023-05-15 14:30:45.123456", - expected: time.Date(2023, 5, 15, 14, 30, 45, 123456000, time.UTC).UnixMicro(), - }, - { - name: "12-hour clock PM", - format: "YYYY-MM-DD HH12:MI:SS A.M.", - value: "2023-05-15 02:30:45 P.M.", - expected: time.Date(2023, 5, 15, 14, 30, 45, 0, time.UTC).UnixMicro(), - }, - { - name: "12-hour clock AM", - format: "YYYY-MM-DD HH12:MI:SS a.m.", - value: "2023-05-15 10:30:45 a.m.", - expected: time.Date(2023, 5, 15, 10, 30, 45, 0, time.UTC).UnixMicro(), - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := parseTimestamp(tt.format, tt.value) - if err != nil { - t.Errorf("parseTimestamp() error = %v", err) - return - } - if result != tt.expected { - t.Errorf("parseTimestamp() = %v, want %v", result, tt.expected) - } - }) - } -} - -func TestFormatUnixMicro(t *testing.T) { - tests := []struct { - name string - unixMicro int64 - format string - expected string - }{ - { - name: "Basic date", - unixMicro: time.Date(2023, 5, 15, 0, 0, 0, 0, time.UTC).UnixMicro(), - format: "YYYY-MM-DD", - expected: "2023-05-15", - }, - { - name: "Date and time", - unixMicro: time.Date(2023, 5, 15, 14, 30, 45, 0, time.UTC).UnixMicro(), - format: "YYYY-MM-DD HH:MI:SS", - expected: "2023-05-15 14:30:45", - }, - { - name: "Date and time with microseconds", - unixMicro: time.Date(2023, 5, 15, 14, 30, 45, 123456000, time.UTC).UnixMicro(), - format: "YYYY-MM-DD HH:MI:SS.US", - expected: "2023-05-15 14:30:45.123456", - }, - { - name: "12-hour clock PM", - unixMicro: time.Date(2023, 5, 15, 14, 30, 45, 0, time.UTC).UnixMicro(), - format: "YYYY-MM-DD HH12:MI:SS P.M.", - expected: "2023-05-15 02:30:45 PM", - }, - { - name: "12-hour clock AM", - unixMicro: time.Date(2023, 5, 15, 10, 30, 45, 0, time.UTC).UnixMicro(), - format: "YYYY-MM-DD HH12:MI:SS a.m.", - expected: "2023-05-15 10:30:45 am", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := formatUnixMicro(tt.unixMicro, tt.format) - if result != tt.expected { - t.Errorf("formatUnixMicro() = %v, want %v", result, tt.expected) - } - }) - } -} diff --git a/parse/common/errors.go b/parse/common/errors.go index e97a63a7e..930047fd9 100644 --- a/parse/common/errors.go +++ b/parse/common/errors.go @@ -3,9 +3,6 @@ package common import "errors" var ( - ErrTypeMismatch = errors.New("type mismatch") - ErrNotArray = errors.New("not an array") - ErrArithmeticOnArray = errors.New("cannot perform arithmetic operation on array") - ErrIndexOutOfBounds = errors.New("index out of bounds") - ErrNegativeSubstringLength = errors.New("negative substring length not allowed") + ErrTypeMismatch = errors.New("type mismatch") + ErrIndexOutOfBounds = errors.New("index out of bounds") ) diff --git a/parse/common/functions.go b/parse/common/functions.go index 98a9b15d6..5c23f7a28 100644 --- a/parse/common/functions.go +++ b/parse/common/functions.go @@ -1,18 +1,10 @@ package common import ( - "crypto" - "crypto/md5" - "crypto/sha1" - "encoding/base64" - "encoding/hex" "fmt" "strings" - "unicode/utf8" "github.com/kwilteam/kwil-db/core/types" - "github.com/kwilteam/kwil-db/core/types/decimal" - "github.com/kwilteam/kwil-db/core/utils" ) var ( @@ -30,27 +22,6 @@ var ( return args[0], nil }, PGFormatFunc: defaultFormat("abs"), - EvaluateFunc: func(_ Interpreter, args []Value) (Value, error) { - switch arg := args[0].(type) { - case *IntValue: - if arg.Val < 0 { - return &IntValue{Val: -arg.Val}, nil - } - return arg, nil - case *DecimalValue: - if arg.Dec.Sign() < 0 { - arg2 := arg.Dec.Copy() - err := arg2.Neg() - if err != nil { - return nil, err - } - return &DecimalValue{Dec: arg2}, nil - } - return arg, nil - } - - return nil, fmt.Errorf("unexpected type %T in abs", args[0]) - }, }, "error": &ScalarFunctionDefinition{ ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) { @@ -67,18 +38,6 @@ var ( return types.NullType, nil }, PGFormatFunc: defaultFormat("error"), - EvaluateFunc: func(_ Interpreter, args []Value) (Value, error) { - if len(args) != 1 { - return nil, fmt.Errorf("error function expects 1 argument, got %d", len(args)) - } - - text, ok := args[0].(*TextValue) - if !ok { - return nil, fmt.Errorf("error function expects a text argument, got %T", args[0]) - } - - return nil, fmt.Errorf("%s", text.Val) - }, }, "parse_unix_timestamp": &ScalarFunctionDefinition{ ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) { @@ -98,31 +57,6 @@ var ( return decimal16_6, nil }, PGFormatFunc: defaultFormat("parse_unix_timestamp"), - EvaluateFunc: func(_ Interpreter, args []Value) (Value, error) { - // Kwil's parseTimestamp takes a timestamp and a format string - // The first arg is the timestamp, the second arg is the format string - res, err := parseTimestamp(args[1].Value().(string), args[0].Value().(string)) - if err != nil { - return nil, err - } - - // we now need to convert the unix timestamp to a decimal(16, 6) - // We start with 22,6 since the current int64 is in microseconds (16 digits). - // We make this a decimal(22, 6), and then divide by 10^6 to get a decimal(16, 6) - dec16, err := decimal.NewExplicit(fmt.Sprintf("%d", res), 22, 6) - if err != nil { - return nil, err - } - - dec16, err = dec16.Div(dec16, dec10ToThe6th) - if err != nil { - return nil, err - } - - err = dec16.SetPrecisionAndScale(16, 6) - - return &DecimalValue{Dec: dec16}, err - }, }, "format_unix_timestamp": &ScalarFunctionDefinition{ ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) { @@ -142,29 +76,6 @@ var ( return types.TextType, nil }, PGFormatFunc: defaultFormat("format_unix_timestamp"), - EvaluateFunc: func(spender Interpreter, args []Value) (Value, error) { - // the inverse of parse_unix_timestamp, we need to convert a decimal(16, 6) to a unix timestamp - // by multiplying by 10^6 and converting to an int64 - dec := args[0].(*DecimalValue).Dec - - err := dec.SetPrecisionAndScale(22, 6) - if err != nil { - return nil, err - } - - dec, err = dec.Mul(dec, dec10ToThe6th) - if err != nil { - return nil, err - } - - i64Microseconds, err := dec.Int64() - if err != nil { - return nil, err - } - - ts := formatUnixMicro(i64Microseconds, args[1].Value().(string)) - return &TextValue{Val: ts}, nil - }, }, "notice": &ScalarFunctionDefinition{ ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) { @@ -186,12 +97,6 @@ var ( // v0.9 changes, so leaving it here for now. return fmt.Sprintf("notice('txid:' || current_setting('ctx.txid') || ' ' || %s)", inputs[0]), nil }, - EvaluateFunc: func(i Interpreter, args []Value) (Value, error) { - i.Notice(args[0].Value().(string)) - return &NullValue{ - DataType: types.NullType.Copy(), - }, nil - }, }, "uuid_generate_v5": &ScalarFunctionDefinition{ ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) { @@ -211,11 +116,6 @@ var ( return types.UUIDType, nil }, PGFormatFunc: defaultFormat("uuid_generate_v5"), - EvaluateFunc: func(interp Interpreter, args []Value) (Value, error) { - // uuidv5 uses sha1 to hash the text input - u := types.NewUUIDV5WithNamespace(types.UUID(args[0].(*UUIDValue).Val), []byte(args[1].(*TextValue).Val)) - return &UUIDValue{Val: u}, nil - }, }, "encode": &ScalarFunctionDefinition{ ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) { @@ -235,20 +135,6 @@ var ( return types.TextType, nil }, PGFormatFunc: defaultFormat("encode"), - EvaluateFunc: func(interp Interpreter, args []Value) (Value, error) { - // postgres supports hex, base64, and escape. - // we won't support escape. - switch args[1].(*TextValue).Val { - case "hex": - return &TextValue{Val: hex.EncodeToString(args[0].Value().([]byte))}, nil - case "base64": - return &TextValue{Val: base64.StdEncoding.EncodeToString(args[0].Value().([]byte))}, nil - case "escape": - return nil, fmt.Errorf("procedures do not support escape encoding") - default: - return nil, fmt.Errorf("unknown encoding: %s", args[1].Value()) - } - }, }, "decode": &ScalarFunctionDefinition{ ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) { @@ -268,28 +154,6 @@ var ( return types.BlobType, nil }, PGFormatFunc: defaultFormat("decode"), - EvaluateFunc: func(interp Interpreter, args []Value) (Value, error) { - // postgres supports hex and base64. - // we won't support escape. - switch args[1].(*TextValue).Val { - case "hex": - b, err := hex.DecodeString(args[0].Value().(string)) - if err != nil { - return nil, err - } - return &BlobValue{Val: b}, nil - case "base64": - b, err := base64.StdEncoding.DecodeString(args[0].Value().(string)) - if err != nil { - return nil, err - } - return &BlobValue{Val: b}, nil - case "escape": - return nil, fmt.Errorf("procedures do not support escape encoding") - default: - return nil, fmt.Errorf("unknown encoding: %s", args[1].Value()) - } - }, }, "digest": &ScalarFunctionDefinition{ ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) { @@ -309,25 +173,6 @@ var ( return types.BlobType, nil }, PGFormatFunc: defaultFormat("digest"), - EvaluateFunc: func(interp Interpreter, args []Value) (Value, error) { - // supports md5, sha1, sha224, sha256, sha384 and sha512 - switch args[1].(*TextValue).Val { - case "md5": - return &BlobValue{Val: md5.New().Sum([]byte(args[0].Value().(string)))}, nil - case "sha1": - return &BlobValue{Val: sha1.New().Sum([]byte(args[0].Value().(string)))}, nil - case "sha224": - return &BlobValue{Val: crypto.SHA224.New().Sum([]byte(args[0].Value().(string)))}, nil - case "sha256": - return &BlobValue{Val: crypto.SHA256.New().Sum([]byte(args[0].Value().(string)))}, nil - case "sha384": - return &BlobValue{Val: crypto.SHA384.New().Sum([]byte(args[0].Value().(string)))}, nil - case "sha512": - return &BlobValue{Val: crypto.SHA512.New().Sum([]byte(args[0].Value().(string)))}, nil - default: - return nil, fmt.Errorf("unknown digest: %s", args[1].Value()) - } - }, }, "generate_dbid": &ScalarFunctionDefinition{ ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) { @@ -349,9 +194,6 @@ var ( PGFormatFunc: func(inputs []string) (string, error) { return fmt.Sprintf(`(select 'x' || encode(sha224(lower(%s)::bytea || %s), 'hex'))`, inputs[0], inputs[1]), nil }, - EvaluateFunc: func(interp Interpreter, args []Value) (Value, error) { - return &TextValue{Val: utils.GenerateDBID(args[0].Value().(string), args[1].Value().([]byte))}, nil - }, }, // array functions "array_append": &ScalarFunctionDefinition{ @@ -375,12 +217,6 @@ var ( return args[0], nil }, PGFormatFunc: defaultFormat("array_append"), - EvaluateFunc: func(interp Interpreter, args []Value) (Value, error) { - arr := args[0].(ArrayValue) - // all Kuneiform arrays are 1-indexed - err := arr.Set(int64(arr.Len()+1), args[1].(ScalarValue)) - return arr, err - }, }, "array_prepend": &ScalarFunctionDefinition{ ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) { @@ -403,22 +239,6 @@ var ( return args[1], nil }, PGFormatFunc: defaultFormat("array_prepend"), - EvaluateFunc: func(interp Interpreter, args []Value) (Value, error) { - scal := args[0].(ScalarValue) - arr := args[1].(ArrayValue) - - var scalars []ScalarValue - // 1-indexed - for i := 1; i <= arr.Len(); i++ { - newScal, err := arr.Index(int64(i)) - if err != nil { - return nil, err - } - scalars = append(scalars, newScal) - } - - return scal.Array(scalars...) - }, }, "array_cat": &ScalarFunctionDefinition{ ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) { @@ -441,24 +261,6 @@ var ( return args[0], nil }, PGFormatFunc: defaultFormat("array_cat"), - EvaluateFunc: func(interp Interpreter, args []Value) (Value, error) { - arr1 := args[0].(ArrayValue) - arr2 := args[1].(ArrayValue) - - startIdx := arr1.Len() - for i := 1; i <= arr2.Len(); i++ { - newScal, err := arr2.Index(int64(i)) - if err != nil { - return nil, err - } - err = arr1.Set(int64(startIdx+i), newScal) - if err != nil { - return nil, err - } - } - - return arr1, nil - }, }, "array_length": &ScalarFunctionDefinition{ ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) { @@ -475,10 +277,6 @@ var ( PGFormatFunc: func(inputs []string) (string, error) { return fmt.Sprintf("array_length(%s, 1)", inputs[0]), nil }, - EvaluateFunc: func(interp Interpreter, args []Value) (Value, error) { - arr := args[0].(ArrayValue) - return &IntValue{Val: int64(arr.Len())}, nil - }, }, // string functions // the main SQL string functions defined here: https://www.postgresql.org/docs/16.1/functions-string.html @@ -495,10 +293,6 @@ var ( return types.IntType, nil }, PGFormatFunc: defaultFormat("bit_length"), - EvaluateFunc: func(_ Interpreter, args []Value) (Value, error) { - text := args[0].(*TextValue).Val - return &IntValue{Val: int64(len(text) * 8)}, nil - }, }, "char_length": &ScalarFunctionDefinition{ ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) { @@ -513,10 +307,6 @@ var ( return types.IntType, nil }, PGFormatFunc: defaultFormat("char_length"), - EvaluateFunc: func(_ Interpreter, args []Value) (Value, error) { - text := args[0].(*TextValue).Val - return &IntValue{Val: int64(utf8.RuneCountInString(text))}, nil - }, }, "character_length": &ScalarFunctionDefinition{ ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) { @@ -531,10 +321,6 @@ var ( return types.IntType, nil }, PGFormatFunc: defaultFormat("character_length"), - EvaluateFunc: func(_ Interpreter, args []Value) (Value, error) { - text := args[0].(*TextValue).Val - return &IntValue{Val: int64(utf8.RuneCountInString(text))}, nil - }, }, "length": &ScalarFunctionDefinition{ ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) { @@ -549,10 +335,6 @@ var ( return types.IntType, nil }, PGFormatFunc: defaultFormat("length"), - EvaluateFunc: func(_ Interpreter, args []Value) (Value, error) { - text := args[0].(*TextValue).Val - return &IntValue{Val: int64(utf8.RuneCountInString(text))}, nil - }, }, "lower": &ScalarFunctionDefinition{ ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) { @@ -567,10 +349,6 @@ var ( return types.TextType, nil }, PGFormatFunc: defaultFormat("lower"), - EvaluateFunc: func(_ Interpreter, args []Value) (Value, error) { - text := args[0].(*TextValue).Val - return &TextValue{Val: strings.ToLower(text)}, nil - }, }, "lpad": &ScalarFunctionDefinition{ ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) { @@ -594,16 +372,6 @@ var ( return types.TextType, nil }, PGFormatFunc: defaultFormat("lpad"), - EvaluateFunc: func(_ Interpreter, args []Value) (Value, error) { - text := args[0].(*TextValue).Val - length := args[1].(*IntValue).Val - padStr := " " - if len(args) == 3 { - padStr = args[2].(*TextValue).Val - } - - return &TextValue{Val: pad(text, int(length), padStr, true)}, nil - }, }, "ltrim": &ScalarFunctionDefinition{ ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) { @@ -621,14 +389,6 @@ var ( return types.TextType, nil }, PGFormatFunc: defaultFormat("ltrim"), - EvaluateFunc: func(_ Interpreter, args []Value) (Value, error) { - text := args[0].(*TextValue).Val - chars := " " - if len(args) == 2 { - chars = args[1].(*TextValue).Val - } - return &TextValue{Val: strings.TrimLeft(text, chars)}, nil - }, }, "octet_length": &ScalarFunctionDefinition{ ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) { @@ -643,10 +403,6 @@ var ( return types.IntType, nil }, PGFormatFunc: defaultFormat("octet_length"), - EvaluateFunc: func(_ Interpreter, args []Value) (Value, error) { - text := args[0].(*TextValue).Val - return &IntValue{Val: int64(len(text))}, nil - }, }, "overlay": &ScalarFunctionDefinition{ ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) { @@ -689,22 +445,6 @@ var ( return str.String(), nil }, - EvaluateFunc: func(_ Interpreter, args []Value) (Value, error) { - input := args[0].(*TextValue).Val - replace := args[1].(*TextValue).Val - start := args[2].(*IntValue).Val - - if start < 0 { - return nil, ErrNegativeSubstringLength - } - - length := int64(len(replace)) - if len(args) == 4 { - length = args[3].(*IntValue).Val - } - - return &TextValue{Val: overlay(input, replace, int(start), int(length))}, nil - }, }, "position": &ScalarFunctionDefinition{ ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) { @@ -724,21 +464,6 @@ var ( PGFormatFunc: func(inputs []string) (string, error) { return fmt.Sprintf("position(%s in %s)", inputs[0], inputs[1]), nil }, - EvaluateFunc: func(interp Interpreter, args []Value) (Value, error) { - substr := args[0].(*TextValue).Val - str := args[1].(*TextValue).Val - - pos := strings.Index(str, substr) - - var res int64 - if pos == -1 { - res = 0 - } else { - res = int64(utf8.RuneCountInString(str[:pos])) + 1 - } - - return &IntValue{Val: res}, nil - }, }, "rpad": &ScalarFunctionDefinition{ ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) { @@ -762,16 +487,6 @@ var ( return types.TextType, nil }, PGFormatFunc: defaultFormat("rpad"), - EvaluateFunc: func(_ Interpreter, args []Value) (Value, error) { - text := args[0].(*TextValue).Val - length := args[1].(*IntValue).Val - padStr := " " - if len(args) == 3 { - padStr = args[2].(*TextValue).Val - } - - return &TextValue{Val: pad(text, int(length), padStr, false)}, nil - }, }, "rtrim": &ScalarFunctionDefinition{ ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) { @@ -789,14 +504,6 @@ var ( return types.TextType, nil }, PGFormatFunc: defaultFormat("rtrim"), - EvaluateFunc: func(_ Interpreter, args []Value) (Value, error) { - text := args[0].(*TextValue).Val - chars := " " - if len(args) == 2 { - chars = args[1].(*TextValue).Val - } - return &TextValue{Val: strings.TrimRight(text, chars)}, nil - }, }, "substring": &ScalarFunctionDefinition{ ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) { @@ -837,46 +544,6 @@ var ( return str.String(), nil }, - EvaluateFunc: func(_ Interpreter, args []Value) (v Value, err error) { - defer func() { - if r := recover(); r != nil { - err = fmt.Errorf("panic: %v", r) - } - }() - - text := args[0].(*TextValue).Val - start := args[1].(*IntValue).Val - - if start > int64(len(text)) { - // not sure why Postgres does this, but it does. - return &TextValue{Val: ""}, nil - } - - length := int64(len(text)) - - if len(args) == 3 { - length = args[2].(*IntValue).Val - } - - if length < 0 { - return nil, ErrNegativeSubstringLength - } - - runes := []rune(text) - if start < 1 { - // if start is negative, then we subtract the difference from 1 - // from the length. I don't know why Postgres does this, but it does. - length -= 1 - start - start = 1 - } - if length < 0 { - // if length is negative, then we set it to 0. - // Not sure why Postgres does this, but it does. - length = 0 - } - end := min(int64(len(runes)), start-1+length) - return &TextValue{Val: string(runes[start-1 : end])}, nil - }, }, "trim": &ScalarFunctionDefinition{ ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) { @@ -894,14 +561,6 @@ var ( return types.TextType, nil }, PGFormatFunc: defaultFormat("trim"), - EvaluateFunc: func(_ Interpreter, args []Value) (Value, error) { - text := args[0].(*TextValue).Val - chars := " " - if len(args) == 2 { - chars = args[1].(*TextValue).Val - } - return &TextValue{Val: strings.Trim(text, chars)}, nil - }, }, "upper": &ScalarFunctionDefinition{ ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) { @@ -916,10 +575,6 @@ var ( return types.TextType, nil }, PGFormatFunc: defaultFormat("upper"), - EvaluateFunc: func(_ Interpreter, args []Value) (Value, error) { - text := args[0].(*TextValue).Val - return &TextValue{Val: strings.ToUpper(text)}, nil - }, }, "format": &ScalarFunctionDefinition{ ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) { @@ -934,16 +589,6 @@ var ( return types.TextType, nil }, PGFormatFunc: defaultFormat("format"), - EvaluateFunc: func(_ Interpreter, args []Value) (Value, error) { - format := args[0].(*TextValue).Val - - values := []any{} - for _, arg := range args[1:] { - values = append(values, arg.Value()) - } - - return &TextValue{Val: positionalSprintf(format, values...)}, nil - }, }, // Aggregate functions "count": &AggregateFunctionDefinition{ @@ -1054,68 +699,6 @@ var ( } ) -// pad pads either side of a string. The side can be specified with the side parameter (left is true, right is false) -func pad(input string, length int, padStr string, side bool) string { - inputLength := len(input) - if inputLength >= length { - return input[:length] // Truncate if the input string is longer than the desired length - } - - padLength := len(padStr) - if padLength == 0 { - return input // If padStr is empty, return the input as is - } - - // Calculate the number of times the padStr needs to be repeated - repeatCount := (length - inputLength) / padLength - remainder := (length - inputLength) % padLength - - // Build the left padding - p := strings.Repeat(padStr, repeatCount) + padStr[:remainder] - - if side { - return p + input - } - return input + p -} - -// overlay function mimics the behavior of the PostgreSQL overlay function -func overlay(input, replace string, start, forInt int) string { - if start < 1 { - start = 1 - } - - // Convert start and length to rune-based indices - startIndex := start - 1 - endIndex := startIndex + forInt - - // Get the slice indices in bytes - inputRunes := []rune(input) - replaceRunes := []rune(replace) - - // Adjust indices if they go beyond the string length - if startIndex > len(inputRunes) { - startIndex = len(inputRunes) - } - if endIndex > len(inputRunes) { - endIndex = len(inputRunes) - } - - // Replace the specified section of the string with the replacement string - resultRunes := append(inputRunes[:startIndex], append(replaceRunes, inputRunes[endIndex:]...)...) - return string(resultRunes) -} - -// positionalSprintf is a version of fmt.Sprintf that supports positional arguments. -// It mimics Postgres's "format" -func positionalSprintf(format string, args ...interface{}) string { - for i, arg := range args { - placeholder := fmt.Sprintf("%%%d$s", i+1) - format = strings.ReplaceAll(format, placeholder, fmt.Sprintf("%v", arg)) - } - return fmt.Sprintf(format, args...) -} - // defaultFormat is the default PGFormat function for functions that do not have a custom one. func defaultFormat(name string) func(inputs []string) (string, error) { return func(inputs []string) (string, error) { @@ -1130,8 +713,6 @@ var ( // it is used to represent UNIX timestamps, allowing microsecond precision. // see internal/sql/pg/sql.go/sqlCreateParseUnixTimestampFunc for more info decimal16_6 *types.DataType - // dec10ToThe6th is 10^6 - dec10ToThe6th *decimal.Decimal ) func init() { @@ -1145,11 +726,6 @@ func init() { if err != nil { panic(fmt.Sprintf("failed to create decimal type: 16, 6: %v", err)) } - - dec10ToThe6th, err = decimal.NewFromString("1000000") - if err != nil { - panic(fmt.Sprintf("failed to create decimal type: 10^6: %v", err)) - } } // FunctionDefinition if a definition of a function. @@ -1166,7 +742,6 @@ type FunctionDefinition interface { type ScalarFunctionDefinition struct { ValidateArgsFunc func(args []*types.DataType) (*types.DataType, error) PGFormatFunc func(inputs []string) (string, error) - EvaluateFunc func(interp Interpreter, args []Value) (Value, error) } func (s *ScalarFunctionDefinition) ValidateArgs(args []*types.DataType) (*types.DataType, error) { diff --git a/parse/common/functions_test.go b/parse/common/functions_test.go index e38e0beaf..5015e0ec4 100644 --- a/parse/common/functions_test.go +++ b/parse/common/functions_test.go @@ -3,10 +3,7 @@ package common_test import ( "testing" - "github.com/kwilteam/kwil-db/core/types" - "github.com/kwilteam/kwil-db/core/types/decimal" "github.com/kwilteam/kwil-db/parse/common" - "github.com/stretchr/testify/require" ) // tests that we have implemented all functions @@ -14,9 +11,6 @@ func Test_AllFunctionsImplemented(t *testing.T) { for name, fn := range common.Functions { scalar, ok := fn.(*common.ScalarFunctionDefinition) if ok { - if scalar.EvaluateFunc == nil { - t.Errorf("function %s has no EvaluateFunc", name) - } if scalar.PGFormatFunc == nil { t.Errorf("function %s has no PGFormatFunc", name) } @@ -37,443 +31,3 @@ func Test_AllFunctionsImplemented(t *testing.T) { } } } - -func Test_ScalarFunctions(t *testing.T) { - type testcase struct { - name string - functionName string - input []any // will be converted to []Value - expected any // result of Value.Value() - err error - } - - tests := []testcase{ - { - name: "parse_unix_timestamp", - functionName: "parse_unix_timestamp", - input: []any{"2023-05-15 14:30:45", "YYYY-MM-DD HH:MI:SS"}, - expected: mustDecimal("1684161045.000000"), - }, - { - name: "format_unix_timestamp", - functionName: "format_unix_timestamp", - input: []any{mustDecimal("1684161045.000000"), "YYYY-MM-DD HH:MI:SS"}, - expected: "2023-05-15 14:30:45", - }, - { - // checking that this matches Postgres's. - // select uuid_generate_v5('9ed26752-08bc-44c5-83ef-3c6df734c4e7'::uuid, 'a'); - // yields 32d43be3-8591-5849-946b-6f5268aef4ae - name: "uuid_v5", - functionName: "uuid_generate_v5", - input: []any{mustUUID("9ed26752-08bc-44c5-83ef-3c6df734c4e7"), "a"}, - expected: mustUUID("32d43be3-8591-5849-946b-6f5268aef4ae"), - }, - { - name: "array_append", - functionName: "array_append", - input: []any{[]int64{1, 2, 3}, int64(4)}, - expected: []*int64{intRef(1), intRef(2), intRef(3), intRef(4)}, - }, - { - name: "array_prepend", - functionName: "array_prepend", - input: []any{int64(1), []int64{2, 3, 4}}, - expected: []*int64{intRef(1), intRef(2), intRef(3), intRef(4)}, - }, - { - name: "array_cat", - functionName: "array_cat", - input: []any{[]int64{1, 2, 3}, []int64{4, 5, 6}}, - expected: []*int64{intRef(1), intRef(2), intRef(3), intRef(4), intRef(5), intRef(6)}, - }, - { - name: "array_length", - functionName: "array_length", - input: []any{[]int64{1, 2, 3}}, - expected: int64(3), - }, - { - name: "bit_length", - functionName: "bit_length", - input: []any{"hello"}, - expected: int64(40), // 5 characters * 8 bits - }, - { - name: "char_length", - functionName: "char_length", - input: []any{"hello世界"}, - expected: int64(7), - }, - { - name: "character_length", - functionName: "character_length", - input: []any{"hello世界"}, - expected: int64(7), - }, - { - name: "length", - functionName: "length", - input: []any{"hello世界"}, - expected: int64(7), - }, - { - name: "lower", - functionName: "lower", - input: []any{"HeLLo"}, - expected: "hello", - }, - { - name: "lpad", - functionName: "lpad", - input: []any{"hi", int64(5), "xy"}, - expected: "xyxhi", - }, - { - name: "lpad_with_space", - functionName: "lpad", - input: []any{"hi", int64(4)}, - expected: " hi", - }, - { - name: "lpad_longer_input", - functionName: "lpad", - input: []any{"hello", int64(4), "xy"}, - expected: "hell", - }, - { - name: "lpad_empty_pad_string", - functionName: "lpad", - input: []any{"hi", int64(5), ""}, - expected: "hi", - }, - { - name: "lpad_single_char_pad", - functionName: "lpad", - input: []any{"hi", int64(5), "x"}, - expected: "xxxhi", - }, - { - name: "ltrim", - functionName: "ltrim", - input: []any{" hello "}, - expected: "hello ", - }, - { - name: "ltrim_with_chars", - functionName: "ltrim", - input: []any{"xxhelloxx", "x"}, - expected: "helloxx", - }, - { - name: "octet_length", - functionName: "octet_length", - input: []any{"hello世界"}, - expected: int64(11), - }, - { - name: "rpad", - functionName: "rpad", - input: []any{"hi", int64(5), "xy"}, - expected: "hixyx", - }, - { - name: "rpad_with_space", - functionName: "rpad", - input: []any{"hi", int64(4)}, - expected: "hi ", - }, - { - name: "rtrim", - functionName: "rtrim", - input: []any{" hello "}, - expected: " hello", - }, - { - name: "rtrim_with_chars", - functionName: "rtrim", - input: []any{"xxhelloxx", "x"}, - expected: "xxhello", - }, - { - name: "substring", - functionName: "substring", - input: []any{"hello", int64(2), int64(3)}, - expected: "ell", - }, - { - name: "substring_without_length", - functionName: "substring", - input: []any{"hello", int64(2)}, - expected: "ello", - }, - { - name: "substring_full_string", - functionName: "substring", - input: []any{"hello", int64(1), int64(5)}, - expected: "hello", - }, - { - name: "substring_beyond_end", - functionName: "substring", - input: []any{"hello", int64(2), int64(10)}, - expected: "ello", - }, - { - name: "substring_zero_length", - functionName: "substring", - input: []any{"hello", int64(2), int64(0)}, - expected: "", - }, - { - name: "substring_negative_start", - functionName: "substring", - input: []any{"hello", int64(-3), int64(2)}, - expected: "", - }, - { - name: "substring_negative_length", - functionName: "substring", - input: []any{"hello1", int64(2), int64(-1)}, - err: common.ErrNegativeSubstringLength, - }, - { - name: "substring_start_beyond_end", - functionName: "substring", - input: []any{"hello", int64(10), int64(2)}, - expected: "", - }, - { - name: "substring_unicode", - functionName: "substring", - input: []any{"hello世界", int64(6), int64(2)}, - expected: "世界", - }, - { - name: "substring_unicode_partial", - functionName: "substring", - input: []any{"hello世界", int64(7), int64(1)}, - expected: "界", - }, - { - name: "trim", - functionName: "trim", - input: []any{" hello "}, - expected: "hello", - }, - { - name: "trim_with_chars", - functionName: "trim", - input: []any{"xxhelloxx", "x"}, - expected: "hello", - }, - { - name: "upper", - functionName: "upper", - input: []any{"HeLLo"}, - expected: "HELLO", - }, - { - name: "format", - functionName: "format", - input: []any{"Hello %s, %1$s", "World"}, - expected: "Hello World, World", - }, - // Overlay tests - { - name: "overlay_basic", - functionName: "overlay", - input: []any{"Txxxxas", "hom", int64(2), int64(4)}, - expected: "Thomas", - }, - { - name: "overlay_without_length", - functionName: "overlay", - input: []any{"Txxxxas", "hom", int64(2)}, - expected: "Thomxas", - }, - { - name: "overlay_beyond_end", - functionName: "overlay", - input: []any{"Hello", "world", int64(6)}, - expected: "Helloworld", - }, - { - name: "overlay_at_start", - functionName: "overlay", - input: []any{"ello", "H", int64(1)}, - expected: "Hllo", - }, - { - name: "overlay_empty_string", - functionName: "overlay", - input: []any{"Hello", "", int64(2), int64(2)}, - expected: "Hlo", - }, - { - name: "overlay_entire_string", - functionName: "overlay", - input: []any{"Hello", "World", int64(1), int64(5)}, - expected: "World", - }, - { - name: "overlay_zero_length", - functionName: "overlay", - input: []any{"Hello", "x", int64(3), int64(0)}, - expected: "Hexllo", - }, - { - name: "overlay_negative_start", - functionName: "overlay", - input: []any{"Hello", "x", int64(-1), int64(2)}, - err: common.ErrNegativeSubstringLength, - }, - { - name: "overlay_negative_length", - functionName: "overlay", - input: []any{"Hello", "x", int64(2), int64(-1)}, - expected: "HxHello", - }, - { - name: "overlay_long_replacement", - functionName: "overlay", - input: []any{"Hello", "Beautiful World", int64(2), int64(4)}, - expected: "HBeautiful World", - }, - { - name: "overlay_unicode_mixed", - functionName: "overlay", - input: []any{"Hello 世界!", "こんにちは", int64(7), int64(2)}, - expected: "Hello こんにちは!", - }, - // Position tests - { - name: "position_basic", - functionName: "position", - input: []any{"hi", "hello world"}, - expected: int64(0), // PostgreSQL returns 0 when substring is not found - }, - { - name: "position_found", - functionName: "position", - input: []any{"world", "hello world"}, - expected: int64(7), - }, - { - name: "position_at_start", - functionName: "position", - input: []any{"hello", "hello world"}, - expected: int64(1), - }, - { - name: "position_at_end", - functionName: "position", - input: []any{"world", "hello world"}, - expected: int64(7), - }, - { - name: "position_empty_substring", - functionName: "position", - input: []any{"", "hello world"}, - expected: int64(1), // PostgreSQL returns 1 for empty substring - }, - { - name: "position_empty_string", - functionName: "position", - input: []any{"hello", ""}, - expected: int64(0), - }, - { - name: "position_both_empty", - functionName: "position", - input: []any{"", ""}, - expected: int64(1), - }, - { - name: "position_case_sensitive", - functionName: "position", - input: []any{"WORLD", "hello world"}, - expected: int64(0), - }, - { - name: "position_multiple_occurrences", - functionName: "position", - input: []any{"o", "hello world"}, - expected: int64(5), // Returns the first occurrence - }, - { - name: "position_unicode", - functionName: "position", - input: []any{"世界", "你好世界"}, - expected: int64(3), - }, - { - name: "position_substring_longer", - functionName: "position", - input: []any{"hello world plus", "hello world"}, - expected: int64(0), - }, - { - name: "position_special_chars", - functionName: "position", - input: []any{"lo_", "hello_world"}, - expected: int64(4), - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - vars := make([]common.Value, len(tt.input)) - - var err error - for i, in := range tt.input { - vars[i], err = common.NewVariable(in) - require.NoError(t, err) - } - - fn, ok := common.Functions[tt.functionName] - require.True(t, ok) - - scalar, ok := fn.(*common.ScalarFunctionDefinition) - require.True(t, ok) - - res, err := scalar.EvaluateFunc(&mockInterpreter{}, vars) - if tt.err != nil { - require.Error(t, err) - require.ErrorIs(t, err, tt.err) - } else { - require.NoError(t, err) - require.EqualValues(t, tt.expected, res.Value()) - } - }) - } -} - -func mustDecimal(s string) *decimal.Decimal { - d, err := decimal.NewFromString(s) - if err != nil { - panic(err) - } - return d -} - -func mustUUID(s string) *types.UUID { - u, err := types.ParseUUID(s) - if err != nil { - panic(err) - } - return u -} - -func intRef(i int64) *int64 { - return &i -} - -type mockInterpreter struct{} - -func (m *mockInterpreter) Spend(_ int64) error { - return nil -} - -func (m *mockInterpreter) Notice(_ string) { -} diff --git a/parse/interpreter/README.md b/parse/interpreter/README.md new file mode 100644 index 000000000..057bd04bc --- /dev/null +++ b/parse/interpreter/README.md @@ -0,0 +1,12 @@ +# Kuneiform Interpreter + +The Kuneiform interpreter is meant to be a simple interpreter for performing basic arithmetic and access control logic. It is capable of: + +- if/then/else statements +- for loops +- basic arithmetic +- executing functions / other actions +- executing SQL statements + +For all function calls, it will make a call to Postgres, so that it can 100% match the functionality provided by Postgres. This is obviously very inefficient, +but it can be optimized later by mirroring Postgres functionality in Go. For now, we are prioritizing speed of development and breadth of supported functions. \ No newline at end of file diff --git a/parse/interpreter/benchmark_test.go b/parse/interpreter/benchmark_test.go index 2235b243f..5eaa172d2 100644 --- a/parse/interpreter/benchmark_test.go +++ b/parse/interpreter/benchmark_test.go @@ -2,7 +2,6 @@ package interpreter import ( "context" - "math" "testing" "github.com/kwilteam/kwil-db/core/types" @@ -64,7 +63,7 @@ func BenchmarkLoops(b *testing.B) { schema.Procedures = []*types.Procedure{tt.proc} b.ResetTimer() for i := 0; i < b.N; i++ { - _, err := Run(ctx, tt.proc, schema, tt.args, math.MaxInt64, ZeroCostTable()) + _, err := Run(ctx, tt.proc, schema, tt.args) if err != nil { b.Fatal(err) } diff --git a/parse/interpreter/interpreter_test.go b/parse/interpreter/interpreter_test.go index 907265a79..13e592eb3 100644 --- a/parse/interpreter/interpreter_test.go +++ b/parse/interpreter/interpreter_test.go @@ -2,7 +2,6 @@ package interpreter_test import ( "context" - "math" "testing" "github.com/kwilteam/kwil-db/parse" @@ -100,8 +99,7 @@ func Test_Interpeter(t *testing.T) { t.Fatalf("procedure %s not found", test.procName) } - // start := time.Now() - res, err := interpreter.Run(ctx, proc, schema, test.inputVals, math.MaxInt64, interpreter.ZeroCostTable()) + res, err := interpreter.Run(ctx, proc, schema, test.inputVals) if test.err != nil { require.Error(t, err) require.ErrorIs(t, err, test.err) @@ -109,11 +107,6 @@ func Test_Interpeter(t *testing.T) { } else { require.NoError(t, err) } - // fmt.Println(time.Since(start)) - // panic("") - - // since we are using the ZeroCostTable, the cost should be 0 - require.Equalf(t, res.Cost, int64(0), "cost is not 0") require.Equal(t, len(test.expected), len(res.Values)) for i, row := range res.Values { diff --git a/parse/interpreter/interpreter.go b/parse/interpreter/planner.go similarity index 71% rename from parse/interpreter/interpreter.go rename to parse/interpreter/planner.go index f8b20e30c..e8f4dd0e7 100644 --- a/parse/interpreter/interpreter.go +++ b/parse/interpreter/planner.go @@ -13,7 +13,7 @@ import ( "github.com/kwilteam/kwil-db/parse/common" ) -func Run(ctx context.Context, proc *types.Procedure, schema *types.Schema, args []any, maxCost int64, costTable *CostTable) (*ProcedureRunResult, error) { +func Run(ctx context.Context, proc *types.Procedure, schema *types.Schema, args []any) (*ProcedureRunResult, error) { parseResult, err := parse.ParseProcedure(proc, schema) if err != nil { return nil, err @@ -25,9 +25,7 @@ func Run(ctx context.Context, proc *types.Procedure, schema *types.Schema, args i := &interpreterPlanner{} exec := &executionContext{ - maxCost: maxCost, - scope: newScope(), - costTable: costTable, + scope: newScope(), } if len(proc.Parameters) != len(args) { @@ -35,7 +33,7 @@ func Run(ctx context.Context, proc *types.Procedure, schema *types.Schema, args } for j, arg := range args { - val, err := common.NewVariable(arg) + val, err := NewVariable(arg) if err != nil { return nil, err } @@ -60,10 +58,6 @@ func Run(ctx context.Context, proc *types.Procedure, schema *types.Schema, args res := newReturnableCursor(expectedShape) procRes := &ProcedureRunResult{} - defer func() { - procRes.Cost = exec.currentCost - }() - go func() { for _, stmt := range parseResult.AST { err := stmt.Accept(i).(stmtFunc)(ctx, exec, res) @@ -148,7 +142,7 @@ func Run(ctx context.Context, proc *types.Procedure, schema *types.Schema, args return procRes, nil } -func makeNamedReturns(expected []*types.NamedType, record common.RecordValue) ([]*NamedValue, error) { +func makeNamedReturns(expected []*types.NamedType, record RecordValue) ([]*NamedValue, error) { if len(expected) != len(record.Fields) { return nil, fmt.Errorf("expected %d return fields, got %d", len(expected), len(record.Fields)) } @@ -176,34 +170,33 @@ func makeNamedReturns(expected []*types.NamedType, record common.RecordValue) ([ type NamedValue struct { Name string - Value common.Value + Value Value } type ProcedureRunResult struct { - Cost int64 Values [][]*NamedValue } -// procedureCallFunc is a function that is generated for a procedure call. -// It is used to call a procedure. -type procedureCallFunc func(ctx context.Context, exec *executionContext, args []common.Value) (Cursor, error) +// functionCall contains logic for either a user-defined PL/pgSQL function, a built-in function, +// or an action. +type functionCall func(ctx context.Context, exec *executionContext, args []Value) (Cursor, error) -func (i *interpreterPlanner) makeProcedureCallFunc(procAst []parse.ProcedureStmt, procParams []*types.NamedType, procReturns *types.ProcedureReturn) procedureCallFunc { - stmtFns := make([]stmtFunc, len(procAst)) - for j, stmt := range procAst { +func (i *interpreterPlanner) makeActionCallFunc(ast []parse.ProcedureStmt, params []*types.NamedType, returns *types.ProcedureReturn) functionCall { + stmtFns := make([]stmtFunc, len(ast)) + for j, stmt := range ast { stmtFns[j] = stmt.Accept(i).(stmtFunc) } var expectedShape []*types.DataType - if procReturns != nil { - for _, f := range procReturns.Fields { + if returns != nil { + for _, f := range returns.Fields { expectedShape = append(expectedShape, f.Type) } } - return func(ctx context.Context, exec *executionContext, args []common.Value) (Cursor, error) { - if len(procParams) != len(args) { - return nil, fmt.Errorf("expected %d arguments, got %d", len(procParams), len(args)) + return func(ctx context.Context, exec *executionContext, args []Value) (Cursor, error) { + if len(params) != len(args) { + return nil, fmt.Errorf("expected %d arguments, got %d", len(params), len(args)) } ret := newReturnableCursor(expectedShape) @@ -214,15 +207,14 @@ func (i *interpreterPlanner) makeProcedureCallFunc(procAst []parse.ProcedureStmt }() // procedures cannot access variables from the parent scope, so we create a new scope - // TODO: handle @foreign_caller exec.scope = newScope() for j, arg := range args { - if !procParams[j].Type.EqualsStrict(arg.Type()) { - return nil, fmt.Errorf("expected argument %d to be %s, got %s", j+1, procParams[j].Type, arg.Type()) + if !params[j].Type.EqualsStrict(arg.Type()) { + return nil, fmt.Errorf("expected argument %d to be %s, got %s", j+1, params[j].Type, arg.Type()) } - err := exec.allocateVariable(procParams[j].Name, arg) + err := exec.allocateVariable(params[j].Name, arg) if err != nil { return nil, err } @@ -237,12 +229,18 @@ func (i *interpreterPlanner) makeProcedureCallFunc(procAst []parse.ProcedureStmt } } -// interpreterPlanner is a basic interpreterPlanner for Kuneiform procedures. -type interpreterPlanner struct { - // schema is the database schema. - schema *types.Schema - // procedures are the asts - // TODO: make a map +// interpreterPlanner creates functions for running Kuneiform logic. +type interpreterPlanner struct{} + +// FunctionSignature is the signature for either a user-defined PL/pgSQL function, a built-in function, +// or an action. +type FunctionSignature struct { + // Name is the name of the function. + Name string + // Parameters are the parameters of the function. + Parameters []*types.NamedType + // Returns are the return values of the function. + Returns *types.ProcedureReturn } var ( @@ -257,7 +255,7 @@ type stmtFunc func(ctx context.Context, exec *executionContext, ret returnChans) func (i *interpreterPlanner) VisitProcedureStmtDeclaration(p0 *parse.ProcedureStmtDeclaration) any { return stmtFunc(func(ctx context.Context, exec *executionContext, ret returnChans) error { - return exec.allocateVariable(p0.Variable.Name, common.NewNullValue(p0.Type)) + return exec.allocateVariable(p0.Variable.Name, NewNullValue(p0.Type)) }) } @@ -280,7 +278,7 @@ func (i *interpreterPlanner) VisitProcedureStmtAssignment(p0 *parse.ProcedureStm case *parse.ExpressionVariable: return exec.setVariable(a.Name, val) case *parse.ExpressionArrayAccess: - scalarVal, ok := val.(common.ScalarValue) + scalarVal, ok := val.(ScalarValue) if !ok { return fmt.Errorf("expected scalar value, got %T", val) } @@ -290,16 +288,11 @@ func (i *interpreterPlanner) VisitProcedureStmtAssignment(p0 *parse.ProcedureStm return err } - arr, ok := arrVal.(common.ArrayValue) + arr, ok := arrVal.(ArrayValue) if !ok { return fmt.Errorf("expected array, got %T", arrVal) } - err = exec.Spend(exec.costTable.ArrayAccessCost + exec.costTable.SetVariableCost) - if err != nil { - return err - } - index, err := indexFn(ctx, exec) if err != nil { return err @@ -317,7 +310,99 @@ func (i *interpreterPlanner) VisitProcedureStmtAssignment(p0 *parse.ProcedureStm } func (i *interpreterPlanner) VisitProcedureStmtCall(p0 *parse.ProcedureStmtCall) any { - panic("TODO: Implement") + fnCall, ok := p0.Call.(*parse.ExpressionFunctionCall) + if !ok { + // this will get removed once we update the AST with v0.10 changes + panic("expected function call") + } + + // we cannot simply use the same visitor as the expression function call, because expression function + // calls always return exactly one value. Here, we can return 0 values, many values, or a table. + + receivers := make([]string, len(p0.Receivers)) + for j, r := range p0.Receivers { + receivers[j] = r.Name + } + + args := make([]exprFunc, len(fnCall.Args)) + for j, arg := range fnCall.Args { + args[j] = arg.Accept(i).(exprFunc) + } + + return stmtFunc(func(ctx context.Context, exec *executionContext, ret returnChans) error { + funcDef, ok := exec.availableFunctions[fnCall.Name] + if !ok { + return fmt.Errorf(`action "%s" no longer exists`, fnCall.Name) + } + + // verify that the args match the function signature + if len(funcDef.Signature.Parameters) != len(args) { + return fmt.Errorf("expected %d arguments, got %d", len(funcDef.Signature.Parameters), len(args)) + } + + // verify the returns. + // If the user expects values, then it must exactly match the number of returns. + // If the user does not expect values, then the function can return anything / return nothing. + if len(receivers) != 0 { + if funcDef.Signature.Returns == nil { + return fmt.Errorf(`expected function "%s" to return %d values, but it does not return anything`, funcDef.Signature.Name, len(receivers)) + } + + if len(funcDef.Signature.Returns.Fields) != len(receivers) { + return fmt.Errorf(`expected function "%s" to return %d values, but it returns %d`, funcDef.Signature.Name, len(receivers), len(funcDef.Signature.Returns.Fields)) + } + + if funcDef.Signature.Returns.IsTable { + return fmt.Errorf(`expected function "%s" to return %d values, but it returns a table`, funcDef.Signature.Name, len(receivers)) + } + } + + vals := make([]Value, len(args)) + for j, valFn := range args { + val, err := valFn(ctx, exec) + if err != nil { + return err + } + + vals[j] = val + } + + cursor, err := funcDef.Func(ctx, exec, vals) + if err != nil { + return err + } + + defer cursor.Close() + + for { + rec, done, err := cursor.Next(ctx) + if err != nil { + return err + } + if done { + break + } + + if len(receivers) != 0 { + // since cursors return a map, we need to match up + // the expected return field names with the actual field names, + // and then assign the values to the receivers in the correct order. + for j, sigField := range funcDef.Signature.Returns.Fields { + val, ok := rec.Fields[sigField.Name] + if !ok { + return fmt.Errorf(`expected return value "%s" not found`, sigField.Name) + } + + err = exec.setVariable(receivers[j], val) + if err != nil { + return err + } + } + } + } + + return nil + }) } // executeBlock executes a block of statements with their own sub-scope. @@ -371,11 +456,6 @@ func (i *interpreterPlanner) VisitProcedureStmtForLoop(p0 *parse.ProcedureStmtFo break } - err = exec.Spend(exec.costTable.LoopCost) - if err != nil { - return err - } - err = executeBlock(ctx, exec, ret, []*NamedValue{ { Name: p0.Receiver.Name, @@ -434,14 +514,14 @@ type rangeLooper struct { current int64 } -func (r *rangeLooper) Next(ctx context.Context) (common.Value, bool, error) { +func (r *rangeLooper) Next(ctx context.Context) (Value, bool, error) { if r.current > r.end { return nil, true, nil } ret := r.current r.current++ - return &common.IntValue{ + return &IntValue{ Val: ret, }, false, nil } @@ -461,7 +541,7 @@ func (i *interpreterPlanner) VisitLoopTermVariable(p0 *parse.LoopTermVariable) a return nil, err } - arr, ok := val.(common.ArrayValue) + arr, ok := val.(ArrayValue) if !ok { return nil, fmt.Errorf("expected array, got %T", val) } @@ -475,16 +555,16 @@ func (i *interpreterPlanner) VisitLoopTermVariable(p0 *parse.LoopTermVariable) a // loopReturn is an interface for iterating over the result of a loop term. type loopReturn interface { - Next(ctx context.Context) (common.Value, bool, error) + Next(ctx context.Context) (Value, bool, error) Close() error } type arrayLooper struct { - arr common.ArrayValue + arr ArrayValue index int64 } -func (a *arrayLooper) Next(ctx context.Context) (common.Value, bool, error) { +func (a *arrayLooper) Next(ctx context.Context) (Value, bool, error) { ret, err := a.arr.Index(a.index) if err != nil { if err == common.ErrIndexOutOfBounds { @@ -543,11 +623,11 @@ func (i *interpreterPlanner) VisitProcedureStmtIf(p0 *parse.ProcedureStmtIf) any } switch c := cond.(type) { - case *common.BoolValue: + case *BoolValue: if !c.Val { continue } - case *common.NullValue: + case *NullValue: continue default: return fmt.Errorf("expected bool, got %s", c.Type()) @@ -578,10 +658,6 @@ func (i *interpreterPlanner) VisitProcedureStmtSQL(p0 *parse.ProcedureStmtSQL) a func (i *interpreterPlanner) VisitProcedureStmtBreak(p0 *parse.ProcedureStmtBreak) any { return stmtFunc(func(ctx context.Context, exec *executionContext, ret returnChans) error { - if err := exec.Spend(exec.costTable.BreakCost); err != nil { - return err - } - return errBreak }) } @@ -593,11 +669,7 @@ func (i *interpreterPlanner) VisitProcedureStmtReturn(p0 *parse.ProcedureStmtRet } return stmtFunc(func(ctx context.Context, exec *executionContext, ret returnChans) error { - if err := exec.Spend(exec.costTable.ReturnCost); err != nil { - return err - } - - vals := make([]common.Value, len(p0.Values)) + vals := make([]Value, len(p0.Values)) for j, valFn := range valFns { val, err := valFn(ctx, exec) if err != nil { @@ -620,11 +692,7 @@ func (i *interpreterPlanner) VisitProcedureStmtReturnNext(p0 *parse.ProcedureStm } return stmtFunc(func(ctx context.Context, exec *executionContext, ret returnChans) error { - if err := exec.Spend(exec.costTable.ReturnCost); err != nil { - return err - } - - vals := make([]common.Value, len(p0.Values)) + vals := make([]Value, len(p0.Values)) for j, valFn := range valFns { val, err := valFn(ctx, exec) if err != nil { @@ -645,11 +713,11 @@ func (i *interpreterPlanner) VisitProcedureStmtReturnNext(p0 *parse.ProcedureStm // everything in this section is for expressions, which evaluate to exactly one value. // exprFunc is a function that returns a value. -type exprFunc func(ctx context.Context, exec *executionContext) (common.Value, error) +type exprFunc func(ctx context.Context, exec *executionContext) (Value, error) func (i *interpreterPlanner) VisitExpressionLiteral(p0 *parse.ExpressionLiteral) any { - return exprFunc(func(ctx context.Context, exec *executionContext) (common.Value, error) { - return common.NewVariable(p0.Value) + return exprFunc(func(ctx context.Context, exec *executionContext) (Value, error) { + return NewVariable(p0.Value) }) } @@ -659,55 +727,74 @@ func (i *interpreterPlanner) VisitExpressionFunctionCall(p0 *parse.ExpressionFun args[j] = arg.Accept(i).(exprFunc) } - // can be a built-in function or a user-defined procedure - funcDef, ok := common.Functions[p0.Name] - if ok { - scalarFunc, ok := funcDef.(*common.ScalarFunctionDefinition) + return exprFunc(func(ctx context.Context, exec *executionContext) (Value, error) { + // we check again because the action might have been dropped + funcDef, ok := exec.availableFunctions[p0.Name] if !ok { - panic("cannot call non-scalar function in procedure") + return nil, fmt.Errorf(`function "%s" no longer exists`, p0.Name) } - return makeBuiltInFunctionCall(scalarFunc, args) - } - panic("TODO: Implement") - // // otherwise, it must be a user defined procedure. - // proc, ok := i.schema.FindProcedure(p0.Name) - // if !ok { - // panic("procedure not found") - // } - - // return exprFunc(func(ctx context.Context, exec *executionContext) (common.Value, error) { - - // }) -} + if len(funcDef.Signature.Parameters) != len(args) { + return nil, fmt.Errorf("expected %d arguments, got %d", len(funcDef.Signature.Parameters), len(args)) + } -func makeBuiltInFunctionCall(funcDef *common.ScalarFunctionDefinition, args []exprFunc) exprFunc { - return exprFunc(func(ctx context.Context, exec *executionContext) (common.Value, error) { - err := exec.Spend(exec.costTable.CallBuiltInFunctionCost) - if err != nil { - return nil, err + if funcDef.Signature.Returns == nil { + return nil, fmt.Errorf(`cannot call function "%s" in an expression because it returns nothing`, p0.Name) + } + if funcDef.Signature.Returns.IsTable { + return nil, fmt.Errorf(`cannot call function "%s" in an expression because it returns a table`, p0.Name) + } + if len(funcDef.Signature.Returns.Fields) != 1 { + return nil, fmt.Errorf(`cannot call function "%s" in an expression because it returns multiple values`, p0.Name) } - vals := make([]common.Value, len(args)) - for i, arg := range args { + vals := make([]Value, len(args)) + for j, arg := range args { val, err := arg(ctx, exec) if err != nil { return nil, err } - vals[i] = val + if !val.Type().EqualsStrict(funcDef.Signature.Parameters[j].Type) { + return nil, fmt.Errorf("expected argument %d to be %s, got %s", j+1, funcDef.Signature.Parameters[j].Type, val.Type()) + } + + vals[j] = val + } + + cursor, err := funcDef.Func(ctx, exec, vals) + if err != nil { + return nil, err + } + + defer cursor.Close() + + rec, done, err := cursor.Next(ctx) + if err != nil { + return nil, err + } + + if done { + return nil, fmt.Errorf("expected scalar value, got nothing") } - return funcDef.EvaluateFunc(exec, vals) + if len(rec.Fields) != 1 { + return nil, fmt.Errorf("expected scalar value, got record with %d fields", len(rec.Fields)) + } + + return rec.Fields[rec.Order[0]], nil }) } func (i *interpreterPlanner) VisitExpressionForeignCall(p0 *parse.ExpressionForeignCall) any { - panic("TODO: Implement") + // since v0.10 is single-schema, we don't need to support foreign calls + // This should be caught at a higher level, but we panic here just in case. + // Will probably remove this in the future. + panic("foreign calls are no longer supported as of Kwil v0.10") } func (i *interpreterPlanner) VisitExpressionVariable(p0 *parse.ExpressionVariable) any { - return exprFunc(func(ctx context.Context, exec *executionContext) (common.Value, error) { + return exprFunc(func(ctx context.Context, exec *executionContext) (Value, error) { return exec.getVariable(p0.Name) }) } @@ -716,18 +803,13 @@ func (i *interpreterPlanner) VisitExpressionArrayAccess(p0 *parse.ExpressionArra arrFn := p0.Array.Accept(i).(exprFunc) indexFn := p0.Index.Accept(i).(exprFunc) - return exprFunc(func(ctx context.Context, exec *executionContext) (common.Value, error) { - err := exec.Spend(exec.costTable.ArrayAccessCost) - if err != nil { - return nil, err - } - + return exprFunc(func(ctx context.Context, exec *executionContext) (Value, error) { arrVal, err := arrFn(ctx, exec) if err != nil { return nil, err } - arr, ok := arrVal.(common.ArrayValue) + arr, ok := arrVal.(ArrayValue) if !ok { return nil, fmt.Errorf("expected array, got %T", arrVal) } @@ -751,12 +833,7 @@ func (i *interpreterPlanner) VisitExpressionMakeArray(p0 *parse.ExpressionMakeAr valFns[j] = v.Accept(i).(exprFunc) } - return exprFunc(func(ctx context.Context, exec *executionContext) (common.Value, error) { - err := exec.Spend(exec.costTable.MakeArrayCost) - if err != nil { - return nil, err - } - + return exprFunc(func(ctx context.Context, exec *executionContext) (Value, error) { if len(valFns) == 0 { return nil, fmt.Errorf("array must have at least one element") } @@ -766,12 +843,12 @@ func (i *interpreterPlanner) VisitExpressionMakeArray(p0 *parse.ExpressionMakeAr return nil, err } - scal, ok := val0.(common.ScalarValue) + scal, ok := val0.(ScalarValue) if !ok { return nil, fmt.Errorf("expected scalar value, got %T", val0) } - var vals []common.ScalarValue + var vals []ScalarValue for j, valFn := range valFns { if j == 0 { continue @@ -782,7 +859,7 @@ func (i *interpreterPlanner) VisitExpressionMakeArray(p0 *parse.ExpressionMakeAr return nil, err } - scal, ok := val.(common.ScalarValue) + scal, ok := val.(ScalarValue) if !ok { return nil, fmt.Errorf("expected scalar value, got %T", val) } @@ -797,18 +874,13 @@ func (i *interpreterPlanner) VisitExpressionMakeArray(p0 *parse.ExpressionMakeAr func (i *interpreterPlanner) VisitExpressionFieldAccess(p0 *parse.ExpressionFieldAccess) any { recordFn := p0.Record.Accept(i).(exprFunc) - return exprFunc(func(ctx context.Context, exec *executionContext) (common.Value, error) { - err := exec.Spend(exec.costTable.GetVariableCost) - if err != nil { - return nil, err - } - + return exprFunc(func(ctx context.Context, exec *executionContext) (Value, error) { objVal, err := recordFn(ctx, exec) if err != nil { return nil, err } - obj, ok := objVal.(*common.RecordValue) + obj, ok := objVal.(*RecordValue) if !ok { return nil, fmt.Errorf("expected object, got %T", objVal) } @@ -847,11 +919,7 @@ func (i *interpreterPlanner) VisitExpressionComparison(p0 *parse.ExpressionCompa // makeComparisonFunc returns a function that compares two values. func makeComparisonFunc(left, right exprFunc, cmpOps common.ComparisonOp) exprFunc { - return func(ctx context.Context, exec *executionContext) (common.Value, error) { - if err := exec.Spend(exec.costTable.ComparisonCost); err != nil { - return nil, err - } - + return func(ctx context.Context, exec *executionContext) (Value, error) { leftVal, err := left(ctx, exec) if err != nil { return nil, err @@ -877,11 +945,7 @@ func (i *interpreterPlanner) VisitExpressionLogical(p0 *parse.ExpressionLogical) // makeLogicalFunc returns a function that performs a logical operation. // If and is true, it performs an AND operation, otherwise it performs an OR operation. func makeLogicalFunc(left, right exprFunc, and bool) exprFunc { - return func(ctx context.Context, exec *executionContext) (common.Value, error) { - if err := exec.Spend(exec.costTable.LogicalCost); err != nil { - return nil, err - } - + return func(ctx context.Context, exec *executionContext) (Value, error) { leftVal, err := left(ctx, exec) if err != nil { return nil, err @@ -896,37 +960,32 @@ func makeLogicalFunc(left, right exprFunc, and bool) exprFunc { return nil, fmt.Errorf("expected bools, got %s and %s", leftVal.Type(), rightVal.Type()) } - if _, ok := leftVal.(*common.NullValue); ok { + if _, ok := leftVal.(*NullValue); ok { return leftVal, nil } - if _, ok := rightVal.(*common.NullValue); ok { + if _, ok := rightVal.(*NullValue); ok { return rightVal, nil } if and { - return &common.BoolValue{ + return &BoolValue{ Val: leftVal.Value().(bool) && rightVal.Value().(bool), }, nil } - return &common.BoolValue{ + return &BoolValue{ Val: leftVal.Value().(bool) || rightVal.Value().(bool), }, nil } } func (i *interpreterPlanner) VisitExpressionArithmetic(p0 *parse.ExpressionArithmetic) any { - op := parse.ConvertArithmeticOp(p0.Operator) leftFn := p0.Left.Accept(i).(exprFunc) rightFn := p0.Right.Accept(i).(exprFunc) - return exprFunc(func(ctx context.Context, exec *executionContext) (common.Value, error) { - if err := exec.Spend(exec.costTable.ArithmeticCost); err != nil { - return nil, err - } - + return exprFunc(func(ctx context.Context, exec *executionContext) (Value, error) { left, err := leftFn(ctx, exec) if err != nil { return nil, err @@ -937,12 +996,12 @@ func (i *interpreterPlanner) VisitExpressionArithmetic(p0 *parse.ExpressionArith return nil, err } - leftScalar, ok := left.(common.ScalarValue) + leftScalar, ok := left.(ScalarValue) if !ok { return nil, fmt.Errorf("expected scalar, got %T", left) } - rightScalar, ok := right.(common.ScalarValue) + rightScalar, ok := right.(ScalarValue) if !ok { return nil, fmt.Errorf("expected scalar, got %T", right) } @@ -959,17 +1018,13 @@ func (i *interpreterPlanner) VisitExpressionUnary(p0 *parse.ExpressionUnary) any // makeUnaryFunc returns a function that performs a unary operation. func makeUnaryFunc(val exprFunc, op common.UnaryOp) exprFunc { - return exprFunc(func(ctx context.Context, exec *executionContext) (common.Value, error) { - if err := exec.Spend(exec.costTable.UnaryCost); err != nil { - return nil, err - } - + return exprFunc(func(ctx context.Context, exec *executionContext) (Value, error) { v, err := val(ctx, exec) if err != nil { return nil, err } - vScalar, ok := v.(common.ScalarValue) + vScalar, ok := v.(ScalarValue) if !ok { return nil, fmt.Errorf("%w: expected scalar, got %T", ErrUnaryOnNonScalar, v) } @@ -1000,107 +1055,107 @@ func (i *interpreterPlanner) VisitExpressionIs(p0 *parse.ExpressionIs) any { // since we will have separate handling for SQL statements at a later stage. func (i *interpreterPlanner) VisitExpressionColumn(p0 *parse.ExpressionColumn) any { - panic("TODO: Implement") + panic("intepreter planner should not be called for SQL expressions") } func (i *interpreterPlanner) VisitExpressionCollate(p0 *parse.ExpressionCollate) any { - panic("TODO: Implement") + panic("intepreter planner should not be called for SQL expressions") } func (i *interpreterPlanner) VisitExpressionStringComparison(p0 *parse.ExpressionStringComparison) any { - panic("TODO: Implement") + panic("intepreter planner should not be called for SQL expressions") } func (i *interpreterPlanner) VisitExpressionIn(p0 *parse.ExpressionIn) any { - panic("TODO: Implement") + panic("intepreter planner should not be called for SQL expressions") } func (i *interpreterPlanner) VisitExpressionBetween(p0 *parse.ExpressionBetween) any { - panic("TODO: Implement") + panic("intepreter planner should not be called for SQL expressions") } func (i *interpreterPlanner) VisitExpressionSubquery(p0 *parse.ExpressionSubquery) any { - panic("TODO: Implement") + panic("intepreter planner should not be called for SQL expressions") } func (i *interpreterPlanner) VisitExpressionCase(p0 *parse.ExpressionCase) any { - panic("TODO: Implement") + panic("intepreter planner should not be called for SQL expressions") } func (i *interpreterPlanner) VisitCommonTableExpression(p0 *parse.CommonTableExpression) any { - panic("TODO: Implement") + panic("intepreter planner should not be called for SQL expressions") } func (i *interpreterPlanner) VisitSQLStatement(p0 *parse.SQLStatement) any { - panic("TODO: Implement") + panic("intepreter planner should not be called for SQL expressions") } func (i *interpreterPlanner) VisitSelectStatement(p0 *parse.SelectStatement) any { - panic("TODO: Implement") + panic("intepreter planner should not be called for SQL expressions") } func (i *interpreterPlanner) VisitSelectCore(p0 *parse.SelectCore) any { - panic("TODO: Implement") + panic("intepreter planner should not be called for SQL expressions") } func (i *interpreterPlanner) VisitResultColumnExpression(p0 *parse.ResultColumnExpression) any { - panic("TODO: Implement") + panic("intepreter planner should not be called for SQL expressions") } func (i *interpreterPlanner) VisitResultColumnWildcard(p0 *parse.ResultColumnWildcard) any { - panic("TODO: Implement") + panic("intepreter planner should not be called for SQL expressions") } func (i *interpreterPlanner) VisitRelationTable(p0 *parse.RelationTable) any { - panic("TODO: Implement") + panic("intepreter planner should not be called for SQL expressions") } func (i *interpreterPlanner) VisitRelationSubquery(p0 *parse.RelationSubquery) any { - panic("TODO: Implement") + panic("intepreter planner should not be called for SQL expressions") } func (i *interpreterPlanner) VisitRelationFunctionCall(p0 *parse.RelationFunctionCall) any { - panic("TODO: Implement") + panic("intepreter planner should not be called for SQL expressions") } func (i *interpreterPlanner) VisitJoin(p0 *parse.Join) any { - panic("TODO: Implement") + panic("intepreter planner should not be called for SQL expressions") } func (i *interpreterPlanner) VisitUpdateStatement(p0 *parse.UpdateStatement) any { - panic("TODO: Implement") + panic("intepreter planner should not be called for SQL expressions") } func (i *interpreterPlanner) VisitUpdateSetClause(p0 *parse.UpdateSetClause) any { - panic("TODO: Implement") + panic("intepreter planner should not be called for SQL expressions") } func (i *interpreterPlanner) VisitDeleteStatement(p0 *parse.DeleteStatement) any { - panic("TODO: Implement") + panic("intepreter planner should not be called for SQL expressions") } func (i *interpreterPlanner) VisitInsertStatement(p0 *parse.InsertStatement) any { - panic("TODO: Implement") + panic("intepreter planner should not be called for SQL expressions") } func (i *interpreterPlanner) VisitUpsertClause(p0 *parse.UpsertClause) any { - panic("TODO: Implement") + panic("intepreter planner should not be called for SQL expressions") } func (i *interpreterPlanner) VisitOrderingTerm(p0 *parse.OrderingTerm) any { - panic("TODO: Implement") + panic("intepreter planner should not be called for SQL expressions") } func (i *interpreterPlanner) VisitActionStmtSQL(p0 *parse.ActionStmtSQL) any { - panic("TODO: Implement") + panic("intepreter planner should not be called for SQL expressions") } func (i *interpreterPlanner) VisitActionStmtExtensionCall(p0 *parse.ActionStmtExtensionCall) any { - panic("TODO: Implement") + panic("intepreter planner should not be called for SQL expressions") } func (i *interpreterPlanner) VisitActionStmtActionCall(p0 *parse.ActionStmtActionCall) any { - panic("TODO: Implement") + panic("intepreter planner should not be called for SQL expressions") } func (i *interpreterPlanner) VisitIfThen(p0 *parse.IfThen) any { diff --git a/parse/interpreter/types.go b/parse/interpreter/types.go index 8a99bdb36..380036afe 100644 --- a/parse/interpreter/types.go +++ b/parse/interpreter/types.go @@ -6,27 +6,33 @@ import ( "fmt" "github.com/kwilteam/kwil-db/core/types" - "github.com/kwilteam/kwil-db/parse/common" ) // executionContext is the context of the entire execution. type executionContext struct { - // maxCost is the maximum allowable cost of the execution. - maxCost int64 - // currentCost is the current cost of the execution. - currentCost int64 // scope is the current scope. scope *scopeContext - // costTable is the cost table for the execution. - costTable *CostTable - // procedures maps local procedure names to their procedure func. - procedures map[string]procedureCallFunc + // availableFunctions is a map of both built-in functions and user-defined PL/pgSQL functions. + // When the interpreter planner is created, it will be populated with all built-in functions, + // and then it will be updated with user-defined functions, effectively allowing users to override + // some function name with their own implementation. This allows Kwil to add new built-in + // functions without worrying about breaking user schemas. + // This will not include aggregate and window functions, as those can only be used in SQL. + // availableFunctions maps local action names to their execution func. + availableFunctions map[string]*executable +} + +// executable is the interface and function to call a built-in Postgres function, +// a user-defined Postgres procedure, or a user-defined Kwil action. +type executable struct { + Signature *FunctionSignature + Func functionCall } // newScope creates a new scope. func newScope() *scopeContext { return &scopeContext{ - variables: make(map[string]common.Value), + variables: make(map[string]Value), } } @@ -34,35 +40,13 @@ func newScope() *scopeContext { func (s *scopeContext) subScope() *scopeContext { return &scopeContext{ parent: s, - variables: make(map[string]common.Value), + variables: make(map[string]Value), } } -// Spend spends a certain amount of cost. -// If the cost exceeds the maximum cost, it returns an error. -func (e *executionContext) Spend(cost int64) error { - if e.currentCost+cost > e.maxCost { - - e.currentCost = e.maxCost - return fmt.Errorf("exceeded maximum cost: %d", e.maxCost) - } - e.currentCost += cost - return nil -} - -// Notice logs a notice. -func (e *executionContext) Notice(format string) { - panic("notice not implemented") -} - // setVariable sets a variable in the current scope. // It will allocate the variable if it does not exist. -func (e *executionContext) setVariable(name string, value common.Value) error { - err := e.Spend(e.costTable.GetVariableCost) - if err != nil { - return err - } - +func (e *executionContext) setVariable(name string, value Value) error { _, foundScope, err := getVarFromScope(name, e.scope) if err != nil { if errors.Is(err, ErrVariableNotFound) { @@ -72,22 +56,12 @@ func (e *executionContext) setVariable(name string, value common.Value) error { } } - err = e.Spend(e.costTable.SetVariableCost + e.costTable.SizeCostConstant*int64(value.Size())) - if err != nil { - return err - } - foundScope.variables[name] = value return nil } // allocateVariable allocates a variable in the current scope. -func (e *executionContext) allocateVariable(name string, value common.Value) error { - err := e.Spend(e.costTable.AllocateVariableCost + e.costTable.SizeCostConstant*int64(value.Size())) - if err != nil { - return err - } - +func (e *executionContext) allocateVariable(name string, value Value) error { _, ok := e.scope.variables[name] if ok { return fmt.Errorf(`variable "%s" already exists`, name) @@ -99,19 +73,14 @@ func (e *executionContext) allocateVariable(name string, value common.Value) err // getVariable gets a variable from the current scope. // It searches the parent scopes if the variable is not found. -func (e *executionContext) getVariable(name string) (common.Value, error) { - err := e.Spend(e.costTable.GetVariableCost) - if err != nil { - return nil, err - } - +func (e *executionContext) getVariable(name string) (Value, error) { v, _, err := getVarFromScope(name, e.scope) return v, err } // getVarFromScope recursively searches the scopes for a variable. // It returns the value, as well as the scope it was found in. -func getVarFromScope(variable string, scope *scopeContext) (common.Value, *scopeContext, error) { +func getVarFromScope(variable string, scope *scopeContext) (Value, *scopeContext, error) { if v, ok := scope.variables[variable]; ok { return v, scope, nil } @@ -127,7 +96,7 @@ type scopeContext struct { // if the parent is nil, this is the root parent *scopeContext // variables are the variables stored in memory. - variables map[string]common.Value + variables map[string]Value } // Cursor is the cursor for the current execution. @@ -136,7 +105,7 @@ type Cursor interface { // Next moves the cursor to the next result. // It returns the value returned, if the cursor is done, and an error. // If the cursor is done, the value returned is not valid. - Next(context.Context) (common.RecordValue, bool, error) + Next(context.Context) (RecordValue, bool, error) // Close closes the cursor. Close() error } @@ -146,19 +115,19 @@ type Cursor interface { // results. type returnableCursor struct { expectedShape []*types.DataType - recordChan chan []common.Value + recordChan chan []Value errChan chan error } func newReturnableCursor(expectedShape []*types.DataType) *returnableCursor { return &returnableCursor{ expectedShape: expectedShape, - recordChan: make(chan []common.Value), + recordChan: make(chan []Value), errChan: make(chan error), } } -func (r *returnableCursor) Record() chan<- []common.Value { +func (r *returnableCursor) Record() chan<- []Value { return r.recordChan } @@ -172,24 +141,24 @@ func (r *returnableCursor) Close() error { return nil } -func (r *returnableCursor) Next(ctx context.Context) (common.RecordValue, bool, error) { +func (r *returnableCursor) Next(ctx context.Context) (RecordValue, bool, error) { select { case rec, ok := <-r.recordChan: if !ok { - return common.RecordValue{}, true, nil + return RecordValue{}, true, nil } else { // check if the shape is correct if len(r.expectedShape) != len(rec) { - return common.RecordValue{}, false, fmt.Errorf("expected %d columns, got %d", len(r.expectedShape), len(rec)) + return RecordValue{}, false, fmt.Errorf("expected %d columns, got %d", len(r.expectedShape), len(rec)) } - record := common.RecordValue{ - Fields: map[string]common.Value{}, + record := RecordValue{ + Fields: map[string]Value{}, } for i, expected := range r.expectedShape { if !expected.EqualsStrict(rec[i].Type()) { - return common.RecordValue{}, false, fmt.Errorf("expected type %s, got %s", expected, rec[i].Type()) + return RecordValue{}, false, fmt.Errorf("expected type %s, got %s", expected, rec[i].Type()) } record.Fields[expected.Name] = rec[i] @@ -200,17 +169,17 @@ func (r *returnableCursor) Next(ctx context.Context) (common.RecordValue, bool, } case err := <-r.errChan: if err == errReturn { - return common.RecordValue{}, true, nil + return RecordValue{}, true, nil } - return common.RecordValue{}, false, err + return RecordValue{}, false, err case <-ctx.Done(): - return common.RecordValue{}, false, ctx.Err() + return RecordValue{}, false, ctx.Err() } } // returnChans is a helper interface for returning values to channels. type returnChans interface { - Record() chan<- []common.Value + Record() chan<- []Value Err() chan<- error } diff --git a/parse/common/values.go b/parse/interpreter/values.go similarity index 75% rename from parse/common/values.go rename to parse/interpreter/values.go index 9e241cc8b..09ab69eb8 100644 --- a/parse/common/values.go +++ b/parse/interpreter/values.go @@ -1,4 +1,4 @@ -package common +package interpreter import ( "fmt" @@ -8,6 +8,7 @@ import ( "github.com/kwilteam/kwil-db/core/types" "github.com/kwilteam/kwil-db/core/types/decimal" + "github.com/kwilteam/kwil-db/parse/common" ) // Value is a value that can be compared, used in arithmetic operations, @@ -15,26 +16,26 @@ import ( type Value interface { // Compare compares the variable with another variable using the given comparison operator. // It will return a boolean value, or null either of the variables is null. - Compare(v Value, op ComparisonOp) (Value, error) + Compare(v Value, op common.ComparisonOp) (Value, error) // Type returns the type of the variable. Type() *types.DataType // Value returns the value of the variable. Value() any // Size is the size of the variable in bytes. Size() int + // Cast casts the variable to the given type. + Cast(t *types.DataType) (Value, error) } // ScalarValue is a scalar value that can be computed on and have unary operations applied to it. type ScalarValue interface { Value // Arithmetic performs an arithmetic operation on the variable with another variable. - Arithmetic(v ScalarValue, op ArithmeticOp) (ScalarValue, error) + Arithmetic(v ScalarValue, op common.ArithmeticOp) (ScalarValue, error) // Unary applies a unary operation to the variable. - Unary(op UnaryOp) (ScalarValue, error) + Unary(op common.UnaryOp) (ScalarValue, error) // Array creates an array from this scalar value and any other scalar values. Array(v ...ScalarValue) (ArrayValue, error) - // Cast casts the variable to the given type. - Cast(t *types.DataType) (Value, error) } // ArrayValue is an array value that can be compared and have unary operations applied to it. @@ -146,14 +147,14 @@ func NewNullValue(t *types.DataType) Value { } func makeTypeErr(left, right Value) error { - return fmt.Errorf("%w: left: %s right: %s", ErrTypeMismatch, left.Type(), right.Type()) + return fmt.Errorf("%w: left: %s right: %s", common.ErrTypeMismatch, left.Type(), right.Type()) } type IntValue struct { Val int64 } -func (v *IntValue) Compare(v2 Value, op ComparisonOp) (Value, error) { +func (v *IntValue) Compare(v2 Value, op common.ComparisonOp) (Value, error) { if res, early := nullCmp(v2, op); early { return res, nil } @@ -165,13 +166,13 @@ func (v *IntValue) Compare(v2 Value, op ComparisonOp) (Value, error) { var b bool switch op { - case Equal: + case common.Equal: b = v.Val == val2.Val - case LessThan: + case common.LessThan: b = v.Val < val2.Val - case GreaterThan: + case common.GreaterThan: b = v.Val > val2.Val - case IsDistinctFrom: + case common.IsDistinctFrom: b = v.Val != val2.Val default: return nil, fmt.Errorf("cannot compare int with operator id %d", op) @@ -183,23 +184,23 @@ func (v *IntValue) Compare(v2 Value, op ComparisonOp) (Value, error) { // nullCmp is a helper function for comparing null values. // It returns a Value, and a boolean as to whether the caller should return early. // It is meant to be called from methods for non-null values that might need to compare with null. -func nullCmp(v Value, op ComparisonOp) (Value, bool) { +func nullCmp(v Value, op common.ComparisonOp) (Value, bool) { if _, ok := v.(*NullValue); !ok { return nil, false } - if op == IsDistinctFrom { + if op == common.IsDistinctFrom { return &BoolValue{Val: true}, true } - if op == Is { + if op == common.Is { return &BoolValue{Val: false}, true } return &NullValue{DataType: v.Type()}, true } -func (i *IntValue) Arithmetic(v ScalarValue, op ArithmeticOp) (ScalarValue, error) { +func (i *IntValue) Arithmetic(v ScalarValue, op common.ArithmeticOp) (ScalarValue, error) { if _, ok := v.(*NullValue); ok { return &NullValue{DataType: types.IntType}, nil } @@ -210,18 +211,18 @@ func (i *IntValue) Arithmetic(v ScalarValue, op ArithmeticOp) (ScalarValue, erro } switch op { - case Add: + case common.Add: return &IntValue{Val: i.Val + val2.Val}, nil - case Sub: + case common.Sub: return &IntValue{Val: i.Val - val2.Val}, nil - case Mul: + case common.Mul: return &IntValue{Val: i.Val * val2.Val}, nil - case Div: + case common.Div: if val2.Val == 0 { return nil, fmt.Errorf("cannot divide by zero") } return &IntValue{Val: i.Val / val2.Val}, nil - case Mod: + case common.Mod: if val2.Val == 0 { return nil, fmt.Errorf("cannot modulo by zero") } @@ -231,13 +232,13 @@ func (i *IntValue) Arithmetic(v ScalarValue, op ArithmeticOp) (ScalarValue, erro } } -func (i *IntValue) Unary(op UnaryOp) (ScalarValue, error) { +func (i *IntValue) Unary(op common.UnaryOp) (ScalarValue, error) { switch op { - case Neg: + case common.Neg: return &IntValue{Val: -i.Val}, nil - case Not: + case common.Not: return nil, fmt.Errorf("cannot apply logical NOT to an integer") - case Pos: + case common.Pos: return i, nil default: return nil, fmt.Errorf("unknown unary operator: %d", op) @@ -313,7 +314,7 @@ type TextValue struct { Val string } -func (s *TextValue) Compare(v Value, op ComparisonOp) (Value, error) { +func (s *TextValue) Compare(v Value, op common.ComparisonOp) (Value, error) { if res, early := nullCmp(v, op); early { return res, nil } @@ -325,13 +326,13 @@ func (s *TextValue) Compare(v Value, op ComparisonOp) (Value, error) { var b bool switch op { - case Equal: + case common.Equal: b = s.Val == val2.Val - case LessThan: + case common.LessThan: b = s.Val < val2.Val - case GreaterThan: + case common.GreaterThan: b = s.Val > val2.Val - case IsDistinctFrom: + case common.IsDistinctFrom: b = s.Val != val2.Val default: return nil, fmt.Errorf("unknown comparison operator: %d", op) @@ -340,20 +341,20 @@ func (s *TextValue) Compare(v Value, op ComparisonOp) (Value, error) { return &BoolValue{Val: b}, nil } -func (s *TextValue) Arithmetic(v ScalarValue, op ArithmeticOp) (ScalarValue, error) { +func (s *TextValue) Arithmetic(v ScalarValue, op common.ArithmeticOp) (ScalarValue, error) { val2, ok := v.(*TextValue) if !ok { return nil, makeTypeErr(s, v) } - if op == Concat { + if op == common.Concat { return &TextValue{Val: s.Val + val2.Val}, nil } return nil, fmt.Errorf("cannot perform arithmetic operation id %d on type string", op) } -func (s *TextValue) Unary(op UnaryOp) (ScalarValue, error) { +func (s *TextValue) Unary(op common.UnaryOp) (ScalarValue, error) { return nil, fmt.Errorf("cannot perform unary operation on string") } @@ -440,7 +441,7 @@ type BoolValue struct { Val bool } -func (b *BoolValue) Compare(v Value, op ComparisonOp) (Value, error) { +func (b *BoolValue) Compare(v Value, op common.ComparisonOp) (Value, error) { if res, early := nullCmp(v, op); early { return res, nil } @@ -452,9 +453,9 @@ func (b *BoolValue) Compare(v Value, op ComparisonOp) (Value, error) { var b2 bool switch op { - case Equal, Is: + case common.Equal, common.Is: b2 = b.Val == val2.Val - case IsDistinctFrom: + case common.IsDistinctFrom: b2 = b.Val != val2.Val default: return nil, fmt.Errorf("unknown comparison operator: %d", op) @@ -463,13 +464,13 @@ func (b *BoolValue) Compare(v Value, op ComparisonOp) (Value, error) { return &BoolValue{Val: b2}, nil } -func (b *BoolValue) Arithmetic(v ScalarValue, op ArithmeticOp) (ScalarValue, error) { +func (b *BoolValue) Arithmetic(v ScalarValue, op common.ArithmeticOp) (ScalarValue, error) { return nil, fmt.Errorf("cannot perform arithmetic operation on bool") } -func (b *BoolValue) Unary(op UnaryOp) (ScalarValue, error) { +func (b *BoolValue) Unary(op common.UnaryOp) (ScalarValue, error) { switch op { - case Not: + case common.Not: return &BoolValue{Val: !b.Val}, nil default: return nil, fmt.Errorf("unexpected operator id %d for bool", op) @@ -528,7 +529,7 @@ type BlobValue struct { Val []byte } -func (b *BlobValue) Compare(v Value, op ComparisonOp) (Value, error) { +func (b *BlobValue) Compare(v Value, op common.ComparisonOp) (Value, error) { if res, early := nullCmp(v, op); early { return res, nil } @@ -540,9 +541,9 @@ func (b *BlobValue) Compare(v Value, op ComparisonOp) (Value, error) { var b2 bool switch op { - case Equal: + case common.Equal: b2 = string(b.Val) == string(val2.Val) - case IsDistinctFrom: + case common.IsDistinctFrom: b2 = string(b.Val) != string(val2.Val) default: return nil, fmt.Errorf("unknown comparison operator: %d", op) @@ -551,11 +552,11 @@ func (b *BlobValue) Compare(v Value, op ComparisonOp) (Value, error) { return &BoolValue{Val: b2}, nil } -func (b *BlobValue) Arithmetic(v ScalarValue, op ArithmeticOp) (ScalarValue, error) { +func (b *BlobValue) Arithmetic(v ScalarValue, op common.ArithmeticOp) (ScalarValue, error) { return nil, fmt.Errorf("cannot perform arithmetic operation on blob") } -func (b *BlobValue) Unary(op UnaryOp) (ScalarValue, error) { +func (b *BlobValue) Unary(op common.UnaryOp) (ScalarValue, error) { return nil, fmt.Errorf("cannot perform unary operation on blob") } @@ -610,7 +611,7 @@ type UUIDValue struct { Val types.UUID } -func (u *UUIDValue) Compare(v Value, op ComparisonOp) (Value, error) { +func (u *UUIDValue) Compare(v Value, op common.ComparisonOp) (Value, error) { if res, early := nullCmp(v, op); early { return res, nil } @@ -622,9 +623,9 @@ func (u *UUIDValue) Compare(v Value, op ComparisonOp) (Value, error) { var b bool switch op { - case Equal: + case common.Equal: b = u.Val == val2.Val - case IsDistinctFrom: + case common.IsDistinctFrom: b = u.Val != val2.Val default: return nil, fmt.Errorf("unknown comparison operator: %d", op) @@ -633,11 +634,11 @@ func (u *UUIDValue) Compare(v Value, op ComparisonOp) (Value, error) { return &BoolValue{Val: b}, nil } -func (u *UUIDValue) Arithmetic(v ScalarValue, op ArithmeticOp) (ScalarValue, error) { +func (u *UUIDValue) Arithmetic(v ScalarValue, op common.ArithmeticOp) (ScalarValue, error) { return nil, fmt.Errorf("cannot perform arithmetic operation on uuid") } -func (u *UUIDValue) Unary(op UnaryOp) (ScalarValue, error) { +func (u *UUIDValue) Unary(op common.UnaryOp) (ScalarValue, error) { return nil, fmt.Errorf("cannot perform unary operation on uuid") } @@ -688,7 +689,7 @@ type DecimalValue struct { Dec *decimal.Decimal } -func (d *DecimalValue) Compare(v Value, op ComparisonOp) (Value, error) { +func (d *DecimalValue) Compare(v Value, op common.ComparisonOp) (Value, error) { if res, early := nullCmp(v, op); early { return res, nil } @@ -706,7 +707,7 @@ func (d *DecimalValue) Compare(v Value, op ComparisonOp) (Value, error) { return cmpIntegers(res, 0, op) } -func (d *DecimalValue) Arithmetic(v ScalarValue, op ArithmeticOp) (ScalarValue, error) { +func (d *DecimalValue) Arithmetic(v ScalarValue, op common.ArithmeticOp) (ScalarValue, error) { // we perform an extra check here to ensure scale and precision are the same if !v.Type().EqualsStrict(d.Type()) { return nil, makeTypeErr(d, v) @@ -720,15 +721,15 @@ func (d *DecimalValue) Arithmetic(v ScalarValue, op ArithmeticOp) (ScalarValue, var d2 *decimal.Decimal var err error switch op { - case Add: + case common.Add: d2, err = decimal.Add(d.Dec, val2.Dec) - case Sub: + case common.Sub: d2, err = decimal.Sub(d.Dec, val2.Dec) - case Mul: + case common.Mul: d2, err = decimal.Mul(d.Dec, val2.Dec) - case Div: + case common.Div: d2, err = decimal.Div(d.Dec, val2.Dec) - case Mod: + case common.Mod: d2, err = decimal.Mod(d.Dec, val2.Dec) default: return nil, fmt.Errorf("unexpected operator id %d for decimal", op) @@ -743,15 +744,15 @@ func (d *DecimalValue) Arithmetic(v ScalarValue, op ArithmeticOp) (ScalarValue, } -func (d *DecimalValue) Unary(op UnaryOp) (ScalarValue, error) { +func (d *DecimalValue) Unary(op common.UnaryOp) (ScalarValue, error) { switch op { - case Neg: + case common.Neg: dec2 := d.Dec.Copy() err := dec2.Neg() return &DecimalValue{ Dec: dec2, }, err - case Pos: + case common.Pos: return d, nil default: return nil, fmt.Errorf("unexpected operator id %d for decimal", op) @@ -795,7 +796,7 @@ type Uint256Value struct { Val *types.Uint256 } -func (u *Uint256Value) Compare(v Value, op ComparisonOp) (Value, error) { +func (u *Uint256Value) Compare(v Value, op common.ComparisonOp) (Value, error) { if res, early := nullCmp(v, op); early { return res, nil } @@ -810,7 +811,7 @@ func (u *Uint256Value) Compare(v Value, op ComparisonOp) (Value, error) { return cmpIntegers(c, 0, op) } -func (u *Uint256Value) Arithmetic(v ScalarValue, op ArithmeticOp) (ScalarValue, error) { +func (u *Uint256Value) Arithmetic(v ScalarValue, op common.ArithmeticOp) (ScalarValue, error) { if _, ok := v.(*NullValue); ok { return &NullValue{DataType: types.Uint256Type}, nil } @@ -821,22 +822,22 @@ func (u *Uint256Value) Arithmetic(v ScalarValue, op ArithmeticOp) (ScalarValue, } switch op { - case Add: + case common.Add: res := u.Val.Add(val2.Val) return &Uint256Value{Val: res}, nil - case Sub: + case common.Sub: res, err := u.Val.Sub(val2.Val) return &Uint256Value{Val: res}, err - case Mul: + case common.Mul: res, err := u.Val.Mul(val2.Val) return &Uint256Value{Val: res}, err - case Div: + case common.Div: if val2.Val.Cmp(types.Uint256FromInt(0)) == 0 { return nil, fmt.Errorf("cannot divide by zero") } res := u.Val.Div(val2.Val) return &Uint256Value{Val: res}, nil - case Mod: + case common.Mod: if val2.Val.Cmp(types.Uint256FromInt(0)) == 0 { return nil, fmt.Errorf("cannot divide by zero") } @@ -847,13 +848,13 @@ func (u *Uint256Value) Arithmetic(v ScalarValue, op ArithmeticOp) (ScalarValue, } } -func (u *Uint256Value) Unary(op UnaryOp) (ScalarValue, error) { +func (u *Uint256Value) Unary(op common.UnaryOp) (ScalarValue, error) { switch op { - case Neg: + case common.Neg: return nil, fmt.Errorf("cannot apply unary negation to a uint256") - case Not: + case common.Not: return nil, fmt.Errorf("cannot apply logical NOT to a uint256") - case Pos: + case common.Pos: return u, nil default: return nil, fmt.Errorf("unknown unary operator: %d", op) @@ -913,7 +914,7 @@ type IntArrayValue struct { Val []*int64 } -func (a *IntArrayValue) Compare(v Value, op ComparisonOp) (Value, error) { +func (a *IntArrayValue) Compare(v Value, op common.ComparisonOp) (Value, error) { if res, early := nullCmp(v, op); early { return res, nil } @@ -938,13 +939,13 @@ func (a *IntArrayValue) Compare(v Value, op ComparisonOp) (Value, error) { var b bool switch op { - case Equal: + case common.Equal: b = *v1 == *v2 - case LessThan: + case common.LessThan: b = *v1 < *v2 - case GreaterThan: + case common.GreaterThan: b = *v1 > *v2 - case IsDistinctFrom: + case common.IsDistinctFrom: b = *v1 != *v2 default: return nil, fmt.Errorf("unknown comparison operator: %d", op) @@ -1015,11 +1016,77 @@ func (a *IntArrayValue) Size() int { return size } +func (a *IntArrayValue) Cast(t *types.DataType) (Value, error) { + switch t { + case types.TextArrayType: + res := make([]*string, len(a.Val)) + for i, v := range a.Val { + if v == nil { + res[i] = nil + } else { + res[i] = new(string) + *res[i] = strconv.FormatInt(*v, 10) + } + } + + return &TextArrayValue{ + Val: res, + }, nil + case types.BoolArrayType: + res := make([]*bool, len(a.Val)) + for i, v := range a.Val { + if v == nil { + res[i] = nil + } else { + b := *v != 0 + res[i] = &b + } + } + + return &BoolArrayValue{ + Val: res, + }, nil + case types.Uint256ArrayType: + res := make([]*types.Uint256, len(a.Val)) + for i, v := range a.Val { + if v == nil { + res[i] = nil + } else { + res[i] = types.Uint256FromInt(uint64(*v)) + } + } + + return &Uint256ArrayValue{ + Val: res, + }, nil + case types.DecimalArrayType: + res := make([]*decimal.Decimal, len(a.Val)) + for i, v := range a.Val { + if v == nil { + res[i] = nil + } else { + dec, err := decimal.NewFromBigInt(big.NewInt(*v), 0) + if err != nil { + return nil, err + } + res[i] = dec + } + } + + return &DecimalArrayValue{ + Val: res, + DataType: types.DecimalType, + }, nil + default: + return nil, fmt.Errorf("cannot cast int array to %s", t) + } +} + type TextArrayValue struct { Val []*string } -func (a *TextArrayValue) Compare(v Value, op ComparisonOp) (Value, error) { +func (a *TextArrayValue) Compare(v Value, op common.ComparisonOp) (Value, error) { if res, early := nullCmp(v, op); early { return res, nil } @@ -1044,13 +1111,13 @@ func (a *TextArrayValue) Compare(v Value, op ComparisonOp) (Value, error) { var b bool switch op { - case Equal: + case common.Equal: b = *v1 == *v2 - case LessThan: + case common.LessThan: b = *v1 < *v2 - case GreaterThan: + case common.GreaterThan: b = *v1 > *v2 - case IsDistinctFrom: + case common.IsDistinctFrom: b = *v1 != *v2 default: return nil, fmt.Errorf("unknown comparison operator: %d", op) @@ -1121,11 +1188,104 @@ func (a *TextArrayValue) Size() int { return size } +func (a *TextArrayValue) Cast(t *types.DataType) (Value, error) { + switch t { + case types.IntArrayType: + res := make([]*int64, len(a.Val)) + for i, v := range a.Val { + if v == nil { + res[i] = nil + } else { + i, err := strconv.ParseInt(*v, 10, 64) + if err != nil { + return nil, fmt.Errorf("cannot cast text array to int array: %w", err) + } + res[i] = &i + } + } + + return &IntArrayValue{ + Val: res, + }, nil + case types.BoolArrayType: + res := make([]*bool, len(a.Val)) + for i, v := range a.Val { + if v == nil { + res[i] = nil + } else { + b, err := strconv.ParseBool(*v) + if err != nil { + return nil, fmt.Errorf("cannot cast text array to bool array: %w", err) + } + res[i] = &b + } + } + + return &BoolArrayValue{ + Val: res, + }, nil + case types.Uint256ArrayType: + res := make([]*types.Uint256, len(a.Val)) + for i, v := range a.Val { + if v == nil { + res[i] = nil + } else { + u, err := types.Uint256FromString(*v) + if err != nil { + return nil, fmt.Errorf("cannot cast text array to uint256 array: %w", err) + } + res[i] = u + } + } + + return &Uint256ArrayValue{ + Val: res, + }, nil + case types.DecimalArrayType: + res := make([]*decimal.Decimal, len(a.Val)) + for i, v := range a.Val { + if v == nil { + res[i] = nil + } else { + dec, err := decimal.NewFromString(*v) + if err != nil { + return nil, fmt.Errorf("cannot cast text array to decimal array: %w", err) + } + res[i] = dec + } + } + + return &DecimalArrayValue{ + Val: res, + DataType: types.DecimalType, + }, nil + case types.UUIDArrayType: + res := make([]*types.UUID, len(a.Val)) + for i, v := range a.Val { + if v == nil { + res[i] = nil + } else { + u, err := types.ParseUUID(*v) + if err != nil { + return nil, fmt.Errorf("cannot cast text array to uuid array: %w", err) + } + res[i] = u + } + } + + return &UuidArrayValue{ + Val: res, + }, nil + default: + return nil, fmt.Errorf("cannot cast text array to %s", t) + } +} + type BoolArrayValue struct { Val []*bool } -func (a *BoolArrayValue) Compare(v Value, op ComparisonOp) (Value, error) { +func (a *BoolArrayValue) Compare(v Value, op common.ComparisonOp) (Value, error) { if res, early := nullCmp(v, op); early { return res, nil } @@ -1150,9 +1310,9 @@ func (a *BoolArrayValue) Compare(v Value, op ComparisonOp) (Value, error) { var b bool switch op { - case Equal: + case common.Equal: b = *v1 == *v2 - case IsDistinctFrom: + case common.IsDistinctFrom: b = *v1 != *v2 default: return nil, fmt.Errorf("unknown comparison operator: %d", op) @@ -1217,12 +1377,44 @@ func (a *BoolArrayValue) Size() int { return len(a.Val) } +func (a *BoolArrayValue) Cast(t *types.DataType) (Value, error) { + switch t { + case types.TextArrayType: + res := make([]*string, len(a.Val)) + for i, v := range a.Val { + if v == nil { + res[i] = nil + } else { + s := strconv.FormatBool(*v) + res[i] = &s + } + } + return &TextArrayValue{Val: res}, nil + case types.IntArrayType: + res := make([]*int64, len(a.Val)) + for i, v := range a.Val { + if v == nil { + res[i] = nil + } else { + var i int64 + if *v { + i = 1 + } + res[i] = &i + } + } + return &IntArrayValue{Val: res}, nil + default: + return nil, fmt.Errorf("cannot cast bool array to %s", t) + } +} + type DecimalArrayValue struct { Val []*decimal.Decimal DataType *types.DataType } -func (a *DecimalArrayValue) Compare(v Value, op ComparisonOp) (Value, error) { +func (a *DecimalArrayValue) Compare(v Value, op common.ComparisonOp) (Value, error) { if res, early := nullCmp(v, op); early { return res, nil } @@ -1360,11 +1552,43 @@ func (a *DecimalArrayValue) Size() int { return size } +func (a *DecimalArrayValue) Cast(t *types.DataType) (Value, error) { + switch t { + case types.TextArrayType: + res := make([]*string, len(a.Val)) + for i, v := range a.Val { + if v == nil { + res[i] = nil + } else { + s := v.String() + res[i] = &s + } + } + return &TextArrayValue{Val: res}, nil + case types.IntArrayType: + res := make([]*int64, len(a.Val)) + for i, v := range a.Val { + if v == nil { + res[i] = nil + } else { + i, err := v.Int64() + if err != nil { + return nil, fmt.Errorf("cannot cast decimal to int: %w", err) + } + res[i] = &i + } + } + return &IntArrayValue{Val: res}, nil + default: + return nil, fmt.Errorf("cannot cast decimal array to %s", t) + } +} + type Uint256ArrayValue struct { Val []*types.Uint256 } -func (a *Uint256ArrayValue) Compare(v Value, op ComparisonOp) (Value, error) { +func (a *Uint256ArrayValue) Compare(v Value, op common.ComparisonOp) (Value, error) { if res, early := nullCmp(v, op); early { return res, nil } @@ -1389,13 +1613,13 @@ func (a *Uint256ArrayValue) Compare(v Value, op ComparisonOp) (Value, error) { var b bool switch op { - case Equal: + case common.Equal: b = v1.Cmp(v2) == 0 - case LessThan: + case common.LessThan: b = v1.Cmp(v2) < 0 - case GreaterThan: + case common.GreaterThan: b = v1.Cmp(v2) > 0 - case IsDistinctFrom: + case common.IsDistinctFrom: b = v1.Cmp(v2) != 0 default: return nil, fmt.Errorf("unknown comparison operator: %d", op) @@ -1466,11 +1690,29 @@ func (a *Uint256ArrayValue) Size() int { return size } +func (a *Uint256ArrayValue) Cast(t *types.DataType) (Value, error) { + switch t { + case types.TextArrayType: + res := make([]*string, len(a.Val)) + for i, v := range a.Val { + if v == nil { + res[i] = nil + } else { + s := v.String() + res[i] = &s + } + } + return &TextArrayValue{Val: res}, nil + default: + return nil, fmt.Errorf("cannot cast uint256 array to %s", t) + } +} + type BlobArrayValue struct { Val [][]byte } -func (a *BlobArrayValue) Compare(v Value, op ComparisonOp) (Value, error) { +func (a *BlobArrayValue) Compare(v Value, op common.ComparisonOp) (Value, error) { if res, early := nullCmp(v, op); early { return res, nil } @@ -1495,9 +1737,9 @@ func (a *BlobArrayValue) Compare(v Value, op ComparisonOp) (Value, error) { var b bool switch op { - case Equal: + case common.Equal: b = string(v1) == string(v2) - case IsDistinctFrom: + case common.IsDistinctFrom: b = string(v1) != string(v2) default: return nil, fmt.Errorf("unknown comparison operator: %d", op) @@ -1568,11 +1810,29 @@ func (a *BlobArrayValue) Size() int { return size } +func (a *BlobArrayValue) Cast(t *types.DataType) (Value, error) { + switch t { + case types.TextArrayType: + res := make([]*string, len(a.Val)) + for i, v := range a.Val { + if v == nil { + res[i] = nil + } else { + s := string(v) + res[i] = &s + } + } + return &TextArrayValue{Val: res}, nil + default: + return nil, fmt.Errorf("cannot cast blob array to %s", t) + } +} + type UuidArrayValue struct { Val []*types.UUID } -func (a *UuidArrayValue) Compare(v Value, op ComparisonOp) (Value, error) { +func (a *UuidArrayValue) Compare(v Value, op common.ComparisonOp) (Value, error) { if res, early := nullCmp(v, op); early { return res, nil } @@ -1597,9 +1857,9 @@ func (a *UuidArrayValue) Compare(v Value, op ComparisonOp) (Value, error) { var b bool switch op { - case Equal: + case common.Equal: b = *v1 == *v2 - case IsDistinctFrom: + case common.IsDistinctFrom: b = *v1 != *v2 default: return nil, fmt.Errorf("unknown comparison operator: %d", op) @@ -1670,31 +1930,49 @@ func (a *UuidArrayValue) Size() int { return size } +func (a *UuidArrayValue) Cast(t *types.DataType) (Value, error) { + switch t { + case types.TextArrayType: + res := make([]*string, len(a.Val)) + for i, v := range a.Val { + if v == nil { + res[i] = nil + } else { + s := v.String() + res[i] = &s + } + } + return &TextArrayValue{Val: res}, nil + default: + return nil, fmt.Errorf("cannot cast uuid array to %s", t) + } +} + type NullValue struct { DataType *types.DataType } -func (n *NullValue) Compare(v Value, op ComparisonOp) (Value, error) { +func (n *NullValue) Compare(v Value, op common.ComparisonOp) (Value, error) { if _, ok := v.(*NullValue); !ok { return &NullValue{DataType: n.DataType}, nil } - if op == IsDistinctFrom { + if op == common.IsDistinctFrom { return &BoolValue{Val: false}, nil } - if op == Is { + if op == common.Is { return &BoolValue{Val: true}, nil } return &NullValue{DataType: n.DataType}, nil } -func (n *NullValue) Arithmetic(v ScalarValue, op ArithmeticOp) (ScalarValue, error) { +func (n *NullValue) Arithmetic(v ScalarValue, op common.ArithmeticOp) (ScalarValue, error) { return &NullValue{DataType: n.DataType}, nil } -func (n *NullValue) Unary(op UnaryOp) (ScalarValue, error) { +func (n *NullValue) Unary(op common.UnaryOp) (ScalarValue, error) { return &NullValue{DataType: n.DataType}, nil } @@ -1735,7 +2013,7 @@ type RecordValue struct { Order []string } -func (o *RecordValue) Compare(v Value, op ComparisonOp) (Value, error) { +func (o *RecordValue) Compare(v Value, op common.ComparisonOp) (Value, error) { if res, early := nullCmp(v, op); early { return res, nil } @@ -1758,7 +2036,7 @@ func (o *RecordValue) Compare(v Value, op ComparisonOp) (Value, error) { break } - eq, err := o.Fields[field].Compare(v2, Equal) + eq, err := o.Fields[field].Compare(v2, common.Equal) if err != nil { return nil, err } @@ -1777,7 +2055,7 @@ func (o *RecordValue) Compare(v Value, op ComparisonOp) (Value, error) { } switch op { - case Equal: + case common.Equal: return &BoolValue{Val: isSame}, nil default: return nil, fmt.Errorf("unknown comparison operator: %d", op) @@ -1801,17 +2079,19 @@ func (o *RecordValue) Size() int { return size } -const nullSize = 1 +func (o *RecordValue) Cast(t *types.DataType) (Value, error) { + return nil, fmt.Errorf("cannot cast record to %s", t) +} -func cmpIntegers(a, b int, op ComparisonOp) (Value, error) { +func cmpIntegers(a, b int, op common.ComparisonOp) (Value, error) { switch op { - case Equal: + case common.Equal: return &BoolValue{Val: a == b}, nil - case LessThan: + case common.LessThan: return &BoolValue{Val: a < b}, nil - case GreaterThan: + case common.GreaterThan: return &BoolValue{Val: a > b}, nil - case IsDistinctFrom: + case common.IsDistinctFrom: return &BoolValue{Val: a != b}, nil default: return nil, fmt.Errorf("unknown comparison operator: %d", op)