diff --git a/CHANGELOG.md b/CHANGELOG.md index c8385b24..6e287b34 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Changed - *BREAKING* partiql-parser: Added a source location to `ParseError::UnexpectedEndOfInput` +- *BREAKING* partiql-ast: Changed the modelling of parsed literals. +- *BREAKING* partiql-logical: Changed the modelling of logical plan literals. - partiql-eval: Fixed behavior of comparison and `BETWEEN` operations w.r.t. type mismatches ### Added diff --git a/extension/partiql-extension-visualize/src/ast_to_dot.rs b/extension/partiql-extension-visualize/src/ast_to_dot.rs index 3c641eb9..74dd1283 100644 --- a/extension/partiql-extension-visualize/src/ast_to_dot.rs +++ b/extension/partiql-extension-visualize/src/ast_to_dot.rs @@ -185,7 +185,7 @@ fn lit_to_str(ast: &ast::Lit) -> String { Lit::FloatLit(l) => l.to_string(), Lit::DoubleLit(l) => l.to_string(), Lit::BoolLit(l) => (if *l { "TRUE" } else { "FALSE" }).to_string(), - Lit::IonStringLit(l) => format!("`{}`", l), + Lit::EmbeddedDocLit(l) => format!("`{}`", l), Lit::CharStringLit(l) => format!("'{}'", l), Lit::NationalCharStringLit(l) => format!("'{}'", l), Lit::BitStringLit(l) => format!("b'{}'", l), diff --git a/partiql-ast/src/ast.rs b/partiql-ast/src/ast.rs index 782ddeb6..25475279 100644 --- a/partiql-ast/src/ast.rs +++ b/partiql-ast/src/ast.rs @@ -444,7 +444,7 @@ pub enum Lit { #[visit(skip)] BoolLit(bool), #[visit(skip)] - IonStringLit(String), + EmbeddedDocLit(String), #[visit(skip)] CharStringLit(String), #[visit(skip)] @@ -454,16 +454,41 @@ pub enum Lit { #[visit(skip)] HexStringLit(String), #[visit(skip)] - StructLit(AstNode), + StructLit(AstNode), #[visit(skip)] - BagLit(AstNode), + BagLit(AstNode), #[visit(skip)] - ListLit(AstNode), + ListLit(AstNode), /// E.g. `TIME WITH TIME ZONE` in `SELECT TIME WITH TIME ZONE '12:00' FROM ...` #[visit(skip)] TypedLit(String, Type), } +#[derive(Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct LitField { + pub first: String, + pub second: AstNode, +} + +#[derive(Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct StructLit { + pub fields: Vec, +} + +#[derive(Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct BagLit { + pub values: Vec, +} + +#[derive(Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct ListLit { + pub values: Vec, +} + #[derive(Visit, Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct VarRef { diff --git a/partiql-ast/src/pretty.rs b/partiql-ast/src/pretty.rs index e7d6ba29..3f1c818f 100644 --- a/partiql-ast/src/pretty.rs +++ b/partiql-ast/src/pretty.rs @@ -394,7 +394,7 @@ impl PrettyDoc for Lit { Lit::FloatLit(inner) => arena.text(inner.to_string()), Lit::DoubleLit(inner) => arena.text(inner.to_string()), Lit::BoolLit(inner) => arena.text(inner.to_string()), - Lit::IonStringLit(inner) => inner.pretty_doc(arena), + Lit::EmbeddedDocLit(inner) => inner.pretty_doc(arena), // TODO better pretty for embedded doc: https://github.com/partiql/partiql-lang-rust/issues/508 Lit::CharStringLit(inner) => inner.pretty_doc(arena), Lit::NationalCharStringLit(inner) => inner.pretty_doc(arena), Lit::BitStringLit(inner) => inner.pretty_doc(arena), @@ -699,6 +699,38 @@ impl PrettyDoc for StructExprPair { } } +impl PrettyDoc for StructLit { + fn pretty_doc<'b, D, A>(&'b self, arena: &'b D) -> DocBuilder<'b, D, A> + where + D: DocAllocator<'b, A>, + D::Doc: Clone, + A: Clone, + { + let wrapped = self.fields.iter().map(|p| unsafe { + let x: &'b StructLitField = std::mem::transmute(p); + x + }); + pretty_seq(wrapped, "{", "}", ",", PRETTY_INDENT_MINOR_NEST, arena) + } +} + +pub struct StructLitField(pub LitField); + +impl PrettyDoc for StructLitField { + fn pretty_doc<'b, D, A>(&'b self, arena: &'b D) -> DocBuilder<'b, D, A> + where + D: DocAllocator<'b, A>, + D::Doc: Clone, + A: Clone, + { + let k = self.0.first.pretty_doc(arena); + let v = self.0.second.pretty_doc(arena); + let sep = arena.text(": "); + + k.append(sep).group().append(v).group() + } +} + impl PrettyDoc for Bag { fn pretty_doc<'b, D, A>(&'b self, arena: &'b D) -> DocBuilder<'b, D, A> where @@ -728,6 +760,35 @@ impl PrettyDoc for List { } } +impl PrettyDoc for BagLit { + fn pretty_doc<'b, D, A>(&'b self, arena: &'b D) -> DocBuilder<'b, D, A> + where + D: DocAllocator<'b, A>, + D::Doc: Clone, + A: Clone, + { + pretty_seq( + &self.values, + "<<", + ">>", + ",", + PRETTY_INDENT_MINOR_NEST, + arena, + ) + } +} + +impl PrettyDoc for ListLit { + fn pretty_doc<'b, D, A>(&'b self, arena: &'b D) -> DocBuilder<'b, D, A> + where + D: DocAllocator<'b, A>, + D::Doc: Clone, + A: Clone, + { + pretty_seq(&self.values, "[", "]", ",", PRETTY_INDENT_MINOR_NEST, arena) + } +} + impl PrettyDoc for Sexp { fn pretty_doc<'b, D, A>(&'b self, _arena: &'b D) -> DocBuilder<'b, D, A> where diff --git a/partiql-eval/src/lib.rs b/partiql-eval/src/lib.rs index 69240b44..45370836 100644 --- a/partiql-eval/src/lib.rs +++ b/partiql-eval/src/lib.rs @@ -162,7 +162,12 @@ mod tests { // TODO: once eval conformance tests added and/or modified evaluation API (to support other values // in evaluator output), change or delete tests using this function #[track_caller] - fn eval_bin_op(op: BinaryOp, lhs: Value, rhs: Value, expected_first_elem: Value) { + fn eval_bin_op>( + op: BinaryOp, + lhs: Value, + rhs_lit: I, + expected_first_elem: Value, + ) { let mut plan = LogicalPlan::new(); let scan = plan.add_operator(BindingsOp::Scan(logical::Scan { expr: ValueExpr::VarRef( @@ -187,7 +192,7 @@ mod tests { "lhs".to_string().into(), ))], )), - Box::new(ValueExpr::Lit(Box::new(rhs))), + Box::new(ValueExpr::Lit(Box::new(rhs_lit.into()))), ), )]), })); @@ -643,13 +648,13 @@ mod tests { #[test] fn and_or_null() { #[track_caller] - fn eval_to_null(op: BinaryOp, lhs: Value, rhs: Value) { + fn eval_to_null>(op: BinaryOp, lhs: I, rhs: I) { let mut plan = LogicalPlan::new(); let expq = plan.add_operator(BindingsOp::ExprQuery(ExprQuery { expr: ValueExpr::BinaryExpr( op, - Box::new(ValueExpr::Lit(Box::new(lhs))), - Box::new(ValueExpr::Lit(Box::new(rhs))), + Box::new(ValueExpr::Lit(Box::new(lhs.into()))), + Box::new(ValueExpr::Lit(Box::new(rhs.into()))), ), })); @@ -697,8 +702,8 @@ mod tests { "value".to_string().into(), ))], )), - from: Box::new(ValueExpr::Lit(Box::new(from))), - to: Box::new(ValueExpr::Lit(Box::new(to))), + from: Box::new(ValueExpr::Lit(Box::new(from.into()))), + to: Box::new(ValueExpr::Lit(Box::new(to.into()))), }), )]), })); @@ -908,7 +913,7 @@ mod tests { kind: JoinKind::Left, left: Box::new(from_lhs), right: Box::new(from_rhs), - on: Some(ValueExpr::Lit(Box::new(Value::from(true)))), + on: Some(ValueExpr::Lit(Box::new(Value::from(true).into()))), })); let sink = lg.add_operator(BindingsOp::Sink); @@ -936,17 +941,21 @@ mod tests { expr: Box::new(path_var("n", "a")), cases: vec![ ( - Box::new(ValueExpr::Lit(Box::new(Value::Integer(1)))), - Box::new(ValueExpr::Lit(Box::new(Value::from("one".to_string())))), + Box::new(ValueExpr::Lit(Box::new(Value::Integer(1).into()))), + Box::new(ValueExpr::Lit(Box::new( + Value::from("one".to_string()).into(), + ))), ), ( - Box::new(ValueExpr::Lit(Box::new(Value::Integer(2)))), - Box::new(ValueExpr::Lit(Box::new(Value::from("two".to_string())))), + Box::new(ValueExpr::Lit(Box::new(Value::Integer(2).into()))), + Box::new(ValueExpr::Lit(Box::new( + Value::from("two".to_string()).into(), + ))), ), ], - default: Some(Box::new(ValueExpr::Lit(Box::new(Value::from( - "other".to_string(), - ))))), + default: Some(Box::new(ValueExpr::Lit(Box::new( + Value::from("other".to_string()).into(), + )))), } } @@ -957,22 +966,26 @@ mod tests { Box::new(ValueExpr::BinaryExpr( BinaryOp::Eq, Box::new(path_var("n", "a")), - Box::new(ValueExpr::Lit(Box::new(Value::Integer(1)))), + Box::new(ValueExpr::Lit(Box::new(Value::Integer(1).into()))), )), - Box::new(ValueExpr::Lit(Box::new(Value::from("one".to_string())))), + Box::new(ValueExpr::Lit(Box::new( + Value::from("one".to_string()).into(), + ))), ), ( Box::new(ValueExpr::BinaryExpr( BinaryOp::Eq, Box::new(path_var("n", "a")), - Box::new(ValueExpr::Lit(Box::new(Value::Integer(2)))), + Box::new(ValueExpr::Lit(Box::new(Value::Integer(2).into()))), )), - Box::new(ValueExpr::Lit(Box::new(Value::from("two".to_string())))), + Box::new(ValueExpr::Lit(Box::new( + Value::from("two".to_string()).into(), + ))), ), ], - default: Some(Box::new(ValueExpr::Lit(Box::new(Value::from( - "other".to_string(), - ))))), + default: Some(Box::new(ValueExpr::Lit(Box::new( + Value::from("other".to_string()).into(), + )))), } } @@ -1236,7 +1249,7 @@ mod tests { "lhs".to_string().into(), ))], )), - rhs: Box::new(ValueExpr::Lit(Box::new(rhs))), + rhs: Box::new(ValueExpr::Lit(Box::new(rhs.into()))), }), )]), })); @@ -1387,7 +1400,7 @@ mod tests { println!("{:?}", &out); assert_eq!(out, expected); } - let list = ValueExpr::Lit(Box::new(Value::List(Box::new(list![1, 2, 3])))); + let list = ValueExpr::Lit(Box::new(Value::List(Box::new(list![1, 2, 3])).into())); // `[1,2,3][0]` -> `1` let index = ValueExpr::Path(Box::new(list.clone()), vec![PathComponent::Index(0)]); @@ -1406,7 +1419,7 @@ mod tests { test(index, Value::Integer(3)); // `{'a':10}[''||'a']` -> `10` - let tuple = ValueExpr::Lit(Box::new(Value::Tuple(Box::new(tuple![("a", 10)])))); + let tuple = ValueExpr::Lit(Box::new(Value::Tuple(Box::new(tuple![("a", 10)])).into())); let index_expr = ValueExpr::BinaryExpr( BinaryOp::Concat, Box::new(ValueExpr::Lit(Box::new("".into()))), @@ -1499,7 +1512,7 @@ mod tests { expr: ValueExpr::BinaryExpr( BinaryOp::Mul, Box::new(va), - Box::new(ValueExpr::Lit(Box::new(Value::Integer(2)))), + Box::new(ValueExpr::Lit(Box::new(Value::Integer(2).into()))), ), })); @@ -1578,7 +1591,7 @@ mod tests { tuple_expr.values.push(ValueExpr::BinaryExpr( BinaryOp::Mul, Box::new(va), - Box::new(ValueExpr::Lit(Box::new(Value::Integer(2)))), + Box::new(ValueExpr::Lit(Box::new(Value::Integer(2).into()))), )); let project = lg.add_operator(ProjectValue(logical::ProjectValue { @@ -1740,7 +1753,7 @@ mod tests { list_expr.elements.push(ValueExpr::BinaryExpr( BinaryOp::Mul, Box::new(va), - Box::new(ValueExpr::Lit(Box::new(Value::Integer(2)))), + Box::new(ValueExpr::Lit(Box::new(Value::Integer(2).into()))), )); let select_value = lg.add_operator(ProjectValue(logical::ProjectValue { @@ -1966,7 +1979,7 @@ mod tests { "balance".to_string().into(), ))], )), - Box::new(ValueExpr::Lit(Box::new(Value::Integer(0)))), + Box::new(ValueExpr::Lit(Box::new(Value::Integer(0).into()))), ), })); @@ -2096,7 +2109,7 @@ mod tests { ValueExpr::BinaryExpr( BinaryOp::Mul, Box::new(va), - Box::new(ValueExpr::Lit(Box::new(Value::Integer(2)))), + Box::new(ValueExpr::Lit(Box::new(Value::Integer(2).into()))), ), )]), })); @@ -2170,7 +2183,7 @@ mod tests { ValueExpr::BinaryExpr( BinaryOp::Mul, Box::new(va), - Box::new(ValueExpr::Lit(Box::new(Value::Integer(2)))), + Box::new(ValueExpr::Lit(Box::new(Value::Integer(2).into()))), ), )]), })); diff --git a/partiql-eval/src/plan.rs b/partiql-eval/src/plan.rs index 961d6187..f269fe34 100644 --- a/partiql-eval/src/plan.rs +++ b/partiql-eval/src/plan.rs @@ -1,13 +1,12 @@ use itertools::{Either, Itertools}; use partiql_logical as logical; -use petgraph::prelude::StableGraph; -use std::collections::HashMap; - use partiql_logical::{ AggFunc, BagOperator, BinaryOp, BindingsOp, CallName, GroupingStrategy, IsTypeExpr, JoinKind, LogicalPlan, OpId, PathComponent, Pattern, PatternMatchExpr, SearchedCase, SetQuantifier, SortSpecNullOrder, SortSpecOrder, Type, UnaryOp, ValueExpr, VarRefType, }; +use petgraph::prelude::StableGraph; +use std::collections::HashMap; use crate::error::{ErrorNode, PlanErr, PlanningError}; use crate::eval; @@ -25,6 +24,7 @@ use crate::eval::expr::{ }; use crate::eval::EvalPlan; use partiql_catalog::catalog::{Catalog, FunctionEntryFunction}; +use partiql_value::Value; use partiql_value::Value::Null; #[macro_export] @@ -392,7 +392,10 @@ impl<'c> EvaluatorPlanner<'c> { ), ValueExpr::Lit(lit) => ( "literal", - EvalLitExpr { lit: *lit.clone() }.bind::<{ STRICT }>(vec![]), + EvalLitExpr { + lit: Value::from(lit.as_ref().clone()), + } + .bind::<{ STRICT }>(vec![]), ), ValueExpr::Path(expr, components) => ( "path", @@ -565,7 +568,7 @@ impl<'c> EvaluatorPlanner<'c> { n.lhs.clone(), n.rhs.clone(), )), - Box::new(ValueExpr::Lit(Box::new(Null))), + Box::new(ValueExpr::Lit(Box::new(logical::Lit::Null))), )], default: Some(n.lhs.clone()), }); @@ -783,7 +786,6 @@ mod tests { use partiql_catalog::catalog::PartiqlCatalog; use partiql_logical::CallExpr; use partiql_logical::ExprQuery; - use partiql_value::Value; #[test] fn test_logical_to_eval_plan_bad_num_arguments() { @@ -794,7 +796,7 @@ mod tests { // report the error. let mut logical = LogicalPlan::new(); fn lit_int(i: usize) -> ValueExpr { - ValueExpr::Lit(Box::new(Value::from(i))) + ValueExpr::Lit(Box::new(logical::Lit::Int64(i as i64))) } let expq = logical.add_operator(BindingsOp::ExprQuery(ExprQuery { diff --git a/partiql-logical-planner/Cargo.toml b/partiql-logical-planner/Cargo.toml index 3d3f0f04..710376e8 100644 --- a/partiql-logical-planner/Cargo.toml +++ b/partiql-logical-planner/Cargo.toml @@ -25,13 +25,11 @@ partiql-ast = { path = "../partiql-ast", version = "0.11.*" } partiql-ast-passes = { path = "../partiql-ast-passes", version = "0.11.*" } partiql-catalog = { path = "../partiql-catalog", version = "0.11.*" } partiql-common = { path = "../partiql-common", version = "0.11.*" } -partiql-extension-ion = { path = "../extension/partiql-extension-ion", version = "0.11.*" } partiql-parser = { path = "../partiql-parser", version = "0.11.*" } partiql-logical = { path = "../partiql-logical", version = "0.11.*" } partiql-types = { path = "../partiql-types", version = "0.11.*" } partiql-value = { path = "../partiql-value", version = "0.11.*" } -ion-rs_old = { version = "0.18", package = "ion-rs" } ordered-float = "4" itertools = "0.13" unicase = "2.7" diff --git a/partiql-logical-planner/src/builtins.rs b/partiql-logical-planner/src/builtins.rs index 86e83bcc..ce577895 100644 --- a/partiql-logical-planner/src/builtins.rs +++ b/partiql-logical-planner/src/builtins.rs @@ -2,7 +2,6 @@ use itertools::Itertools; use once_cell::sync::Lazy; use partiql_logical as logical; use partiql_logical::{SetQuantifier, ValueExpr}; -use partiql_value::Value; use std::collections::HashMap; use std::fmt::Debug; @@ -135,7 +134,7 @@ fn function_call_def_substring() -> CallDef { CallSpec { input: vec![CallSpecArg::Positional, CallSpecArg::Named("for".into())], output: Box::new(|mut args| { - args.insert(1, ValueExpr::Lit(Box::new(Value::Integer(0)))); + args.insert(1, ValueExpr::Lit(Box::new(logical::Lit::Int8(0)))); logical::ValueExpr::Call(logical::CallExpr { name: logical::CallName::Substring, arguments: args, @@ -241,7 +240,7 @@ fn function_call_def_trim() -> CallDef { output: Box::new(|mut args| { args.insert( 0, - ValueExpr::Lit(Box::new(Value::String(" ".to_string().into()))), + ValueExpr::Lit(Box::new(logical::Lit::String(" ".to_string()))), ); logical::ValueExpr::Call(logical::CallExpr { @@ -255,7 +254,7 @@ fn function_call_def_trim() -> CallDef { output: Box::new(|mut args| { args.insert( 0, - ValueExpr::Lit(Box::new(Value::String(" ".to_string().into()))), + ValueExpr::Lit(Box::new(logical::Lit::String(" ".to_string()))), ); logical::ValueExpr::Call(logical::CallExpr { name: logical::CallName::BTrim, diff --git a/partiql-logical-planner/src/lower.rs b/partiql-logical-planner/src/lower.rs index 1860027b..cfbd19ff 100644 --- a/partiql-logical-planner/src/lower.rs +++ b/partiql-logical-planner/src/lower.rs @@ -21,7 +21,7 @@ use partiql_logical::{ }; use std::borrow::Cow; -use partiql_value::{BindingsName, Value}; +use partiql_value::BindingsName; use std::collections::{HashMap, HashSet}; @@ -36,8 +36,6 @@ use crate::functions::Function; use partiql_ast_passes::name_resolver::NameRef; use partiql_catalog::catalog::Catalog; use partiql_common::node::NodeId; -use partiql_extension_ion::decode::{IonDecoderBuilder, IonDecoderConfig}; -use partiql_extension_ion::Encoding; use partiql_logical::AggFunc::{AggAny, AggAvg, AggCount, AggEvery, AggMax, AggMin, AggSum}; use partiql_logical::ValueExpr::DynamicLookup; use std::sync::atomic::{AtomicU32, Ordering}; @@ -541,8 +539,8 @@ impl<'a> AstToLogical<'a> { } #[inline] - fn push_value(&mut self, val: Value) { - self.push_vexpr(ValueExpr::Lit(Box::new(val))); + fn push_lit(&mut self, lit: logical::Lit) { + self.push_vexpr(ValueExpr::Lit(Box::new(lit))); } #[inline] @@ -832,7 +830,7 @@ impl<'ast> Visitor<'ast> for AstToLogical<'_> { { let exprs = HashMap::from([( "$__gk".to_string(), - ValueExpr::Lit(Box::new(Value::from(true))), + ValueExpr::Lit(Box::new(logical::Lit::Bool(true))), )]); let group_by: BindingsOp = BindingsOp::GroupBy(logical::GroupBy { strategy: logical::GroupingStrategy::GroupFull, @@ -889,7 +887,7 @@ impl<'ast> Visitor<'ast> for AstToLogical<'_> { let alias = iter.next().unwrap(); let alias = match alias { ValueExpr::Lit(lit) => match *lit { - Value::String(s) => (*s).clone(), + logical::Lit::String(s) => s.clone(), _ => { // Report error but allow visitor to continue self.errors.push(AstTransformError::IllegalState( @@ -951,7 +949,7 @@ impl<'ast> Visitor<'ast> for AstToLogical<'_> { name_resolver::Symbol::Known(sym) => sym.value.clone(), name_resolver::Symbol::Unknown(id) => format!("_{id}"), }; - self.push_value(as_key.into()); + self.push_lit(logical::Lit::String(as_key)); Traverse::Continue } @@ -969,8 +967,8 @@ impl<'ast> Visitor<'ast> for AstToLogical<'_> { if _bin_op.kind == BinOpKind::Is { let is_type = match rhs { ValueExpr::Lit(lit) => match lit.as_ref() { - Value::Null => logical::Type::NullType, - Value::Missing => logical::Type::MissingType, + logical::Lit::Null => logical::Type::NullType, + logical::Lit::Missing => logical::Type::MissingType, _ => { not_yet_implemented_fault!( self, @@ -1079,7 +1077,7 @@ impl<'ast> Visitor<'ast> for AstToLogical<'_> { let escape_ve = if env.len() == 3 { env.pop().unwrap() } else { - ValueExpr::Lit(Box::new(Value::String(Box::default()))) + ValueExpr::Lit(Box::new(logical::Lit::String(String::default()))) }; let pattern_ve = env.pop().unwrap(); let value = Box::new(env.pop().unwrap()); @@ -1087,10 +1085,12 @@ impl<'ast> Visitor<'ast> for AstToLogical<'_> { let pattern = match (&pattern_ve, &escape_ve) { (ValueExpr::Lit(pattern_lit), ValueExpr::Lit(escape_lit)) => { match (pattern_lit.as_ref(), escape_lit.as_ref()) { - (Value::String(pattern), Value::String(escape)) => Pattern::Like(LikeMatch { - pattern: pattern.to_string(), - escape: escape.to_string(), - }), + (logical::Lit::String(pattern), logical::Lit::String(escape)) => { + Pattern::Like(LikeMatch { + pattern: pattern.to_string(), + escape: escape.to_string(), + }) + } _ => Pattern::LikeNonStringNonLiteral(LikeNonStringNonLiteralMatch { pattern: Box::new(pattern_ve), escape: Box::new(escape_ve), @@ -1136,7 +1136,7 @@ impl<'ast> Visitor<'ast> for AstToLogical<'_> { Ok(expr) => expr, Err(err) => { self.errors.push(err); - ValueExpr::Lit(Box::new(Value::Missing)) // dummy expression to allow lowering to continue + ValueExpr::Lit(Box::new(logical::Lit::Missing)) // dummy expression to allow lowering to continue } }; self.push_vexpr(expr); @@ -1178,15 +1178,15 @@ impl<'ast> Visitor<'ast> for AstToLogical<'_> { // Values & Value Constructors fn enter_lit(&mut self, lit: &'ast Lit) -> Traverse { - let val = match lit_to_value(lit) { + let val = match lit_to_lit(lit) { Ok(v) => v, Err(e) => { // Report error but allow visitor to continue self.errors.push(e); - Value::Missing + logical::Lit::Missing } }; - self.push_value(val); + self.push_lit(val); Traverse::Continue } @@ -1281,7 +1281,7 @@ impl<'ast> Visitor<'ast> for AstToLogical<'_> { }, CallArgument::Star => ( logical::SetQuantifier::All, - ValueExpr::Lit(Box::new(Value::Integer(1))), + ValueExpr::Lit(Box::new(logical::Lit::Int8(1))), ), }; @@ -1417,9 +1417,12 @@ impl<'ast> Visitor<'ast> for AstToLogical<'_> { let path = env.pop().unwrap(); match path { ValueExpr::Lit(val) => match *val { - Value::Integer(idx) => logical::PathComponent::Index(idx), - Value::String(k) => logical::PathComponent::Key( - BindingsName::CaseInsensitive(Cow::Owned(*k)), + logical::Lit::Int8(idx) => logical::PathComponent::Index(idx.into()), + logical::Lit::Int16(idx) => logical::PathComponent::Index(idx.into()), + logical::Lit::Int32(idx) => logical::PathComponent::Index(idx.into()), + logical::Lit::Int64(idx) => logical::PathComponent::Index(idx), + logical::Lit::String(k) => logical::PathComponent::Key( + BindingsName::CaseInsensitive(Cow::Owned(k)), ), expr => logical::PathComponent::IndexExpr(Box::new(ValueExpr::Lit( Box::new(expr), @@ -1678,7 +1681,7 @@ impl<'ast> Visitor<'ast> for AstToLogical<'_> { let alias = iter.next().unwrap(); let alias = match alias { ValueExpr::Lit(lit) => match *lit { - Value::String(s) => (*s).clone(), + logical::Lit::String(s) => s.clone(), _ => { // Report error but allow visitor to continue self.errors.push(AstTransformError::IllegalState( @@ -1727,7 +1730,7 @@ impl<'ast> Visitor<'ast> for AstToLogical<'_> { name_resolver::Symbol::Known(sym) => sym.value.clone(), name_resolver::Symbol::Unknown(id) => format!("_{id}"), }; - self.push_value(as_key.into()); + self.push_lit(logical::Lit::String(as_key)); Traverse::Continue } @@ -1892,50 +1895,38 @@ impl<'ast> Visitor<'ast> for AstToLogical<'_> { } } -fn lit_to_value(lit: &Lit) -> Result { - fn expect_lit(v: &Expr) -> Result { - match v { - Expr::Lit(l) => lit_to_value(&l.node), - _ => Err(AstTransformError::IllegalState( - "non literal in literal aggregate".to_string(), - )), - } - } - - fn tuple_pair(pair: &ast::ExprPair) -> Option> { - let key = match expect_lit(pair.first.as_ref()) { - Ok(Value::String(s)) => s.as_ref().clone(), - Ok(_) => { - return Some(Err(AstTransformError::IllegalState( - "non string literal in literal struct key".to_string(), - ))) - } - Err(e) => return Some(Err(e)), - }; - - match expect_lit(pair.second.as_ref()) { - Ok(Value::Missing) => None, - Ok(val) => Some(Ok((key, val))), - Err(e) => Some(Err(e)), +fn lit_to_lit(lit: &Lit) -> Result { + fn tuple_pair( + field: &ast::LitField, + ) -> Option> { + let key = field.first.clone(); + match &field.second.node { + Lit::Missing => None, + value => match lit_to_lit(value) { + Ok(value) => Some(Ok((key, value))), + Err(e) => Some(Err(e)), + }, } } let val = match lit { - Lit::Null => Value::Null, - Lit::Missing => Value::Missing, - Lit::Int8Lit(n) => Value::Integer(i64::from(*n)), - Lit::Int16Lit(n) => Value::Integer(i64::from(*n)), - Lit::Int32Lit(n) => Value::Integer(i64::from(*n)), - Lit::Int64Lit(n) => Value::Integer(*n), - Lit::DecimalLit(d) => Value::Decimal(Box::new(*d)), - Lit::NumericLit(n) => Value::Decimal(Box::new(*n)), - Lit::RealLit(f) => Value::Real(OrderedFloat::from(f64::from(*f))), - Lit::FloatLit(f) => Value::Real(OrderedFloat::from(f64::from(*f))), - Lit::DoubleLit(f) => Value::Real(OrderedFloat::from(*f)), - Lit::BoolLit(b) => Value::Boolean(*b), - Lit::IonStringLit(s) => parse_embedded_ion_str(s)?, - Lit::CharStringLit(s) => Value::String(Box::new(s.clone())), - Lit::NationalCharStringLit(s) => Value::String(Box::new(s.clone())), + Lit::Null => logical::Lit::Null, + Lit::Missing => logical::Lit::Missing, + Lit::Int8Lit(n) => logical::Lit::Int8(*n), + Lit::Int16Lit(n) => logical::Lit::Int16(*n), + Lit::Int32Lit(n) => logical::Lit::Int32(*n), + Lit::Int64Lit(n) => logical::Lit::Int64(*n), + Lit::DecimalLit(d) => logical::Lit::Decimal(*d), + Lit::NumericLit(n) => logical::Lit::Decimal(*n), + Lit::RealLit(f) => logical::Lit::Double(OrderedFloat::from(*f as f64)), + Lit::FloatLit(f) => logical::Lit::Double(OrderedFloat::from(*f as f64)), + Lit::DoubleLit(f) => logical::Lit::Double(OrderedFloat::from(*f)), + Lit::BoolLit(b) => logical::Lit::Bool(*b), + Lit::EmbeddedDocLit(s) => { + logical::Lit::BoxDocument(s.clone().into_bytes(), "Ion".to_string()) + } + Lit::CharStringLit(s) => logical::Lit::String(s.clone()), + Lit::NationalCharStringLit(s) => logical::Lit::String(s.clone()), Lit::BitStringLit(_) => { return Err(AstTransformError::NotYetImplemented( "Lit::BitStringLit".to_string(), @@ -1947,27 +1938,16 @@ fn lit_to_value(lit: &Lit) -> Result { )) } Lit::BagLit(b) => { - let bag: Result = b - .node - .values - .iter() - .map(|l| expect_lit(l.as_ref())) - .collect(); - Value::from(bag?) + let bag: Result<_, _> = b.node.values.iter().map(lit_to_lit).collect(); + logical::Lit::Bag(bag?) } Lit::ListLit(l) => { - let l: Result = l - .node - .values - .iter() - .map(|l| expect_lit(l.as_ref())) - .collect(); - Value::from(l?) + let l: Result<_, _> = l.node.values.iter().map(lit_to_lit).collect(); + logical::Lit::List(l?) } Lit::StructLit(s) => { - let tuple: Result = - s.node.fields.iter().filter_map(tuple_pair).collect(); - Value::from(tuple?) + let tuple: Result<_, _> = s.node.fields.iter().filter_map(tuple_pair).collect(); + logical::Lit::Struct(tuple?) } Lit::TypedLit(_, _) => { return Err(AstTransformError::NotYetImplemented( @@ -1978,33 +1958,11 @@ fn lit_to_value(lit: &Lit) -> Result { Ok(val) } -fn parse_embedded_ion_str(contents: &str) -> Result { - fn lit_err(literal: &str, err: impl std::error::Error) -> AstTransformError { - AstTransformError::Literal { - literal: literal.into(), - error: err.to_string(), - } - } - - let reader = ion_rs_old::ReaderBuilder::new() - .build(contents) - .map_err(|e| lit_err(contents, e))?; - let mut iter = IonDecoderBuilder::new(IonDecoderConfig::default().with_mode(Encoding::Ion)) - .build(reader) - .map_err(|e| lit_err(contents, e))?; - - iter.next() - .ok_or_else(|| AstTransformError::Literal { - literal: contents.into(), - error: "Contains no value".into(), - })? - .map_err(|e| lit_err(contents, e)) -} - #[cfg(test)] mod tests { use super::*; use crate::LogicalPlanner; + use assert_matches::assert_matches; use partiql_catalog::catalog::{PartiqlCatalog, TypeEnvEntry}; use partiql_logical::BindingsOp::Project; use partiql_logical::ValueExpr; @@ -2022,13 +1980,13 @@ mod tests { assert!(logical.is_err()); let lowering_errs = logical.expect_err("Expect errs").errors; assert_eq!(lowering_errs.len(), 2); - assert_eq!( + assert_matches!( lowering_errs.first(), - Some(&AstTransformError::UnsupportedFunction("foo".to_string())) + Some(AstTransformError::UnsupportedFunction(fnc)) if fnc == "foo" ); - assert_eq!( + assert_matches!( lowering_errs.get(1), - Some(&AstTransformError::UnsupportedFunction("bar".to_string())) + Some(AstTransformError::UnsupportedFunction(fnc)) if fnc == "bar" ); } @@ -2044,17 +2002,13 @@ mod tests { assert!(logical.is_err()); let lowering_errs = logical.expect_err("Expect errs").errors; assert_eq!(lowering_errs.len(), 2); - assert_eq!( + assert_matches!( lowering_errs.first(), - Some(&AstTransformError::InvalidNumberOfArguments( - "abs".to_string() - )) + Some(AstTransformError::InvalidNumberOfArguments(fnc)) if fnc == "abs" ); - assert_eq!( + assert_matches!( lowering_errs.get(1), - Some(&AstTransformError::InvalidNumberOfArguments( - "mod".to_string() - )) + Some(AstTransformError::InvalidNumberOfArguments(fnc)) if fnc == "mod" ); } diff --git a/partiql-logical-planner/src/typer.rs b/partiql-logical-planner/src/typer.rs index 84fc38ff..6f524103 100644 --- a/partiql-logical-planner/src/typer.rs +++ b/partiql-logical-planner/src/typer.rs @@ -2,13 +2,13 @@ use crate::typer::LookupOrder::{GlobalLocal, LocalGlobal}; use indexmap::{IndexMap, IndexSet}; use partiql_ast::ast::{CaseSensitivity, SymbolPrimitive}; use partiql_catalog::catalog::Catalog; -use partiql_logical::{BindingsOp, LogicalPlan, OpId, PathComponent, ValueExpr, VarRefType}; +use partiql_logical::{BindingsOp, Lit, LogicalPlan, OpId, PathComponent, ValueExpr, VarRefType}; use partiql_types::{ - type_array, type_bag, type_bool, type_decimal, type_dynamic, type_int, type_string, - type_struct, type_undefined, ArrayType, BagType, PartiqlShape, PartiqlShapeBuilder, - ShapeResultError, Static, StructConstraint, StructField, StructType, + type_array, type_bag, type_bool, type_decimal, type_dynamic, type_float64, type_int, + type_string, type_struct, type_undefined, ArrayType, BagType, PartiqlShape, + PartiqlShapeBuilder, ShapeResultError, Static, StructConstraint, StructField, StructType, }; -use partiql_value::{BindingsName, Value}; +use partiql_value::BindingsName; use petgraph::algo::toposort; use petgraph::graph::NodeIndex; use petgraph::prelude::StableGraph; @@ -333,15 +333,17 @@ impl<'c> PlanTyper<'c> { } ValueExpr::Lit(v) => { let ty = match **v { - Value::Null => type_undefined!(), - Value::Missing => type_undefined!(), - Value::Integer(_) => type_int!(), - Value::Decimal(_) => type_decimal!(), - Value::Boolean(_) => type_bool!(), - Value::String(_) => type_string!(), - Value::Tuple(_) => type_struct!(), - Value::List(_) => type_array!(), - Value::Bag(_) => type_bag!(), + Lit::Null | Lit::Missing => type_undefined!(), + Lit::Int8(_) | Lit::Int16(_) | Lit::Int32(_) | Lit::Int64(_) => { + type_int!() + } + Lit::Decimal(_) => type_decimal!(), + Lit::Double(_) => type_float64!(), + Lit::Bool(_) => type_bool!(), + Lit::String(_) => type_string!(), + Lit::Struct(_) => type_struct!(), + Lit::Bag(_) => type_bag!(), + Lit::List(_) => type_array!(), _ => { self.errors.push(TypingError::NotYetImplemented( "Unsupported Literal".to_string(), diff --git a/partiql-logical/Cargo.toml b/partiql-logical/Cargo.toml index af61b289..c926c7ef 100644 --- a/partiql-logical/Cargo.toml +++ b/partiql-logical/Cargo.toml @@ -23,9 +23,15 @@ bench = false [dependencies] partiql-value = { path = "../partiql-value", version = "0.11.*" } partiql-common = { path = "../partiql-common", version = "0.11.*" } +partiql-extension-ion = { path = "../extension/partiql-extension-ion", version = "0.11.*" } + +ion-rs_old = { version = "0.18", package = "ion-rs" } ordered-float = "4" itertools = "0.13" +rust_decimal = { version = "1.36.0", default-features = false, features = ["std"] } +rust_decimal_macros = "1.36" unicase = "2.7" +thiserror = "1" serde = { version = "1", features = ["derive"], optional = true } diff --git a/partiql-logical/src/lib.rs b/partiql-logical/src/lib.rs index 675f7757..c8deaab0 100644 --- a/partiql-logical/src/lib.rs +++ b/partiql-logical/src/lib.rs @@ -12,6 +12,11 @@ //! Plan graph nodes are called _operators_ and edges are called _flows_ re-instating the fact that //! the plan captures data flows for a given `PartiQL` statement. //! + +mod util; + +use ordered_float::OrderedFloat; +use partiql_common::catalog::ObjectId; /// # Examples /// ``` /// use partiql_logical::{BinaryOp, BindingsOp, LogicalPlan, PathComponent, ProjectValue, Scan, ValueExpr, VarRefType}; @@ -38,7 +43,7 @@ /// expr: ValueExpr::BinaryExpr( /// BinaryOp::Mul, /// Box::new(va), -/// Box::new(ValueExpr::Lit(Box::new(Value::Integer(2)))), +/// Box::new(ValueExpr::Lit(Box::new(Value::Integer(2).into()))), /// ), /// })); /// @@ -51,12 +56,11 @@ /// assert_eq!(3, p.operators().len()); /// assert_eq!(2, p.flows().len()); /// ``` -use partiql_value::{BindingsName, Value}; +use partiql_value::BindingsName; +use rust_decimal::Decimal as RustDecimal; use std::collections::HashMap; use std::fmt::{Debug, Display, Formatter}; -use partiql_common::catalog::ObjectId; - #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -423,7 +427,7 @@ pub struct ExprQuery { pub enum ValueExpr { UnExpr(UnaryOp, Box), BinaryExpr(BinaryOp, Box, Box), - Lit(Box), + Lit(Box), DynamicLookup(Box>), Path(Box, Vec), VarRef(BindingsName<'static>, VarRefType), @@ -441,6 +445,26 @@ pub enum ValueExpr { Call(CallExpr), } +/// Represents a `PartiQL` literal value. +#[derive(Debug, Clone, Eq, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum Lit { + Null, + Missing, + Int8(i8), + Int16(i16), + Int32(i32), + Int64(i64), + Decimal(RustDecimal), + Double(OrderedFloat), + Bool(bool), + String(String), + BoxDocument(Vec, String), // (bytes, type-name as string) TODO replace with strongly typed box name. + Struct(Vec<(String, Lit)>), + Bag(Vec), + List(Vec), +} + // TODO we should replace this enum with some identifier that can be looked up in a symtab/funcregistry? /// Represents logical plan's unary operators. #[derive(Debug, Clone, Eq, PartialEq)] diff --git a/partiql-logical/src/util.rs b/partiql-logical/src/util.rs new file mode 100644 index 00000000..6f7f1c9f --- /dev/null +++ b/partiql-logical/src/util.rs @@ -0,0 +1,190 @@ +use crate::Lit; +use partiql_extension_ion::decode::{IonDecoderBuilder, IonDecoderConfig}; +use partiql_extension_ion::Encoding; +use partiql_value::{Bag, List, Tuple, Value}; +use thiserror::Error; + +impl From for Lit { + fn from(value: Value) -> Self { + match value { + Value::Null => Lit::Null, + Value::Missing => Lit::Missing, + Value::Boolean(b) => Lit::Bool(b), + Value::Integer(n) => Lit::Int64(n), + Value::Real(f) => Lit::Double(f), + Value::Decimal(d) => Lit::Decimal(*d), + Value::String(s) => Lit::String(*s), + Value::Blob(_bytes) => { + todo!("Value to Lit: Blob") + } + Value::DateTime(_dt) => { + todo!("Value to Lit: DateTime") + } + Value::List(list) => (*list).into(), + Value::Bag(bag) => (*bag).into(), + Value::Tuple(tuple) => (*tuple).into(), + } + } +} + +impl From for Lit { + fn from(list: List) -> Self { + Lit::List(list.into_iter().map(Lit::from).collect()) + } +} + +impl From for Lit { + fn from(bag: Bag) -> Self { + Lit::Bag(bag.into_iter().map(Lit::from).collect()) + } +} + +impl From for Lit { + fn from(tuple: Tuple) -> Self { + Lit::Struct(tuple.into_iter().map(|(k, v)| (k, Lit::from(v))).collect()) + } +} + +impl From for Value { + fn from(lit: Lit) -> Self { + match lit { + Lit::Null => Value::Null, + Lit::Missing => Value::Missing, + Lit::Int8(n) => Value::Integer(n.into()), + Lit::Int16(n) => Value::Integer(n.into()), + Lit::Int32(n) => Value::Integer(n.into()), + Lit::Int64(n) => Value::Integer(n), + Lit::Decimal(d) => Value::Decimal(d.into()), + Lit::Double(f) => Value::Real(f), + Lit::Bool(b) => Value::Boolean(b), + Lit::String(s) => Value::String(s.into()), + Lit::BoxDocument(contents, _typ) => { + parse_embedded_ion_str(&String::from_utf8_lossy(contents.as_slice())) + .expect("TODO ion parsing error") + } + Lit::Struct(strct) => Value::from(Tuple::from_iter( + strct.into_iter().map(|(k, v)| (k, Value::from(v))), + )), + Lit::Bag(bag) => Value::from(Bag::from_iter(bag.into_iter().map(Value::from))), + Lit::List(list) => Value::from(List::from_iter(list.into_iter().map(Value::from))), + } + } +} + +/// Represents a Literal Value Error +#[derive(Error, Debug, Clone, PartialEq, Eq, Hash)] +#[non_exhaustive] +pub enum LiteralError { + /// Indicates that there was an error interpreting a literal value. + #[error("Error with literal: {literal}: {error}")] + Literal { literal: String, error: String }, +} + +// TODO remove parsing in favor of embedding +fn parse_embedded_ion_str(contents: &str) -> Result { + fn lit_err(literal: &str, err: impl std::error::Error) -> LiteralError { + LiteralError::Literal { + literal: literal.into(), + error: err.to_string(), + } + } + + let reader = ion_rs_old::ReaderBuilder::new() + .build(contents) + .map_err(|e| lit_err(contents, e))?; + let mut iter = IonDecoderBuilder::new(IonDecoderConfig::default().with_mode(Encoding::Ion)) + .build(reader) + .map_err(|e| lit_err(contents, e))?; + + iter.next() + .ok_or_else(|| LiteralError::Literal { + literal: contents.into(), + error: "Contains no value".into(), + })? + .map_err(|e| lit_err(contents, e)) +} + +impl From for Lit { + #[inline] + fn from(b: bool) -> Self { + Lit::Bool(b) + } +} + +impl From for Lit { + #[inline] + fn from(s: String) -> Self { + Lit::String(s) + } +} + +impl From<&str> for Lit { + #[inline] + fn from(s: &str) -> Self { + Lit::String(s.to_string()) + } +} + +impl From for Lit { + #[inline] + fn from(n: i64) -> Self { + Lit::Int64(n) + } +} + +impl From for Lit { + #[inline] + fn from(n: i32) -> Self { + i64::from(n).into() + } +} + +impl From for Lit { + #[inline] + fn from(n: i16) -> Self { + i64::from(n).into() + } +} + +impl From for Lit { + #[inline] + fn from(n: i8) -> Self { + i64::from(n).into() + } +} + +impl From for Lit { + #[inline] + fn from(n: usize) -> Self { + // TODO overflow to bigint/decimal + Lit::Int64(n as i64) + } +} + +impl From for Lit { + #[inline] + fn from(n: u8) -> Self { + (n as usize).into() + } +} + +impl From for Lit { + #[inline] + fn from(n: u16) -> Self { + (n as usize).into() + } +} + +impl From for Lit { + #[inline] + fn from(n: u32) -> Self { + (n as usize).into() + } +} + +impl From for Lit { + #[inline] + fn from(n: u64) -> Self { + (n as usize).into() + } +} diff --git a/partiql-parser/Cargo.toml b/partiql-parser/Cargo.toml index 012379ad..66e68354 100644 --- a/partiql-parser/Cargo.toml +++ b/partiql-parser/Cargo.toml @@ -47,6 +47,7 @@ serde = { version = "1", features = ["derive"], optional = true } [dev-dependencies] criterion = "0.5" +assert_matches = "1" [features] default = [] diff --git a/partiql-parser/src/error.rs b/partiql-parser/src/error.rs index 5c5fad16..7760001c 100644 --- a/partiql-parser/src/error.rs +++ b/partiql-parser/src/error.rs @@ -23,8 +23,8 @@ pub enum LexError<'input> { #[error("Lexing error: invalid input `{}`", .0)] InvalidInput(Cow<'input, str>), /// Embedded Ion value is not properly terminated. - #[error("Lexing error: unterminated ion literal")] - UnterminatedIonLiteral, + #[error("Lexing error: unterminated embedded document literal")] + UnterminatedDocLiteral, /// Comment is not properly terminated. #[error("Lexing error: unterminated comment")] UnterminatedComment, diff --git a/partiql-parser/src/lexer/embedded_doc.rs b/partiql-parser/src/lexer/embedded_doc.rs new file mode 100644 index 00000000..ad2e2932 --- /dev/null +++ b/partiql-parser/src/lexer/embedded_doc.rs @@ -0,0 +1,112 @@ +use crate::error::LexError; +use crate::lexer::SpannedResult; +use logos::{Logos, Span}; +use partiql_common::syntax::line_offset_tracker::LineOffsetTracker; +use partiql_common::syntax::location::ByteOffset; + +/// An embedded Doc string (e.g. `[{a: 1}, {b: 2}]`) with [`ByteOffset`] span +/// relative to lexed source. +/// +/// Note: +/// - The lexer parses the embedded Doc value enclosed in backticks. +/// - The returned string *does not* include the backticks +/// - The returned `ByteOffset` span *does* include the backticks +type EmbeddedDocStringResult<'input> = SpannedResult<&'input str, ByteOffset, LexError<'input>>; + +/// Tokens used to parse Doc literals embedded in backticks (\`) +#[derive(Logos, Debug, Clone, PartialEq)] +#[logos(skip r#"[^/*'"`\r\n\u0085\u2028\u2029]+"#)] // skip things that aren't newlines or backticks +enum EmbeddedDocToken { + // Skip newlines, but record their position. + // For line break recommendations, + // see https://www.unicode.org/standard/reports/tr13/tr13-5.html + #[regex(r"(([\r])?[\n])|\u0085|\u2028|\u2029")] + Newline, + + // An embed open/close tag is a (greedily-captured) odd-number of backticks + #[regex(r"`(``)*")] + Embed, +} + +/// A Lexer for Doc literals embedded in backticks (\`) that returns the parsed [`EmbeddedDocString`] +/// +/// Parses just enough Doc to make sure not to include a backtick that is inside a string or comment. +pub struct EmbeddedDocLexer<'input, 'tracker> { + /// Wrap a logos-generated lexer + lexer: logos::Lexer<'input, EmbeddedDocToken>, + tracker: &'tracker mut LineOffsetTracker, +} + +impl<'input, 'tracker> EmbeddedDocLexer<'input, 'tracker> { + /// Creates a new embedded Doc lexer over `input` text. + #[inline] + pub fn new(input: &'input str, tracker: &'tracker mut LineOffsetTracker) -> Self { + EmbeddedDocLexer { + lexer: EmbeddedDocToken::lexer(input), + tracker, + } + } + + /// Parses a single embedded Doc value, quoted between backticks (`), and returns it + fn next_internal(&mut self) -> Option> { + let next_token = self.lexer.next(); + match next_token { + Some(Ok(EmbeddedDocToken::Embed)) => { + let Span { + start: b_start, + end: b_end, + } = self.lexer.span(); + let start_quote_len = b_end - b_start; + loop { + let next_tok = self.lexer.next(); + match next_tok { + Some(Ok(EmbeddedDocToken::Newline)) => { + // track the newline, and keep accumulating + self.tracker.record(self.lexer.span().end.into()); + } + Some(Ok(EmbeddedDocToken::Embed)) => { + let Span { + start: e_start, + end: e_end, + } = self.lexer.span(); + let end_quote_len = e_end - e_start; + if end_quote_len >= start_quote_len { + let backup = end_quote_len - start_quote_len; + let (str_start, str_end) = + (b_start + start_quote_len, e_end - end_quote_len); + let doc_value = &self.lexer.source()[str_start..str_end]; + + return Some(Ok(( + b_start.into(), + doc_value, + (e_end - backup).into(), + ))); + } + } + Some(_) => { + // just consume all other tokens + } + None => { + let Span { end, .. } = self.lexer.span(); + return Some(Err(( + b_start.into(), + LexError::UnterminatedDocLiteral, + end.into(), + ))); + } + } + } + } + _ => None, + } + } +} + +impl<'input> Iterator for EmbeddedDocLexer<'input, '_> { + type Item = EmbeddedDocStringResult<'input>; + + #[inline(always)] + fn next(&mut self) -> Option { + self.next_internal() + } +} diff --git a/partiql-parser/src/lexer/embedded_ion.rs b/partiql-parser/src/lexer/embedded_ion.rs deleted file mode 100644 index f98359df..00000000 --- a/partiql-parser/src/lexer/embedded_ion.rs +++ /dev/null @@ -1,135 +0,0 @@ -use crate::error::LexError; -use crate::lexer::{CommentLexer, SpannedResult}; -use logos::{Logos, Span}; -use partiql_common::syntax::line_offset_tracker::LineOffsetTracker; -use partiql_common::syntax::location::ByteOffset; - -/// An embedded Ion string (e.g. `[{a: 1}, {b: 2}]`) with [`ByteOffset`] span -/// relative to lexed source. -/// -/// Note: -/// - The lexer parses the embedded ion value enclosed in backticks. -/// - The returned string *does not* include the backticks -/// - The returned `ByteOffset` span *does* include the backticks -type EmbeddedIonStringResult<'input> = SpannedResult<&'input str, ByteOffset, LexError<'input>>; - -/// Tokens used to parse Ion literals embedded in backticks (\`) -#[derive(Logos, Debug, Clone, PartialEq)] -#[logos(skip r#"[^/*'"`\r\n\u0085\u2028\u2029]+"#)] -enum EmbeddedIonToken { - // Skip newlines, but record their position. - // For line break recommendations, - // see https://www.unicode.org/standard/reports/tr13/tr13-5.html - #[regex(r"(([\r])?[\n])|\u0085|\u2028|\u2029")] - Newline, - - #[token("`")] - Embed, - - #[regex(r"//[^\n]*")] - CommentLine, - #[token("/*")] - CommentBlock, - - #[regex(r#""([^"\\]|\\t|\\u|\\")*""#)] - String, - #[regex(r#"'([^'\\]|\\t|\\u|\\')*'"#)] - Symbol, - #[token("'''")] - LongString, -} - -/// A Lexer for Ion literals embedded in backticks (\`) that returns the parsed [`EmbeddedIonString`] -/// -/// Parses just enough Ion to make sure not to include a backtick that is inside a string or comment. -pub struct EmbeddedIonLexer<'input, 'tracker> { - /// Wrap a logos-generated lexer - lexer: logos::Lexer<'input, EmbeddedIonToken>, - tracker: &'tracker mut LineOffsetTracker, -} - -impl<'input, 'tracker> EmbeddedIonLexer<'input, 'tracker> { - /// Creates a new embedded ion lexer over `input` text. - #[inline] - pub fn new(input: &'input str, tracker: &'tracker mut LineOffsetTracker) -> Self { - EmbeddedIonLexer { - lexer: EmbeddedIonToken::lexer(input), - tracker, - } - } - - /// Parses a single embedded ion value, quoted between backticks (`), and returns it - fn next_internal(&mut self) -> Option> { - let next_token = self.lexer.next(); - match next_token { - Some(Ok(EmbeddedIonToken::Embed)) => { - let Span { start, .. } = self.lexer.span(); - 'ion_value: loop { - let next_tok = self.lexer.next(); - match next_tok { - Some(Ok(EmbeddedIonToken::Newline)) => { - self.tracker.record(self.lexer.span().end.into()); - } - Some(Ok(EmbeddedIonToken::Embed)) => { - break 'ion_value; - } - Some(Ok(EmbeddedIonToken::CommentBlock)) => { - let embed = self.lexer.span(); - let remaining = &self.lexer.source()[embed.start..]; - let mut comment_tracker = LineOffsetTracker::default(); - let mut comment_lexer = - CommentLexer::new(remaining, &mut comment_tracker); - match comment_lexer.next() { - Some(Ok((s, _c, e))) => { - self.tracker.append(&comment_tracker, embed.start.into()); - self.lexer.bump((e - s).to_usize() - embed.len()); - } - Some(Err((s, err, e))) => { - let offset: ByteOffset = embed.start.into(); - return Some(Err((s + offset, err, e + offset))); - } - None => unreachable!(), - } - } - Some(Ok(EmbeddedIonToken::LongString)) => { - 'triple_quote: loop { - let next_tok = self.lexer.next(); - match next_tok { - Some(Ok(EmbeddedIonToken::LongString)) => break 'triple_quote, - Some(_) => (), // just consume all other tokens - None => continue 'ion_value, - } - } - } - Some(_) => { - // just consume all other tokens - } - None => { - let Span { end, .. } = self.lexer.span(); - return Some(Err(( - start.into(), - LexError::UnterminatedIonLiteral, - end.into(), - ))); - } - } - } - let Span { end, .. } = self.lexer.span(); - let (str_start, str_end) = (start + 1, end - 1); - let ion_value = &self.lexer.source()[str_start..str_end]; - - Some(Ok((start.into(), ion_value, end.into()))) - } - _ => None, - } - } -} - -impl<'input> Iterator for EmbeddedIonLexer<'input, '_> { - type Item = EmbeddedIonStringResult<'input>; - - #[inline(always)] - fn next(&mut self) -> Option { - self.next_internal() - } -} diff --git a/partiql-parser/src/lexer/mod.rs b/partiql-parser/src/lexer/mod.rs index 7a81fefb..f48c953d 100644 --- a/partiql-parser/src/lexer/mod.rs +++ b/partiql-parser/src/lexer/mod.rs @@ -3,14 +3,14 @@ use partiql_common::syntax::location::{ByteOffset, BytePosition, ToLocated}; use crate::error::{LexError, ParseError}; mod comment; -mod embedded_ion; +mod embedded_doc; mod partiql; pub use comment::*; -pub use embedded_ion::*; +pub use embedded_doc::*; pub use partiql::*; -/// A 3-tuple of (start, `Tok`, end) denoting a token and it start and end offsets. +/// A 3-tuple of (start, `Tok`, end) denoting a token and its start and end offsets. pub type Spanned = (Loc, Tok, Loc); /// A [`Result`] of a [`Spanned`] token. pub(crate) type SpannedResult = Result, Spanned>; @@ -72,6 +72,7 @@ where #[cfg(test)] mod tests { use super::*; + use assert_matches::assert_matches; use partiql_common::syntax::line_offset_tracker::{LineOffsetError, LineOffsetTracker}; use partiql_common::syntax::location::{ CharOffset, LineAndCharPosition, LineAndColumn, LineOffset, Located, Location, @@ -126,7 +127,7 @@ mod tests { let ion_value = r" `{'input':1, 'b':1}`--comment "; let mut offset_tracker = LineOffsetTracker::default(); - let ion_lexer = EmbeddedIonLexer::new(ion_value.trim(), &mut offset_tracker); + let ion_lexer = EmbeddedDocLexer::new(ion_value.trim(), &mut offset_tracker); assert_eq!(ion_lexer.into_iter().count(), 1); assert_eq!(offset_tracker.num_lines(), 1); @@ -134,9 +135,7 @@ mod tests { let mut lexer = PartiqlLexer::new(ion_value, &mut offset_tracker); let tok = lexer.next().unwrap().unwrap(); - assert!( - matches!(tok, (ByteOffset(5), Token::Ion(ion_str), ByteOffset(24)) if ion_str == "{'input':1, 'b':1}") - ); + assert_matches!(tok, (ByteOffset(4), Token::EmbeddedDoc(ion_str), ByteOffset(25)) if ion_str == "{'input':1, 'b':1}"); let tok = lexer.next().unwrap().unwrap(); assert!( matches!(tok, (ByteOffset(25), Token::CommentLine(cmt_str), ByteOffset(35)) if cmt_str == "--comment ") @@ -145,27 +144,47 @@ mod tests { #[test] fn ion() { - let ion_value = r#" `{'input' // comment ' " + let embedded_ion_doc = r#" `{'input' // comment ' " :1, /* comment */ 'b':1}` "#; - let mut offset_tracker = LineOffsetTracker::default(); - let ion_lexer = EmbeddedIonLexer::new(ion_value.trim(), &mut offset_tracker); - assert_eq!(ion_lexer.into_iter().count(), 1); + let mut lexer = PartiqlLexer::new(embedded_ion_doc, &mut offset_tracker); + + let next_tok = lexer.next(); + let tok = next_tok.unwrap().unwrap(); + assert_matches!(tok, (ByteOffset(1), Token::EmbeddedDoc(ion_str), ByteOffset(159)) if ion_str == embedded_ion_doc.trim().trim_matches('`')); assert_eq!(offset_tracker.num_lines(), 5); + } + #[test] + fn ion_5_backticks() { + let embedded_ion_doc = r#" `````{'input' // comment ' " + :1, /* + comment + */ + 'b':1}````` "#; let mut offset_tracker = LineOffsetTracker::default(); - let mut lexer = PartiqlLexer::new(ion_value, &mut offset_tracker); + let mut lexer = PartiqlLexer::new(embedded_ion_doc, &mut offset_tracker); - let tok = lexer.next().unwrap().unwrap(); - assert!( - matches!(tok, (ByteOffset(2), Token::Ion(ion_str), ByteOffset(158)) if ion_str == ion_value.trim().trim_matches('`')) - ); + let next_tok = lexer.next(); + let tok = next_tok.unwrap().unwrap(); + assert_matches!(tok, (ByteOffset(1), Token::EmbeddedDoc(ion_str), ByteOffset(165)) if ion_str == embedded_ion_doc.trim().trim_matches('`')); assert_eq!(offset_tracker.num_lines(), 5); } + #[test] + fn empty_doc() { + let embedded_empty_doc = r#" `````` "#; + let mut offset_tracker = LineOffsetTracker::default(); + let mut lexer = PartiqlLexer::new(embedded_empty_doc, &mut offset_tracker); + + let next_tok = lexer.next(); + let tok = next_tok.unwrap().unwrap(); + assert_matches!(tok, (ByteOffset(1), Token::EmbeddedDoc(empty_str), ByteOffset(7)) if empty_str.is_empty()); + } + #[test] fn nested_comments() { let comments = r#"/* @@ -188,14 +207,14 @@ mod tests { let toks: Result, Spanned, ByteOffset>> = nonnested_lex.collect(); assert!(toks.is_err()); let error = toks.unwrap_err(); - assert!(matches!( + assert_matches!( error, ( ByteOffset(187), LexError::UnterminatedComment, ByteOffset(189) ) - )); + ); assert_eq!(error.1.to_string(), "Lexing error: unterminated comment"); } @@ -320,16 +339,16 @@ mod tests { lexer.count(); let last = offset_tracker.at(query, ByteOffset(query.len() as u32).into()); - assert!(matches!( + assert_matches!( last, Ok(LineAndCharPosition { line: LineOffset(4), char: CharOffset(10) }) - )); + ); let overflow = offset_tracker.at(query, ByteOffset(1 + query.len() as u32).into()); - assert!(matches!(overflow, Err(LineOffsetError::EndOfInput))); + assert_matches!(overflow, Err(LineOffsetError::EndOfInput)); } #[test] @@ -433,11 +452,11 @@ mod tests { error.to_string(), r"Lexing error: invalid input `#` at `(b7..b8)`" ); - assert!(matches!(error, + assert_matches!(error, ParseError::LexicalError(Located { inner: LexError::InvalidInput(s), location: Location{start: BytePosition(ByteOffset(7)), end: BytePosition(ByteOffset(8))} - }) if s == "#")); + }) if s == "#"); assert_eq!(offset_tracker.num_lines(), 1); assert_eq!( LineAndColumn::from(offset_tracker.at(query, 7.into()).unwrap()), @@ -446,31 +465,12 @@ mod tests { } #[test] - fn err_unterminated_ion() { + fn unterminated_ion() { let query = r#" ` "fooo` "#; let mut offset_tracker = LineOffsetTracker::default(); let toks: Result, _> = PartiqlLexer::new(query, &mut offset_tracker).collect(); - assert!(toks.is_err()); - let error = toks.unwrap_err(); - - assert!(matches!( - error, - ParseError::LexicalError(Located { - inner: LexError::UnterminatedIonLiteral, - location: Location { - start: BytePosition(ByteOffset(1)), - end: BytePosition(ByteOffset(10)) - } - }) - )); - assert_eq!( - error.to_string(), - "Lexing error: unterminated ion literal at `(b1..b10)`" - ); - assert_eq!( - LineAndColumn::from(offset_tracker.at(query, BytePosition::from(1)).unwrap()), - LineAndColumn::new(1, 2).unwrap() - ); + // ion is not eagerly parsed, so unterminated ion does not cause a lex/parse error + assert!(toks.is_ok()); } #[test] @@ -480,7 +480,7 @@ mod tests { let toks: Result, _> = PartiqlLexer::new(query, &mut offset_tracker).collect(); assert!(toks.is_err()); let error = toks.unwrap_err(); - assert!(matches!( + assert_matches!( error, ParseError::LexicalError(Located { inner: LexError::UnterminatedComment, @@ -489,7 +489,7 @@ mod tests { end: BytePosition(ByteOffset(11)) } }) - )); + ); assert_eq!( error.to_string(), "Lexing error: unterminated comment at `(b1..b11)`" @@ -501,21 +501,12 @@ mod tests { } #[test] - fn err_unterminated_ion_comment() { + fn unterminated_ion_comment() { let query = r" `/*12345678`"; let mut offset_tracker = LineOffsetTracker::default(); - let ion_lexer = EmbeddedIonLexer::new(query, &mut offset_tracker); + let ion_lexer = EmbeddedDocLexer::new(query, &mut offset_tracker); let toks: Result, Spanned, ByteOffset>> = ion_lexer.collect(); - assert!(toks.is_err()); - let error = toks.unwrap_err(); - assert!(matches!( - error, - (ByteOffset(2), LexError::UnterminatedComment, ByteOffset(13)) - )); - assert_eq!(error.1.to_string(), "Lexing error: unterminated comment"); - assert_eq!( - LineAndColumn::from(offset_tracker.at(query, BytePosition::from(2)).unwrap()), - LineAndColumn::new(1, 3).unwrap() - ); + // ion is not eagerly parsed, so unterminated ion does not cause a lex/parse error + assert!(toks.is_ok()); } } diff --git a/partiql-parser/src/lexer/partiql.rs b/partiql-parser/src/lexer/partiql.rs index c6b5c9f9..3440a92a 100644 --- a/partiql-parser/src/lexer/partiql.rs +++ b/partiql-parser/src/lexer/partiql.rs @@ -1,5 +1,5 @@ use crate::error::LexError; -use crate::lexer::{CommentLexer, EmbeddedIonLexer, InternalLexResult, LexResult}; +use crate::lexer::{CommentLexer, EmbeddedDocLexer, InternalLexResult, LexResult}; use logos::{Logos, Span}; use partiql_common::syntax::line_offset_tracker::LineOffsetTracker; use partiql_common::syntax::location::ByteOffset; @@ -35,6 +35,7 @@ impl<'input, 'tracker> PartiqlLexer<'input, 'tracker> { Err((start.into(), err_ctor(region.into()), end.into())) } + #[inline(always)] pub fn slice(&self) -> &'input str { self.lexer.slice() } @@ -59,7 +60,8 @@ impl<'input, 'tracker> PartiqlLexer<'input, 'tracker> { continue 'next_tok; } - Token::EmbeddedIonQuote => self.parse_embedded_ion(), + Token::EmbeddedDocQuote => self.parse_embedded_doc(), + Token::EmptyEmbeddedDocQuote => self.parse_empty_embedded_doc(), Token::CommentBlockStart => self.parse_block_comment(), @@ -92,20 +94,20 @@ impl<'input, 'tracker> PartiqlLexer<'input, 'tracker> { }) } - /// Uses [`EmbeddedIonLexer`] to parse an embedded ion value - fn parse_embedded_ion(&mut self) -> Option> { + /// Uses [`EmbeddedDocLexer`] to parse an embedded doc value + fn parse_embedded_doc(&mut self) -> Option> { let embed = self.lexer.span(); let remaining = &self.lexer.source()[embed.start..]; - let mut ion_tracker = LineOffsetTracker::default(); - let mut ion_lexer = EmbeddedIonLexer::new(remaining, &mut ion_tracker); - ion_lexer.next().map(|res| match res { - Ok((s, ion, e)) => { + let mut doc_tracker = LineOffsetTracker::default(); + let mut doc_lexer = EmbeddedDocLexer::new(remaining, &mut doc_tracker); + doc_lexer.next().map(|res| match res { + Ok((s, doc, e)) => { let val_len = e - s; - let val_start = embed.end.into(); // embed end is 1 past the starting '`' - let val_end = val_start + val_len - 2; // sub 2 to remove surrounding '`' - self.tracker.append(&ion_tracker, embed.start.into()); + let val_start = embed.start.into(); // embed end is 1 past the starting '/*' + let val_end = val_start + val_len; + self.tracker.append(&doc_tracker, embed.start.into()); self.lexer.bump(val_len.to_usize() - embed.len()); - Ok((val_start, Token::Ion(ion), val_end)) + Ok((val_start, Token::EmbeddedDoc(doc), val_end)) } Err((s, err, e)) => { let offset: ByteOffset = embed.start.into(); @@ -113,6 +115,14 @@ impl<'input, 'tracker> PartiqlLexer<'input, 'tracker> { } }) } + + #[inline] + fn parse_empty_embedded_doc(&mut self) -> Option> { + let embed = self.lexer.span(); + let mid = embed.start + ((embed.end - embed.start) / 2); + let doc = &self.lexer.source()[mid..mid]; + Some(self.wrap(Token::EmbeddedDoc(doc))) + } } impl<'input> Iterator for PartiqlLexer<'input, '_> { @@ -241,9 +251,13 @@ pub enum Token<'input> { |lex| lex.slice().trim_matches('\''))] String(&'input str), - #[token("`")] - EmbeddedIonQuote, - Ion(&'input str), + // An embed open/close tag is a (greedily-captured) odd-number of backticks + #[regex(r"`(``)*")] + EmbeddedDocQuote, + // An empty embedded doc is a (greedily-captured) even-number of backticks + #[regex(r"(``)+")] + EmptyEmbeddedDocQuote, + EmbeddedDoc(&'input str), // Keywords #[regex("(?i:All)")] @@ -492,8 +506,9 @@ impl fmt::Display for Token<'_> { Token::ExpReal(txt) => write!(f, "<{txt}:REAL>"), Token::Real(txt) => write!(f, "<{txt}:REAL>"), Token::String(txt) => write!(f, "<{txt}:STRING>"), - Token::EmbeddedIonQuote => write!(f, ""), - Token::Ion(txt) => write!(f, "<{txt}:ION>"), + Token::EmbeddedDocQuote => write!(f, ""), + Token::EmbeddedDoc(txt) => write!(f, "<```{txt}```:DOC>"), + Token::EmptyEmbeddedDocQuote => write!(f, "<``:DOC>"), Token::All | Token::Asc diff --git a/partiql-parser/src/parse/mod.rs b/partiql-parser/src/parse/mod.rs index 31396661..f8407a80 100644 --- a/partiql-parser/src/parse/mod.rs +++ b/partiql-parser/src/parse/mod.rs @@ -211,7 +211,9 @@ mod tests { #[test] fn ion() { parse!(r#" `[{'a':1, 'b':1}, {'a':2}, "foo"]` "#); - parse!(r#" `[{'a':1, 'b':1}, {'a':2}, "foo", 'a`b', "a`b", '''`s''', {{"a`b"}}]` "#); + parse!( + r#" ```[{'a':1, 'b':1}, {'a':2}, "foo", 'a`b', "a`b", '''`s''', {{"a`b"}}]``` "# + ); parse!( r#" `{'a':1, // comment ' " 'b':1} ` "# @@ -798,7 +800,7 @@ mod tests { assert_eq!( err_data.errors[1], ParseError::LexicalError(Located { - inner: LexError::UnterminatedIonLiteral, + inner: LexError::UnterminatedDocLiteral, location: Location { start: BytePosition::from(1), end: BytePosition::from(4), diff --git a/partiql-parser/src/parse/parse_util.rs b/partiql-parser/src/parse/parse_util.rs index 2bb1da0f..f98fcaf5 100644 --- a/partiql-parser/src/parse/parse_util.rs +++ b/partiql-parser/src/parse/parse_util.rs @@ -1,9 +1,11 @@ use partiql_ast::ast; use crate::parse::parser_state::ParserState; +use crate::ParseError; use bitflags::bitflags; +use partiql_ast::ast::{Expr, Lit}; use partiql_common::node::NodeIdGenerator; -use partiql_common::syntax::location::ByteOffset; +use partiql_common::syntax::location::{ByteOffset, BytePosition}; bitflags! { /// Set of AST node attributes to use as synthesized attributes. @@ -33,7 +35,7 @@ pub(crate) struct Synth { impl Synth { #[inline] - pub fn new(data: T, attrs: Attrs) -> Self { + fn new(data: T, attrs: Attrs) -> Self { Synth { data, attrs } } @@ -41,6 +43,17 @@ impl Synth { pub fn empty(data: T) -> Self { Self::new(data, Attrs::empty()) } + + #[inline] + pub fn lit(data: T) -> Self { + Self::new(data, Attrs::LIT) + } + + pub fn map_data(self, f: impl FnOnce(T) -> U) -> Synth { + let Self { data, attrs } = self; + let data = f(data); + Synth::new(data, attrs) + } } impl FromIterator> for Synth> { @@ -170,3 +183,61 @@ pub(crate) fn strip_expr(q: ast::AstNode) -> Box { Box::new(ast::Expr::Query(q)) } } + +#[inline] +#[track_caller] +fn illegal_literal<'a, T>() -> Result> { + Err(ParseError::IllegalState("Expected literal".to_string())) +} + +pub(crate) type LitFlattenResult<'a, T> = Result>; +#[inline] +pub(crate) fn struct_to_lit<'a>(strct: ast::Struct) -> LitFlattenResult<'a, ast::StructLit> { + strct + .fields + .into_iter() + .map(exprpair_to_lit) + .collect::>>() + .map(|fields| ast::StructLit { fields }) +} + +#[inline] +pub(crate) fn bag_to_lit<'a>(bag: ast::Bag) -> LitFlattenResult<'a, ast::BagLit> { + bag.values + .into_iter() + .map(|v| expr_to_lit(*v).map(|n| n.node)) + .collect::>>() + .map(|values| ast::BagLit { values }) +} + +#[inline] +pub(crate) fn list_to_lit<'a>(list: ast::List) -> LitFlattenResult<'a, ast::ListLit> { + list.values + .into_iter() + .map(|v| expr_to_lit(*v).map(|n| n.node)) + .collect::>>() + .map(|values| ast::ListLit { values }) +} + +#[inline] +pub(crate) fn exprpair_to_lit<'a>(pair: ast::ExprPair) -> LitFlattenResult<'a, ast::LitField> { + let ast::ExprPair { first, second } = pair; + let (first, second) = (expr_to_litstr(*first)?, expr_to_lit(*second)?); + Ok(ast::LitField { first, second }) +} + +#[inline] +pub(crate) fn expr_to_litstr<'a>(expr: ast::Expr) -> LitFlattenResult<'a, String> { + match expr_to_lit(expr)?.node { + Lit::CharStringLit(s) | Lit::NationalCharStringLit(s) => Ok(s), + _ => illegal_literal(), + } +} + +#[inline] +pub(crate) fn expr_to_lit<'a>(expr: ast::Expr) -> LitFlattenResult<'a, ast::AstNode> { + match expr { + Expr::Lit(lit) => Ok(lit), + _ => illegal_literal(), + } +} diff --git a/partiql-parser/src/parse/partiql.lalrpop b/partiql-parser/src/parse/partiql.lalrpop index d65ca1b0..7bfc0175 100644 --- a/partiql-parser/src/parse/partiql.lalrpop +++ b/partiql-parser/src/parse/partiql.lalrpop @@ -9,7 +9,17 @@ use partiql_ast::ast; use partiql_common::syntax::location::{ByteOffset, BytePosition, Location, ToLocated}; -use crate::parse::parse_util::{strip_expr, strip_query, strip_query_set, CallSite, Attrs, Synth}; +use crate::parse::parse_util::{ + strip_expr, + strip_query, + strip_query_set, + struct_to_lit, + bag_to_lit, + list_to_lit, + CallSite, + Attrs, + Synth +}; use crate::parse::parser_state::ParserState; use partiql_common::node::NodeIdGenerator; @@ -568,8 +578,7 @@ ExprQuery: Box = { ExprQuerySynth: Synth> = { => { - let Synth{data, attrs} = e; - Synth::new(Box::new(data), attrs) + e.map_data(|e| Box::new(e)) } } @@ -865,13 +874,39 @@ ExprPrecedence01: Synth = { ExprTerm: Synth = { => Synth::empty(s), - => Synth::new(ast::Expr::Lit( state.node(lit, lo..hi) ), Attrs::LIT), + => Synth::lit(ast::Expr::Lit( state.node(lit, lo..hi) )), => Synth::empty(v), => { if c.attrs.contains(Attrs::LIT) { match c.data { - ast::Expr::List(l) => Synth::new(ast::Expr::Lit( state.node(ast::Lit::ListLit(l), lo..hi) ), Attrs::LIT), - ast::Expr::Bag(b) => Synth::new(ast::Expr::Lit( state.node(ast::Lit::BagLit(b), lo..hi) ), Attrs::LIT), + ast::Expr::List(l) => { + match list_to_lit(l.node) { + Ok(list_lit) => { + let list_lit = state.node(list_lit, lo..hi); + let lit = state.node(ast::Lit::ListLit(list_lit), lo..hi); + Synth::lit(ast::Expr::Lit( lit )) + }, + Err(e) => { + let err = lpop::ErrorRecovery{error: e.into(), dropped_tokens: Default::default()}; + state.errors.push(err); + Synth::empty(ast::Expr::Error) + } + } + }, + ast::Expr::Bag(b) => { + match bag_to_lit(b.node) { + Ok(bag_lit) => { + let bag_lit = state.node(bag_lit, lo..hi); + let lit = state.node(ast::Lit::BagLit(bag_lit), lo..hi); + Synth::lit(ast::Expr::Lit( lit )) + }, + Err(e) => { + let err = lpop::ErrorRecovery{error: e.into(), dropped_tokens: Default::default()}; + state.errors.push(err); + Synth::empty(ast::Expr::Error) + } + } + }, _ => unreachable!(), } } else { @@ -881,7 +916,20 @@ ExprTerm: Synth = { => { if t.attrs.contains(Attrs::LIT) { match t.data { - ast::Expr::Struct(s) => Synth::new(ast::Expr::Lit( state.node(ast::Lit::StructLit(s), lo..hi) ), Attrs::LIT), + ast::Expr::Struct(s) => { + match struct_to_lit(s.node) { + Ok(struct_lit) => { + let struct_lit = state.node(struct_lit, lo..hi); + let lit = state.node(ast::Lit::StructLit(struct_lit), lo..hi); + Synth::lit(ast::Expr::Lit( lit )) + }, + Err(e) => { + let err = lpop::ErrorRecovery{error: e.into(), dropped_tokens: Default::default()}; + state.errors.push(err); + Synth::empty(ast::Expr::Error) + } + } + }, _ => unreachable!(), } } else { @@ -1197,7 +1245,7 @@ ExcludePathStep: ast::ExcludePathStep = { Literal: ast::Lit = { , , - , + , , } @@ -1250,11 +1298,13 @@ LiteralNumber: ast::Lit = { }) }, } + #[inline] -LiteralIon: ast::Lit = { - => ast::Lit::IonStringLit(ion.to_owned()), +LiteralEmbeddedDoc: ast::Lit = { + => ast::Lit::EmbeddedDocLit(ion.to_owned()), } + #[inline] TypeKeywordStr: &'static str = { "DATE" => "DATE", @@ -1425,7 +1475,7 @@ extern { "Real" => lexer::Token::Real(<&'input str>), "ExpReal" => lexer::Token::ExpReal(<&'input str>), "String" => lexer::Token::String(<&'input str>), - "Ion" => lexer::Token::Ion(<&'input str>), + "EmbeddedDoc" => lexer::Token::EmbeddedDoc(<&'input str>), // Keywords "ALL" => lexer::Token::All, diff --git a/partiql-value/src/bag.rs b/partiql-value/src/bag.rs index a846aa85..9fc7ca68 100644 --- a/partiql-value/src/bag.rs +++ b/partiql-value/src/bag.rs @@ -8,7 +8,8 @@ use std::hash::{Hash, Hasher}; use std::{slice, vec}; -use crate::{EqualityValue, List, NullSortedValue, NullableEq, Value}; +use crate::sort::NullSortedValue; +use crate::{EqualityValue, List, NullableEq, Value}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; diff --git a/partiql-value/src/bindings.rs b/partiql-value/src/bindings.rs new file mode 100644 index 00000000..ede44318 --- /dev/null +++ b/partiql-value/src/bindings.rs @@ -0,0 +1,70 @@ +use crate::{PairsIntoIter, PairsIter, Value}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; +use std::borrow::Cow; +use std::iter::Once; + +#[derive(Clone, Hash, Debug, Eq, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum BindingsName<'s> { + CaseSensitive(Cow<'s, str>), + CaseInsensitive(Cow<'s, str>), +} + +#[derive(Debug, Clone)] +pub enum BindingIter<'a> { + Tuple(PairsIter<'a>), + Single(Once<&'a Value>), + Empty, +} + +impl<'a> Iterator for BindingIter<'a> { + type Item = (Option<&'a String>, &'a Value); + + #[inline] + fn next(&mut self) -> Option { + match self { + BindingIter::Tuple(t) => t.next().map(|(k, v)| (Some(k), v)), + BindingIter::Single(single) => single.next().map(|v| (None, v)), + BindingIter::Empty => None, + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + match self { + BindingIter::Tuple(t) => t.size_hint(), + BindingIter::Single(_single) => (1, Some(1)), + BindingIter::Empty => (0, Some(0)), + } + } +} + +#[derive(Debug)] +pub enum BindingIntoIter { + Tuple(PairsIntoIter), + Single(Once), + Empty, +} + +impl Iterator for BindingIntoIter { + type Item = (Option, Value); + + #[inline] + fn next(&mut self) -> Option { + match self { + BindingIntoIter::Tuple(t) => t.next().map(|(k, v)| (Some(k), v)), + BindingIntoIter::Single(single) => single.next().map(|v| (None, v)), + BindingIntoIter::Empty => None, + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + match self { + BindingIntoIter::Tuple(t) => t.size_hint(), + BindingIntoIter::Single(_single) => (1, Some(1)), + BindingIntoIter::Empty => (0, Some(0)), + } + } +} diff --git a/partiql-value/src/comparison.rs b/partiql-value/src/comparison.rs new file mode 100644 index 00000000..d37b5ae4 --- /dev/null +++ b/partiql-value/src/comparison.rs @@ -0,0 +1,165 @@ +use crate::util; +use crate::Value; + +pub trait Comparable { + fn is_comparable_to(&self, rhs: &Self) -> bool; +} + +impl Comparable for Value { + /// Returns true if and only if `self` is comparable to `rhs` + fn is_comparable_to(&self, rhs: &Self) -> bool { + match (self, rhs) { + // Null/Missing compare to anything + (Value::Missing | Value::Null, _) + | (_, Value::Missing | Value::Null) + // Everything compares to its own type + | (Value::Boolean(_), Value::Boolean(_)) + | (Value::String(_), Value::String(_)) + | (Value::Blob(_), Value::Blob(_)) + | (Value::List(_), Value::List(_)) + | (Value::Bag(_), Value::Bag(_)) + | (Value::Tuple(_), Value::Tuple(_)) + // Numerics compare to each other + | ( + Value::Integer(_) | Value::Real(_) | Value::Decimal(_), + Value::Integer(_) | Value::Real(_) | Value::Decimal(_), + )=> true, + (_, _) => false, + } + } +} + +// `Value` `eq` and `neq` with Missing and Null propagation +pub trait NullableEq { + type Output; + fn eq(&self, rhs: &Self) -> Self::Output; + fn neq(&self, rhs: &Self) -> Self::Output; +} + +/// A wrapper on [`T`] that specifies if missing and null values should be equal. +#[derive(Eq, PartialEq)] +pub struct EqualityValue<'a, const NULLS_EQUAL: bool, T>(pub &'a T); + +impl NullableEq for EqualityValue<'_, GROUP_NULLS, Value> { + type Output = Value; + + fn eq(&self, rhs: &Self) -> Self::Output { + if GROUP_NULLS { + if let (Value::Missing | Value::Null, Value::Missing | Value::Null) = (self.0, rhs.0) { + return Value::Boolean(true); + } + } else if matches!(self.0, Value::Missing) || matches!(rhs.0, Value::Missing) { + return Value::Missing; + } else if matches!(self.0, Value::Null) || matches!(rhs.0, Value::Null) { + return Value::Null; + } + + match (self.0, rhs.0) { + (Value::Integer(_), Value::Real(_)) => { + Value::from(&util::coerce_int_to_real(self.0) == rhs.0) + } + (Value::Integer(_), Value::Decimal(_)) => { + Value::from(&util::coerce_int_or_real_to_decimal(self.0) == rhs.0) + } + (Value::Real(_), Value::Decimal(_)) => { + Value::from(&util::coerce_int_or_real_to_decimal(self.0) == rhs.0) + } + (Value::Real(_), Value::Integer(_)) => { + Value::from(self.0 == &util::coerce_int_to_real(rhs.0)) + } + (Value::Decimal(_), Value::Integer(_)) => { + Value::from(self.0 == &util::coerce_int_or_real_to_decimal(rhs.0)) + } + (Value::Decimal(_), Value::Real(_)) => { + Value::from(self.0 == &util::coerce_int_or_real_to_decimal(rhs.0)) + } + (_, _) => Value::from(self.0 == rhs.0), + } + } + + fn neq(&self, rhs: &Self) -> Self::Output { + let eq_result = NullableEq::eq(self, rhs); + match eq_result { + Value::Boolean(_) | Value::Null => !eq_result, + _ => Value::Missing, + } + } +} + +// `Value` comparison with Missing and Null propagation +pub trait NullableOrd { + type Output; + + fn lt(&self, rhs: &Self) -> Self::Output; + fn gt(&self, rhs: &Self) -> Self::Output; + fn lteq(&self, rhs: &Self) -> Self::Output; + fn gteq(&self, rhs: &Self) -> Self::Output; +} + +impl NullableOrd for Value { + type Output = Self; + + fn lt(&self, rhs: &Self) -> Self::Output { + match (self, rhs) { + (Value::Missing, _) => Value::Missing, + (_, Value::Missing) => Value::Missing, + (Value::Null, _) => Value::Null, + (_, Value::Null) => Value::Null, + (_, _) => { + if self.is_comparable_to(rhs) { + Value::from(self < rhs) + } else { + Value::Missing + } + } + } + } + + fn gt(&self, rhs: &Self) -> Self::Output { + match (self, rhs) { + (Value::Missing, _) => Value::Missing, + (_, Value::Missing) => Value::Missing, + (Value::Null, _) => Value::Null, + (_, Value::Null) => Value::Null, + (_, _) => { + if self.is_comparable_to(rhs) { + Value::from(self > rhs) + } else { + Value::Missing + } + } + } + } + + fn lteq(&self, rhs: &Self) -> Self::Output { + match (self, rhs) { + (Value::Missing, _) => Value::Missing, + (_, Value::Missing) => Value::Missing, + (Value::Null, _) => Value::Null, + (_, Value::Null) => Value::Null, + (_, _) => { + if self.is_comparable_to(rhs) { + Value::from(self <= rhs) + } else { + Value::Missing + } + } + } + } + + fn gteq(&self, rhs: &Self) -> Self::Output { + match (self, rhs) { + (Value::Missing, _) => Value::Missing, + (_, Value::Missing) => Value::Missing, + (Value::Null, _) => Value::Null, + (_, Value::Null) => Value::Null, + (_, _) => { + if self.is_comparable_to(rhs) { + Value::from(self >= rhs) + } else { + Value::Missing + } + } + } + } +} diff --git a/partiql-value/src/lib.rs b/partiql-value/src/lib.rs index 561c6d07..6510a9c7 100644 --- a/partiql-value/src/lib.rs +++ b/partiql-value/src/lib.rs @@ -1,1211 +1,41 @@ #![deny(rust_2018_idioms)] #![deny(clippy::all)] -use ordered_float::OrderedFloat; -use std::cmp::Ordering; - -use std::borrow::Cow; - -use std::fmt::{Debug, Display, Formatter}; -use std::hash::Hash; - -use std::iter::Once; -use std::{ops, vec}; - -use rust_decimal::prelude::FromPrimitive; -use rust_decimal::{Decimal as RustDecimal, Decimal}; - mod bag; +mod bindings; +pub mod comparison; mod datetime; mod list; mod pretty; +mod sort; mod tuple; +mod util; +mod value; pub use bag::*; +pub use bindings::*; +pub use comparison::*; pub use datetime::*; pub use list::*; pub use pretty::*; +pub use sort::*; pub use tuple::*; +pub use value::*; -use partiql_common::pretty::ToPretty; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -#[derive(Clone, Hash, Debug, Eq, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub enum BindingsName<'s> { - CaseSensitive(Cow<'s, str>), - CaseInsensitive(Cow<'s, str>), -} - -// TODO these are all quite simplified for PoC/demonstration -// TODO have an optional-like wrapper for null/missing instead of inlined here? -#[derive(Hash, PartialEq, Eq, Clone)] -#[allow(dead_code)] // TODO remove once out of PoC -#[derive(Default)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub enum Value { - Null, - #[default] - Missing, - Boolean(bool), - Integer(i64), - Real(OrderedFloat), - Decimal(Box), - String(Box), - Blob(Box>), - DateTime(Box), - List(Box), - Bag(Box), - Tuple(Box), - // TODO: add other supported PartiQL values -- sexp -} - -impl Display for Value { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self.to_pretty_string(f.width().unwrap_or(80)) { - Ok(pretty) => f.write_str(&pretty), - Err(_) => f.write_str(""), - } - } -} - -impl ops::Add for &Value { - type Output = Value; - - fn add(self, rhs: Self) -> Self::Output { - match (&self, &rhs) { - // TODO: edge cases dealing with overflow - (Value::Missing, _) => Value::Missing, - (_, Value::Missing) => Value::Missing, - (Value::Null, _) => Value::Null, - (_, Value::Null) => Value::Null, - (Value::Integer(l), Value::Integer(r)) => Value::Integer(l + r), - (Value::Real(l), Value::Real(r)) => Value::Real(*l + *r), - (Value::Decimal(l), Value::Decimal(r)) => { - Value::Decimal(Box::new(l.as_ref() + r.as_ref())) - } - (Value::Integer(_), Value::Real(_)) => &coerce_int_to_real(self) + rhs, - (Value::Integer(_), Value::Decimal(_)) => &coerce_int_or_real_to_decimal(self) + rhs, - (Value::Real(_), Value::Decimal(_)) => &coerce_int_or_real_to_decimal(self) + rhs, - (Value::Real(_), Value::Integer(_)) => self + &coerce_int_to_real(rhs), - (Value::Decimal(_), Value::Integer(_)) => self + &coerce_int_or_real_to_decimal(rhs), - (Value::Decimal(_), Value::Real(_)) => self + &coerce_int_or_real_to_decimal(rhs), - _ => Value::Missing, // data type mismatch => Missing - } - } -} - -impl ops::AddAssign<&Value> for Value { - fn add_assign(&mut self, rhs: &Value) { - match (self, &rhs) { - // TODO: edge cases dealing with overflow - (Value::Missing, _) => {} - (this, Value::Missing) => *this = Value::Missing, - (Value::Null, _) => {} - (this, Value::Null) => *this = Value::Null, - - (Value::Integer(l), Value::Integer(r)) => l.add_assign(r), - - (Value::Real(l), Value::Real(r)) => l.add_assign(r), - (Value::Real(l), Value::Integer(i)) => l.add_assign(*i as f64), - - (Value::Decimal(l), Value::Decimal(r)) => l.add_assign(r.as_ref()), - (Value::Decimal(l), Value::Integer(i)) => l.add_assign(rust_decimal::Decimal::from(*i)), - (Value::Decimal(l), Value::Real(r)) => match coerce_f64_to_decimal(r) { - Some(d) => l.add_assign(d), - None => todo!(), - }, - - (this, Value::Real(r)) => { - *this = match &this { - Value::Integer(l) => Value::from((*l as f64) + r.0), - _ => Value::Missing, - }; - } - (this, Value::Decimal(r)) => { - *this = match &this { - Value::Integer(l) => { - Value::Decimal(Box::new(rust_decimal::Decimal::from(*l) + r.as_ref())) - } - Value::Real(l) => match coerce_f64_to_decimal(&l.0) { - None => Value::Missing, - Some(d) => Value::Decimal(Box::new(d + r.as_ref())), - }, - _ => Value::Missing, - }; - } - (this, _) => *this = Value::Missing, // data type mismatch => Missing - } - } -} - -impl ops::Sub for &Value { - type Output = Value; - - fn sub(self, rhs: Self) -> Self::Output { - match (&self, &rhs) { - // TODO: edge cases dealing with overflow - (Value::Missing, _) => Value::Missing, - (_, Value::Missing) => Value::Missing, - (Value::Null, _) => Value::Null, - (_, Value::Null) => Value::Null, - (Value::Integer(l), Value::Integer(r)) => Value::Integer(l - r), - (Value::Real(l), Value::Real(r)) => Value::Real(*l - *r), - (Value::Decimal(l), Value::Decimal(r)) => { - Value::Decimal(Box::new(l.as_ref() - r.as_ref())) - } - (Value::Integer(_), Value::Real(_)) => &coerce_int_to_real(self) - rhs, - (Value::Integer(_), Value::Decimal(_)) => &coerce_int_or_real_to_decimal(self) - rhs, - (Value::Real(_), Value::Decimal(_)) => &coerce_int_or_real_to_decimal(self) - rhs, - (Value::Real(_), Value::Integer(_)) => self - &coerce_int_to_real(rhs), - (Value::Decimal(_), Value::Integer(_)) => self - &coerce_int_or_real_to_decimal(rhs), - (Value::Decimal(_), Value::Real(_)) => self - &coerce_int_or_real_to_decimal(rhs), - _ => Value::Missing, // data type mismatch => Missing - } - } -} - -impl ops::Mul for &Value { - type Output = Value; - - fn mul(self, rhs: Self) -> Self::Output { - match (&self, &rhs) { - // TODO: edge cases dealing with overflow - (Value::Missing, _) => Value::Missing, - (_, Value::Missing) => Value::Missing, - (Value::Null, _) => Value::Null, - (_, Value::Null) => Value::Null, - (Value::Integer(l), Value::Integer(r)) => Value::Integer(l * r), - (Value::Real(l), Value::Real(r)) => Value::Real(*l * *r), - (Value::Decimal(l), Value::Decimal(r)) => { - Value::Decimal(Box::new(l.as_ref() * r.as_ref())) - } - (Value::Integer(_), Value::Real(_)) => &coerce_int_to_real(self) * rhs, - (Value::Integer(_), Value::Decimal(_)) => &coerce_int_or_real_to_decimal(self) * rhs, - (Value::Real(_), Value::Decimal(_)) => &coerce_int_or_real_to_decimal(self) * rhs, - (Value::Real(_), Value::Integer(_)) => self * &coerce_int_to_real(rhs), - (Value::Decimal(_), Value::Integer(_)) => self * &coerce_int_or_real_to_decimal(rhs), - (Value::Decimal(_), Value::Real(_)) => self * &coerce_int_or_real_to_decimal(rhs), - _ => Value::Missing, // data type mismatch => Missing - } - } -} - -impl ops::Div for &Value { - type Output = Value; - - fn div(self, rhs: Self) -> Self::Output { - match (&self, &rhs) { - // TODO: edge cases dealing with division by 0 - (Value::Missing, _) => Value::Missing, - (_, Value::Missing) => Value::Missing, - (Value::Null, _) => Value::Null, - (_, Value::Null) => Value::Null, - (Value::Integer(l), Value::Integer(r)) => Value::Integer(l / r), - (Value::Real(l), Value::Real(r)) => Value::Real(*l / *r), - (Value::Decimal(l), Value::Decimal(r)) => { - Value::Decimal(Box::new(l.as_ref() / r.as_ref())) - } - (Value::Integer(_), Value::Real(_)) => &coerce_int_to_real(self) / rhs, - (Value::Integer(_), Value::Decimal(_)) => &coerce_int_or_real_to_decimal(self) / rhs, - (Value::Real(_), Value::Decimal(_)) => &coerce_int_or_real_to_decimal(self) / rhs, - (Value::Real(_), Value::Integer(_)) => self / &coerce_int_to_real(rhs), - (Value::Decimal(_), Value::Integer(_)) => self / &coerce_int_or_real_to_decimal(rhs), - (Value::Decimal(_), Value::Real(_)) => self / &coerce_int_or_real_to_decimal(rhs), - _ => Value::Missing, // data type mismatch => Missing - } - } -} - -impl ops::Rem for &Value { - type Output = Value; - - fn rem(self, rhs: Self) -> Self::Output { - match (&self, &rhs) { - // TODO: edge cases dealing with division by 0 - (Value::Missing, _) => Value::Missing, - (_, Value::Missing) => Value::Missing, - (Value::Null, _) => Value::Null, - (_, Value::Null) => Value::Null, - (Value::Integer(l), Value::Integer(r)) => Value::Integer(l % r), - (Value::Real(l), Value::Real(r)) => Value::Real(*l % *r), - (Value::Decimal(l), Value::Decimal(r)) => { - Value::Decimal(Box::new(l.as_ref() % r.as_ref())) - } - (Value::Integer(_), Value::Real(_)) => &coerce_int_to_real(self) % rhs, - (Value::Integer(_), Value::Decimal(_)) => &coerce_int_or_real_to_decimal(self) % rhs, - (Value::Real(_), Value::Decimal(_)) => &coerce_int_or_real_to_decimal(self) % rhs, - (Value::Real(_), Value::Integer(_)) => self % &coerce_int_to_real(rhs), - (Value::Decimal(_), Value::Integer(_)) => self % &coerce_int_or_real_to_decimal(rhs), - (Value::Decimal(_), Value::Real(_)) => self % &coerce_int_or_real_to_decimal(rhs), - _ => Value::Missing, // data type mismatch => Missing - } - } -} - -pub trait UnaryPlus { - type Output; - - fn positive(self) -> Self::Output; -} - -impl UnaryPlus for Value { - type Output = Self; - fn positive(self) -> Self::Output { - match self { - Value::Null => Value::Null, - Value::Missing => Value::Missing, - Value::Integer(_) | Value::Real(_) | Value::Decimal(_) => self, - _ => Value::Missing, // data type mismatch => Missing - } - } -} - -impl ops::Neg for &Value { - type Output = Value; - - fn neg(self) -> Self::Output { - match self { - // TODO: handle overflow for negation - Value::Null => Value::Null, - Value::Missing => Value::Missing, - Value::Integer(i) => Value::from(-i), - Value::Real(f) => Value::Real(-f), - Value::Decimal(d) => Value::from(-d.as_ref()), - _ => Value::Missing, // data type mismatch => Missing - } - } -} - -impl ops::Neg for Value { - type Output = Value; - - fn neg(self) -> Self::Output { - match self { - // TODO: handle overflow for negation - Value::Null => self, - Value::Missing => self, - Value::Integer(i) => Value::from(-i), - Value::Real(f) => Value::Real(-f), - Value::Decimal(d) => Value::from(-d.as_ref()), - _ => Value::Missing, // data type mismatch => Missing - } - } -} - -pub trait BinaryAnd { - type Output; - - fn and(&self, rhs: &Self) -> Self::Output; -} - -impl BinaryAnd for Value { - type Output = Self; - fn and(&self, rhs: &Self) -> Self::Output { - match (self, rhs) { - (Value::Boolean(l), Value::Boolean(r)) => Value::from(*l && *r), - (Value::Null | Value::Missing, Value::Boolean(false)) - | (Value::Boolean(false), Value::Null | Value::Missing) => Value::from(false), - _ => { - if matches!(self, Value::Missing | Value::Null | Value::Boolean(true)) - && matches!(rhs, Value::Missing | Value::Null | Value::Boolean(true)) - { - Value::Null - } else { - Value::Missing - } - } - } - } -} - -pub trait BinaryOr { - type Output; - - fn or(&self, rhs: &Self) -> Self::Output; -} - -impl BinaryOr for Value { - type Output = Self; - fn or(&self, rhs: &Self) -> Self::Output { - match (self, rhs) { - (Value::Boolean(l), Value::Boolean(r)) => Value::from(*l || *r), - (Value::Null | Value::Missing, Value::Boolean(true)) - | (Value::Boolean(true), Value::Null | Value::Missing) => Value::from(true), - _ => { - if matches!(self, Value::Missing | Value::Null | Value::Boolean(false)) - && matches!(rhs, Value::Missing | Value::Null | Value::Boolean(false)) - { - Value::Null - } else { - Value::Missing - } - } - } - } -} - -impl ops::Not for &Value { - type Output = Value; - - fn not(self) -> Self::Output { - match self { - Value::Boolean(b) => Value::from(!b), - Value::Null | Value::Missing => Value::Null, - _ => Value::Missing, // data type mismatch => Missing - } - } -} - -impl ops::Not for Value { - type Output = Self; - - fn not(self) -> Self::Output { - match self { - Value::Boolean(b) => Value::from(!b), - Value::Null | Value::Missing => Value::Null, - _ => Value::Missing, // data type mismatch => Missing - } - } -} - -pub trait Comparable { - fn is_comparable_to(&self, rhs: &Self) -> bool; -} - -impl Comparable for Value { - /// Returns true if and only if `self` is comparable to `rhs` - fn is_comparable_to(&self, rhs: &Self) -> bool { - match (self, rhs) { - // Null/Missing compare to anything - (Value::Missing | Value::Null, _) - | (_, Value::Missing | Value::Null) - // Everything compares to its own type - | (Value::Boolean(_), Value::Boolean(_)) - | (Value::String(_), Value::String(_)) - | (Value::Blob(_), Value::Blob(_)) - | (Value::List(_), Value::List(_)) - | (Value::Bag(_), Value::Bag(_)) - | (Value::Tuple(_), Value::Tuple(_)) - // Numerics compare to each other - | ( - Value::Integer(_) | Value::Real(_) | Value::Decimal(_), - Value::Integer(_) | Value::Real(_) | Value::Decimal(_), - )=> true, - (_, _) => false, - } - } -} - -// `Value` `eq` and `neq` with Missing and Null propagation -pub trait NullableEq { - type Output; - fn eq(&self, rhs: &Self) -> Self::Output; - fn neq(&self, rhs: &Self) -> Self::Output; -} - -// `Value` comparison with Missing and Null propagation -pub trait NullableOrd { - type Output; - - fn lt(&self, rhs: &Self) -> Self::Output; - fn gt(&self, rhs: &Self) -> Self::Output; - fn lteq(&self, rhs: &Self) -> Self::Output; - fn gteq(&self, rhs: &Self) -> Self::Output; -} - -/// A wrapper on [`T`] that specifies if missing and null values should be equal. -#[derive(Eq, PartialEq)] -pub struct EqualityValue<'a, const NULLS_EQUAL: bool, T>(pub &'a T); - -impl NullableEq for EqualityValue<'_, GROUP_NULLS, Value> { - type Output = Value; - - fn eq(&self, rhs: &Self) -> Self::Output { - if GROUP_NULLS { - if let (Value::Missing | Value::Null, Value::Missing | Value::Null) = (self.0, rhs.0) { - return Value::Boolean(true); - } - } else if matches!(self.0, Value::Missing) || matches!(rhs.0, Value::Missing) { - return Value::Missing; - } else if matches!(self.0, Value::Null) || matches!(rhs.0, Value::Null) { - return Value::Null; - } - - match (self.0, rhs.0) { - (Value::Integer(_), Value::Real(_)) => { - Value::from(&coerce_int_to_real(self.0) == rhs.0) - } - (Value::Integer(_), Value::Decimal(_)) => { - Value::from(&coerce_int_or_real_to_decimal(self.0) == rhs.0) - } - (Value::Real(_), Value::Decimal(_)) => { - Value::from(&coerce_int_or_real_to_decimal(self.0) == rhs.0) - } - (Value::Real(_), Value::Integer(_)) => { - Value::from(self.0 == &coerce_int_to_real(rhs.0)) - } - (Value::Decimal(_), Value::Integer(_)) => { - Value::from(self.0 == &coerce_int_or_real_to_decimal(rhs.0)) - } - (Value::Decimal(_), Value::Real(_)) => { - Value::from(self.0 == &coerce_int_or_real_to_decimal(rhs.0)) - } - (_, _) => Value::from(self.0 == rhs.0), - } - } - - fn neq(&self, rhs: &Self) -> Self::Output { - let eq_result = NullableEq::eq(self, rhs); - match eq_result { - Value::Boolean(_) | Value::Null => !eq_result, - _ => Value::Missing, - } - } -} - -impl NullableOrd for Value { - type Output = Self; - - fn lt(&self, rhs: &Self) -> Self::Output { - match (self, rhs) { - (Value::Missing, _) => Value::Missing, - (_, Value::Missing) => Value::Missing, - (Value::Null, _) => Value::Null, - (_, Value::Null) => Value::Null, - (_, _) => { - if self.is_comparable_to(rhs) { - Value::from(self < rhs) - } else { - Value::Missing - } - } - } - } - - fn gt(&self, rhs: &Self) -> Self::Output { - match (self, rhs) { - (Value::Missing, _) => Value::Missing, - (_, Value::Missing) => Value::Missing, - (Value::Null, _) => Value::Null, - (_, Value::Null) => Value::Null, - (_, _) => { - if self.is_comparable_to(rhs) { - Value::from(self > rhs) - } else { - Value::Missing - } - } - } - } - - fn lteq(&self, rhs: &Self) -> Self::Output { - match (self, rhs) { - (Value::Missing, _) => Value::Missing, - (_, Value::Missing) => Value::Missing, - (Value::Null, _) => Value::Null, - (_, Value::Null) => Value::Null, - (_, _) => { - if self.is_comparable_to(rhs) { - Value::from(self <= rhs) - } else { - Value::Missing - } - } - } - } - - fn gteq(&self, rhs: &Self) -> Self::Output { - match (self, rhs) { - (Value::Missing, _) => Value::Missing, - (_, Value::Missing) => Value::Missing, - (Value::Null, _) => Value::Null, - (_, Value::Null) => Value::Null, - (_, _) => { - if self.is_comparable_to(rhs) { - Value::from(self >= rhs) - } else { - Value::Missing - } - } - } - } -} - -fn coerce_f64_to_decimal(real_value: &f64) -> Option { - if !real_value.is_finite() { - None - } else { - Decimal::from_f64(*real_value) - } -} - -fn coerce_int_or_real_to_decimal(value: &Value) -> Value { - match value { - Value::Integer(int_value) => Value::from(rust_decimal::Decimal::from(*int_value)), - Value::Real(real_value) => { - if !real_value.is_finite() { - Value::Missing - } else { - match Decimal::from_f64(real_value.0) { - Some(d_from_r) => Value::from(d_from_r), - None => Value::Missing, // TODO: decide on behavior when float cannot be coerced to Decimal - } - } - } - _ => todo!("Unsupported coercion to Decimal"), - } -} - -fn coerce_int_to_real(value: &Value) -> Value { - match value { - Value::Integer(int_value) => Value::Real(OrderedFloat(*int_value as f64)), - _ => todo!("Unsupported coercion to Real"), - } -} - -impl Value { - #[inline] - #[must_use] - pub fn is_tuple(&self) -> bool { - matches!(self, Value::Tuple(_)) - } - - #[inline] - #[must_use] - pub fn is_list(&self) -> bool { - matches!(self, Value::List(_)) - } - - #[inline] - #[must_use] - pub fn is_bag(&self) -> bool { - matches!(self, Value::Bag(_)) - } - - #[inline] - #[must_use] - pub fn is_sequence(&self) -> bool { - self.is_bag() || self.is_list() - } - - #[inline] - /// Returns true if and only if Value is an integer, real, or decimal - #[must_use] - pub fn is_number(&self) -> bool { - matches!(self, Value::Integer(_) | Value::Real(_) | Value::Decimal(_)) - } - #[inline] - /// Returns true if and only if Value is null or missing - #[must_use] - pub fn is_absent(&self) -> bool { - matches!(self, Value::Missing | Value::Null) - } - - #[inline] - /// Returns true if Value is neither null nor missing - #[must_use] - pub fn is_present(&self) -> bool { - !self.is_absent() - } - - #[inline] - #[must_use] - pub fn is_ordered(&self) -> bool { - self.is_list() - } - - #[inline] - #[must_use] - pub fn coerce_into_tuple(self) -> Tuple { - match self { - Value::Tuple(t) => *t, - _ => self - .into_bindings() - .map(|(k, v)| (k.unwrap_or_else(|| "_1".to_string()), v)) - .collect(), - } - } - - #[inline] - #[must_use] - pub fn coerce_to_tuple(&self) -> Tuple { - match self { - Value::Tuple(t) => t.as_ref().clone(), - _ => { - let fresh = "_1".to_string(); - self.as_bindings() - .map(|(k, v)| (k.unwrap_or(&fresh), v.clone())) - .collect() - } - } - } - - #[inline] - #[must_use] - pub fn as_tuple_ref(&self) -> Cow<'_, Tuple> { - if let Value::Tuple(t) = self { - Cow::Borrowed(t) - } else { - Cow::Owned(self.coerce_to_tuple()) - } - } - - #[inline] - #[must_use] - pub fn as_bindings(&self) -> BindingIter<'_> { - match self { - Value::Tuple(t) => BindingIter::Tuple(t.pairs()), - Value::Missing => BindingIter::Empty, - _ => BindingIter::Single(std::iter::once(self)), - } - } - - #[inline] - #[must_use] - pub fn into_bindings(self) -> BindingIntoIter { - match self { - Value::Tuple(t) => BindingIntoIter::Tuple(t.into_pairs()), - Value::Missing => BindingIntoIter::Empty, - _ => BindingIntoIter::Single(std::iter::once(self)), - } - } - - #[inline] - #[must_use] - pub fn coerce_into_bag(self) -> Bag { - if let Value::Bag(b) = self { - *b - } else { - Bag::from(vec![self]) - } - } - - #[inline] - #[must_use] - pub fn as_bag_ref(&self) -> Cow<'_, Bag> { - if let Value::Bag(b) = self { - Cow::Borrowed(b) - } else { - Cow::Owned(self.clone().coerce_into_bag()) - } - } - - #[inline] - #[must_use] - pub fn coerce_into_list(self) -> List { - if let Value::List(b) = self { - *b - } else { - List::from(vec![self]) - } - } - - #[inline] - #[must_use] - pub fn as_list_ref(&self) -> Cow<'_, List> { - if let Value::List(l) = self { - Cow::Borrowed(l) - } else { - Cow::Owned(self.clone().coerce_into_list()) - } - } - - #[inline] - #[must_use] - pub fn iter(&self) -> ValueIter<'_> { - match self { - Value::Null | Value::Missing => ValueIter::Single(None), - Value::List(list) => ValueIter::List(list.iter()), - Value::Bag(bag) => ValueIter::Bag(bag.iter()), - other => ValueIter::Single(Some(other)), - } - } - - #[inline] - #[must_use] - pub fn sequence_iter(&self) -> Option> { - if self.is_sequence() { - Some(self.iter()) - } else { - None - } - } -} - -#[derive(Debug, Clone)] -pub enum BindingIter<'a> { - Tuple(PairsIter<'a>), - Single(Once<&'a Value>), - Empty, -} - -impl<'a> Iterator for BindingIter<'a> { - type Item = (Option<&'a String>, &'a Value); - - #[inline] - fn next(&mut self) -> Option { - match self { - BindingIter::Tuple(t) => t.next().map(|(k, v)| (Some(k), v)), - BindingIter::Single(single) => single.next().map(|v| (None, v)), - BindingIter::Empty => None, - } - } - - #[inline] - fn size_hint(&self) -> (usize, Option) { - match self { - BindingIter::Tuple(t) => t.size_hint(), - BindingIter::Single(_single) => (1, Some(1)), - BindingIter::Empty => (0, Some(0)), - } - } -} - -#[derive(Debug)] -pub enum BindingIntoIter { - Tuple(PairsIntoIter), - Single(Once), - Empty, -} - -impl Iterator for BindingIntoIter { - type Item = (Option, Value); - - #[inline] - fn next(&mut self) -> Option { - match self { - BindingIntoIter::Tuple(t) => t.next().map(|(k, v)| (Some(k), v)), - BindingIntoIter::Single(single) => single.next().map(|v| (None, v)), - BindingIntoIter::Empty => None, - } - } - - #[inline] - fn size_hint(&self) -> (usize, Option) { - match self { - BindingIntoIter::Tuple(t) => t.size_hint(), - BindingIntoIter::Single(_single) => (1, Some(1)), - BindingIntoIter::Empty => (0, Some(0)), - } - } -} - -#[derive(Debug, Clone)] -pub enum ValueIter<'a> { - List(ListIter<'a>), - Bag(BagIter<'a>), - Single(Option<&'a Value>), -} - -impl<'a> Iterator for ValueIter<'a> { - type Item = &'a Value; - - #[inline] - fn next(&mut self) -> Option { - match self { - ValueIter::List(list) => list.next(), - ValueIter::Bag(bag) => bag.next(), - ValueIter::Single(v) => v.take(), - } - } - - #[inline] - fn size_hint(&self) -> (usize, Option) { - match self { - ValueIter::List(list) => list.size_hint(), - ValueIter::Bag(bag) => bag.size_hint(), - ValueIter::Single(_) => (1, Some(1)), - } - } -} - -impl IntoIterator for Value { - type Item = Value; - type IntoIter = ValueIntoIterator; - - #[inline] - fn into_iter(self) -> ValueIntoIterator { - match self { - Value::List(list) => ValueIntoIterator::List(list.into_iter()), - Value::Bag(bag) => ValueIntoIterator::Bag(bag.into_iter()), - other => ValueIntoIterator::Single(Some(other)), - } - } -} - -pub enum ValueIntoIterator { - List(ListIntoIterator), - Bag(BagIntoIterator), - Single(Option), -} - -impl Iterator for ValueIntoIterator { - type Item = Value; - - #[inline] - fn next(&mut self) -> Option { - match self { - ValueIntoIterator::List(list) => list.next(), - ValueIntoIterator::Bag(bag) => bag.next(), - ValueIntoIterator::Single(v) => v.take(), - } - } - - #[inline] - fn size_hint(&self) -> (usize, Option) { - match self { - ValueIntoIterator::List(list) => list.size_hint(), - ValueIntoIterator::Bag(bag) => bag.size_hint(), - ValueIntoIterator::Single(_) => (1, Some(1)), - } - } -} - -// TODO make debug emit proper PartiQL notation -// TODO perhaps this should be display as well? -impl Debug for Value { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - Value::Null => write!(f, "NULL"), - Value::Missing => write!(f, "MISSING"), - Value::Boolean(b) => write!(f, "{b}"), - Value::Integer(i) => write!(f, "{i}"), - Value::Real(r) => write!(f, "{}", r.0), - Value::Decimal(d) => write!(f, "{d}"), - Value::String(s) => write!(f, "'{s}'"), - Value::Blob(s) => write!(f, "'{s:?}'"), - Value::DateTime(t) => t.fmt(f), - Value::List(l) => l.fmt(f), - Value::Bag(b) => b.fmt(f), - Value::Tuple(t) => t.fmt(f), - } - } -} - -impl PartialOrd for Value { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -/// A wrapper on [`T`] that specifies if a null or missing value should be ordered before -/// ([`NULLS_FIRST`] is true) or after ([`NULLS_FIRST`] is false) other values. -#[derive(Eq, PartialEq)] -pub struct NullSortedValue<'a, const NULLS_FIRST: bool, T>(pub &'a T); - -impl<'a, const NULLS_FIRST: bool, T> PartialOrd for NullSortedValue<'a, NULLS_FIRST, T> -where - T: PartialOrd, - NullSortedValue<'a, NULLS_FIRST, T>: Ord, -{ - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for NullSortedValue<'_, NULLS_FIRST, Value> { - fn cmp(&self, other: &Self) -> Ordering { - let wrap_list = NullSortedValue::<{ NULLS_FIRST }, List>; - let wrap_tuple = NullSortedValue::<{ NULLS_FIRST }, Tuple>; - let wrap_bag = NullSortedValue::<{ NULLS_FIRST }, Bag>; - let null_cond = |order: Ordering| { - if NULLS_FIRST { - order - } else { - order.reverse() - } - }; - - match (self.0, other.0) { - (Value::Null, Value::Null) => Ordering::Equal, - (Value::Missing, Value::Null) => Ordering::Equal, - - (Value::Null, Value::Missing) => Ordering::Equal, - (Value::Null, _) => null_cond(Ordering::Less), - (_, Value::Null) => null_cond(Ordering::Greater), - - (Value::Missing, Value::Missing) => Ordering::Equal, - (Value::Missing, _) => null_cond(Ordering::Less), - (_, Value::Missing) => null_cond(Ordering::Greater), - - (Value::List(l), Value::List(r)) => wrap_list(l.as_ref()).cmp(&wrap_list(r.as_ref())), - - (Value::Tuple(l), Value::Tuple(r)) => { - wrap_tuple(l.as_ref()).cmp(&wrap_tuple(r.as_ref())) - } - - (Value::Bag(l), Value::Bag(r)) => wrap_bag(l.as_ref()).cmp(&wrap_bag(r.as_ref())), - (l, r) => l.cmp(r), - } - } -} - -/// Implementation of spec's `order-by less-than` assuming nulls first. -/// TODO: more tests for Ord on Value -impl Ord for Value { - fn cmp(&self, other: &Self) -> Ordering { - match (self, other) { - (Value::Null, Value::Null) => Ordering::Equal, - (Value::Missing, Value::Null) => Ordering::Equal, - - (Value::Null, Value::Missing) => Ordering::Equal, - (Value::Null, _) => Ordering::Less, - (_, Value::Null) => Ordering::Greater, - - (Value::Missing, Value::Missing) => Ordering::Equal, - (Value::Missing, _) => Ordering::Less, - (_, Value::Missing) => Ordering::Greater, - - (Value::Boolean(l), Value::Boolean(r)) => match (l, r) { - (false, true) => Ordering::Less, - (true, false) => Ordering::Greater, - (_, _) => Ordering::Equal, - }, - (Value::Boolean(_), _) => Ordering::Less, - (_, Value::Boolean(_)) => Ordering::Greater, - - // TODO: `OrderedFloat`'s implementation of `Ord` slightly differs from what we want in - // the PartiQL spec. See https://partiql.org/assets/PartiQL-Specification.pdf#subsection.12.2 - // point 3. In PartiQL, `nan`, comes before `-inf` which comes before all numeric - // values, which are followed by `+inf`. `OrderedFloat` places `NaN` as greater than - // all other `OrderedFloat` values. We could consider creating our own float type - // to get around this annoyance. - (Value::Real(l), Value::Real(r)) => { - if l.is_nan() { - if r.is_nan() { - Ordering::Equal - } else { - Ordering::Less - } - } else if r.is_nan() { - Ordering::Greater - } else { - l.cmp(r) - } - } - (Value::Integer(l), Value::Integer(r)) => l.cmp(r), - (Value::Decimal(l), Value::Decimal(r)) => l.cmp(r), - (Value::Integer(l), Value::Real(_)) => { - Value::Real(ordered_float::OrderedFloat(*l as f64)).cmp(other) - } - (Value::Real(_), Value::Integer(r)) => { - self.cmp(&Value::Real(ordered_float::OrderedFloat(*r as f64))) - } - (Value::Integer(l), Value::Decimal(r)) => RustDecimal::from(*l).cmp(r), - (Value::Decimal(l), Value::Integer(r)) => l.as_ref().cmp(&RustDecimal::from(*r)), - (Value::Real(l), Value::Decimal(r)) => { - if l.is_nan() || l.0 == f64::NEG_INFINITY { - Ordering::Less - } else if l.0 == f64::INFINITY { - Ordering::Greater - } else { - match RustDecimal::from_f64(l.0) { - Some(l_d) => l_d.cmp(r), - None => todo!( - "Decide default behavior when f64 can't be converted to RustDecimal" - ), - } - } - } - (Value::Decimal(l), Value::Real(r)) => { - if r.is_nan() || r.0 == f64::NEG_INFINITY { - Ordering::Greater - } else if r.0 == f64::INFINITY { - Ordering::Less - } else { - match RustDecimal::from_f64(r.0) { - Some(r_d) => l.as_ref().cmp(&r_d), - None => todo!( - "Decide default behavior when f64 can't be converted to RustDecimal" - ), - } - } - } - (Value::Integer(_), _) => Ordering::Less, - (Value::Real(_), _) => Ordering::Less, - (Value::Decimal(_), _) => Ordering::Less, - (_, Value::Integer(_)) => Ordering::Greater, - (_, Value::Real(_)) => Ordering::Greater, - (_, Value::Decimal(_)) => Ordering::Greater, - - (Value::DateTime(l), Value::DateTime(r)) => l.cmp(r), - (Value::DateTime(_), _) => Ordering::Less, - (_, Value::DateTime(_)) => Ordering::Greater, - - (Value::String(l), Value::String(r)) => l.cmp(r), - (Value::String(_), _) => Ordering::Less, - (_, Value::String(_)) => Ordering::Greater, - - (Value::Blob(l), Value::Blob(r)) => l.cmp(r), - (Value::Blob(_), _) => Ordering::Less, - (_, Value::Blob(_)) => Ordering::Greater, - - (Value::List(l), Value::List(r)) => l.cmp(r), - (Value::List(_), _) => Ordering::Less, - (_, Value::List(_)) => Ordering::Greater, - - (Value::Tuple(l), Value::Tuple(r)) => l.cmp(r), - (Value::Tuple(_), _) => Ordering::Less, - (_, Value::Tuple(_)) => Ordering::Greater, - - (Value::Bag(l), Value::Bag(r)) => l.cmp(r), - } - } -} - -impl From<&T> for Value -where - T: Copy, - Value: From, -{ - #[inline] - fn from(t: &T) -> Self { - Value::from(*t) - } -} - -impl From for Value { - #[inline] - fn from(b: bool) -> Self { - Value::Boolean(b) - } -} - -impl From for Value { - #[inline] - fn from(s: String) -> Self { - Value::String(Box::new(s)) - } -} - -impl From<&str> for Value { - #[inline] - fn from(s: &str) -> Self { - Value::String(Box::new(s.to_string())) - } -} - -impl From for Value { - #[inline] - fn from(n: i64) -> Self { - Value::Integer(n) - } -} - -impl From for Value { - #[inline] - fn from(n: i32) -> Self { - i64::from(n).into() - } -} - -impl From for Value { - #[inline] - fn from(n: i16) -> Self { - i64::from(n).into() - } -} - -impl From for Value { - #[inline] - fn from(n: i8) -> Self { - i64::from(n).into() - } -} - -impl From for Value { - #[inline] - fn from(n: usize) -> Self { - // TODO overflow to bigint/decimal - Value::Integer(n as i64) - } -} - -impl From for Value { - #[inline] - fn from(n: u8) -> Self { - (n as usize).into() - } -} - -impl From for Value { - #[inline] - fn from(n: u16) -> Self { - (n as usize).into() - } -} - -impl From for Value { - #[inline] - fn from(n: u32) -> Self { - (n as usize).into() - } -} - -impl From for Value { - #[inline] - fn from(n: u64) -> Self { - (n as usize).into() - } -} - -impl From for Value { - #[inline] - fn from(n: u128) -> Self { - (n as usize).into() - } -} - -impl From for Value { - #[inline] - fn from(f: f64) -> Self { - Value::Real(OrderedFloat(f)) - } -} - -impl From for Value { - #[inline] - fn from(d: RustDecimal) -> Self { - Value::Decimal(Box::new(d)) - } -} - -impl From for Value { - #[inline] - fn from(t: DateTime) -> Self { - Value::DateTime(Box::new(t)) - } -} - -impl From for Value { - #[inline] - fn from(v: List) -> Self { - Value::List(Box::new(v)) - } -} - -impl From for Value { - #[inline] - fn from(v: Tuple) -> Self { - Value::Tuple(Box::new(v)) - } -} - -impl From for Value { - #[inline] - fn from(v: Bag) -> Self { - Value::Bag(Box::new(v)) - } -} - #[cfg(test)] mod tests { use super::*; + use crate::comparison::{EqualityValue, NullableEq, NullableOrd}; + use crate::sort::NullSortedValue; + use ordered_float::OrderedFloat; + use rust_decimal::Decimal as RustDecimal; use rust_decimal_macros::dec; use std::borrow::Cow; use std::cell::RefCell; + use std::cmp::Ordering; use std::collections::HashSet; use std::mem; use std::rc::Rc; @@ -1445,25 +275,25 @@ mod tests { assert_eq!(Value::Integer(2), &Value::Integer(1) * &Value::Integer(2)); assert_eq!(Value::from(3.75), &Value::from(1.5) * &Value::from(2.5)); assert_eq!( - Value::from(Decimal::new(2, 0)), + Value::from(RustDecimal::new(2, 0)), &Value::Decimal(Box::new(dec!(1))) * &Value::from(dec!(2)) ); assert_eq!(Value::from(2.5), &Value::Integer(1) * &Value::from(2.5)); assert_eq!(Value::from(2.), &Value::from(1.) * &Value::from(2.)); assert_eq!( - Value::from(Decimal::new(2, 0)), + Value::from(RustDecimal::new(2, 0)), &Value::Integer(1) * &Value::Decimal(Box::new(dec!(2))) ); assert_eq!( - Value::from(Decimal::new(2, 0)), + Value::from(RustDecimal::new(2, 0)), &Value::Decimal(Box::new(dec!(1))) * &Value::Integer(2) ); assert_eq!( - Value::from(Decimal::new(2, 0)), + Value::from(RustDecimal::new(2, 0)), &Value::from(1.) * &Value::Decimal(Box::new(dec!(2))) ); assert_eq!( - Value::from(Decimal::new(2, 0)), + Value::from(RustDecimal::new(2, 0)), &Value::Decimal(Box::new(dec!(1))) * &Value::from(2.) ); diff --git a/partiql-value/src/list.rs b/partiql-value/src/list.rs index a7d66de9..224b16e6 100644 --- a/partiql-value/src/list.rs +++ b/partiql-value/src/list.rs @@ -5,7 +5,8 @@ use std::hash::{Hash, Hasher}; use std::{slice, vec}; -use crate::{Bag, EqualityValue, NullSortedValue, NullableEq, Value}; +use crate::sort::NullSortedValue; +use crate::{Bag, EqualityValue, NullableEq, Value}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; diff --git a/partiql-value/src/sort.rs b/partiql-value/src/sort.rs new file mode 100644 index 00000000..96faa0cf --- /dev/null +++ b/partiql-value/src/sort.rs @@ -0,0 +1,54 @@ +use crate::{Bag, List, Tuple, Value}; +use std::cmp::Ordering; + +/// A wrapper on [`T`] that specifies if a null or missing value should be ordered before +/// ([`NULLS_FIRST`] is true) or after ([`NULLS_FIRST`] is false) other values. +#[derive(Eq, PartialEq)] +pub struct NullSortedValue<'a, const NULLS_FIRST: bool, T>(pub &'a T); + +impl<'a, const NULLS_FIRST: bool, T> PartialOrd for NullSortedValue<'a, NULLS_FIRST, T> +where + T: PartialOrd, + NullSortedValue<'a, NULLS_FIRST, T>: Ord, +{ + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for NullSortedValue<'_, NULLS_FIRST, Value> { + fn cmp(&self, other: &Self) -> Ordering { + let wrap_list = NullSortedValue::<{ NULLS_FIRST }, List>; + let wrap_tuple = NullSortedValue::<{ NULLS_FIRST }, Tuple>; + let wrap_bag = NullSortedValue::<{ NULLS_FIRST }, Bag>; + let null_cond = |order: Ordering| { + if NULLS_FIRST { + order + } else { + order.reverse() + } + }; + + match (self.0, other.0) { + (Value::Null, Value::Null) => Ordering::Equal, + (Value::Missing, Value::Null) => Ordering::Equal, + + (Value::Null, Value::Missing) => Ordering::Equal, + (Value::Null, _) => null_cond(Ordering::Less), + (_, Value::Null) => null_cond(Ordering::Greater), + + (Value::Missing, Value::Missing) => Ordering::Equal, + (Value::Missing, _) => null_cond(Ordering::Less), + (_, Value::Missing) => null_cond(Ordering::Greater), + + (Value::List(l), Value::List(r)) => wrap_list(l.as_ref()).cmp(&wrap_list(r.as_ref())), + + (Value::Tuple(l), Value::Tuple(r)) => { + wrap_tuple(l.as_ref()).cmp(&wrap_tuple(r.as_ref())) + } + + (Value::Bag(l), Value::Bag(r)) => wrap_bag(l.as_ref()).cmp(&wrap_bag(r.as_ref())), + (l, r) => l.cmp(r), + } + } +} diff --git a/partiql-value/src/tuple.rs b/partiql-value/src/tuple.rs index 2c7edf9a..9559eb10 100644 --- a/partiql-value/src/tuple.rs +++ b/partiql-value/src/tuple.rs @@ -9,7 +9,8 @@ use std::vec; use unicase::UniCase; -use crate::{BindingsName, EqualityValue, NullSortedValue, NullableEq, Value}; +use crate::sort::NullSortedValue; +use crate::{BindingsName, EqualityValue, NullableEq, Value}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; diff --git a/partiql-value/src/util.rs b/partiql-value/src/util.rs new file mode 100644 index 00000000..d97c6cba --- /dev/null +++ b/partiql-value/src/util.rs @@ -0,0 +1,36 @@ +use crate::Value; +use ordered_float::OrderedFloat; +use rust_decimal::prelude::FromPrimitive; +use rust_decimal::Decimal; + +pub fn coerce_f64_to_decimal(real_value: &f64) -> Option { + if !real_value.is_finite() { + None + } else { + Decimal::from_f64(*real_value) + } +} + +pub fn coerce_int_or_real_to_decimal(value: &Value) -> Value { + match value { + Value::Integer(int_value) => Value::from(rust_decimal::Decimal::from(*int_value)), + Value::Real(real_value) => { + if !real_value.is_finite() { + Value::Missing + } else { + match Decimal::from_f64(real_value.0) { + Some(d_from_r) => Value::from(d_from_r), + None => Value::Missing, // TODO: decide on behavior when float cannot be coerced to Decimal + } + } + } + _ => todo!("Unsupported coercion to Decimal"), + } +} + +pub fn coerce_int_to_real(value: &Value) -> Value { + match value { + Value::Integer(int_value) => Value::Real(OrderedFloat(*int_value as f64)), + _ => todo!("Unsupported coercion to Real"), + } +} diff --git a/partiql-value/src/value.rs b/partiql-value/src/value.rs new file mode 100644 index 00000000..7022f37d --- /dev/null +++ b/partiql-value/src/value.rs @@ -0,0 +1,501 @@ +use ordered_float::OrderedFloat; +use std::borrow::Cow; +use std::fmt::{Debug, Display, Formatter}; +use std::hash::Hash; + +use rust_decimal::Decimal as RustDecimal; + +use crate::{Bag, BindingIntoIter, BindingIter, DateTime, List, Tuple}; +use rust_decimal::prelude::FromPrimitive; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +mod iter; +mod logic; +mod math; + +pub use iter::*; +pub use logic::*; +pub use math::*; +use partiql_common::pretty::ToPretty; +use std::cmp::Ordering; + +#[derive(Hash, PartialEq, Eq, Clone, Default)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum Value { + Null, + #[default] + Missing, + Boolean(bool), + Integer(i64), + Real(OrderedFloat), + Decimal(Box), + String(Box), + Blob(Box>), + DateTime(Box), + List(Box), + Bag(Box), + Tuple(Box), + // TODO: add other supported PartiQL values -- sexp +} + +impl Value { + #[inline] + #[must_use] + pub fn is_tuple(&self) -> bool { + matches!(self, Value::Tuple(_)) + } + + #[inline] + #[must_use] + pub fn is_list(&self) -> bool { + matches!(self, Value::List(_)) + } + + #[inline] + #[must_use] + pub fn is_bag(&self) -> bool { + matches!(self, Value::Bag(_)) + } + + #[inline] + #[must_use] + pub fn is_sequence(&self) -> bool { + self.is_bag() || self.is_list() + } + + #[inline] + /// Returns true if and only if Value is an integer, real, or decimal + #[must_use] + pub fn is_number(&self) -> bool { + matches!(self, Value::Integer(_) | Value::Real(_) | Value::Decimal(_)) + } + #[inline] + /// Returns true if and only if Value is null or missing + #[must_use] + pub fn is_absent(&self) -> bool { + matches!(self, Value::Missing | Value::Null) + } + + #[inline] + /// Returns true if Value is neither null nor missing + #[must_use] + pub fn is_present(&self) -> bool { + !self.is_absent() + } + + #[inline] + #[must_use] + pub fn is_ordered(&self) -> bool { + self.is_list() + } + + #[inline] + #[must_use] + pub fn coerce_into_tuple(self) -> Tuple { + match self { + Value::Tuple(t) => *t, + _ => self + .into_bindings() + .map(|(k, v)| (k.unwrap_or_else(|| "_1".to_string()), v)) + .collect(), + } + } + + #[inline] + #[must_use] + pub fn coerce_to_tuple(&self) -> Tuple { + match self { + Value::Tuple(t) => t.as_ref().clone(), + _ => { + let fresh = "_1".to_string(); + self.as_bindings() + .map(|(k, v)| (k.unwrap_or(&fresh), v.clone())) + .collect() + } + } + } + + #[inline] + #[must_use] + pub fn as_tuple_ref(&self) -> Cow<'_, Tuple> { + if let Value::Tuple(t) = self { + Cow::Borrowed(t) + } else { + Cow::Owned(self.coerce_to_tuple()) + } + } + + #[inline] + #[must_use] + pub fn as_bindings(&self) -> BindingIter<'_> { + match self { + Value::Tuple(t) => BindingIter::Tuple(t.pairs()), + Value::Missing => BindingIter::Empty, + _ => BindingIter::Single(std::iter::once(self)), + } + } + + #[inline] + #[must_use] + pub fn into_bindings(self) -> BindingIntoIter { + match self { + Value::Tuple(t) => BindingIntoIter::Tuple(t.into_pairs()), + Value::Missing => BindingIntoIter::Empty, + _ => BindingIntoIter::Single(std::iter::once(self)), + } + } + + #[inline] + #[must_use] + pub fn coerce_into_bag(self) -> Bag { + if let Value::Bag(b) = self { + *b + } else { + Bag::from(vec![self]) + } + } + + #[inline] + #[must_use] + pub fn as_bag_ref(&self) -> Cow<'_, Bag> { + if let Value::Bag(b) = self { + Cow::Borrowed(b) + } else { + Cow::Owned(self.clone().coerce_into_bag()) + } + } + + #[inline] + #[must_use] + pub fn coerce_into_list(self) -> List { + if let Value::List(b) = self { + *b + } else { + List::from(vec![self]) + } + } + + #[inline] + #[must_use] + pub fn as_list_ref(&self) -> Cow<'_, List> { + if let Value::List(l) = self { + Cow::Borrowed(l) + } else { + Cow::Owned(self.clone().coerce_into_list()) + } + } + + #[inline] + #[must_use] + pub fn iter(&self) -> ValueIter<'_> { + match self { + Value::Null | Value::Missing => ValueIter::Single(None), + Value::List(list) => ValueIter::List(list.iter()), + Value::Bag(bag) => ValueIter::Bag(bag.iter()), + other => ValueIter::Single(Some(other)), + } + } + + #[inline] + #[must_use] + pub fn sequence_iter(&self) -> Option> { + if self.is_sequence() { + Some(self.iter()) + } else { + None + } + } +} + +impl Display for Value { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self.to_pretty_string(f.width().unwrap_or(80)) { + Ok(pretty) => f.write_str(&pretty), + Err(_) => f.write_str(""), + } + } +} + +impl Debug for Value { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Value::Null => write!(f, "NULL"), + Value::Missing => write!(f, "MISSING"), + Value::Boolean(b) => write!(f, "{b}"), + Value::Integer(i) => write!(f, "{i}"), + Value::Real(r) => write!(f, "{}", r.0), + Value::Decimal(d) => write!(f, "{d}"), + Value::String(s) => write!(f, "'{s}'"), + Value::Blob(s) => write!(f, "'{s:?}'"), + Value::DateTime(t) => t.fmt(f), + Value::List(l) => l.fmt(f), + Value::Bag(b) => b.fmt(f), + Value::Tuple(t) => t.fmt(f), + } + } +} + +impl PartialOrd for Value { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +/// Implementation of spec's `order-by less-than` assuming nulls first. +/// TODO: more tests for Ord on Value +impl Ord for Value { + fn cmp(&self, other: &Self) -> Ordering { + match (self, other) { + (Value::Null, Value::Null) => Ordering::Equal, + (Value::Missing, Value::Null) => Ordering::Equal, + + (Value::Null, Value::Missing) => Ordering::Equal, + (Value::Null, _) => Ordering::Less, + (_, Value::Null) => Ordering::Greater, + + (Value::Missing, Value::Missing) => Ordering::Equal, + (Value::Missing, _) => Ordering::Less, + (_, Value::Missing) => Ordering::Greater, + + (Value::Boolean(l), Value::Boolean(r)) => match (l, r) { + (false, true) => Ordering::Less, + (true, false) => Ordering::Greater, + (_, _) => Ordering::Equal, + }, + (Value::Boolean(_), _) => Ordering::Less, + (_, Value::Boolean(_)) => Ordering::Greater, + + // TODO: `OrderedFloat`'s implementation of `Ord` slightly differs from what we want in + // the PartiQL spec. See https://partiql.org/assets/PartiQL-Specification.pdf#subsection.12.2 + // point 3. In PartiQL, `nan`, comes before `-inf` which comes before all numeric + // values, which are followed by `+inf`. `OrderedFloat` places `NaN` as greater than + // all other `OrderedFloat` values. We could consider creating our own float type + // to get around this annoyance. + (Value::Real(l), Value::Real(r)) => { + if l.is_nan() { + if r.is_nan() { + Ordering::Equal + } else { + Ordering::Less + } + } else if r.is_nan() { + Ordering::Greater + } else { + l.cmp(r) + } + } + (Value::Integer(l), Value::Integer(r)) => l.cmp(r), + (Value::Decimal(l), Value::Decimal(r)) => l.cmp(r), + (Value::Integer(l), Value::Real(_)) => { + Value::Real(ordered_float::OrderedFloat(*l as f64)).cmp(other) + } + (Value::Real(_), Value::Integer(r)) => { + self.cmp(&Value::Real(ordered_float::OrderedFloat(*r as f64))) + } + (Value::Integer(l), Value::Decimal(r)) => RustDecimal::from(*l).cmp(r), + (Value::Decimal(l), Value::Integer(r)) => l.as_ref().cmp(&RustDecimal::from(*r)), + (Value::Real(l), Value::Decimal(r)) => { + if l.is_nan() || l.0 == f64::NEG_INFINITY { + Ordering::Less + } else if l.0 == f64::INFINITY { + Ordering::Greater + } else { + match RustDecimal::from_f64(l.0) { + Some(l_d) => l_d.cmp(r), + None => todo!( + "Decide default behavior when f64 can't be converted to RustDecimal" + ), + } + } + } + (Value::Decimal(l), Value::Real(r)) => { + if r.is_nan() || r.0 == f64::NEG_INFINITY { + Ordering::Greater + } else if r.0 == f64::INFINITY { + Ordering::Less + } else { + match RustDecimal::from_f64(r.0) { + Some(r_d) => l.as_ref().cmp(&r_d), + None => todo!( + "Decide default behavior when f64 can't be converted to RustDecimal" + ), + } + } + } + (Value::Integer(_), _) => Ordering::Less, + (Value::Real(_), _) => Ordering::Less, + (Value::Decimal(_), _) => Ordering::Less, + (_, Value::Integer(_)) => Ordering::Greater, + (_, Value::Real(_)) => Ordering::Greater, + (_, Value::Decimal(_)) => Ordering::Greater, + + (Value::DateTime(l), Value::DateTime(r)) => l.cmp(r), + (Value::DateTime(_), _) => Ordering::Less, + (_, Value::DateTime(_)) => Ordering::Greater, + + (Value::String(l), Value::String(r)) => l.cmp(r), + (Value::String(_), _) => Ordering::Less, + (_, Value::String(_)) => Ordering::Greater, + + (Value::Blob(l), Value::Blob(r)) => l.cmp(r), + (Value::Blob(_), _) => Ordering::Less, + (_, Value::Blob(_)) => Ordering::Greater, + + (Value::List(l), Value::List(r)) => l.cmp(r), + (Value::List(_), _) => Ordering::Less, + (_, Value::List(_)) => Ordering::Greater, + + (Value::Tuple(l), Value::Tuple(r)) => l.cmp(r), + (Value::Tuple(_), _) => Ordering::Less, + (_, Value::Tuple(_)) => Ordering::Greater, + + (Value::Bag(l), Value::Bag(r)) => l.cmp(r), + } + } +} + +impl From<&T> for Value +where + T: Copy, + Value: From, +{ + #[inline] + fn from(t: &T) -> Self { + Value::from(*t) + } +} + +impl From for Value { + #[inline] + fn from(b: bool) -> Self { + Value::Boolean(b) + } +} + +impl From for Value { + #[inline] + fn from(s: String) -> Self { + Value::String(Box::new(s)) + } +} + +impl From<&str> for Value { + #[inline] + fn from(s: &str) -> Self { + Value::String(Box::new(s.to_string())) + } +} + +impl From for Value { + #[inline] + fn from(n: i64) -> Self { + Value::Integer(n) + } +} + +impl From for Value { + #[inline] + fn from(n: i32) -> Self { + i64::from(n).into() + } +} + +impl From for Value { + #[inline] + fn from(n: i16) -> Self { + i64::from(n).into() + } +} + +impl From for Value { + #[inline] + fn from(n: i8) -> Self { + i64::from(n).into() + } +} + +impl From for Value { + #[inline] + fn from(n: usize) -> Self { + // TODO overflow to bigint/decimal + Value::Integer(n as i64) + } +} + +impl From for Value { + #[inline] + fn from(n: u8) -> Self { + (n as usize).into() + } +} + +impl From for Value { + #[inline] + fn from(n: u16) -> Self { + (n as usize).into() + } +} + +impl From for Value { + #[inline] + fn from(n: u32) -> Self { + (n as usize).into() + } +} + +impl From for Value { + #[inline] + fn from(n: u64) -> Self { + (n as usize).into() + } +} + +impl From for Value { + #[inline] + fn from(n: u128) -> Self { + (n as usize).into() + } +} + +impl From for Value { + #[inline] + fn from(f: f64) -> Self { + Value::Real(OrderedFloat(f)) + } +} + +impl From for Value { + #[inline] + fn from(d: RustDecimal) -> Self { + Value::Decimal(Box::new(d)) + } +} + +impl From for Value { + #[inline] + fn from(t: DateTime) -> Self { + Value::DateTime(Box::new(t)) + } +} + +impl From for Value { + #[inline] + fn from(v: List) -> Self { + Value::List(Box::new(v)) + } +} + +impl From for Value { + #[inline] + fn from(v: Tuple) -> Self { + Value::Tuple(Box::new(v)) + } +} + +impl From for Value { + #[inline] + fn from(v: Bag) -> Self { + Value::Bag(Box::new(v)) + } +} diff --git a/partiql-value/src/value/iter.rs b/partiql-value/src/value/iter.rs new file mode 100644 index 00000000..81ff9d4a --- /dev/null +++ b/partiql-value/src/value/iter.rs @@ -0,0 +1,72 @@ +use crate::{BagIntoIterator, BagIter, ListIntoIterator, ListIter, Value}; + +#[derive(Debug, Clone)] +pub enum ValueIter<'a> { + List(ListIter<'a>), + Bag(BagIter<'a>), + Single(Option<&'a Value>), +} + +impl<'a> Iterator for ValueIter<'a> { + type Item = &'a Value; + + #[inline] + fn next(&mut self) -> Option { + match self { + ValueIter::List(list) => list.next(), + ValueIter::Bag(bag) => bag.next(), + ValueIter::Single(v) => v.take(), + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + match self { + ValueIter::List(list) => list.size_hint(), + ValueIter::Bag(bag) => bag.size_hint(), + ValueIter::Single(_) => (1, Some(1)), + } + } +} + +impl IntoIterator for Value { + type Item = Value; + type IntoIter = ValueIntoIterator; + + #[inline] + fn into_iter(self) -> ValueIntoIterator { + match self { + Value::List(list) => ValueIntoIterator::List(list.into_iter()), + Value::Bag(bag) => ValueIntoIterator::Bag(bag.into_iter()), + other => ValueIntoIterator::Single(Some(other)), + } + } +} + +pub enum ValueIntoIterator { + List(ListIntoIterator), + Bag(BagIntoIterator), + Single(Option), +} + +impl Iterator for ValueIntoIterator { + type Item = Value; + + #[inline] + fn next(&mut self) -> Option { + match self { + ValueIntoIterator::List(list) => list.next(), + ValueIntoIterator::Bag(bag) => bag.next(), + ValueIntoIterator::Single(v) => v.take(), + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + match self { + ValueIntoIterator::List(list) => list.size_hint(), + ValueIntoIterator::Bag(bag) => bag.size_hint(), + ValueIntoIterator::Single(_) => (1, Some(1)), + } + } +} diff --git a/partiql-value/src/value/logic.rs b/partiql-value/src/value/logic.rs new file mode 100644 index 00000000..51f9c7e6 --- /dev/null +++ b/partiql-value/src/value/logic.rs @@ -0,0 +1,78 @@ +use crate::Value; +use std::ops; + +impl ops::Not for &Value { + type Output = Value; + + fn not(self) -> Self::Output { + match self { + Value::Boolean(b) => Value::from(!b), + Value::Null | Value::Missing => Value::Null, + _ => Value::Missing, // data type mismatch => Missing + } + } +} + +impl ops::Not for Value { + type Output = Self; + + fn not(self) -> Self::Output { + match self { + Value::Boolean(b) => Value::from(!b), + Value::Null | Value::Missing => Value::Null, + _ => Value::Missing, // data type mismatch => Missing + } + } +} + +pub trait BinaryAnd { + type Output; + + fn and(&self, rhs: &Self) -> Self::Output; +} + +impl BinaryAnd for Value { + type Output = Self; + fn and(&self, rhs: &Self) -> Self::Output { + match (self, rhs) { + (Value::Boolean(l), Value::Boolean(r)) => Value::from(*l && *r), + (Value::Null | Value::Missing, Value::Boolean(false)) + | (Value::Boolean(false), Value::Null | Value::Missing) => Value::from(false), + _ => { + if matches!(self, Value::Missing | Value::Null | Value::Boolean(true)) + && matches!(rhs, Value::Missing | Value::Null | Value::Boolean(true)) + { + Value::Null + } else { + Value::Missing + } + } + } + } +} + +pub trait BinaryOr { + type Output; + + fn or(&self, rhs: &Self) -> Self::Output; +} + +impl BinaryOr for Value { + type Output = Self; + fn or(&self, rhs: &Self) -> Self::Output { + match (self, rhs) { + (Value::Boolean(l), Value::Boolean(r)) => Value::from(*l || *r), + (Value::Null | Value::Missing, Value::Boolean(true)) + | (Value::Boolean(true), Value::Null | Value::Missing) => Value::from(true), + _ => { + if matches!(self, Value::Missing | Value::Null | Value::Boolean(false)) + && matches!(rhs, Value::Missing | Value::Null | Value::Boolean(false)) + { + Value::Null + } else { + Value::Missing + } + } + } + } +} diff --git a/partiql-value/src/value/math.rs b/partiql-value/src/value/math.rs new file mode 100644 index 00000000..a8be2ffd --- /dev/null +++ b/partiql-value/src/value/math.rs @@ -0,0 +1,247 @@ +use crate::util; +use crate::Value; +use std::ops; + +impl ops::Add for &Value { + type Output = Value; + + fn add(self, rhs: Self) -> Self::Output { + match (&self, &rhs) { + // TODO: edge cases dealing with overflow + (Value::Missing, _) => Value::Missing, + (_, Value::Missing) => Value::Missing, + (Value::Null, _) => Value::Null, + (_, Value::Null) => Value::Null, + (Value::Integer(l), Value::Integer(r)) => Value::Integer(l + r), + (Value::Real(l), Value::Real(r)) => Value::Real(*l + *r), + (Value::Decimal(l), Value::Decimal(r)) => { + Value::Decimal(Box::new(l.as_ref() + r.as_ref())) + } + (Value::Integer(_), Value::Real(_)) => &util::coerce_int_to_real(self) + rhs, + (Value::Integer(_), Value::Decimal(_)) => { + &util::coerce_int_or_real_to_decimal(self) + rhs + } + (Value::Real(_), Value::Decimal(_)) => &util::coerce_int_or_real_to_decimal(self) + rhs, + (Value::Real(_), Value::Integer(_)) => self + &util::coerce_int_to_real(rhs), + (Value::Decimal(_), Value::Integer(_)) => { + self + &util::coerce_int_or_real_to_decimal(rhs) + } + (Value::Decimal(_), Value::Real(_)) => self + &util::coerce_int_or_real_to_decimal(rhs), + _ => Value::Missing, // data type mismatch => Missing + } + } +} + +impl ops::AddAssign<&Value> for Value { + fn add_assign(&mut self, rhs: &Value) { + match (self, &rhs) { + // TODO: edge cases dealing with overflow + (Value::Missing, _) => {} + (this, Value::Missing) => *this = Value::Missing, + (Value::Null, _) => {} + (this, Value::Null) => *this = Value::Null, + + (Value::Integer(l), Value::Integer(r)) => l.add_assign(r), + + (Value::Real(l), Value::Real(r)) => l.add_assign(r), + (Value::Real(l), Value::Integer(i)) => l.add_assign(*i as f64), + + (Value::Decimal(l), Value::Decimal(r)) => l.add_assign(r.as_ref()), + (Value::Decimal(l), Value::Integer(i)) => l.add_assign(rust_decimal::Decimal::from(*i)), + (Value::Decimal(l), Value::Real(r)) => match util::coerce_f64_to_decimal(r) { + Some(d) => l.add_assign(d), + None => todo!(), + }, + + (this, Value::Real(r)) => { + *this = match &this { + Value::Integer(l) => Value::from((*l as f64) + r.0), + _ => Value::Missing, + }; + } + (this, Value::Decimal(r)) => { + *this = match &this { + Value::Integer(l) => { + Value::Decimal(Box::new(rust_decimal::Decimal::from(*l) + r.as_ref())) + } + Value::Real(l) => match util::coerce_f64_to_decimal(&l.0) { + None => Value::Missing, + Some(d) => Value::Decimal(Box::new(d + r.as_ref())), + }, + _ => Value::Missing, + }; + } + (this, _) => *this = Value::Missing, // data type mismatch => Missing + } + } +} + +pub trait UnaryPlus { + type Output; + + fn positive(self) -> Self::Output; +} + +impl UnaryPlus for Value { + type Output = Self; + fn positive(self) -> Self::Output { + match self { + Value::Null => Value::Null, + Value::Missing => Value::Missing, + Value::Integer(_) | Value::Real(_) | Value::Decimal(_) => self, + _ => Value::Missing, // data type mismatch => Missing + } + } +} + +impl ops::Sub for &Value { + type Output = Value; + + fn sub(self, rhs: Self) -> Self::Output { + match (&self, &rhs) { + // TODO: edge cases dealing with overflow + (Value::Missing, _) => Value::Missing, + (_, Value::Missing) => Value::Missing, + (Value::Null, _) => Value::Null, + (_, Value::Null) => Value::Null, + (Value::Integer(l), Value::Integer(r)) => Value::Integer(l - r), + (Value::Real(l), Value::Real(r)) => Value::Real(*l - *r), + (Value::Decimal(l), Value::Decimal(r)) => { + Value::Decimal(Box::new(l.as_ref() - r.as_ref())) + } + (Value::Integer(_), Value::Real(_)) => &util::coerce_int_to_real(self) - rhs, + (Value::Integer(_), Value::Decimal(_)) => { + &util::coerce_int_or_real_to_decimal(self) - rhs + } + (Value::Real(_), Value::Decimal(_)) => &util::coerce_int_or_real_to_decimal(self) - rhs, + (Value::Real(_), Value::Integer(_)) => self - &util::coerce_int_to_real(rhs), + (Value::Decimal(_), Value::Integer(_)) => { + self - &util::coerce_int_or_real_to_decimal(rhs) + } + (Value::Decimal(_), Value::Real(_)) => self - &util::coerce_int_or_real_to_decimal(rhs), + _ => Value::Missing, // data type mismatch => Missing + } + } +} + +impl ops::Mul for &Value { + type Output = Value; + + fn mul(self, rhs: Self) -> Self::Output { + match (&self, &rhs) { + // TODO: edge cases dealing with overflow + (Value::Missing, _) => Value::Missing, + (_, Value::Missing) => Value::Missing, + (Value::Null, _) => Value::Null, + (_, Value::Null) => Value::Null, + (Value::Integer(l), Value::Integer(r)) => Value::Integer(l * r), + (Value::Real(l), Value::Real(r)) => Value::Real(*l * *r), + (Value::Decimal(l), Value::Decimal(r)) => { + Value::Decimal(Box::new(l.as_ref() * r.as_ref())) + } + (Value::Integer(_), Value::Real(_)) => &util::coerce_int_to_real(self) * rhs, + (Value::Integer(_), Value::Decimal(_)) => { + &util::coerce_int_or_real_to_decimal(self) * rhs + } + (Value::Real(_), Value::Decimal(_)) => &util::coerce_int_or_real_to_decimal(self) * rhs, + (Value::Real(_), Value::Integer(_)) => self * &util::coerce_int_to_real(rhs), + (Value::Decimal(_), Value::Integer(_)) => { + self * &util::coerce_int_or_real_to_decimal(rhs) + } + (Value::Decimal(_), Value::Real(_)) => self * &util::coerce_int_or_real_to_decimal(rhs), + _ => Value::Missing, // data type mismatch => Missing + } + } +} + +impl ops::Div for &Value { + type Output = Value; + + fn div(self, rhs: Self) -> Self::Output { + match (&self, &rhs) { + // TODO: edge cases dealing with division by 0 + (Value::Missing, _) => Value::Missing, + (_, Value::Missing) => Value::Missing, + (Value::Null, _) => Value::Null, + (_, Value::Null) => Value::Null, + (Value::Integer(l), Value::Integer(r)) => Value::Integer(l / r), + (Value::Real(l), Value::Real(r)) => Value::Real(*l / *r), + (Value::Decimal(l), Value::Decimal(r)) => { + Value::Decimal(Box::new(l.as_ref() / r.as_ref())) + } + (Value::Integer(_), Value::Real(_)) => &util::coerce_int_to_real(self) / rhs, + (Value::Integer(_), Value::Decimal(_)) => { + &util::coerce_int_or_real_to_decimal(self) / rhs + } + (Value::Real(_), Value::Decimal(_)) => &util::coerce_int_or_real_to_decimal(self) / rhs, + (Value::Real(_), Value::Integer(_)) => self / &util::coerce_int_to_real(rhs), + (Value::Decimal(_), Value::Integer(_)) => { + self / &util::coerce_int_or_real_to_decimal(rhs) + } + (Value::Decimal(_), Value::Real(_)) => self / &util::coerce_int_or_real_to_decimal(rhs), + _ => Value::Missing, // data type mismatch => Missing + } + } +} + +impl ops::Rem for &Value { + type Output = Value; + + fn rem(self, rhs: Self) -> Self::Output { + match (&self, &rhs) { + // TODO: edge cases dealing with division by 0 + (Value::Missing, _) => Value::Missing, + (_, Value::Missing) => Value::Missing, + (Value::Null, _) => Value::Null, + (_, Value::Null) => Value::Null, + (Value::Integer(l), Value::Integer(r)) => Value::Integer(l % r), + (Value::Real(l), Value::Real(r)) => Value::Real(*l % *r), + (Value::Decimal(l), Value::Decimal(r)) => { + Value::Decimal(Box::new(l.as_ref() % r.as_ref())) + } + (Value::Integer(_), Value::Real(_)) => &util::coerce_int_to_real(self) % rhs, + (Value::Integer(_), Value::Decimal(_)) => { + &util::coerce_int_or_real_to_decimal(self) % rhs + } + (Value::Real(_), Value::Decimal(_)) => &util::coerce_int_or_real_to_decimal(self) % rhs, + (Value::Real(_), Value::Integer(_)) => self % &util::coerce_int_to_real(rhs), + (Value::Decimal(_), Value::Integer(_)) => { + self % &util::coerce_int_or_real_to_decimal(rhs) + } + (Value::Decimal(_), Value::Real(_)) => self % &util::coerce_int_or_real_to_decimal(rhs), + _ => Value::Missing, // data type mismatch => Missing + } + } +} + +impl ops::Neg for &Value { + type Output = Value; + + fn neg(self) -> Self::Output { + match self { + // TODO: handle overflow for negation + Value::Null => Value::Null, + Value::Missing => Value::Missing, + Value::Integer(i) => Value::from(-i), + Value::Real(f) => Value::Real(-f), + Value::Decimal(d) => Value::from(-d.as_ref()), + _ => Value::Missing, // data type mismatch => Missing + } + } +} + +impl ops::Neg for Value { + type Output = Value; + + fn neg(self) -> Self::Output { + match self { + // TODO: handle overflow for negation + Value::Null => self, + Value::Missing => self, + Value::Integer(i) => Value::from(-i), + Value::Real(f) => Value::Real(-f), + Value::Decimal(d) => Value::from(-d.as_ref()), + _ => Value::Missing, // data type mismatch => Missing + } + } +} diff --git a/partiql/Cargo.toml b/partiql/Cargo.toml index 5cea77e3..e4ff2850 100644 --- a/partiql/Cargo.toml +++ b/partiql/Cargo.toml @@ -46,7 +46,7 @@ time = { version = "0.3", features = ["macros"] } criterion = "0.5" rand = "0.8" -assert_matches = "1.5" +assert_matches = "1" [[bench]] name = "bench_eval_multi_like" diff --git a/partiql/src/subquery_tests.rs b/partiql/src/subquery_tests.rs index e8d167f3..951df941 100644 --- a/partiql/src/subquery_tests.rs +++ b/partiql/src/subquery_tests.rs @@ -40,11 +40,16 @@ mod tests { fn locals_in_subqueries() { // `SELECT VALUE _1 from (SELECT VALUE foo from <<{'a': 'b'}>> AS foo) AS _1;` let mut sub_query = LogicalPlan::new(); + + let data = Box::new(partiql_logical::Lit::Bag(vec![ + partiql_logical::Lit::Struct(vec![( + "a".to_string(), + partiql_logical::Lit::String("b".to_string()), + )]), + ])); let scan_op_id = sub_query.add_operator(partiql_logical::BindingsOp::Scan(partiql_logical::Scan { - expr: partiql_logical::ValueExpr::Lit(Box::new(Value::Bag(Box::new(Bag::from( - vec![tuple![("a", "b")].into()], - ))))), + expr: partiql_logical::ValueExpr::Lit(data), as_key: "foo".into(), at_key: None, }));