From 393e48f98872c696a90fce033fa584533d2326fa Mon Sep 17 00:00:00 2001 From: Will Jones Date: Sat, 18 Nov 2023 08:30:05 -0800 Subject: [PATCH] feat: support arbitrary binaryexpr simplifications (#8256) --- .../src/simplify_expressions/guarantees.rs | 73 +++++++++++-------- 1 file changed, 44 insertions(+), 29 deletions(-) diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index 5504d7d76e35..0204698571b4 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -20,7 +20,7 @@ //! [`ExprSimplifier::with_guarantees()`]: crate::simplify_expressions::expr_simplifier::ExprSimplifier::with_guarantees use datafusion_common::{tree_node::TreeNodeRewriter, DataFusionError, Result}; use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr}; -use std::collections::HashMap; +use std::{borrow::Cow, collections::HashMap}; use datafusion_physical_expr::intervals::{Interval, IntervalBound, NullableInterval}; @@ -103,37 +103,44 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - // We only support comparisons for now - if !op.is_comparison_operator() { - return Ok(expr); - }; - - // Check if this is a comparison between a column and literal - let (col, op, value) = match (left.as_ref(), right.as_ref()) { - (Expr::Column(_), Expr::Literal(value)) => (left, *op, value), - (Expr::Literal(value), Expr::Column(_)) => { - // If we can swap the op, we can simplify the expression - if let Some(op) = op.swap() { - (right, op, value) + // The left or right side of expression might either have a guarantee + // or be a literal. Either way, we can resolve them to a NullableInterval. + let left_interval = self + .guarantees + .get(left.as_ref()) + .map(|interval| Cow::Borrowed(*interval)) + .or_else(|| { + if let Expr::Literal(value) = left.as_ref() { + Some(Cow::Owned(value.clone().into())) } else { - return Ok(expr); + None + } + }); + let right_interval = self + .guarantees + .get(right.as_ref()) + .map(|interval| Cow::Borrowed(*interval)) + .or_else(|| { + if let Expr::Literal(value) = right.as_ref() { + Some(Cow::Owned(value.clone().into())) + } else { + None + } + }); + + match (left_interval, right_interval) { + (Some(left_interval), Some(right_interval)) => { + let result = + left_interval.apply_operator(op, right_interval.as_ref())?; + if result.is_certainly_true() { + Ok(lit(true)) + } else if result.is_certainly_false() { + Ok(lit(false)) + } else { + Ok(expr) } } - _ => return Ok(expr), - }; - - if let Some(col_interval) = self.guarantees.get(col.as_ref()) { - let result = - col_interval.apply_operator(&op, &value.clone().into())?; - if result.is_certainly_true() { - Ok(lit(true)) - } else if result.is_certainly_false() { - Ok(lit(false)) - } else { - Ok(expr) - } - } else { - Ok(expr) + _ => Ok(expr), } } @@ -262,6 +269,13 @@ mod tests { values: Interval::make(Some(1_i32), Some(3_i32), (true, false)), }, ), + // s.y ∈ (1, 3] (not null) + ( + col("s").field("y"), + NullableInterval::NotNull { + values: Interval::make(Some(1_i32), Some(3_i32), (true, false)), + }, + ), ]; let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); @@ -269,6 +283,7 @@ mod tests { // (original_expr, expected_simplification) let simplified_cases = &[ (col("x").lt_eq(lit(1)), false), + (col("s").field("y").lt_eq(lit(1)), false), (col("x").lt_eq(lit(3)), true), (col("x").gt(lit(3)), false), (col("x").gt(lit(1)), true),