From f94d5259c90bc5c8cd921b30975b850b0ef7c324 Mon Sep 17 00:00:00 2001 From: Josh Pschorr Date: Tue, 14 Jan 2025 15:37:53 -0800 Subject: [PATCH] Refactor equality as per spec. `eqg` and add testing equality ops --- partiql-conformance-tests/tests/mod.rs | 4 +- partiql-conformance-tests/tests/test_value.rs | 29 +++++++-- partiql-eval/src/eval/expr/operators.rs | 4 +- partiql-value/src/bag.rs | 36 +++++++++-- partiql-value/src/comparison.rs | 64 ++++++++++++++----- partiql-value/src/lib.rs | 8 +-- partiql-value/src/list.rs | 28 ++++++-- partiql-value/src/sort.rs | 7 +- partiql-value/src/tuple.rs | 32 +++++++--- partiql-value/src/value.rs | 25 ++++++-- 10 files changed, 182 insertions(+), 55 deletions(-) diff --git a/partiql-conformance-tests/tests/mod.rs b/partiql-conformance-tests/tests/mod.rs index 26fe2826..c731070f 100644 --- a/partiql-conformance-tests/tests/mod.rs +++ b/partiql-conformance-tests/tests/mod.rs @@ -210,7 +210,9 @@ pub(crate) fn pass_eval( expected: &TestValue, ) { match eval(statement, mode, env) { - Ok(v) => assert_eq!(v.result, expected.value), + Ok(v) => { + assert_eq!(&TestValue::from(v), expected) + }, Err(TestError::Parse(_)) => { panic!("When evaluating (mode = {mode:#?}) `{statement}`, unexpected parse error") } diff --git a/partiql-conformance-tests/tests/test_value.rs b/partiql-conformance-tests/tests/test_value.rs index f46d46cf..8884dab6 100644 --- a/partiql-conformance-tests/tests/test_value.rs +++ b/partiql-conformance-tests/tests/test_value.rs @@ -1,18 +1,39 @@ -use partiql_value::Value; +use partiql_eval::eval::Evaluated; +use partiql_value::{EqualityValue, NullableEq, Value}; use partiql_extension_ion::decode::{IonDecoderBuilder, IonDecoderConfig}; use partiql_extension_ion::Encoding; #[allow(dead_code)] +#[derive(Debug, Ord, PartialOrd)] pub(crate) struct TestValue { pub value: Value, } +impl Eq for TestValue {} + +impl PartialEq for TestValue { + fn eq(&self, other: &Self) -> bool { + let wrap_value = EqualityValue::<'_, true, true, Value>; + NullableEq::eq(&wrap_value(&self.value), &wrap_value(&other.value)) == Value::Boolean(true) + } +} + +impl From for TestValue { + fn from(value: Value) -> Self { + TestValue { value } + } +} + +impl From for TestValue { + fn from(value: Evaluated) -> Self { + value.result.into() + } +} + impl From<&str> for TestValue { fn from(contents: &str) -> Self { - TestValue { - value: parse_test_value_str(contents), - } + parse_test_value_str(contents).into() } } diff --git a/partiql-eval/src/eval/expr/operators.rs b/partiql-eval/src/eval/expr/operators.rs index 2737a753..0639aaa0 100644 --- a/partiql-eval/src/eval/expr/operators.rs +++ b/partiql-eval/src/eval/expr/operators.rs @@ -249,11 +249,11 @@ impl BindEvalExpr for EvalOpBinary { EvalOpBinary::And => logical!(AndCheck, partiql_value::BinaryAnd::and), EvalOpBinary::Or => logical!(OrCheck, partiql_value::BinaryOr::or), EvalOpBinary::Eq => equality!(|lhs, rhs| { - let wrap = EqualityValue::; + let wrap = EqualityValue::; NullableEq::eq(&wrap(lhs), &wrap(rhs)) }), EvalOpBinary::Neq => equality!(|lhs, rhs| { - let wrap = EqualityValue::; + let wrap = EqualityValue::; NullableEq::neq(&wrap(lhs), &wrap(rhs)) }), EvalOpBinary::Gt => comparison!(NullableOrd::gt), diff --git a/partiql-value/src/bag.rs b/partiql-value/src/bag.rs index 9fc7ca68..259aa20f 100644 --- a/partiql-value/src/bag.rs +++ b/partiql-value/src/bag.rs @@ -178,16 +178,38 @@ impl Debug for Bag { impl PartialEq for Bag { fn eq(&self, other: &Self) -> bool { - if self.len() != other.len() { - return false; + let wrap = EqualityValue::; + NullableEq::eq(&wrap(self), &wrap(other)) == Value::Boolean(true) + } +} + +impl NullableEq + for EqualityValue<'_, NULLS_EQUAL, NAN_EQUAL, Bag> +{ + #[inline(always)] + fn eq(&self, other: &Self) -> Value { + let ord_wrap = NullSortedValue::<'_, false, _>; + let (l, r) = (self.0, other.0); + if l.len() != r.len() { + return Value::Boolean(false); } - for (v1, v2) in self.0.iter().sorted().zip(other.0.iter().sorted()) { - let wrap = EqualityValue::; - if NullableEq::eq(&wrap(v1), &wrap(v2)) != Value::Boolean(true) { - return false; + + let li = l.iter().map(ord_wrap).sorted().map(|nsv| nsv.0); + let ri = r.iter().map(ord_wrap).sorted().map(|nsv| nsv.0); + + for (v1, v2) in li.zip(ri) { + let wrap = EqualityValue::<{ NULLS_EQUAL }, { NAN_EQUAL }, Value>; + if NullableEq::eqg(&wrap(v1), &wrap(v2)) != Value::Boolean(true) { + return Value::Boolean(false); } } - true + Value::Boolean(true) + } + + #[inline(always)] + fn eqg(&self, rhs: &Self) -> Value { + let wrap = EqualityValue::<'_, true, { NAN_EQUAL }, _>; + NullableEq::eq(&wrap(self.0), &wrap(rhs.0)) } } diff --git a/partiql-value/src/comparison.rs b/partiql-value/src/comparison.rs index d37b5ae4..0640140e 100644 --- a/partiql-value/src/comparison.rs +++ b/partiql-value/src/comparison.rs @@ -1,5 +1,5 @@ -use crate::util; use crate::Value; +use crate::{util, Bag, List, Tuple}; pub trait Comparable { fn is_comparable_to(&self, rhs: &Self) -> bool; @@ -16,6 +16,7 @@ impl Comparable for Value { | (Value::Boolean(_), Value::Boolean(_)) | (Value::String(_), Value::String(_)) | (Value::Blob(_), Value::Blob(_)) + | (Value::DateTime(_), Value::DateTime(_)) | (Value::List(_), Value::List(_)) | (Value::Bag(_), Value::Bag(_)) | (Value::Tuple(_), Value::Tuple(_)) @@ -31,19 +32,43 @@ impl Comparable for Value { // `Value` `eq` and `neq` with Missing and Null propagation pub trait NullableEq { - type Output; - fn eq(&self, rhs: &Self) -> Self::Output; - fn neq(&self, rhs: &Self) -> Self::Output; + fn eq(&self, rhs: &Self) -> Value; + + fn neq(&self, rhs: &Self) -> Value { + let eq_result = NullableEq::eq(self, rhs); + match eq_result { + Value::Boolean(_) | Value::Null => !eq_result, + _ => Value::Missing, + } + } + + /// `PartiQL's `eqg` is used to compare the internals of Lists, Bags, and Tuples. + /// + /// > The eqg, unlike the =, returns true when a NULL is compared to a NULL or a MISSING + /// > to a MISSING + fn eqg(&self, rhs: &Self) -> Value; + + fn neqg(&self, rhs: &Self) -> Value { + let eqg_result = NullableEq::eqg(self, rhs); + match eqg_result { + Value::Boolean(_) | Value::Null => !eqg_result, + _ => Value::Missing, + } + } } /// A wrapper on [`T`] that specifies if missing and null values should be equal. -#[derive(Eq, PartialEq)] -pub struct EqualityValue<'a, const NULLS_EQUAL: bool, T>(pub &'a T); - -impl NullableEq for EqualityValue<'_, GROUP_NULLS, Value> { - type Output = Value; +#[derive(Eq, PartialEq, Debug)] +pub struct EqualityValue<'a, const NULLS_EQUAL: bool, const NAN_EQUAL: bool, T>(pub &'a T); - fn eq(&self, rhs: &Self) -> Self::Output { +impl NullableEq + for EqualityValue<'_, GROUP_NULLS, NAN_EQUAL, Value> +{ + #[inline(always)] + fn eq(&self, rhs: &Self) -> Value { + let wrap_list = EqualityValue::<'_, { GROUP_NULLS }, { NAN_EQUAL }, List>; + let wrap_bag = EqualityValue::<'_, { GROUP_NULLS }, { NAN_EQUAL }, Bag>; + let wrap_tuple = EqualityValue::<'_, { GROUP_NULLS }, { NAN_EQUAL }, Tuple>; if GROUP_NULLS { if let (Value::Missing | Value::Null, Value::Missing | Value::Null) = (self.0, rhs.0) { return Value::Boolean(true); @@ -73,16 +98,23 @@ impl NullableEq for EqualityValue<'_, GROUP_NULLS, Valu (Value::Decimal(_), Value::Real(_)) => { Value::from(self.0 == &util::coerce_int_or_real_to_decimal(rhs.0)) } + (Value::Real(l), Value::Real(r)) => { + if NAN_EQUAL && l.is_nan() && r.is_nan() { + return Value::Boolean(true); + } + Value::from(l == r) + } + (Value::List(l), Value::List(r)) => NullableEq::eq(&wrap_list(l), &wrap_list(r)), + (Value::Bag(l), Value::Bag(r)) => NullableEq::eq(&wrap_bag(l), &wrap_bag(r)), + (Value::Tuple(l), Value::Tuple(r)) => NullableEq::eq(&wrap_tuple(l), &wrap_tuple(r)), (_, _) => Value::from(self.0 == rhs.0), } } - fn neq(&self, rhs: &Self) -> Self::Output { - let eq_result = NullableEq::eq(self, rhs); - match eq_result { - Value::Boolean(_) | Value::Null => !eq_result, - _ => Value::Missing, - } + #[inline(always)] + fn eqg(&self, rhs: &Self) -> Value { + let wrap = EqualityValue::<'_, true, { NAN_EQUAL }, _>; + NullableEq::eq(&wrap(self.0), &wrap(rhs.0)) } } diff --git a/partiql-value/src/lib.rs b/partiql-value/src/lib.rs index 677a6f14..990e963c 100644 --- a/partiql-value/src/lib.rs +++ b/partiql-value/src/lib.rs @@ -88,7 +88,7 @@ mod tests { #[test] fn iterators() { - let bag: Bag = [1, 10, 3, 4].iter().collect(); + let bag: Bag = [1, 10, 3, 4].into_iter().collect(); assert_eq!(bag.len(), 4); let max = bag .iter() @@ -96,7 +96,7 @@ mod tests { assert_eq!(max, Value::Integer(10)); let _bref = Value::from(bag).as_bag_ref(); - let list: List = [1, 2, 3, -4].iter().collect(); + let list: List = [1, 2, 3, -4].into_iter().collect(); assert_eq!(list.len(), 4); let max = list .iter() @@ -445,14 +445,14 @@ mod tests { // tests fn nullable_eq(lhs: Value, rhs: Value) -> Value { - let wrap = EqualityValue::; + let wrap = EqualityValue::; let lhs = wrap(&lhs); let rhs = wrap(&rhs); NullableEq::eq(&lhs, &rhs) } fn nullable_neq(lhs: Value, rhs: Value) -> Value { - let wrap = EqualityValue::; + let wrap = EqualityValue::; let lhs = wrap(&lhs); let rhs = wrap(&rhs); NullableEq::neq(&lhs, &rhs) diff --git a/partiql-value/src/list.rs b/partiql-value/src/list.rs index 224b16e6..f4c4ef62 100644 --- a/partiql-value/src/list.rs +++ b/partiql-value/src/list.rs @@ -175,16 +175,32 @@ impl Debug for List { impl PartialEq for List { fn eq(&self, other: &Self) -> bool { - if self.len() != other.len() { - return false; + let wrap = EqualityValue::; + NullableEq::eq(&wrap(self), &wrap(other)) == Value::Boolean(true) + } +} + +impl NullableEq + for EqualityValue<'_, NULLS_EQUAL, NAN_EQUAL, List> +{ + #[inline(always)] + fn eq(&self, other: &Self) -> Value { + if self.0.len() != other.0.len() { + return Value::Boolean(false); } for (v1, v2) in self.0.iter().zip(other.0.iter()) { - let wrap = EqualityValue::; - if NullableEq::eq(&wrap(v1), &wrap(v2)) != Value::Boolean(true) { - return false; + let wrap = EqualityValue::<{ NULLS_EQUAL }, { NAN_EQUAL }, Value>; + if NullableEq::eqg(&wrap(v1), &wrap(v2)) != Value::Boolean(true) { + return Value::Boolean(false); } } - true + Value::Boolean(true) + } + + #[inline(always)] + fn eqg(&self, rhs: &Self) -> Value { + let wrap = EqualityValue::<'_, true, { NAN_EQUAL }, _>; + NullableEq::eq(&wrap(self.0), &wrap(rhs.0)) } } diff --git a/partiql-value/src/sort.rs b/partiql-value/src/sort.rs index 96faa0cf..6499937b 100644 --- a/partiql-value/src/sort.rs +++ b/partiql-value/src/sort.rs @@ -18,9 +18,10 @@ where impl Ord for NullSortedValue<'_, NULLS_FIRST, Value> { fn cmp(&self, other: &Self) -> Ordering { - let wrap_list = NullSortedValue::<{ NULLS_FIRST }, List>; - let wrap_tuple = NullSortedValue::<{ NULLS_FIRST }, Tuple>; - let wrap_bag = NullSortedValue::<{ NULLS_FIRST }, Bag>; + let wrap_value = NullSortedValue::<'_, { NULLS_FIRST }, Value>; + let wrap_list = NullSortedValue::<'_, { NULLS_FIRST }, List>; + let wrap_tuple = NullSortedValue::<'_, { NULLS_FIRST }, Tuple>; + let wrap_bag = NullSortedValue::<'_, { NULLS_FIRST }, Bag>; let null_cond = |order: Ordering| { if NULLS_FIRST { order diff --git a/partiql-value/src/tuple.rs b/partiql-value/src/tuple.rs index 337784be..aa27ad24 100644 --- a/partiql-value/src/tuple.rs +++ b/partiql-value/src/tuple.rs @@ -213,19 +213,35 @@ impl Iterator for Tuple { impl PartialEq for Tuple { fn eq(&self, other: &Self) -> bool { - if self.vals.len() != other.vals.len() { - return false; + let wrap = EqualityValue::; + NullableEq::eq(&wrap(self), &wrap(other)) == Value::Boolean(true) + } +} + +impl NullableEq + for EqualityValue<'_, NULLS_EQUAL, NAN_EQUAL, Tuple> +{ + #[inline(always)] + fn eq(&self, other: &Self) -> Value { + if self.0.vals.len() != other.0.vals.len() { + return Value::Boolean(false); } - for ((ls, lv), (rs, rv)) in self.pairs().sorted().zip(other.pairs().sorted()) { + for ((ls, lv), (rs, rv)) in self.0.pairs().sorted().zip(other.0.pairs().sorted()) { if ls != rs { - return false; + return Value::Boolean(false); } - let wrap = EqualityValue::; - if NullableEq::eq(&wrap(lv), &wrap(rv)) != Value::Boolean(true) { - return false; + let wrap = EqualityValue::<{ NULLS_EQUAL }, { NAN_EQUAL }, Value>; + if NullableEq::eqg(&wrap(lv), &wrap(rv)) != Value::Boolean(true) { + return Value::Boolean(false); } } - true + Value::Boolean(true) + } + + #[inline(always)] + fn eqg(&self, rhs: &Self) -> Value { + let wrap = EqualityValue::<'_, true, { NAN_EQUAL }, _>; + NullableEq::eq(&wrap(self.0), &wrap(rhs.0)) } } diff --git a/partiql-value/src/value.rs b/partiql-value/src/value.rs index 8b020bb3..1940230e 100644 --- a/partiql-value/src/value.rs +++ b/partiql-value/src/value.rs @@ -378,6 +378,13 @@ impl From<&str> for Value { } } +impl From for Value { + #[inline] + fn from(n: i128) -> Self { + Value::from(RustDecimal::from(n)) + } +} + impl From for Value { #[inline] fn from(n: i64) -> Self { @@ -409,8 +416,11 @@ impl From for Value { impl From for Value { #[inline] fn from(n: usize) -> Self { - // TODO overflow to bigint/decimal - Value::Integer(n as i64) + if n > i64::MAX as usize { + Value::from(RustDecimal::from(n)) + } else { + Value::Integer(n as i64) + } } } @@ -445,14 +455,21 @@ impl From for Value { impl From for Value { #[inline] fn from(n: u128) -> Self { - (n as usize).into() + Value::from(RustDecimal::from(n)) } } impl From for Value { #[inline] fn from(f: f64) -> Self { - Value::Real(OrderedFloat(f)) + Value::from(OrderedFloat(f)) + } +} + +impl From> for Value { + #[inline] + fn from(f: OrderedFloat) -> Self { + Value::Real(f) } }