From ebd16e3fcf41e1f50c774d00ee21dc9b192cefb6 Mon Sep 17 00:00:00 2001 From: Arash M <27716912+am357@users.noreply.github.com> Date: Mon, 12 Aug 2024 13:08:18 -0700 Subject: [PATCH 1/4] Adds `NodeId` to the `StaticType` This PR - adds `NodeId` to `StaticType`. - makes `AutoNodeIdGenerator` thread-safe - adds `PartiqlShapeBuilder` and moves some `PartiqlShape` APIs to it; this is to be able to generate unique `NodeId`s for a `PartiqlShape` that includes static types that themselves can include other static types. - adds a static thread safe `shape_builder` function that provides a convenient way for using `PartiqlShapeBuilder` for creating new shapes. - prepends existing type macros with `type` such as `type_int!` to make macro names more friendly. - removes `const` PartiQL types under `partiql-types` in favor of `PartiqlShapeBuilder`. --- CHANGELOG.md | 16 +- extension/partiql-extension-ddl/src/ddl.rs | 55 +- .../partiql-extension-ddl/tests/ddl-tests.rs | 22 +- partiql-ast/src/builder.rs | 3 +- partiql-common/src/node.rs | 29 +- partiql-eval/src/eval/eval_expr_wrapper.rs | 4 +- partiql-eval/src/eval/expr/coll.rs | 28 +- partiql-eval/src/eval/expr/datetime.rs | 4 +- partiql-eval/src/eval/expr/operators.rs | 66 +- partiql-eval/src/eval/expr/pattern_match.rs | 13 +- partiql-eval/src/eval/expr/strings.rs | 16 +- partiql-logical-planner/src/lower.rs | 4 +- partiql-logical-planner/src/typer.rs | 212 +++-- partiql-types/Cargo.toml | 3 +- partiql-types/src/lib.rs | 740 ++++++++++++------ 15 files changed, 732 insertions(+), 483 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bfca598e..6262612a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,13 +11,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - partiql-ast: improved pretty-printing of `CASE` and various clauses - ### Added - -### Fixed +- Added `partiql-common` and moved node id generation and `partiql-source-map` code to it under `syntax` +- Added `NodeId` to `StaticType`. +- *BREAKING* Added thread-safe `PartiqlShapeBuilder` and automatic `NodeId` generation for the `StaticType`. +- *BREAKING* Moved some of the `PartiqlShape` APIs to the `PartiqlShapeBuilder`. +- *BREAKING* Prepended existing type macros with `type` such as `type_int!` to make macro names more friendly. +- Added a static thread safe `shape_builder` function that provides a convenient way for using `PartiqlShapeBuilder` for creating new shapes. + +### Removed +- *BREAKING* Removed `partiql-source-map`. +- *BREAKING* Removed `const` PartiQL types under `partiql-types` in favor of `PartiqlShapeBuilder`. +- *BREAKING* Removed `StaticType`'s `new`, `new_non_nullable`, and `as_non-nullable` APIs in favor of `PartiqlShapeBuilder`. ## [0.10.0] ### Changed - *BREAKING:* partiql-ast: added modeling of `EXCLUDE` -- *BREAKING:* partiql-ast: added pretty-printing of `EXCLUDE` +- *BREAKING:* partiql-ast: added pretty-printing of `EXCLUDE +- Changed `AutoNodeIdGenerator` to a thread-safe version ### Added - *BREAKING:* partiql-parser: added parsing of `EXCLUDE` diff --git a/extension/partiql-extension-ddl/src/ddl.rs b/extension/partiql-extension-ddl/src/ddl.rs index b2f5a25d..ae16cd86 100644 --- a/extension/partiql-extension-ddl/src/ddl.rs +++ b/extension/partiql-extension-ddl/src/ddl.rs @@ -123,8 +123,8 @@ impl PartiqlBasicDdlEncoder { Static::Float64 => out.push_str("DOUBLE"), Static::String => out.push_str("VARCHAR"), Static::Struct(s) => out.push_str(&self.write_struct(s)?), - Static::Bag(b) => out.push_str(&self.write_bag(b)?), - Static::Array(a) => out.push_str(&self.write_array(a)?), + Static::Bag(b) => out.push_str(&self.write_type_bag(b)?), + Static::Array(a) => out.push_str(&self.write_type_array(a)?), // non-exhaustive catch-all _ => todo!("handle type for {}", ty), } @@ -136,12 +136,18 @@ impl PartiqlBasicDdlEncoder { Ok(out) } - fn write_bag(&self, bag: &BagType) -> ShapeDdlEncodeResult { - Ok(format!("BAG<{}>", self.write_shape(bag.element_type())?)) + fn write_type_bag(&self, type_bag: &BagType) -> ShapeDdlEncodeResult { + Ok(format!( + "type_bag<{}>", + self.write_shape(type_bag.element_type())? + )) } - fn write_array(&self, arr: &ArrayType) -> ShapeDdlEncodeResult { - Ok(format!("ARRAY<{}>", self.write_shape(arr.element_type())?)) + fn write_type_array(&self, arr: &ArrayType) -> ShapeDdlEncodeResult { + Ok(format!( + "type_array<{}>", + self.write_shape(arr.element_type())? + )) } fn write_struct(&self, strct: &StructType) -> ShapeDdlEncodeResult { @@ -189,8 +195,8 @@ impl PartiqlDdlEncoder for PartiqlBasicDdlEncoder { let mut output = String::new(); let ty = ty.expect_static()?; - if let Static::Bag(bag) = ty.ty() { - let s = bag.element_type().expect_struct()?; + if let Static::Bag(type_bag) = ty.ty() { + let s = type_bag.element_type().expect_struct()?; let mut fields = s.fields().peekable(); while let Some(field) = fields.next() { output.push_str(&format!("\"{}\" ", field.name())); @@ -223,41 +229,44 @@ impl PartiqlDdlEncoder for PartiqlBasicDdlEncoder { mod tests { use super::*; use indexmap::IndexSet; - use partiql_types::{array, bag, f64, int8, r#struct, str, struct_fields, StructConstraint}; + use partiql_types::{ + shape_builder, struct_fields, type_array, type_bag, type_float64, type_int8, type_string, + type_struct, StructConstraint, + }; #[test] fn ddl_test() { let nested_attrs = struct_fields![ ( "a", - PartiqlShape::any_of(vec![ - PartiqlShape::new(Static::DecimalP(5, 4)), - PartiqlShape::new(Static::Int8), + shape_builder().any_of(vec![ + shape_builder().new_static(Static::DecimalP(5, 4)), + shape_builder().new_static(Static::Int8), ]) ), - ("b", array![str![]]), - ("c", f64!()), + ("b", type_array![type_string![]]), + ("c", type_float64!()), ]; - let details = r#struct![IndexSet::from([nested_attrs])]; + let details = type_struct![IndexSet::from([nested_attrs])]; let fields = struct_fields![ - ("employee_id", int8![]), - ("full_name", str![]), - ("salary", PartiqlShape::new(Static::DecimalP(8, 2))), + ("employee_id", type_int8![]), + ("full_name", type_string![]), + ("salary", shape_builder().new_static(Static::DecimalP(8, 2))), ("details", details), - ("dependents", array![str![]]) + ("dependents", type_array![type_string![]]) ]; - let ty = bag![r#struct![IndexSet::from([ + let ty = type_bag![type_struct![IndexSet::from([ fields, StructConstraint::Open(false) ])]]; - let expected_compact = r#""employee_id" TINYINT,"full_name" VARCHAR,"salary" DECIMAL(8, 2),"details" STRUCT<"a": UNION,"b": ARRAY,"c": DOUBLE>,"dependents" ARRAY"#; + let expected_compact = r#""employee_id" TINYINT,"full_name" VARCHAR,"salary" DECIMAL(8, 2),"details" STRUCT<"a": UNION,"b": type_array,"c": DOUBLE>,"dependents" type_array"#; let expected_pretty = r#""employee_id" TINYINT, "full_name" VARCHAR, "salary" DECIMAL(8, 2), -"details" STRUCT<"a": UNION,"b": ARRAY,"c": DOUBLE>, -"dependents" ARRAY"#; +"details" STRUCT<"a": UNION,"b": type_array,"c": DOUBLE>, +"dependents" type_array"#; let ddl_compact = PartiqlBasicDdlEncoder::new(DdlFormat::Compact); assert_eq!(ddl_compact.ddl(&ty).expect("write shape"), expected_compact); diff --git a/extension/partiql-extension-ddl/tests/ddl-tests.rs b/extension/partiql-extension-ddl/tests/ddl-tests.rs index 87c64d03..635de09b 100644 --- a/extension/partiql-extension-ddl/tests/ddl-tests.rs +++ b/extension/partiql-extension-ddl/tests/ddl-tests.rs @@ -1,20 +1,26 @@ use indexmap::IndexSet; use partiql_extension_ddl::ddl::{DdlFormat, PartiqlBasicDdlEncoder, PartiqlDdlEncoder}; -use partiql_types::{bag, int, r#struct, str, struct_fields, StructConstraint, StructField}; -use partiql_types::{BagType, PartiqlShape, Static, StructType}; +use partiql_types::{ + shape_builder, struct_fields, type_bag, type_int, type_string, type_struct, StructConstraint, + StructField, +}; +use partiql_types::{BagType, Static, StructType}; #[test] fn basic_ddl_test() { - let details_fields = struct_fields![("age", int!())]; - let details = r#struct![IndexSet::from([details_fields])]; + let details_fields = struct_fields![("age", type_int!())]; + let details = type_struct![IndexSet::from([details_fields])]; let fields = [ - StructField::new("id", int!()), - StructField::new("name", str!()), - StructField::new("address", PartiqlShape::new_non_nullable(Static::String)), + StructField::new("id", type_int!()), + StructField::new("name", type_string!()), + StructField::new( + "address", + shape_builder().new_non_nullable_static(Static::String), + ), StructField::new_optional("details", details.clone()), ] .into(); - let shape = bag![r#struct![IndexSet::from([ + let shape = type_bag![type_struct![IndexSet::from([ StructConstraint::Fields(fields), StructConstraint::Open(false) ])]]; diff --git a/partiql-ast/src/builder.rs b/partiql-ast/src/builder.rs index b7973204..a2f3219e 100644 --- a/partiql-ast/src/builder.rs +++ b/partiql-ast/src/builder.rs @@ -17,7 +17,8 @@ where pub fn node(&mut self, node: T) -> AstNode { let id = self.id_gen.id(); - AstNode { id, node } + let id = id.read().expect("NodeId read lock"); + AstNode { id: *id, node } } } diff --git a/partiql-common/src/node.rs b/partiql-common/src/node.rs index 30a81159..9cc3be78 100644 --- a/partiql-common/src/node.rs +++ b/partiql-common/src/node.rs @@ -1,5 +1,6 @@ use indexmap::IndexMap; use std::hash::Hash; +use std::sync::{Arc, RwLock}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -12,27 +13,37 @@ pub struct NodeId(pub u32); /// Auto-incrementing [`NodeIdGenerator`] pub struct AutoNodeIdGenerator { - next_id: NodeId, + next_id: Arc>, } impl Default for AutoNodeIdGenerator { fn default() -> Self { - AutoNodeIdGenerator { next_id: NodeId(1) } + AutoNodeIdGenerator { + next_id: Arc::new(RwLock::from(NodeId(1))), + } + } +} + +impl AutoNodeIdGenerator { + pub fn next_id(&self) -> Arc> { + self.id() } } /// A provider of 'fresh' [`NodeId`]s. pub trait NodeIdGenerator { /// Provides a 'fresh' [`NodeId`]. - fn id(&mut self) -> NodeId; + fn id(&self) -> Arc>; } impl NodeIdGenerator for AutoNodeIdGenerator { #[inline] - fn id(&mut self) -> NodeId { - let mut next = NodeId(&self.next_id.0 + 1); - std::mem::swap(&mut self.next_id, &mut next); - next + fn id(&self) -> Arc> { + let id = &self.next_id.read().expect("NodeId read lock"); + let next = NodeId(id.0 + 1); + let mut w = self.next_id.write().expect("NodeId write lock"); + *w = next; + Arc::clone(&self.next_id) } } @@ -41,7 +52,7 @@ impl NodeIdGenerator for AutoNodeIdGenerator { pub struct NullIdGenerator {} impl NodeIdGenerator for NullIdGenerator { - fn id(&mut self) -> NodeId { - NodeId(0) + fn id(&self) -> Arc> { + Arc::new(RwLock::new(NodeId(0))) } } diff --git a/partiql-eval/src/eval/eval_expr_wrapper.rs b/partiql-eval/src/eval/eval_expr_wrapper.rs index 7a0a2faa..b9a21d4f 100644 --- a/partiql-eval/src/eval/eval_expr_wrapper.rs +++ b/partiql-eval/src/eval/eval_expr_wrapper.rs @@ -4,7 +4,7 @@ use crate::eval::expr::{BindError, EvalExpr}; use crate::eval::EvalContext; use itertools::Itertools; -use partiql_types::{PartiqlShape, Static, TYPE_DYNAMIC}; +use partiql_types::{type_dynamic, PartiqlShape, Static, TYPE_DYNAMIC}; use partiql_value::Value::{Missing, Null}; use partiql_value::{Tuple, Value}; @@ -413,7 +413,7 @@ impl UnaryValueExpr { where F: 'static + Fn(&Value) -> Value, { - Self::create_typed::([TYPE_DYNAMIC; 1], args, f) + Self::create_typed::([type_dynamic!(); 1], args, f) } #[allow(dead_code)] diff --git a/partiql-eval/src/eval/expr/coll.rs b/partiql-eval/src/eval/expr/coll.rs index d3361187..0af297ff 100644 --- a/partiql-eval/src/eval/expr/coll.rs +++ b/partiql-eval/src/eval/expr/coll.rs @@ -4,7 +4,9 @@ use crate::eval::expr::{BindError, BindEvalExpr, EvalExpr}; use itertools::{Itertools, Unique}; -use partiql_types::{ArrayType, BagType, PartiqlShape, Static, TYPE_BOOL, TYPE_NUMERIC_TYPES}; +use partiql_types::{ + shape_builder, type_bool, type_numeric, ArrayType, BagType, PartiqlShape, Static, +}; use partiql_value::Value::{Missing, Null}; use partiql_value::{BinaryAnd, BinaryOr, Value, ValueIter}; @@ -49,21 +51,21 @@ impl BindEvalExpr for EvalCollFn { value.sequence_iter().map_or(Missing, &f) }) } - let boolean_elems = [PartiqlShape::any_of([ - PartiqlShape::new(Static::Array(ArrayType::new(Box::new(TYPE_BOOL)))), - PartiqlShape::new(Static::Bag(BagType::new(Box::new(TYPE_BOOL)))), + let boolean_elems = [shape_builder().any_of([ + shape_builder().new_static(Static::Array(ArrayType::new(Box::new(type_bool!())))), + shape_builder().new_static(Static::Bag(BagType::new(Box::new(type_bool!())))), ])]; - let numeric_elems = [PartiqlShape::any_of([ - PartiqlShape::new(Static::Array(ArrayType::new(Box::new( - PartiqlShape::any_of(TYPE_NUMERIC_TYPES), + let numeric_elems = [shape_builder().any_of([ + shape_builder().new_static(Static::Array(ArrayType::new(Box::new( + shape_builder().any_of(type_numeric!()), + )))), + shape_builder().new_static(Static::Bag(BagType::new(Box::new( + shape_builder().any_of(type_numeric!()), )))), - PartiqlShape::new(Static::Bag(BagType::new(Box::new(PartiqlShape::any_of( - TYPE_NUMERIC_TYPES, - ))))), ])]; - let any_elems = [PartiqlShape::any_of([ - PartiqlShape::new(Static::Array(ArrayType::new_any())), - PartiqlShape::new(Static::Bag(BagType::new_any())), + let any_elems = [shape_builder().any_of([ + shape_builder().new_static(Static::Array(ArrayType::new_any())), + shape_builder().new_static(Static::Bag(BagType::new_any())), ])]; match *self { diff --git a/partiql-eval/src/eval/expr/datetime.rs b/partiql-eval/src/eval/expr/datetime.rs index 0b140bff..c90dd459 100644 --- a/partiql-eval/src/eval/expr/datetime.rs +++ b/partiql-eval/src/eval/expr/datetime.rs @@ -1,6 +1,6 @@ use crate::eval::expr::{BindError, BindEvalExpr, EvalExpr}; -use partiql_types::TYPE_DATETIME; +use partiql_types::type_datetime; use partiql_value::Value::Missing; use partiql_value::{DateTime, Value}; @@ -43,7 +43,7 @@ impl BindEvalExpr for EvalExtractFn { } let create = |f: fn(&DateTime) -> Value| { - UnaryValueExpr::create_typed::<{ STRICT }, _>([TYPE_DATETIME], args, move |value| { + UnaryValueExpr::create_typed::<{ STRICT }, _>([type_datetime!()], args, move |value| { match value { Value::DateTime(dt) => f(dt.as_ref()), _ => Missing, diff --git a/partiql-eval/src/eval/expr/operators.rs b/partiql-eval/src/eval/expr/operators.rs index 76f067b5..bae68bc4 100644 --- a/partiql-eval/src/eval/expr/operators.rs +++ b/partiql-eval/src/eval/expr/operators.rs @@ -8,8 +8,8 @@ use crate::eval::expr::{BindError, BindEvalExpr, EvalExpr}; use crate::eval::EvalContext; use partiql_types::{ - ArrayType, BagType, PartiqlShape, Static, StructType, TYPE_BOOL, TYPE_DYNAMIC, - TYPE_NUMERIC_TYPES, + shape_builder, type_bool, type_dynamic, type_numeric, ArrayType, BagType, PartiqlShape, Static, + StructType, }; use partiql_value::Value::{Boolean, Missing, Null}; use partiql_value::{BinaryAnd, EqualityValue, NullableEq, NullableOrd, Tuple, Value}; @@ -80,7 +80,7 @@ impl BindEvalExpr for EvalOpUnary { &self, args: Vec>, ) -> Result, BindError> { - let any_num = PartiqlShape::any_of(TYPE_NUMERIC_TYPES); + let any_num = shape_builder().any_of(type_numeric!()); let unop = |types, f: fn(&Value) -> Value| { UnaryValueExpr::create_typed::<{ STRICT }, _>(types, args, f) @@ -89,7 +89,7 @@ impl BindEvalExpr for EvalOpUnary { match self { EvalOpUnary::Pos => unop([any_num], std::clone::Clone::clone), EvalOpUnary::Neg => unop([any_num], |operand| -operand), - EvalOpUnary::Not => unop([TYPE_BOOL], |operand| !operand), + EvalOpUnary::Not => unop([type_bool!()], |operand| !operand), } } } @@ -167,19 +167,19 @@ impl BindEvalExpr for EvalOpBinary { macro_rules! logical { ($check: ty, $f:expr) => { - create!($check, [TYPE_BOOL, TYPE_BOOL], $f) + create!($check, [type_bool!(), type_bool!()], $f) }; } macro_rules! equality { ($f:expr) => { - create!(EqCheck, [TYPE_DYNAMIC, TYPE_DYNAMIC], $f) + create!(EqCheck, [type_dynamic!(), type_dynamic!()], $f) }; } macro_rules! math { ($f:expr) => {{ - let nums = PartiqlShape::any_of(TYPE_NUMERIC_TYPES); + let nums = shape_builder().any_of(type_numeric!()); create!(MathCheck, [nums.clone(), nums], $f) }}; } @@ -209,10 +209,10 @@ impl BindEvalExpr for EvalOpBinary { create!( InCheck, [ - TYPE_DYNAMIC, - PartiqlShape::any_of([ - PartiqlShape::new(Static::Array(ArrayType::new_any())), - PartiqlShape::new(Static::Bag(BagType::new_any())), + type_dynamic!(), + shape_builder().any_of([ + shape_builder().new_static(Static::Array(ArrayType::new_any())), + shape_builder().new_static(Static::Bag(BagType::new_any())), ]) ], |lhs, rhs| { @@ -250,20 +250,24 @@ impl BindEvalExpr for EvalOpBinary { ) } EvalOpBinary::Concat => { - create!(Check, [TYPE_DYNAMIC, TYPE_DYNAMIC], |lhs, rhs| { - // TODO non-naive concat (i.e., don't just use debug print for non-strings). - let lhs = if let Value::String(s) = lhs { - s.as_ref().clone() - } else { - format!("{lhs:?}") - }; - let rhs = if let Value::String(s) = rhs { - s.as_ref().clone() - } else { - format!("{rhs:?}") - }; - Value::String(Box::new(format!("{lhs}{rhs}"))) - }) + create!( + Check, + [type_dynamic!(), type_dynamic!()], + |lhs, rhs| { + // TODO non-naive concat (i.e., don't just use debug print for non-strings). + let lhs = if let Value::String(s) = lhs { + s.as_ref().clone() + } else { + format!("{lhs:?}") + }; + let rhs = if let Value::String(s) = rhs { + s.as_ref().clone() + } else { + format!("{rhs:?}") + }; + Value::String(Box::new(format!("{lhs}{rhs}"))) + } + ) } } } @@ -278,7 +282,7 @@ impl BindEvalExpr for EvalBetweenExpr { &self, args: Vec>, ) -> Result, BindError> { - let types = [TYPE_DYNAMIC, TYPE_DYNAMIC, TYPE_DYNAMIC]; + let types = [type_dynamic!(), type_dynamic!(), type_dynamic!()]; TernaryValueExpr::create_checked::<{ STRICT }, NullArgChecker, _>( types, args, @@ -316,7 +320,7 @@ impl BindEvalExpr for EvalFnAbs { &self, args: Vec>, ) -> Result, BindError> { - let nums = PartiqlShape::any_of(TYPE_NUMERIC_TYPES); + let nums = shape_builder().any_of(type_numeric!()); UnaryValueExpr::create_typed::<{ STRICT }, _>([nums], args, |v| { match NullableOrd::lt(v, &Value::from(0)) { Null => Null, @@ -337,10 +341,10 @@ impl BindEvalExpr for EvalFnCardinality { &self, args: Vec>, ) -> Result, BindError> { - let collections = PartiqlShape::any_of([ - PartiqlShape::new(Static::Array(ArrayType::new_any())), - PartiqlShape::new(Static::Bag(BagType::new_any())), - PartiqlShape::new(Static::Struct(StructType::new_any())), + let collections = shape_builder().any_of([ + shape_builder().new_static(Static::Array(ArrayType::new_any())), + shape_builder().new_static(Static::Bag(BagType::new_any())), + shape_builder().new_static(Static::Struct(StructType::new_any())), ]); UnaryValueExpr::create_typed::<{ STRICT }, _>([collections], args, |v| match v { diff --git a/partiql-eval/src/eval/expr/pattern_match.rs b/partiql-eval/src/eval/expr/pattern_match.rs index 056210a0..a6004e5b 100644 --- a/partiql-eval/src/eval/expr/pattern_match.rs +++ b/partiql-eval/src/eval/expr/pattern_match.rs @@ -2,7 +2,7 @@ use crate::error::PlanningError; use crate::eval::eval_expr_wrapper::{TernaryValueExpr, UnaryValueExpr}; use crate::eval::expr::{BindError, BindEvalExpr, EvalExpr}; -use partiql_types::TYPE_STRING; +use partiql_types::type_string; use partiql_value::Value; use partiql_value::Value::Missing; use regex::{Regex, RegexBuilder}; @@ -47,10 +47,11 @@ impl BindEvalExpr for EvalLikeMatch { args: Vec>, ) -> Result, BindError> { let pattern = self.pattern.clone(); - UnaryValueExpr::create_typed::<{ STRICT }, _>([TYPE_STRING], args, move |value| match value - { - Value::String(s) => Value::Boolean(pattern.is_match(s.as_ref())), - _ => Missing, + UnaryValueExpr::create_typed::<{ STRICT }, _>([type_string!()], args, move |value| { + match value { + Value::String(s) => Value::Boolean(pattern.is_match(s.as_ref())), + _ => Missing, + } }) } } @@ -65,7 +66,7 @@ impl BindEvalExpr for EvalLikeNonStringNonLiteralMatch { &self, args: Vec>, ) -> Result, BindError> { - let types = [TYPE_STRING, TYPE_STRING, TYPE_STRING]; + let types = [type_string!(), type_string!(), type_string!()]; TernaryValueExpr::create_typed::<{ STRICT }, _>( types, args, diff --git a/partiql-eval/src/eval/expr/strings.rs b/partiql-eval/src/eval/expr/strings.rs index 2d5f4c44..8e0d50a9 100644 --- a/partiql-eval/src/eval/expr/strings.rs +++ b/partiql-eval/src/eval/expr/strings.rs @@ -7,7 +7,7 @@ use crate::eval::expr::{BindError, BindEvalExpr, EvalExpr}; use crate::eval::EvalContext; use itertools::Itertools; -use partiql_types::{TYPE_INT, TYPE_STRING}; +use partiql_types::{type_int, type_string}; use partiql_value::Value; use partiql_value::Value::Missing; @@ -43,7 +43,7 @@ impl BindEvalExpr for EvalStringFn { F: Fn(&Box) -> R + 'static, R: Into + 'static, { - UnaryValueExpr::create_typed::<{ STRICT }, _>([TYPE_STRING], args, move |value| { + UnaryValueExpr::create_typed::<{ STRICT }, _>([type_string!()], args, move |value| { match value { Value::String(value) => (f(value)).into(), _ => Missing, @@ -99,7 +99,7 @@ impl BindEvalExpr for EvalTrimFn { ) -> Result, BindError> { let create = |f: for<'a> fn(&'a str, &'a str) -> &'a str| { BinaryValueExpr::create_typed::<{ STRICT }, _>( - [TYPE_STRING, TYPE_STRING], + [type_string!(), type_string!()], args, move |to_trim, value| match (to_trim, value) { (Value::String(to_trim), Value::String(value)) => { @@ -136,7 +136,7 @@ impl BindEvalExpr for EvalFnPosition { args: Vec>, ) -> Result, BindError> { BinaryValueExpr::create_typed::( - [TYPE_STRING, TYPE_STRING], + [type_string!(), type_string!()], args, |needle, haystack| match (needle, haystack) { (Value::String(needle), Value::String(haystack)) => { @@ -159,7 +159,7 @@ impl BindEvalExpr for EvalFnSubstring { ) -> Result, BindError> { match args.len() { 2 => BinaryValueExpr::create_typed::( - [TYPE_STRING, TYPE_INT], + [type_string!(), type_int!()], args, |value, offset| match (value, offset) { (Value::String(value), Value::Integer(offset)) => { @@ -171,7 +171,7 @@ impl BindEvalExpr for EvalFnSubstring { }, ), 3 => TernaryValueExpr::create_typed::( - [TYPE_STRING, TYPE_INT, TYPE_INT], + [type_string!(), type_int!(), type_int!()], args, |value, offset, length| match (value, offset, length) { (Value::String(value), Value::Integer(offset), Value::Integer(length)) => { @@ -222,7 +222,7 @@ impl BindEvalExpr for EvalFnOverlay { match args.len() { 3 => TernaryValueExpr::create_typed::( - [TYPE_STRING, TYPE_STRING, TYPE_INT], + [type_string!(), type_string!(), type_int!()], args, |value, replacement, offset| match (value, replacement, offset) { (Value::String(value), Value::String(replacement), Value::Integer(offset)) => { @@ -233,7 +233,7 @@ impl BindEvalExpr for EvalFnOverlay { }, ), 4 => QuaternaryValueExpr::create_typed::( - [TYPE_STRING, TYPE_STRING, TYPE_INT, TYPE_INT], + [type_string!(), type_string!(), type_int!(), type_int!()], args, |value, replacement, offset, length| match (value, replacement, offset, length) { ( diff --git a/partiql-logical-planner/src/lower.rs b/partiql-logical-planner/src/lower.rs index adcb5867..e23c8566 100644 --- a/partiql-logical-planner/src/lower.rs +++ b/partiql-logical-planner/src/lower.rs @@ -2007,7 +2007,7 @@ mod tests { use partiql_catalog::{PartiqlCatalog, TypeEnvEntry}; use partiql_logical::BindingsOp::Project; use partiql_logical::ValueExpr; - use partiql_types::dynamic; + use partiql_types::type_dynamic; #[test] fn test_plan_non_existent_fns() { @@ -2107,7 +2107,7 @@ mod tests { expected_logical.add_flow_with_branch_num(project, sink, 0); let mut catalog = PartiqlCatalog::default(); - let _oid = catalog.add_type_entry(TypeEnvEntry::new("customers", &[], dynamic!())); + let _oid = catalog.add_type_entry(TypeEnvEntry::new("customers", &[], type_dynamic!())); let statement = "SELECT c.id AS my_id, customers.name AS my_name FROM customers AS c"; let parsed = partiql_parser::Parser::default() .parse(statement) diff --git a/partiql-logical-planner/src/typer.rs b/partiql-logical-planner/src/typer.rs index 8c78800d..657ada2a 100644 --- a/partiql-logical-planner/src/typer.rs +++ b/partiql-logical-planner/src/typer.rs @@ -4,8 +4,8 @@ use partiql_ast::ast::{CaseSensitivity, SymbolPrimitive}; use partiql_catalog::Catalog; use partiql_logical::{BindingsOp, LogicalPlan, OpId, PathComponent, ValueExpr, VarRefType}; use partiql_types::{ - dynamic, undefined, ArrayType, BagType, PartiqlShape, ShapeResultError, Static, - StructConstraint, StructField, StructType, + shape_builder, type_dynamic, type_undefined, ArrayType, BagType, PartiqlShape, + ShapeResultError, Static, StructConstraint, StructField, StructType, }; use partiql_value::{BindingsName, Value}; use petgraph::algo::toposort; @@ -107,7 +107,7 @@ impl Default for TypeEnvContext { fn default() -> Self { TypeEnvContext { env: LocalTypeEnv::new(), - derived_type: dynamic!(), + derived_type: type_dynamic!(), } } } @@ -175,7 +175,7 @@ impl<'c> PlanTyper<'c> { } if self.errors.is_empty() { - Ok(self.output.clone().unwrap_or(undefined!())) + Ok(self.output.clone().unwrap_or(type_undefined!())) } else { let output_schema = self.get_singleton_type_from_env(); Err(TypeErr { @@ -218,16 +218,16 @@ impl<'c> PlanTyper<'c> { StructField::new(k.as_str(), self.get_singleton_type_from_env()) }); - let ty = PartiqlShape::new_struct(StructType::new(IndexSet::from([ + let ty = shape_builder().new_struct(StructType::new(IndexSet::from([ StructConstraint::Fields(fields.collect()), ]))); let derived_type_ctx = self.local_type_ctx(); let derived_type = &self.derived_type(&derived_type_ctx); let schema = if derived_type.is_ordered_collection() { - PartiqlShape::new_array(ArrayType::new(Box::new(ty))) + shape_builder().new_array(ArrayType::new(Box::new(ty))) } else if derived_type.is_unordered_collection() { - PartiqlShape::new_bag(BagType::new(Box::new(ty))) + shape_builder().new_static(Static::Bag(BagType::new(Box::new(ty)))) } else { self.errors.push(TypingError::IllegalState(format!( "Expecting Collection for the output Schema but found {:?}", @@ -304,8 +304,10 @@ impl<'c> PlanTyper<'c> { let ctx = ty_ctx![(&ty_env![(key_as_sym, ty.clone())], &ty)]; self.type_env_stack.push(ctx); } else { - let ctx = - ty_ctx![(&ty_env![(key_as_sym, undefined!())], &undefined!())]; + let ctx = ty_ctx![( + &ty_env![(key_as_sym, type_undefined!())], + &type_undefined!() + )]; self.type_env_stack.push(ctx); } } @@ -329,20 +331,24 @@ impl<'c> PlanTyper<'c> { } ValueExpr::Lit(v) => { let ty = match **v { - Value::Null => PartiqlShape::Undefined, - Value::Missing => PartiqlShape::Undefined, - Value::Integer(_) => PartiqlShape::new(Static::Int), - Value::Decimal(_) => PartiqlShape::new(Static::Decimal), - Value::Boolean(_) => PartiqlShape::new(Static::Bool), - Value::String(_) => PartiqlShape::new(Static::String), - Value::Tuple(_) => PartiqlShape::new(Static::Struct(StructType::new_any())), - Value::List(_) => PartiqlShape::new(Static::Array(ArrayType::new_any())), - Value::Bag(_) => PartiqlShape::new(Static::Bag(BagType::new_any())), + Value::Null => shape_builder().new_undefined(), + Value::Missing => shape_builder().new_undefined(), + Value::Integer(_) => shape_builder().new_static(Static::Int), + Value::Decimal(_) => shape_builder().new_static(Static::Decimal), + Value::Boolean(_) => shape_builder().new_static(Static::Bool), + Value::String(_) => shape_builder().new_static(Static::String), + Value::Tuple(_) => { + shape_builder().new_static(Static::Struct(StructType::new_any())) + } + Value::List(_) => { + shape_builder().new_static(Static::Array(ArrayType::new_any())) + } + Value::Bag(_) => shape_builder().new_static(Static::Bag(BagType::new_any())), _ => { self.errors.push(TypingError::NotYetImplemented( "Unsupported Literal".to_string(), )); - PartiqlShape::Undefined + shape_builder().new_undefined() } }; @@ -410,14 +416,14 @@ impl<'c> PlanTyper<'c> { fn element_type<'a>(&'a mut self, ty: &'a PartiqlShape) -> PartiqlShape { match ty { - PartiqlShape::Dynamic => dynamic!(), + PartiqlShape::Dynamic => type_dynamic!(), PartiqlShape::Static(s) => match s.ty() { Static::Bag(b) => b.element_type().clone(), Static::Array(a) => a.element_type().clone(), _ => ty.clone(), }, - undefined!() => { - todo!("Undefined type in catalog") + type_undefined!() => { + todo!("type_undefined type in catalog") } PartiqlShape::AnyOf(_any_of) => ty.clone(), } @@ -432,10 +438,10 @@ impl<'c> PlanTyper<'c> { Some(ty.clone()) } else if let Ok(s) = derived_type.expect_struct() { if s.is_partial() { - Some(dynamic!()) + Some(type_dynamic!()) } else { match &self.typing_mode { - TypingMode::Permissive => Some(undefined!()), + TypingMode::Permissive => Some(type_undefined!()), TypingMode::Strict => { self.errors.push(TypingError::TypeCheck(format!( "No Typing Information for {:?} in closed Schema {:?}", @@ -446,7 +452,7 @@ impl<'c> PlanTyper<'c> { } } } else if derived_type.is_dynamic() { - Some(dynamic!()) + Some(type_dynamic!()) } else { self.errors.push(TypingError::IllegalState(format!( "Illegal Derive Type {:?}", @@ -510,7 +516,7 @@ impl<'c> PlanTyper<'c> { let ty = self.element_type(type_entry.ty()); ty } else { - undefined!() + type_undefined!() } } @@ -521,14 +527,17 @@ impl<'c> PlanTyper<'c> { } } - undefined!() + type_undefined!() } - fn type_with_undefined(&mut self, key: &SymbolPrimitive) { + fn type_with_type_undefined(&mut self, key: &SymbolPrimitive) { if let TypingMode::Permissive = &self.typing_mode { // TODO Revise this once the following discussion is conclusive and spec. is // in place: https://github.com/partiql/partiql-spec/discussions/64 - let type_ctx = ty_ctx![(&ty_env![(key.clone(), undefined!())], &undefined!())]; + let type_ctx = ty_ctx![( + &ty_env![(key.clone(), type_undefined!())], + &type_undefined!() + )]; self.type_env_stack.push(type_ctx); } @@ -544,7 +553,7 @@ impl<'c> PlanTyper<'c> { "Unexpected Typing Environment; expected typing environment with only one type but found {:?} types", &env.len() ))); - undefined!() + type_undefined!() } else { env[0].clone() } @@ -552,7 +561,7 @@ impl<'c> PlanTyper<'c> { fn type_varef(&mut self, key: &SymbolPrimitive, ty: &PartiqlShape) { if ty.is_undefined() { - self.type_with_undefined(key); + self.type_with_type_undefined(key); } else { let mut new_type_env = LocalTypeEnv::new(); if let Ok(s) = ty.expect_struct() { @@ -598,7 +607,10 @@ mod tests { use partiql_ast_passes::error::AstTransformationError; use partiql_catalog::{PartiqlCatalog, TypeEnvEntry}; use partiql_parser::{Parsed, Parser}; - use partiql_types::{bag, int, r#struct, str, struct_fields, BagType, StructType}; + use partiql_types::{ + struct_fields, type_bag, type_int_with_const_id, type_string_with_const_id, + type_struct_with_const_id, type_undefined, BagType, StructType, + }; #[test] fn simple_sfw() { @@ -609,15 +621,15 @@ mod tests { create_customer_schema( false, [ - StructField::new("id", int!()), - StructField::new("name", str!()), - StructField::new("age", dynamic!()), + StructField::new("id", type_int_with_const_id!()), + StructField::new("name", type_string_with_const_id!()), + StructField::new("age", type_dynamic!()), ] .into(), ), vec![ - StructField::new("id", int!()), - StructField::new("name", str!()), + StructField::new("id", type_int_with_const_id!()), + StructField::new("name", type_string_with_const_id!()), ], ) .expect("Type"); @@ -629,15 +641,15 @@ mod tests { create_customer_schema( false, [ - StructField::new("id", int!()), - StructField::new("name", str!()), - StructField::new("age", dynamic!()), + StructField::new("id", type_int_with_const_id!()), + StructField::new("name", type_string_with_const_id!()), + StructField::new("age", type_dynamic!()), ] .into(), ), vec![ - StructField::new("id", int!()), - StructField::new("name", str!()), + StructField::new("id", type_int_with_const_id!()), + StructField::new("name", type_string_with_const_id!()), ], ) .expect("Type"); @@ -649,16 +661,16 @@ mod tests { create_customer_schema( true, [ - StructField::new("id", int!()), - StructField::new("name", str!()), - StructField::new("age", dynamic!()), + StructField::new("id", type_int_with_const_id!()), + StructField::new("name", type_string_with_const_id!()), + StructField::new("age", type_dynamic!()), ] .into(), ), vec![ - StructField::new("id", int!()), - StructField::new("name", str!()), - StructField::new("age", dynamic!()), + StructField::new("id", type_int_with_const_id!()), + StructField::new("name", type_string_with_const_id!()), + StructField::new("age", type_dynamic!()), ], ) .expect("Type"); @@ -670,22 +682,22 @@ mod tests { create_customer_schema( false, [ - StructField::new("id", int!()), - StructField::new("name", str!()), + StructField::new("id", type_int_with_const_id!()), + StructField::new("name", type_string_with_const_id!()), ] .into(), ), vec![ - StructField::new("id", int!()), - StructField::new("name", str!()), - StructField::new("age", undefined!()), + StructField::new("id", type_int_with_const_id!()), + StructField::new("name", type_string_with_const_id!()), + StructField::new("age", type_undefined!()), ], ) .expect("Type"); // Open Schema with `Strict` typing mode and `age` in nested attribute. - let details_fields = struct_fields![("age", int!())]; - let details = r#struct![IndexSet::from([details_fields])]; + let details_fields = struct_fields![("age", type_int_with_const_id!())]; + let details = type_struct_with_const_id![IndexSet::from([details_fields])]; assert_query_typing( TypingMode::Strict, @@ -693,16 +705,16 @@ mod tests { create_customer_schema( true, [ - StructField::new("id", int!()), - StructField::new("name", str!()), + StructField::new("id", type_int_with_const_id!()), + StructField::new("name", type_string_with_const_id!()), StructField::new("details", details.clone()), ] .into(), ), vec![ - StructField::new("id", int!()), - StructField::new("name", str!()), - StructField::new("age", int!()), + StructField::new("id", type_int_with_const_id!()), + StructField::new("name", type_string_with_const_id!()), + StructField::new("age", type_int_with_const_id!()), ], ) .expect("Type"); @@ -712,15 +724,15 @@ mod tests { TypingMode::Strict, "SELECT customers.id, customers.name, customers.details.age, customers.details.foo.bar FROM customers", create_customer_schema(true, [ - StructField::new("id", int!()), - StructField::new("name", str!()), + StructField::new("id", type_int_with_const_id!()), + StructField::new("name", type_string_with_const_id!()), StructField::new("details", details.clone()), ].into()), vec![ - StructField::new("id", int!()), - StructField::new("name", str!()), - StructField::new("age", int!()), - StructField::new("bar", dynamic!()), + StructField::new("id", type_int_with_const_id!()), + StructField::new("name", type_string_with_const_id!()), + StructField::new("age", type_int_with_const_id!()), + StructField::new("bar", type_dynamic!()), ], ) .expect("Type"); @@ -729,8 +741,8 @@ mod tests { #[test] fn simple_sfw_with_alias() { // Open Schema with `Strict` typing mode and `age` in nested attribute. - let details_fields = struct_fields![("age", int!())]; - let details = r#struct![IndexSet::from([details_fields])]; + let details_fields = struct_fields![("age", type_int_with_const_id!())]; + let details = type_struct_with_const_id![IndexSet::from([details_fields])]; // TODO Revise this behavior once the following discussion is conclusive and spec. is // in place: https://github.com/partiql/partiql-spec/discussions/65 @@ -740,13 +752,13 @@ mod tests { create_customer_schema( true, [ - StructField::new("id", int!()), - StructField::new("name", str!()), + StructField::new("id", type_int_with_const_id!()), + StructField::new("name", type_string_with_const_id!()), StructField::new("details", details.clone()), ] .into(), ), - vec![StructField::new("age", int!())], + vec![StructField::new("age", type_int_with_const_id!())], ) .expect("Type"); @@ -757,15 +769,15 @@ mod tests { create_customer_schema( false, [ - StructField::new("id", int!()), - StructField::new("name", str!()), - StructField::new("age", dynamic!()), + StructField::new("id", type_int_with_const_id!()), + StructField::new("name", type_string_with_const_id!()), + StructField::new("age", type_dynamic!()), ] .into(), ), vec![ - StructField::new("my_id", int!()), - StructField::new("my_name", str!()), + StructField::new("my_id", type_int_with_const_id!()), + StructField::new("my_name", type_string_with_const_id!()), ], ) .expect("Type"); @@ -774,7 +786,7 @@ mod tests { #[test] fn simple_sfw_err() { // Closed Schema with `Strict` typing mode and `age` non-existent projection. - let err1 = r#"No Typing Information for SymbolPrimitive { value: "age", case: CaseInsensitive } in closed Schema Static(StaticType { ty: Struct(StructType { constraints: {Fields({StructField { optional: false, name: "id", ty: Static(StaticType { ty: Int, nullable: true }) }, StructField { optional: false, name: "name", ty: Static(StaticType { ty: String, nullable: true }) }}), Open(false)} }), nullable: true })"#; + let err1 = r#"No Typing Information for SymbolPrimitive { value: "age", case: CaseInsensitive } in closed Schema Static(StaticType { id: NodeId(1), ty: Struct(StructType { constraints: {Fields({StructField { optional: false, name: "id", ty: Static(StaticType { id: NodeId(1), ty: Int, nullable: true }) }, StructField { optional: false, name: "name", ty: Static(StaticType { id: NodeId(1), ty: String, nullable: true }) }}), Open(false)} }), nullable: true })"#; assert_err( assert_query_typing( @@ -783,32 +795,24 @@ mod tests { create_customer_schema( false, [ - StructField::new("id", int!()), - StructField::new("name", str!()), + StructField::new("id", type_int_with_const_id!()), + StructField::new("name", type_string_with_const_id!()), ] .into(), ), vec![], ), vec![TypingError::TypeCheck(err1.to_string())], - Some(bag![r#struct![IndexSet::from([StructConstraint::Fields( - [ - StructField::new("id", int!()), - StructField::new("name", str!()), - StructField::new("age", undefined!()), - ] - .into() - ),])]]), ); // Closed Schema with `Strict` typing mode and `bar` non-existent projection from closed nested `details`. - let details_fields = struct_fields![("age", int!())]; - let details = r#struct![IndexSet::from([ + let details_fields = struct_fields![("age", type_int_with_const_id!())]; + let details = type_struct_with_const_id![IndexSet::from([ details_fields, StructConstraint::Open(false) ])]; - let err1 = r#"No Typing Information for SymbolPrimitive { value: "details", case: CaseInsensitive } in closed Schema Static(StaticType { ty: Struct(StructType { constraints: {Fields({StructField { optional: false, name: "age", ty: Static(StaticType { ty: Int, nullable: true }) }}), Open(false)} }), nullable: true })"#; + let err1 = r#"No Typing Information for SymbolPrimitive { value: "details", case: CaseInsensitive } in closed Schema Static(StaticType { id: NodeId(1), ty: Struct(StructType { constraints: {Fields({StructField { optional: false, name: "age", ty: Static(StaticType { id: NodeId(1), ty: Int, nullable: true }) }}), Open(false)} }), nullable: true })"#; let err2 = r"Illegal Derive Type Undefined"; assert_err( @@ -818,8 +822,8 @@ mod tests { create_customer_schema( false, [ - StructField::new("id", int!()), - StructField::new("name", str!()), + StructField::new("id", type_int_with_const_id!()), + StructField::new("name", type_string_with_const_id!()), StructField::new("details", details), ] .into(), @@ -830,40 +834,22 @@ mod tests { TypingError::TypeCheck(err1.to_string()), TypingError::IllegalState(err2.to_string()), ], - Some(bag![r#struct![IndexSet::from([StructConstraint::Fields( - [ - StructField::new("id", int!()), - StructField::new("name", str!()), - StructField::new("bar", undefined!()), - ] - .into() - ),])]]), ); } - fn assert_err( - result: Result<(), TypeErr>, - expected_errors: Vec, - output: Option, - ) { + fn assert_err(result: Result<(), TypeErr>, expected_errors: Vec) { match result { Ok(()) => { panic!("Expected Error"); } Err(e) => { - assert_eq!( - e, - TypeErr { - errors: expected_errors, - output, - } - ); + assert_eq!(e.errors, expected_errors); } }; } fn create_customer_schema(is_open: bool, fields: IndexSet) -> PartiqlShape { - bag![r#struct![IndexSet::from([ + type_bag![type_struct_with_const_id![IndexSet::from([ StructConstraint::Fields(fields), StructConstraint::Open(is_open) ])]] diff --git a/partiql-types/Cargo.toml b/partiql-types/Cargo.toml index 84eed7b5..21645c5c 100644 --- a/partiql-types/Cargo.toml +++ b/partiql-types/Cargo.toml @@ -21,7 +21,7 @@ edition.workspace = true bench = false [dependencies] - +partiql-common = { path = "../partiql-common", version = "0.10.*"} ordered-float = "3.*" itertools = "0.10.*" unicase = "2.6" @@ -32,6 +32,7 @@ thiserror = "1.*" indexmap = "2.2" derivative = "2.2" +lazy_static = "1.5.0" [dev-dependencies] criterion = "0.4" diff --git a/partiql-types/src/lib.rs b/partiql-types/src/lib.rs index 6a0b3616..564462dc 100644 --- a/partiql-types/src/lib.rs +++ b/partiql-types/src/lib.rs @@ -5,10 +5,19 @@ use derivative::Derivative; use indexmap::IndexSet; use itertools::Itertools; use miette::Diagnostic; +use partiql_common::node::{AutoNodeIdGenerator, NodeId, NodeIdGenerator}; +use std::collections::HashMap; use std::fmt::{Debug, Display, Formatter}; use std::hash::{Hash, Hasher}; +use std::sync::OnceLock; use thiserror::Error; +#[track_caller] +pub fn shape_builder() -> &'static PartiqlShapeBuilder { + static SHAPE_BUILDER: OnceLock = OnceLock::new(); + SHAPE_BUILDER.get_or_init(PartiqlShapeBuilder::default) +} + #[derive(Debug, Clone, Eq, PartialEq, Hash, Error, Diagnostic)] #[error("ShapeResult Error")] #[non_exhaustive] @@ -35,84 +44,110 @@ where } #[macro_export] -macro_rules! dynamic { +macro_rules! type_dynamic { () => { - $crate::PartiqlShape::Dynamic + $crate::shape_builder().new_dynamic() }; } #[macro_export] -macro_rules! int { +macro_rules! type_int { () => { - $crate::PartiqlShape::new($crate::Static::Int) + $crate::shape_builder().new_static($crate::Static::Int) }; } #[macro_export] -macro_rules! int8 { +macro_rules! type_int8 { () => { - $crate::PartiqlShape::new($crate::Static::Int8) + $crate::shape_builder().new_static($crate::Static::Int8) }; } #[macro_export] -macro_rules! int16 { +macro_rules! type_int16 { () => { - $crate::PartiqlShape::new($crate::Static::Int16) + $crate::shape_builder().new_static($crate::Static::Int16) }; } #[macro_export] -macro_rules! int32 { +macro_rules! type_int32 { () => { - $crate::PartiqlShape::new($crate::Static::Int32) + $crate::shape_builder().new_static($crate::Static::Int32) }; } #[macro_export] -macro_rules! int64 { +macro_rules! type_int64 { () => { - $crate::PartiqlShape::new($crate::Static::Int64) + $crate::shape_builder().new_static($crate::Static::Int64) }; } #[macro_export] -macro_rules! dec { +macro_rules! type_decimal { () => { - $crate::PartiqlShape::new($crate::Static::Decimal) + $crate::shape_builder().new_static($crate::Static::Decimal) }; } // TODO add macro_rule for Decimal with precision and scale #[macro_export] -macro_rules! f32 { +macro_rules! type_float32 { + () => { + $crate::shape_builder().new_static($crate::Static::Float32) + }; +} + +#[macro_export] +macro_rules! type_float64 { () => { - $crate::PartiqlShape::new($crate::Static::Float32) + $crate::shape_builder().new_static($crate::Static::Float64) }; } #[macro_export] -macro_rules! f64 { +macro_rules! type_string { () => { - $crate::PartiqlShape::new($crate::Static::Float64) + $crate::shape_builder().new_static($crate::Static::String) }; } #[macro_export] -macro_rules! str { +macro_rules! type_bool { () => { - $crate::PartiqlShape::new($crate::Static::String) + $crate::shape_builder().new_static($crate::Static::Bool) }; } #[macro_export] -macro_rules! r#struct { +macro_rules! type_numeric { () => { - $crate::PartiqlShape::new_struct(StructType::new_any()) + [ + $crate::shape_builder().new_static($crate::Static::Int), + $crate::shape_builder().new_static($crate::Static::Float32), + $crate::shape_builder().new_static($crate::Static::Float64), + $crate::shape_builder().new_static($crate::Static::Decimal), + ] + }; +} + +#[macro_export] +macro_rules! type_datetime { + () => { + $crate::shape_builder().new_static($crate::Static::DateTime) + }; +} + +#[macro_export] +macro_rules! type_struct { + () => { + $crate::shape_builder().new_struct(StructType::new_any()) }; ($elem:expr) => { - $crate::PartiqlShape::new_struct(StructType::new($elem)) + $crate::shape_builder().new_struct(StructType::new($elem)) }; } @@ -124,32 +159,66 @@ macro_rules! struct_fields { } #[macro_export] -macro_rules! r#bag { +macro_rules! type_bag { () => { - $crate::PartiqlShape::new_bag(BagType::new_any()); + $crate::shape_builder().new_bag(BagType::new_any()); }; ($elem:expr) => { - $crate::PartiqlShape::new_bag(BagType::new(Box::new($elem))) + $crate::shape_builder().new_bag(BagType::new(Box::new($elem))) }; } #[macro_export] -macro_rules! r#array { +macro_rules! type_array { () => { - $crate::PartiqlShape::new_array(ArrayType::new_any()); + $crate::shape_builder().new_array(ArrayType::new_any()); }; ($elem:expr) => { - $crate::PartiqlShape::new_array(ArrayType::new(Box::new($elem))) + $crate::shape_builder().new_array(ArrayType::new(Box::new($elem))) }; } #[macro_export] -macro_rules! undefined { +macro_rules! type_undefined { () => { $crate::PartiqlShape::Undefined }; } +// Types with constant `NodeId`, e.g., `NodeId(1)` convenient for testing or use-cases with no +// requirement for unique node ids. + +#[macro_export] +macro_rules! type_int_with_const_id { + () => { + $crate::shape_builder().new_static_with_const_id($crate::Static::Int) + }; +} + +#[macro_export] +macro_rules! type_float32_with_const_id { + () => { + $crate::shape_builder().new_static_with_const_id($crate::Static::Float32) + }; +} + +#[macro_export] +macro_rules! type_string_with_const_id { + () => { + $crate::shape_builder().new_static_with_const_id($crate::Static::String) + }; +} + +#[macro_export] +macro_rules! type_struct_with_const_id { + () => { + $crate::shape_builder().new_static_with_const_id(Static::Struct(StructType::new_any())) + }; + ($elem:expr) => { + $crate::shape_builder().new_static_with_const_id(Static::Struct(StructType::new($elem))) + }; +} + /// Represents a PartiQL Shape #[derive(Debug, Clone, Eq, PartialEq, Hash)] // With this implementation `Dynamic` and `AnyOf` cannot have `nullability`; this does not mean their @@ -162,240 +231,44 @@ pub enum PartiqlShape { Undefined, } -#[derive(Debug, Clone, Eq, PartialEq, Hash)] -pub struct StaticType { - ty: Static, - nullable: bool, -} - -#[derive(Debug, Clone, Eq, PartialEq, Hash)] -pub enum Static { - // Scalar Types - Int, - Int8, - Int16, - Int32, - Int64, - Bool, - Decimal, - DecimalP(usize, usize), - - Float32, - Float64, - - String, - StringFixed(usize), - StringVarying(usize), - - DateTime, - - // Container Types - Struct(StructType), - Bag(BagType), - Array(ArrayType), - // TODO Add BitString, ByteString, Blob, Clob, and Graph types -} - -impl Static { - pub fn is_scalar(&self) -> bool { - !matches!(self, Static::Struct(_) | Static::Bag(_) | Static::Array(_)) - } - - pub fn is_sequence(&self) -> bool { - matches!(self, Static::Bag(_) | Static::Array(_)) - } - - pub fn is_struct(&self) -> bool { - matches!(self, Static::Struct(_)) - } -} - -impl StaticType { - #[must_use] - pub fn new(ty: Static) -> StaticType { - StaticType { ty, nullable: true } - } - - #[must_use] - pub fn new_non_nullable(ty: Static) -> StaticType { - StaticType { - ty, - nullable: false, - } - } - - #[must_use] - pub fn ty(&self) -> &Static { - &self.ty - } - - #[must_use] - pub fn is_nullable(&self) -> bool { - self.nullable - } - - #[must_use] - pub fn is_not_nullable(&self) -> bool { - !self.nullable - } - - pub fn is_scalar(&self) -> bool { - self.ty.is_scalar() - } - - pub fn is_sequence(&self) -> bool { - self.ty.is_sequence() - } - - pub fn is_struct(&self) -> bool { - self.ty.is_struct() - } -} - -impl Display for StaticType { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - let nullable = if self.nullable { - "nullable" - } else { - "non_nullable" - }; - write!(f, "({}, {})", self.ty, nullable) - } -} - -impl Display for Static { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - let x = match self { - Static::Int => "Int".to_string(), - Static::Int8 => "Int8".to_string(), - Static::Int16 => "Int16".to_string(), - Static::Int32 => "Int32".to_string(), - Static::Int64 => "Int64".to_string(), - Static::Bool => "Bool".to_string(), - Static::Decimal => "Decimal".to_string(), - Static::DecimalP(_, _) => { - todo!() - } - Static::Float32 => "Float32".to_string(), - Static::Float64 => "Float64".to_string(), - Static::String => "String".to_string(), - Static::StringFixed(_) => { - todo!() - } - Static::StringVarying(_) => { - todo!() - } - Static::DateTime => "DateTime".to_string(), - Static::Struct(_) => "Struct".to_string(), - Static::Bag(_) => "Bag".to_string(), - Static::Array(_) => "Array".to_string(), - }; - write!(f, "{x}") - } -} - -pub const TYPE_DYNAMIC: PartiqlShape = PartiqlShape::Dynamic; -pub const TYPE_BOOL: PartiqlShape = PartiqlShape::new(Static::Bool); -pub const TYPE_INT: PartiqlShape = PartiqlShape::new(Static::Int); -pub const TYPE_INT8: PartiqlShape = PartiqlShape::new(Static::Int8); -pub const TYPE_INT16: PartiqlShape = PartiqlShape::new(Static::Int16); -pub const TYPE_INT32: PartiqlShape = PartiqlShape::new(Static::Int32); -pub const TYPE_INT64: PartiqlShape = PartiqlShape::new(Static::Int64); -pub const TYPE_REAL: PartiqlShape = PartiqlShape::new(Static::Float32); -pub const TYPE_DOUBLE: PartiqlShape = PartiqlShape::new(Static::Float64); -pub const TYPE_DECIMAL: PartiqlShape = PartiqlShape::new(Static::Decimal); -pub const TYPE_STRING: PartiqlShape = PartiqlShape::new(Static::String); -pub const TYPE_DATETIME: PartiqlShape = PartiqlShape::new(Static::DateTime); -pub const TYPE_NUMERIC_TYPES: [PartiqlShape; 4] = [TYPE_INT, TYPE_REAL, TYPE_DOUBLE, TYPE_DECIMAL]; - #[allow(dead_code)] impl PartiqlShape { - #[must_use] - pub const fn new(ty: Static) -> PartiqlShape { - PartiqlShape::Static(StaticType { ty, nullable: true }) - } - #[must_use] - pub const fn new_non_nullable(ty: Static) -> PartiqlShape { - PartiqlShape::Static(StaticType { - ty, - nullable: false, - }) - } - - #[must_use] - pub fn as_non_nullable(&self) -> Option { - if let PartiqlShape::Static(stype) = self { - Some(PartiqlShape::Static(StaticType { - ty: stype.ty.clone(), - nullable: false, - })) - } else { - None - } - } - - #[must_use] - pub fn new_dynamic() -> PartiqlShape { - PartiqlShape::Dynamic - } - - #[must_use] - pub fn new_struct(s: StructType) -> PartiqlShape { - PartiqlShape::new(Static::Struct(s)) - } - - #[must_use] - pub fn new_bag(b: BagType) -> PartiqlShape { - PartiqlShape::new(Static::Bag(b)) - } - - #[must_use] - pub fn new_array(a: ArrayType) -> PartiqlShape { - PartiqlShape::new(Static::Array(a)) - } - - pub fn any_of(types: I) -> PartiqlShape - where - I: IntoIterator, - { - let any_of = AnyOf::from_iter(types); - match any_of.types.len() { - 0 => TYPE_DYNAMIC, - 1 => { - let AnyOf { types } = any_of; - types.into_iter().next().unwrap() - } - // TODO figure out what does it mean for a Union to be nullable or not - _ => PartiqlShape::AnyOf(any_of), - } - } - #[must_use] pub fn union_with(self, other: PartiqlShape) -> PartiqlShape { match (self, other) { - (PartiqlShape::Dynamic, _) | (_, PartiqlShape::Dynamic) => PartiqlShape::new_dynamic(), + (PartiqlShape::Dynamic, _) | (_, PartiqlShape::Dynamic) => PartiqlShape::Dynamic, (PartiqlShape::AnyOf(lhs), PartiqlShape::AnyOf(rhs)) => { - PartiqlShape::any_of(lhs.types.into_iter().chain(rhs.types)) + shape_builder().any_of(lhs.types.into_iter().chain(rhs.types)) } (PartiqlShape::AnyOf(anyof), other) | (other, PartiqlShape::AnyOf(anyof)) => { let mut types = anyof.types; types.insert(other); - PartiqlShape::any_of(types) + shape_builder().any_of(types) } (l, r) => { let types = [l, r]; - PartiqlShape::any_of(types) + shape_builder().any_of(types) } } } + #[must_use] + pub fn static_type_id(&self) -> Option { + if let PartiqlShape::Static(StaticType { id, .. }) = self { + Some(*id) + } else { + None + } + } + #[must_use] pub fn is_string(&self) -> bool { matches!( &self, PartiqlShape::Static(StaticType { ty: Static::String, - nullable: true + nullable: true, + .. }) ) } @@ -406,7 +279,8 @@ impl PartiqlShape { *self, PartiqlShape::Static(StaticType { ty: Static::Struct(_), - nullable: true + nullable: true, + .. }) ) } @@ -417,13 +291,15 @@ impl PartiqlShape { *self, PartiqlShape::Static(StaticType { ty: Static::Bag(_), - nullable: true + nullable: true, + .. }) ) || matches!( *self, PartiqlShape::Static(StaticType { ty: Static::Array(_), - nullable: true + nullable: true, + .. }) ) } @@ -440,7 +316,8 @@ impl PartiqlShape { *self, PartiqlShape::Static(StaticType { ty: Static::Array(_), - nullable: true + nullable: true, + .. }) ) } @@ -451,7 +328,8 @@ impl PartiqlShape { *self, PartiqlShape::Static(StaticType { ty: Static::Bag(_), - nullable: true + nullable: true, + .. }) ) } @@ -462,7 +340,8 @@ impl PartiqlShape { *self, PartiqlShape::Static(StaticType { ty: Static::Array(_), - nullable: true + nullable: true, + .. }) ) } @@ -479,11 +358,13 @@ impl PartiqlShape { pub fn expect_bool(&self) -> ShapeResult { if let PartiqlShape::Static(StaticType { + id, ty: Static::Bool, nullable: n, }) = self { Ok(StaticType { + id: *id, ty: Static::Bool, nullable: *n, }) @@ -492,6 +373,18 @@ impl PartiqlShape { } } + pub fn expect_bag(&self) -> ShapeResult { + if let PartiqlShape::Static(StaticType { + ty: Static::Bag(bag), + .. + }) = self + { + Ok(bag.clone()) + } else { + Err(ShapeResultError::UnexpectedType(format!("{self}"))) + } + } + pub fn expect_struct(&self) -> ShapeResult { if let PartiqlShape::Static(StaticType { ty: Static::Struct(stct), @@ -543,6 +436,103 @@ impl Display for PartiqlShape { } } +#[derive(Default)] +pub struct PartiqlShapeBuilder { + id_gen: AutoNodeIdGenerator, +} + +impl PartiqlShapeBuilder { + #[must_use] + pub fn new_static(&self, ty: Static) -> PartiqlShape { + let id = self.id_gen.next_id(); + let id = id.read().expect("NodeId read lock"); + PartiqlShape::Static(StaticType { + id: *id, + ty, + nullable: true, + }) + } + + #[must_use] + pub fn new_static_with_const_id(&self, ty: Static) -> PartiqlShape { + PartiqlShape::Static(StaticType { + id: NodeId(1), + ty, + nullable: true, + }) + } + + #[must_use] + pub fn new_non_nullable_static(&self, ty: Static) -> PartiqlShape { + let id = self.id_gen.next_id(); + let id = id.read().expect("NodeId read lock"); + PartiqlShape::Static(StaticType { + id: *id, + ty, + nullable: false, + }) + } + + #[must_use] + pub fn new_non_nullable_static_with_const_id(&self, ty: Static) -> PartiqlShape { + PartiqlShape::Static(StaticType { + id: NodeId(1), + ty, + nullable: false, + }) + } + + #[must_use] + pub fn new_dynamic(&self) -> PartiqlShape { + PartiqlShape::Dynamic + } + + #[must_use] + pub fn new_undefined(&self) -> PartiqlShape { + PartiqlShape::Dynamic + } + + #[must_use] + pub fn new_struct(&self, s: StructType) -> PartiqlShape { + self.new_static(Static::Struct(s)) + } + + #[must_use] + pub fn new_bag(&self, b: BagType) -> PartiqlShape { + self.new_static(Static::Bag(b)) + } + + #[must_use] + pub fn new_array(&self, a: ArrayType) -> PartiqlShape { + self.new_static(Static::Array(a)) + } + + pub fn any_of(&self, types: I) -> PartiqlShape + where + I: IntoIterator, + { + let any_of = AnyOf::from_iter(types); + match any_of.types.len() { + 0 => type_dynamic!(), + 1 => { + let AnyOf { types } = any_of; + types.into_iter().next().unwrap() + } + // TODO figure out what does it mean for a Union to be nullable or not + _ => PartiqlShape::AnyOf(any_of), + } + } + + #[must_use] + pub fn as_non_nullable(&self, shape: &PartiqlShape) -> Option { + if let PartiqlShape::Static(stype) = shape { + Some(self.new_non_nullable_static(stype.ty.clone())) + } else { + None + } + } +} + #[derive(Derivative, Eq, Debug, Clone)] #[derivative(PartialEq, Hash)] #[allow(dead_code)] @@ -570,6 +560,174 @@ impl FromIterator for AnyOf { } } +/// A Builder for [`AstNode`]s that uses a [`NodeIdGenerator`] to assign [`NodeId`]s +pub struct StaticTypeBuilder { + /// Generator for 'fresh' [`NodeId`]s + pub id_gen: IdGen, +} + +impl StaticTypeBuilder +where + IdGen: NodeIdGenerator, +{ + pub fn new(id_gen: IdGen) -> Self { + Self { id_gen } + } + + pub fn nullable_ty(&mut self, ty: Static) -> StaticType { + let id = self.id_gen.id(); + let id = id.read().expect("NodeId read lock"); + self.ty(ty, *id, true) + } + + pub fn non_nullable_ty(&self, ty: Static) -> StaticType { + let id = self.id_gen.id(); + let id = id.read().expect("NodeId read lock"); + self.ty(ty, *id, false) + } + + pub fn ty(&self, ty: Static, id: NodeId, nullable: bool) -> StaticType { + StaticType { id, ty, nullable } + } +} + +impl Default for StaticTypeBuilder +where + T: NodeIdGenerator + Default, +{ + fn default() -> Self { + Self::new(T::default()) + } +} + +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +pub struct StaticType { + id: NodeId, + ty: Static, + nullable: bool, +} + +impl StaticType { + #[must_use] + pub fn ty(&self) -> &Static { + &self.ty + } + + pub fn ty_id(&self) -> &NodeId { + &self.id + } + + #[must_use] + pub fn is_nullable(&self) -> bool { + self.nullable + } + + #[must_use] + pub fn is_not_nullable(&self) -> bool { + !self.nullable + } + + pub fn is_scalar(&self) -> bool { + self.ty.is_scalar() + } + + pub fn is_sequence(&self) -> bool { + self.ty.is_sequence() + } + + pub fn is_struct(&self) -> bool { + self.ty.is_struct() + } +} + +impl Display for StaticType { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let nullable = if self.nullable { + "nullable" + } else { + "non_nullable" + }; + write!(f, "({}, {})", self.ty, nullable) + } +} + +pub type StaticTypeMetas = HashMap; + +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +pub enum Static { + // Scalar Types + Int, + Int8, + Int16, + Int32, + Int64, + Bool, + Decimal, + DecimalP(usize, usize), + + Float32, + Float64, + + String, + StringFixed(usize), + StringVarying(usize), + + DateTime, + + // Container Types + Struct(StructType), + Bag(BagType), + Array(ArrayType), + // TODO Add BitString, ByteString, Blob, Clob, and Graph types +} + +impl Static { + pub fn is_scalar(&self) -> bool { + !matches!(self, Static::Struct(_) | Static::Bag(_) | Static::Array(_)) + } + + pub fn is_sequence(&self) -> bool { + matches!(self, Static::Bag(_) | Static::Array(_)) + } + + pub fn is_struct(&self) -> bool { + matches!(self, Static::Struct(_)) + } +} + +impl Display for Static { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let x = match self { + Static::Int => "Int".to_string(), + Static::Int8 => "Int8".to_string(), + Static::Int16 => "Int16".to_string(), + Static::Int32 => "Int32".to_string(), + Static::Int64 => "Int64".to_string(), + Static::Bool => "Bool".to_string(), + Static::Decimal => "Decimal".to_string(), + Static::DecimalP(_, _) => { + todo!() + } + Static::Float32 => "Float32".to_string(), + Static::Float64 => "Float64".to_string(), + Static::String => "String".to_string(), + Static::StringFixed(_) => { + todo!() + } + Static::StringVarying(_) => { + todo!() + } + Static::DateTime => "DateTime".to_string(), + Static::Struct(_) => "Struct".to_string(), + Static::Bag(_) => "Bag".to_string(), + Static::Array(_) => "Array".to_string(), + }; + write!(f, "{x}") + } +} + +pub const TYPE_DYNAMIC: PartiqlShape = PartiqlShape::Dynamic; + #[derive(Derivative, Eq, Debug, Clone)] #[derivative(PartialEq, Hash)] #[allow(dead_code)] @@ -754,32 +912,92 @@ impl ArrayType { #[cfg(test)] mod tests { - use crate::{PartiqlShape, TYPE_INT, TYPE_REAL}; + use crate::{ + shape_builder, BagType, PartiqlShape, Static, StructConstraint, StructField, StructType, + }; + use indexmap::IndexSet; #[test] fn union() { - let expect_int = TYPE_INT; - assert_eq!(expect_int, TYPE_INT.union_with(TYPE_INT)); + let expect_int = type_int_with_const_id!(); + assert_eq!( + expect_int, + type_int_with_const_id!().union_with(type_int_with_const_id!()) + ); - let expect_nums = PartiqlShape::any_of([TYPE_INT, TYPE_REAL]); - assert_eq!(expect_nums, TYPE_INT.union_with(TYPE_REAL)); + let expect_nums = + shape_builder().any_of([type_int_with_const_id!(), type_float32_with_const_id!()]); assert_eq!( expect_nums, - PartiqlShape::any_of([ - TYPE_INT.union_with(TYPE_REAL), - TYPE_INT.union_with(TYPE_REAL) + type_int_with_const_id!().union_with(type_float32_with_const_id!()) + ); + assert_eq!( + expect_nums, + shape_builder().any_of([ + type_int_with_const_id!().union_with(type_float32_with_const_id!()), + type_int_with_const_id!().union_with(type_float32_with_const_id!()) ]) ); assert_eq!( expect_nums, - PartiqlShape::any_of([ - TYPE_INT.union_with(TYPE_REAL), - TYPE_INT.union_with(TYPE_REAL), - PartiqlShape::any_of([ - TYPE_INT.union_with(TYPE_REAL), - TYPE_INT.union_with(TYPE_REAL) + shape_builder().any_of([ + type_int_with_const_id!().union_with(type_float32_with_const_id!()), + type_int_with_const_id!().union_with(type_float32_with_const_id!()), + shape_builder().any_of([ + type_int_with_const_id!().union_with(type_float32_with_const_id!()), + type_int_with_const_id!().union_with(type_float32_with_const_id!()) ]) ]) ); } + + #[test] + fn unique_node_ids() { + let age_field = struct_fields![("age", type_int!())]; + let details = type_struct![IndexSet::from([age_field])]; + + let fields = [ + StructField::new("id", type_int!()), + StructField::new("name", type_string!()), + StructField::new("details", details.clone()), + ]; + + let row = type_struct![IndexSet::from([ + StructConstraint::Fields(IndexSet::from(fields)), + StructConstraint::Open(false) + ])]; + + let shape = type_bag![row.clone()]; + + let mut ids = collect_ids(shape); + ids.sort_unstable(); + assert!(ids.windows(2).all(|w| w[0] != w[1])); + } + + fn collect_ids(row: PartiqlShape) -> Vec { + let mut out = vec![]; + match row { + PartiqlShape::Dynamic => {} + PartiqlShape::AnyOf(anyof) => { + for shape in anyof.types { + out.push(collect_ids(shape)); + } + } + PartiqlShape::Static(static_type) => { + out.push(vec![static_type.id.0]); + match static_type.ty { + Static::Struct(struct_type) => { + for f in struct_type.fields() { + out.push(collect_ids(f.ty.clone())); + } + } + Static::Bag(bag_type) => out.push(collect_ids(*bag_type.element_type)), + Static::Array(array_type) => out.push(collect_ids(*array_type.element_type)), + _ => {} + } + } + PartiqlShape::Undefined => {} + } + out.into_iter().flatten().collect() + } } From 1b946b424dd203ad4eead8de9e154de1bc63fae3 Mon Sep 17 00:00:00 2001 From: Arash M <27716912+am357@users.noreply.github.com> Date: Mon, 12 Aug 2024 16:37:37 -0700 Subject: [PATCH 2/4] self review --- CHANGELOG.md | 11 +- extension/partiql-extension-ddl/src/ddl.rs | 15 +- .../partiql-extension-ddl/tests/ddl-tests.rs | 6 +- partiql-ast/src/builder.rs | 2 +- partiql-common/src/node.rs | 31 ++-- partiql-eval/src/eval/expr/coll.rs | 26 ++-- partiql-eval/src/eval/expr/operators.rs | 27 ++-- partiql-logical-planner/src/typer.rs | 38 +++-- partiql-types/src/lib.rs | 133 +++++++----------- 9 files changed, 133 insertions(+), 156 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6262612a..fc612a92 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,11 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - partiql-ast: improved pretty-printing of `CASE` and various clauses - ### Added -- Added `partiql-common` and moved node id generation and `partiql-source-map` code to it under `syntax` +- Added `partiql-common`. - Added `NodeId` to `StaticType`. - *BREAKING* Added thread-safe `PartiqlShapeBuilder` and automatic `NodeId` generation for the `StaticType`. -- *BREAKING* Moved some of the `PartiqlShape` APIs to the `PartiqlShapeBuilder`. -- *BREAKING* Prepended existing type macros with `type` such as `type_int!` to make macro names more friendly. - Added a static thread safe `shape_builder` function that provides a convenient way for using `PartiqlShapeBuilder` for creating new shapes. ### Removed @@ -26,8 +24,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [0.10.0] ### Changed - *BREAKING:* partiql-ast: added modeling of `EXCLUDE` -- *BREAKING:* partiql-ast: added pretty-printing of `EXCLUDE -- Changed `AutoNodeIdGenerator` to a thread-safe version +- *BREAKING:* partiql-ast: added pretty-printing of `EXCLUDE` +- *BREAKING* Moved some of the `PartiqlShape` APIs to the `PartiqlShapeBuilder`. +- *BREAKING* Prepended existing type macros with `type` to make macro names more friendly: e.g., `type_int!` +- *BREAKING* Moved node id generation and `partiql-source-map` to it. +- *BREAKING* Changed `AutoNodeIdGenerator` to a thread-safe version ### Added - *BREAKING:* partiql-parser: added parsing of `EXCLUDE` diff --git a/extension/partiql-extension-ddl/src/ddl.rs b/extension/partiql-extension-ddl/src/ddl.rs index ae16cd86..02e83760 100644 --- a/extension/partiql-extension-ddl/src/ddl.rs +++ b/extension/partiql-extension-ddl/src/ddl.rs @@ -230,8 +230,8 @@ mod tests { use super::*; use indexmap::IndexSet; use partiql_types::{ - shape_builder, struct_fields, type_array, type_bag, type_float64, type_int8, type_string, - type_struct, StructConstraint, + struct_fields, type_array, type_bag, type_float64, type_int8, type_string, type_struct, + PartiqlShapeBuilder, StructConstraint, }; #[test] @@ -239,9 +239,9 @@ mod tests { let nested_attrs = struct_fields![ ( "a", - shape_builder().any_of(vec![ - shape_builder().new_static(Static::DecimalP(5, 4)), - shape_builder().new_static(Static::Int8), + PartiqlShapeBuilder::init_or_get().any_of(vec![ + PartiqlShapeBuilder::init_or_get().new_static(Static::DecimalP(5, 4)), + PartiqlShapeBuilder::init_or_get().new_static(Static::Int8), ]) ), ("b", type_array![type_string![]]), @@ -252,7 +252,10 @@ mod tests { let fields = struct_fields![ ("employee_id", type_int8![]), ("full_name", type_string![]), - ("salary", shape_builder().new_static(Static::DecimalP(8, 2))), + ( + "salary", + PartiqlShapeBuilder::init_or_get().new_static(Static::DecimalP(8, 2)) + ), ("details", details), ("dependents", type_array![type_string![]]) ]; diff --git a/extension/partiql-extension-ddl/tests/ddl-tests.rs b/extension/partiql-extension-ddl/tests/ddl-tests.rs index 635de09b..20072856 100644 --- a/extension/partiql-extension-ddl/tests/ddl-tests.rs +++ b/extension/partiql-extension-ddl/tests/ddl-tests.rs @@ -1,8 +1,8 @@ use indexmap::IndexSet; use partiql_extension_ddl::ddl::{DdlFormat, PartiqlBasicDdlEncoder, PartiqlDdlEncoder}; use partiql_types::{ - shape_builder, struct_fields, type_bag, type_int, type_string, type_struct, StructConstraint, - StructField, + struct_fields, type_bag, type_int, type_string, type_struct, PartiqlShapeBuilder, + StructConstraint, StructField, }; use partiql_types::{BagType, Static, StructType}; @@ -15,7 +15,7 @@ fn basic_ddl_test() { StructField::new("name", type_string!()), StructField::new( "address", - shape_builder().new_non_nullable_static(Static::String), + PartiqlShapeBuilder::init_or_get().new_non_nullable_static(Static::String), ), StructField::new_optional("details", details.clone()), ] diff --git a/partiql-ast/src/builder.rs b/partiql-ast/src/builder.rs index a2f3219e..571b9796 100644 --- a/partiql-ast/src/builder.rs +++ b/partiql-ast/src/builder.rs @@ -17,7 +17,7 @@ where pub fn node(&mut self, node: T) -> AstNode { let id = self.id_gen.id(); - let id = id.read().expect("NodeId read lock"); + let id = id.read().expect("NodId read lock"); AstNode { id: *id, node } } } diff --git a/partiql-common/src/node.rs b/partiql-common/src/node.rs index 9cc3be78..d8a1640b 100644 --- a/partiql-common/src/node.rs +++ b/partiql-common/src/node.rs @@ -1,5 +1,6 @@ use indexmap::IndexMap; use std::hash::Hash; +use std::mem; use std::sync::{Arc, RwLock}; #[cfg(feature = "serde")] @@ -24,27 +25,27 @@ impl Default for AutoNodeIdGenerator { } } -impl AutoNodeIdGenerator { - pub fn next_id(&self) -> Arc> { - self.id() - } -} - /// A provider of 'fresh' [`NodeId`]s. pub trait NodeIdGenerator { - /// Provides a 'fresh' [`NodeId`]. fn id(&self) -> Arc>; + + /// Provides a 'fresh' [`NodeId`]. + fn next_id(&self) -> NodeId; } impl NodeIdGenerator for AutoNodeIdGenerator { - #[inline] fn id(&self) -> Arc> { - let id = &self.next_id.read().expect("NodeId read lock"); - let next = NodeId(id.0 + 1); - let mut w = self.next_id.write().expect("NodeId write lock"); - *w = next; + let id = self.next_id(); + let mut w = self.next_id.write().expect("NodId write lock"); + *w = id; Arc::clone(&self.next_id) } + + #[inline] + fn next_id(&self) -> NodeId { + let id = &self.next_id.read().expect("NodId read lock"); + NodeId(id.0 + 1) + } } /// A provider of [`NodeId`]s that are always `0`; Useful for testing @@ -53,6 +54,10 @@ pub struct NullIdGenerator {} impl NodeIdGenerator for NullIdGenerator { fn id(&self) -> Arc> { - Arc::new(RwLock::new(NodeId(0))) + Arc::new(RwLock::from(self.next_id())) + } + + fn next_id(&self) -> NodeId { + NodeId(0) } } diff --git a/partiql-eval/src/eval/expr/coll.rs b/partiql-eval/src/eval/expr/coll.rs index 0af297ff..5991f670 100644 --- a/partiql-eval/src/eval/expr/coll.rs +++ b/partiql-eval/src/eval/expr/coll.rs @@ -5,7 +5,7 @@ use crate::eval::expr::{BindError, BindEvalExpr, EvalExpr}; use itertools::{Itertools, Unique}; use partiql_types::{ - shape_builder, type_bool, type_numeric, ArrayType, BagType, PartiqlShape, Static, + type_bool, type_numeric, ArrayType, BagType, PartiqlShape, PartiqlShapeBuilder, Static, }; use partiql_value::Value::{Missing, Null}; use partiql_value::{BinaryAnd, BinaryOr, Value, ValueIter}; @@ -51,21 +51,23 @@ impl BindEvalExpr for EvalCollFn { value.sequence_iter().map_or(Missing, &f) }) } - let boolean_elems = [shape_builder().any_of([ - shape_builder().new_static(Static::Array(ArrayType::new(Box::new(type_bool!())))), - shape_builder().new_static(Static::Bag(BagType::new(Box::new(type_bool!())))), + let boolean_elems = [PartiqlShapeBuilder::init_or_get().any_of([ + PartiqlShapeBuilder::init_or_get() + .new_static(Static::Array(ArrayType::new(Box::new(type_bool!())))), + PartiqlShapeBuilder::init_or_get() + .new_static(Static::Bag(BagType::new(Box::new(type_bool!())))), ])]; - let numeric_elems = [shape_builder().any_of([ - shape_builder().new_static(Static::Array(ArrayType::new(Box::new( - shape_builder().any_of(type_numeric!()), + let numeric_elems = [PartiqlShapeBuilder::init_or_get().any_of([ + PartiqlShapeBuilder::init_or_get().new_static(Static::Array(ArrayType::new(Box::new( + PartiqlShapeBuilder::init_or_get().any_of(type_numeric!()), )))), - shape_builder().new_static(Static::Bag(BagType::new(Box::new( - shape_builder().any_of(type_numeric!()), + PartiqlShapeBuilder::init_or_get().new_static(Static::Bag(BagType::new(Box::new( + PartiqlShapeBuilder::init_or_get().any_of(type_numeric!()), )))), ])]; - let any_elems = [shape_builder().any_of([ - shape_builder().new_static(Static::Array(ArrayType::new_any())), - shape_builder().new_static(Static::Bag(BagType::new_any())), + let any_elems = [PartiqlShapeBuilder::init_or_get().any_of([ + PartiqlShapeBuilder::init_or_get().new_static(Static::Array(ArrayType::new_any())), + PartiqlShapeBuilder::init_or_get().new_static(Static::Bag(BagType::new_any())), ])]; match *self { diff --git a/partiql-eval/src/eval/expr/operators.rs b/partiql-eval/src/eval/expr/operators.rs index bae68bc4..693bf695 100644 --- a/partiql-eval/src/eval/expr/operators.rs +++ b/partiql-eval/src/eval/expr/operators.rs @@ -8,8 +8,8 @@ use crate::eval::expr::{BindError, BindEvalExpr, EvalExpr}; use crate::eval::EvalContext; use partiql_types::{ - shape_builder, type_bool, type_dynamic, type_numeric, ArrayType, BagType, PartiqlShape, Static, - StructType, + type_bool, type_dynamic, type_numeric, ArrayType, BagType, PartiqlShape, PartiqlShapeBuilder, + Static, StructType, }; use partiql_value::Value::{Boolean, Missing, Null}; use partiql_value::{BinaryAnd, EqualityValue, NullableEq, NullableOrd, Tuple, Value}; @@ -80,7 +80,7 @@ impl BindEvalExpr for EvalOpUnary { &self, args: Vec>, ) -> Result, BindError> { - let any_num = shape_builder().any_of(type_numeric!()); + let any_num = PartiqlShapeBuilder::init_or_get().any_of(type_numeric!()); let unop = |types, f: fn(&Value) -> Value| { UnaryValueExpr::create_typed::<{ STRICT }, _>(types, args, f) @@ -179,7 +179,7 @@ impl BindEvalExpr for EvalOpBinary { macro_rules! math { ($f:expr) => {{ - let nums = shape_builder().any_of(type_numeric!()); + let nums = PartiqlShapeBuilder::init_or_get().any_of(type_numeric!()); create!(MathCheck, [nums.clone(), nums], $f) }}; } @@ -210,9 +210,11 @@ impl BindEvalExpr for EvalOpBinary { InCheck, [ type_dynamic!(), - shape_builder().any_of([ - shape_builder().new_static(Static::Array(ArrayType::new_any())), - shape_builder().new_static(Static::Bag(BagType::new_any())), + PartiqlShapeBuilder::init_or_get().any_of([ + PartiqlShapeBuilder::init_or_get() + .new_static(Static::Array(ArrayType::new_any())), + PartiqlShapeBuilder::init_or_get() + .new_static(Static::Bag(BagType::new_any())), ]) ], |lhs, rhs| { @@ -320,7 +322,7 @@ impl BindEvalExpr for EvalFnAbs { &self, args: Vec>, ) -> Result, BindError> { - let nums = shape_builder().any_of(type_numeric!()); + let nums = PartiqlShapeBuilder::init_or_get().any_of(type_numeric!()); UnaryValueExpr::create_typed::<{ STRICT }, _>([nums], args, |v| { match NullableOrd::lt(v, &Value::from(0)) { Null => Null, @@ -341,10 +343,11 @@ impl BindEvalExpr for EvalFnCardinality { &self, args: Vec>, ) -> Result, BindError> { - let collections = shape_builder().any_of([ - shape_builder().new_static(Static::Array(ArrayType::new_any())), - shape_builder().new_static(Static::Bag(BagType::new_any())), - shape_builder().new_static(Static::Struct(StructType::new_any())), + let shape_builder = PartiqlShapeBuilder::init_or_get(); + let collections = PartiqlShapeBuilder::init_or_get().any_of([ + shape_builder.new_static(Static::Array(ArrayType::new_any())), + shape_builder.new_static(Static::Bag(BagType::new_any())), + shape_builder.new_static(Static::Struct(StructType::new_any())), ]); UnaryValueExpr::create_typed::<{ STRICT }, _>([collections], args, |v| match v { diff --git a/partiql-logical-planner/src/typer.rs b/partiql-logical-planner/src/typer.rs index 657ada2a..f42b8851 100644 --- a/partiql-logical-planner/src/typer.rs +++ b/partiql-logical-planner/src/typer.rs @@ -4,7 +4,8 @@ use partiql_ast::ast::{CaseSensitivity, SymbolPrimitive}; use partiql_catalog::Catalog; use partiql_logical::{BindingsOp, LogicalPlan, OpId, PathComponent, ValueExpr, VarRefType}; use partiql_types::{ - shape_builder, type_dynamic, type_undefined, ArrayType, BagType, PartiqlShape, + type_array, type_bag, type_bool, type_decimal, type_dynamic, type_int, type_string, + type_struct, type_undefined, ArrayType, BagType, PartiqlShape, PartiqlShapeBuilder, ShapeResultError, Static, StructConstraint, StructField, StructType, }; use partiql_value::{BindingsName, Value}; @@ -218,16 +219,17 @@ impl<'c> PlanTyper<'c> { StructField::new(k.as_str(), self.get_singleton_type_from_env()) }); - let ty = shape_builder().new_struct(StructType::new(IndexSet::from([ - StructConstraint::Fields(fields.collect()), - ]))); + let ty = PartiqlShapeBuilder::init_or_get().new_struct(StructType::new( + IndexSet::from([StructConstraint::Fields(fields.collect())]), + )); let derived_type_ctx = self.local_type_ctx(); let derived_type = &self.derived_type(&derived_type_ctx); let schema = if derived_type.is_ordered_collection() { - shape_builder().new_array(ArrayType::new(Box::new(ty))) + PartiqlShapeBuilder::init_or_get().new_array(ArrayType::new(Box::new(ty))) } else if derived_type.is_unordered_collection() { - shape_builder().new_static(Static::Bag(BagType::new(Box::new(ty)))) + PartiqlShapeBuilder::init_or_get() + .new_static(Static::Bag(BagType::new(Box::new(ty)))) } else { self.errors.push(TypingError::IllegalState(format!( "Expecting Collection for the output Schema but found {:?}", @@ -331,24 +333,20 @@ impl<'c> PlanTyper<'c> { } ValueExpr::Lit(v) => { let ty = match **v { - Value::Null => shape_builder().new_undefined(), - Value::Missing => shape_builder().new_undefined(), - Value::Integer(_) => shape_builder().new_static(Static::Int), - Value::Decimal(_) => shape_builder().new_static(Static::Decimal), - Value::Boolean(_) => shape_builder().new_static(Static::Bool), - Value::String(_) => shape_builder().new_static(Static::String), - Value::Tuple(_) => { - shape_builder().new_static(Static::Struct(StructType::new_any())) - } - Value::List(_) => { - shape_builder().new_static(Static::Array(ArrayType::new_any())) - } - Value::Bag(_) => shape_builder().new_static(Static::Bag(BagType::new_any())), + Value::Null => type_undefined!(), + Value::Missing => type_undefined!(), + Value::Integer(_) => type_int!(), + Value::Decimal(_) => type_decimal!(), + Value::Boolean(_) => type_bool!(), + Value::String(_) => type_string!(), + Value::Tuple(_) => type_struct!(), + Value::List(_) => type_array!(), + Value::Bag(_) => type_bag!(), _ => { self.errors.push(TypingError::NotYetImplemented( "Unsupported Literal".to_string(), )); - shape_builder().new_undefined() + type_undefined!() } }; diff --git a/partiql-types/src/lib.rs b/partiql-types/src/lib.rs index 564462dc..f251795f 100644 --- a/partiql-types/src/lib.rs +++ b/partiql-types/src/lib.rs @@ -12,12 +12,6 @@ use std::hash::{Hash, Hasher}; use std::sync::OnceLock; use thiserror::Error; -#[track_caller] -pub fn shape_builder() -> &'static PartiqlShapeBuilder { - static SHAPE_BUILDER: OnceLock = OnceLock::new(); - SHAPE_BUILDER.get_or_init(PartiqlShapeBuilder::default) -} - #[derive(Debug, Clone, Eq, PartialEq, Hash, Error, Diagnostic)] #[error("ShapeResult Error")] #[non_exhaustive] @@ -46,49 +40,49 @@ where #[macro_export] macro_rules! type_dynamic { () => { - $crate::shape_builder().new_dynamic() + $crate::PartiqlShapeBuilder::init_or_get().new_dynamic() }; } #[macro_export] macro_rules! type_int { () => { - $crate::shape_builder().new_static($crate::Static::Int) + $crate::PartiqlShapeBuilder::init_or_get().new_static($crate::Static::Int) }; } #[macro_export] macro_rules! type_int8 { () => { - $crate::shape_builder().new_static($crate::Static::Int8) + $crate::PartiqlShapeBuilder::init_or_get().new_static($crate::Static::Int8) }; } #[macro_export] macro_rules! type_int16 { () => { - $crate::shape_builder().new_static($crate::Static::Int16) + $crate::PartiqlShapeBuilder::init_or_get().new_static($crate::Static::Int16) }; } #[macro_export] macro_rules! type_int32 { () => { - $crate::shape_builder().new_static($crate::Static::Int32) + $crate::PartiqlShapeBuilder::init_or_get().new_static($crate::Static::Int32) }; } #[macro_export] macro_rules! type_int64 { () => { - $crate::shape_builder().new_static($crate::Static::Int64) + $crate::PartiqlShapeBuilder::init_or_get().new_static($crate::Static::Int64) }; } #[macro_export] macro_rules! type_decimal { () => { - $crate::shape_builder().new_static($crate::Static::Decimal) + $crate::PartiqlShapeBuilder::init_or_get().new_static($crate::Static::Decimal) }; } @@ -97,28 +91,28 @@ macro_rules! type_decimal { #[macro_export] macro_rules! type_float32 { () => { - $crate::shape_builder().new_static($crate::Static::Float32) + $crate::PartiqlShapeBuilder::init_or_get().new_static($crate::Static::Float32) }; } #[macro_export] macro_rules! type_float64 { () => { - $crate::shape_builder().new_static($crate::Static::Float64) + $crate::PartiqlShapeBuilder::init_or_get().new_static($crate::Static::Float64) }; } #[macro_export] macro_rules! type_string { () => { - $crate::shape_builder().new_static($crate::Static::String) + $crate::PartiqlShapeBuilder::init_or_get().new_static($crate::Static::String) }; } #[macro_export] macro_rules! type_bool { () => { - $crate::shape_builder().new_static($crate::Static::Bool) + $crate::PartiqlShapeBuilder::init_or_get().new_static($crate::Static::Bool) }; } @@ -126,10 +120,10 @@ macro_rules! type_bool { macro_rules! type_numeric { () => { [ - $crate::shape_builder().new_static($crate::Static::Int), - $crate::shape_builder().new_static($crate::Static::Float32), - $crate::shape_builder().new_static($crate::Static::Float64), - $crate::shape_builder().new_static($crate::Static::Decimal), + $crate::PartiqlShapeBuilder::init_or_get().new_static($crate::Static::Int), + $crate::PartiqlShapeBuilder::init_or_get().new_static($crate::Static::Float32), + $crate::PartiqlShapeBuilder::init_or_get().new_static($crate::Static::Float64), + $crate::PartiqlShapeBuilder::init_or_get().new_static($crate::Static::Decimal), ] }; } @@ -137,17 +131,17 @@ macro_rules! type_numeric { #[macro_export] macro_rules! type_datetime { () => { - $crate::shape_builder().new_static($crate::Static::DateTime) + $crate::PartiqlShapeBuilder::init_or_get().new_static($crate::Static::DateTime) }; } #[macro_export] macro_rules! type_struct { () => { - $crate::shape_builder().new_struct(StructType::new_any()) + $crate::PartiqlShapeBuilder::init_or_get().new_struct(StructType::new_any()) }; ($elem:expr) => { - $crate::shape_builder().new_struct(StructType::new($elem)) + $crate::PartiqlShapeBuilder::init_or_get().new_struct(StructType::new($elem)) }; } @@ -161,20 +155,20 @@ macro_rules! struct_fields { #[macro_export] macro_rules! type_bag { () => { - $crate::shape_builder().new_bag(BagType::new_any()); + $crate::PartiqlShapeBuilder::init_or_get().new_bag(BagType::new_any()); }; ($elem:expr) => { - $crate::shape_builder().new_bag(BagType::new(Box::new($elem))) + $crate::PartiqlShapeBuilder::init_or_get().new_bag(BagType::new(Box::new($elem))) }; } #[macro_export] macro_rules! type_array { () => { - $crate::shape_builder().new_array(ArrayType::new_any()); + $crate::PartiqlShapeBuilder::init_or_get().new_array(ArrayType::new_any()); }; ($elem:expr) => { - $crate::shape_builder().new_array(ArrayType::new(Box::new($elem))) + $crate::PartiqlShapeBuilder::init_or_get().new_array(ArrayType::new(Box::new($elem))) }; } @@ -191,31 +185,33 @@ macro_rules! type_undefined { #[macro_export] macro_rules! type_int_with_const_id { () => { - $crate::shape_builder().new_static_with_const_id($crate::Static::Int) + $crate::PartiqlShapeBuilder::init_or_get().new_static_with_const_id($crate::Static::Int) }; } #[macro_export] macro_rules! type_float32_with_const_id { () => { - $crate::shape_builder().new_static_with_const_id($crate::Static::Float32) + $crate::PartiqlShapeBuilder::init_or_get().new_static_with_const_id($crate::Static::Float32) }; } #[macro_export] macro_rules! type_string_with_const_id { () => { - $crate::shape_builder().new_static_with_const_id($crate::Static::String) + $crate::PartiqlShapeBuilder::init_or_get().new_static_with_const_id($crate::Static::String) }; } #[macro_export] macro_rules! type_struct_with_const_id { () => { - $crate::shape_builder().new_static_with_const_id(Static::Struct(StructType::new_any())) + $crate::PartiqlShapeBuilder::init_or_get() + .new_static_with_const_id(Static::Struct(StructType::new_any())) }; ($elem:expr) => { - $crate::shape_builder().new_static_with_const_id(Static::Struct(StructType::new($elem))) + $crate::PartiqlShapeBuilder::init_or_get() + .new_static_with_const_id(Static::Struct(StructType::new($elem))) }; } @@ -238,16 +234,16 @@ impl PartiqlShape { match (self, other) { (PartiqlShape::Dynamic, _) | (_, PartiqlShape::Dynamic) => PartiqlShape::Dynamic, (PartiqlShape::AnyOf(lhs), PartiqlShape::AnyOf(rhs)) => { - shape_builder().any_of(lhs.types.into_iter().chain(rhs.types)) + PartiqlShapeBuilder::init_or_get().any_of(lhs.types.into_iter().chain(rhs.types)) } (PartiqlShape::AnyOf(anyof), other) | (other, PartiqlShape::AnyOf(anyof)) => { let mut types = anyof.types; types.insert(other); - shape_builder().any_of(types) + PartiqlShapeBuilder::init_or_get().any_of(types) } (l, r) => { let types = [l, r]; - shape_builder().any_of(types) + PartiqlShapeBuilder::init_or_get().any_of(types) } } } @@ -442,9 +438,17 @@ pub struct PartiqlShapeBuilder { } impl PartiqlShapeBuilder { + /// A thread-safe method for creating PartiQL shapes with guaranteed uniqueness over + /// generated `NodeId`s. + #[track_caller] + pub fn init_or_get() -> &'static PartiqlShapeBuilder { + static SHAPE_BUILDER: OnceLock = OnceLock::new(); + SHAPE_BUILDER.get_or_init(PartiqlShapeBuilder::default) + } + #[must_use] pub fn new_static(&self, ty: Static) -> PartiqlShape { - let id = self.id_gen.next_id(); + let id = self.id_gen.id(); let id = id.read().expect("NodeId read lock"); PartiqlShape::Static(StaticType { id: *id, @@ -464,7 +468,7 @@ impl PartiqlShapeBuilder { #[must_use] pub fn new_non_nullable_static(&self, ty: Static) -> PartiqlShape { - let id = self.id_gen.next_id(); + let id = self.id_gen.id(); let id = id.read().expect("NodeId read lock"); PartiqlShape::Static(StaticType { id: *id, @@ -560,46 +564,6 @@ impl FromIterator for AnyOf { } } -/// A Builder for [`AstNode`]s that uses a [`NodeIdGenerator`] to assign [`NodeId`]s -pub struct StaticTypeBuilder { - /// Generator for 'fresh' [`NodeId`]s - pub id_gen: IdGen, -} - -impl StaticTypeBuilder -where - IdGen: NodeIdGenerator, -{ - pub fn new(id_gen: IdGen) -> Self { - Self { id_gen } - } - - pub fn nullable_ty(&mut self, ty: Static) -> StaticType { - let id = self.id_gen.id(); - let id = id.read().expect("NodeId read lock"); - self.ty(ty, *id, true) - } - - pub fn non_nullable_ty(&self, ty: Static) -> StaticType { - let id = self.id_gen.id(); - let id = id.read().expect("NodeId read lock"); - self.ty(ty, *id, false) - } - - pub fn ty(&self, ty: Static, id: NodeId, nullable: bool) -> StaticType { - StaticType { id, ty, nullable } - } -} - -impl Default for StaticTypeBuilder -where - T: NodeIdGenerator + Default, -{ - fn default() -> Self { - Self::new(T::default()) - } -} - #[derive(Debug, Clone, Eq, PartialEq, Hash)] pub struct StaticType { id: NodeId, @@ -913,7 +877,8 @@ impl ArrayType { #[cfg(test)] mod tests { use crate::{ - shape_builder, BagType, PartiqlShape, Static, StructConstraint, StructField, StructType, + BagType, PartiqlShape, PartiqlShapeBuilder, Static, StructConstraint, StructField, + StructType, }; use indexmap::IndexSet; @@ -925,25 +890,25 @@ mod tests { type_int_with_const_id!().union_with(type_int_with_const_id!()) ); - let expect_nums = - shape_builder().any_of([type_int_with_const_id!(), type_float32_with_const_id!()]); + let expect_nums = PartiqlShapeBuilder::init_or_get() + .any_of([type_int_with_const_id!(), type_float32_with_const_id!()]); assert_eq!( expect_nums, type_int_with_const_id!().union_with(type_float32_with_const_id!()) ); assert_eq!( expect_nums, - shape_builder().any_of([ + PartiqlShapeBuilder::init_or_get().any_of([ type_int_with_const_id!().union_with(type_float32_with_const_id!()), type_int_with_const_id!().union_with(type_float32_with_const_id!()) ]) ); assert_eq!( expect_nums, - shape_builder().any_of([ + PartiqlShapeBuilder::init_or_get().any_of([ type_int_with_const_id!().union_with(type_float32_with_const_id!()), type_int_with_const_id!().union_with(type_float32_with_const_id!()), - shape_builder().any_of([ + PartiqlShapeBuilder::init_or_get().any_of([ type_int_with_const_id!().union_with(type_float32_with_const_id!()), type_int_with_const_id!().union_with(type_float32_with_const_id!()) ]) From 207b62f87676cb06c36e3287048b5607e86082e1 Mon Sep 17 00:00:00 2001 From: Arash M <27716912+am357@users.noreply.github.com> Date: Wed, 14 Aug 2024 12:18:07 -0700 Subject: [PATCH 3/4] Implement NodeMap with DashMap --- partiql-common/Cargo.toml | 1 + partiql-common/src/lib.rs | 1 + partiql-common/src/node.rs | 106 ++++++++++++++++++++++- partiql-parser/src/parse/partiql.lalrpop | 6 +- partiql-types/src/lib.rs | 2 +- 5 files changed, 110 insertions(+), 6 deletions(-) diff --git a/partiql-common/Cargo.toml b/partiql-common/Cargo.toml index 203a3b71..c7cfa413 100644 --- a/partiql-common/Cargo.toml +++ b/partiql-common/Cargo.toml @@ -25,6 +25,7 @@ pretty = "0.12" serde = { version = "1.*", features = ["derive"], optional = true } smallvec = { version = "1.*" } thiserror = "1.0" +dashmap = "6.0.1" [features] default = [] diff --git a/partiql-common/src/lib.rs b/partiql-common/src/lib.rs index 3f107b89..18891302 100644 --- a/partiql-common/src/lib.rs +++ b/partiql-common/src/lib.rs @@ -1,3 +1,4 @@ +#![feature(mapped_lock_guards)] #![deny(rust_2018_idioms)] #![deny(clippy::all)] pub mod node; diff --git a/partiql-common/src/node.rs b/partiql-common/src/node.rs index d8a1640b..550de051 100644 --- a/partiql-common/src/node.rs +++ b/partiql-common/src/node.rs @@ -1,12 +1,112 @@ +use dashmap::DashMap; use indexmap::IndexMap; use std::hash::Hash; -use std::mem; -use std::sync::{Arc, RwLock}; +use std::sync::{Arc, MappedRwLockReadGuard, Mutex, RwLock, RwLockReadGuard}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -pub type NodeMap = IndexMap; +// #[derive(Debug, Clone)] +// pub struct NodeMap { +// map: Arc>>, +// } +// +// impl Default for NodeMap { +// fn default() -> Self { +// Self::new() +// } +// } +// +// impl NodeMap { +// pub fn new() -> Self { +// NodeMap { +// map: Arc::new(RwLock::new(IndexMap::new())), +// } +// } +// +// pub fn with_capacity(capacity: usize) -> Self { +// NodeMap { +// map: Arc::new(RwLock::new(IndexMap::with_capacity(capacity))), +// } +// } +// +// pub fn insert(&self, node_id: NodeId, value: T) -> Option { +// let mut map = self.map.write().expect("NodeMap write lock"); +// map.insert(node_id, value) +// } +// +// // pub fn get(&self, node_id: &NodeId) -> Option +// // where +// // T: Clone, +// // { +// // let map = self.map.read().expect("NodeMap read lock"); +// // map.get(node_id).cloned() +// // } +// +// pub fn get(&self, node_id: &NodeId) -> Option>> { +// let map = self.map.read().unwrap(); // Acquire a read lock +// Some(RwLockReadGuard::map(map, |m| m.get(node_id))) +// } +// } + +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct NodeMap { + map: Arc>, // DashMap for thread-safe key-value storage + order: Arc>>, // Mutex-protected Vec for maintaining insertion order +} + +impl Default for NodeMap { + fn default() -> Self { + Self::new() + } +} + +impl NodeMap { + // Constructor to create a new, empty NodeMap + pub fn new() -> Self { + NodeMap { + map: Arc::new(DashMap::new()), + order: Arc::new(Mutex::new(Vec::new())), + } + } + + // Constructor to create a NodeMap with a specified capacity + pub fn with_capacity(capacity: usize) -> Self { + NodeMap { + map: Arc::new(DashMap::with_capacity(capacity)), + order: Arc::new(Mutex::new(Vec::with_capacity(capacity))), + } + } + + // The insert method to add a new key-value pair to the map + pub fn insert(&self, node_id: NodeId, value: T) -> Option { + let mut order = self.order.lock().unwrap(); // Acquire a lock to modify the order + if !self.map.contains_key(&node_id) { + order.push(node_id); // Only add to order if it's a new key + } + self.map.insert(node_id, value) + } + + // The get method to retrieve a reference to a value by its NodeId + pub fn get(&self, node_id: &NodeId) -> Option> { + self.map.get(node_id) + } + + // The unwrap_or method to get a value or return a default if not found + pub fn unwrap_or(&self, node_id: &NodeId, default: T) -> T + where + T: Clone, + { + self.map.get(node_id).map(|r| r.clone()).unwrap_or(default) + } + + // Method to retrieve the keys in insertion order + pub fn keys_in_order(&self) -> Vec { + let order = self.order.lock().unwrap(); // Acquire a lock to read the order + order.clone() // Return a cloned version of the order + } +} #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] diff --git a/partiql-parser/src/parse/partiql.lalrpop b/partiql-parser/src/parse/partiql.lalrpop index 8aa04b52..2eb38dad 100644 --- a/partiql-parser/src/parse/partiql.lalrpop +++ b/partiql-parser/src/parse/partiql.lalrpop @@ -287,8 +287,10 @@ FromClause: ast::AstNode = { ast::FromSource::Join(node) => node.id, }; - let start = state.locations.get(&start_id).unwrap_or(&total).start.0.clone(); - let end = state.locations.get(&end_id).unwrap_or(&total).end.0.clone(); + // let start = state.locations.get(&start_id).unwrap_or(&total).start.0.clone(); + let start = state.locations.unwrap_or(&start_id, total.clone()).start.0.clone(); + // let end = state.locations.get(&end_id).unwrap_or(&total).end.0.clone(); + let end = state.locations.unwrap_or(&end_id, total.clone()).end.0.clone(); let range = start..end; let join = state.node(ast::Join { kind: ast::JoinKind::Cross, diff --git a/partiql-types/src/lib.rs b/partiql-types/src/lib.rs index f251795f..92961e1e 100644 --- a/partiql-types/src/lib.rs +++ b/partiql-types/src/lib.rs @@ -5,7 +5,7 @@ use derivative::Derivative; use indexmap::IndexSet; use itertools::Itertools; use miette::Diagnostic; -use partiql_common::node::{AutoNodeIdGenerator, NodeId, NodeIdGenerator}; +use partiql_common::node::{AutoNodeIdGenerator, NodeId, NodeIdGenerator, NodeMap}; use std::collections::HashMap; use std::fmt::{Debug, Display, Formatter}; use std::hash::{Hash, Hasher}; From f8e7773b11b2b093e06a34c1de57bf9a1cd25de8 Mon Sep 17 00:00:00 2001 From: Arash M <27716912+am357@users.noreply.github.com> Date: Wed, 14 Aug 2024 14:27:50 -0700 Subject: [PATCH 4/4] in progress --- partiql-common/Cargo.toml | 3 +- partiql-common/src/node.rs | 88 ++++-------------------- partiql-eval/Cargo.toml | 3 + partiql-eval/src/eval/mod.rs | 1 - partiql-parser/Cargo.toml | 1 + partiql-parser/src/parse/partiql.lalrpop | 8 +-- partiql-types/src/lib.rs | 2 +- 7 files changed, 25 insertions(+), 81 deletions(-) diff --git a/partiql-common/Cargo.toml b/partiql-common/Cargo.toml index c7cfa413..4cf96291 100644 --- a/partiql-common/Cargo.toml +++ b/partiql-common/Cargo.toml @@ -32,5 +32,6 @@ default = [] serde = [ "dep:serde", "indexmap/serde", - "smallvec/serde" + "smallvec/serde", + "dashmap/serde" ] diff --git a/partiql-common/src/node.rs b/partiql-common/src/node.rs index 550de051..bd736f44 100644 --- a/partiql-common/src/node.rs +++ b/partiql-common/src/node.rs @@ -1,59 +1,15 @@ use dashmap::DashMap; -use indexmap::IndexMap; -use std::hash::Hash; -use std::sync::{Arc, MappedRwLockReadGuard, Mutex, RwLock, RwLockReadGuard}; +use std::sync::{Arc, RwLock}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -// #[derive(Debug, Clone)] -// pub struct NodeMap { -// map: Arc>>, -// } -// -// impl Default for NodeMap { -// fn default() -> Self { -// Self::new() -// } -// } -// -// impl NodeMap { -// pub fn new() -> Self { -// NodeMap { -// map: Arc::new(RwLock::new(IndexMap::new())), -// } -// } -// -// pub fn with_capacity(capacity: usize) -> Self { -// NodeMap { -// map: Arc::new(RwLock::new(IndexMap::with_capacity(capacity))), -// } -// } -// -// pub fn insert(&self, node_id: NodeId, value: T) -> Option { -// let mut map = self.map.write().expect("NodeMap write lock"); -// map.insert(node_id, value) -// } -// -// // pub fn get(&self, node_id: &NodeId) -> Option -// // where -// // T: Clone, -// // { -// // let map = self.map.read().expect("NodeMap read lock"); -// // map.get(node_id).cloned() -// // } -// -// pub fn get(&self, node_id: &NodeId) -> Option>> { -// let map = self.map.read().unwrap(); // Acquire a read lock -// Some(RwLockReadGuard::map(map, |m| m.get(node_id))) -// } -// } - #[derive(Debug, Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct NodeMap { - map: Arc>, // DashMap for thread-safe key-value storage - order: Arc>>, // Mutex-protected Vec for maintaining insertion order + map: DashMap, + #[cfg_attr(feature = "serde", serde(skip))] + order: Arc>>, } impl Default for NodeMap { @@ -63,49 +19,33 @@ impl Default for NodeMap { } impl NodeMap { - // Constructor to create a new, empty NodeMap pub fn new() -> Self { NodeMap { - map: Arc::new(DashMap::new()), - order: Arc::new(Mutex::new(Vec::new())), + map: DashMap::new(), + order: Arc::new(RwLock::new(Vec::new())), } } - // Constructor to create a NodeMap with a specified capacity pub fn with_capacity(capacity: usize) -> Self { NodeMap { - map: Arc::new(DashMap::with_capacity(capacity)), - order: Arc::new(Mutex::new(Vec::with_capacity(capacity))), + map: DashMap::with_capacity(capacity), + order: Arc::new(RwLock::new(Vec::with_capacity(capacity))), } } - // The insert method to add a new key-value pair to the map pub fn insert(&self, node_id: NodeId, value: T) -> Option { - let mut order = self.order.lock().unwrap(); // Acquire a lock to modify the order - if !self.map.contains_key(&node_id) { - order.push(node_id); // Only add to order if it's a new key + let mut order = self.order.write().expect("NodeMap order write lock"); + if self.map.contains_key(&node_id) { + self.map.insert(node_id, value) + } else { + order.push(node_id); + self.map.insert(node_id, value) } - self.map.insert(node_id, value) } - // The get method to retrieve a reference to a value by its NodeId pub fn get(&self, node_id: &NodeId) -> Option> { self.map.get(node_id) } - - // The unwrap_or method to get a value or return a default if not found - pub fn unwrap_or(&self, node_id: &NodeId, default: T) -> T - where - T: Clone, - { - self.map.get(node_id).map(|r| r.clone()).unwrap_or(default) - } - - // Method to retrieve the keys in insertion order - pub fn keys_in_order(&self) -> Vec { - let order = self.order.lock().unwrap(); // Acquire a lock to read the order - order.clone() // Return a cloned version of the order - } } #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] diff --git a/partiql-eval/Cargo.toml b/partiql-eval/Cargo.toml index 4384d6c4..1ed152c4 100644 --- a/partiql-eval/Cargo.toml +++ b/partiql-eval/Cargo.toml @@ -41,6 +41,9 @@ delegate = "0.12" [dev-dependencies] criterion = "0.4" +[features] +default = [] + [[bench]] name = "bench_eval" harness = false diff --git a/partiql-eval/src/eval/mod.rs b/partiql-eval/src/eval/mod.rs index 5cfe7dce..d693427d 100644 --- a/partiql-eval/src/eval/mod.rs +++ b/partiql-eval/src/eval/mod.rs @@ -146,7 +146,6 @@ pub type EvalResult = Result; /// Represents result of evaluation as an evaluated entity. #[non_exhaustive] #[derive(Debug)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Evaluated { pub result: Value, } diff --git a/partiql-parser/Cargo.toml b/partiql-parser/Cargo.toml index fa666915..74dbce68 100644 --- a/partiql-parser/Cargo.toml +++ b/partiql-parser/Cargo.toml @@ -36,6 +36,7 @@ bigdecimal = "~0.2.0" rust_decimal = { version = "1.25.0", default-features = false, features = ["std"] } bitflags = "2" +dashmap = "6.0.1" lalrpop-util = "0.20" logos = "0.12" diff --git a/partiql-parser/src/parse/partiql.lalrpop b/partiql-parser/src/parse/partiql.lalrpop index 2eb38dad..c65227b0 100644 --- a/partiql-parser/src/parse/partiql.lalrpop +++ b/partiql-parser/src/parse/partiql.lalrpop @@ -287,10 +287,10 @@ FromClause: ast::AstNode = { ast::FromSource::Join(node) => node.id, }; - // let start = state.locations.get(&start_id).unwrap_or(&total).start.0.clone(); - let start = state.locations.unwrap_or(&start_id, total.clone()).start.0.clone(); - // let end = state.locations.get(&end_id).unwrap_or(&total).end.0.clone(); - let end = state.locations.unwrap_or(&end_id, total.clone()).end.0.clone(); + let start = state.locations.get(&start_id).map_or(total.start.0.clone(), |v| v.start.0.clone()); + // let start = state.locations.get_or(&start_id, &total).start.0.clone(); + let end = state.locations.get(&end_id).map_or(total.end.0.clone(), |v| v.end.0.clone()); + // let end = state.locations.get_or(&end_id, &total).end.0.clone(); let range = start..end; let join = state.node(ast::Join { kind: ast::JoinKind::Cross, diff --git a/partiql-types/src/lib.rs b/partiql-types/src/lib.rs index 92961e1e..f251795f 100644 --- a/partiql-types/src/lib.rs +++ b/partiql-types/src/lib.rs @@ -5,7 +5,7 @@ use derivative::Derivative; use indexmap::IndexSet; use itertools::Itertools; use miette::Diagnostic; -use partiql_common::node::{AutoNodeIdGenerator, NodeId, NodeIdGenerator, NodeMap}; +use partiql_common::node::{AutoNodeIdGenerator, NodeId, NodeIdGenerator}; use std::collections::HashMap; use std::fmt::{Debug, Display, Formatter}; use std::hash::{Hash, Hasher};