From f6897d72ab8af1c30ffb29f5aebf327d611d0a03 Mon Sep 17 00:00:00 2001 From: Kould <2435992353@qq.com> Date: Tue, 30 Jan 2024 18:08:07 +0800 Subject: [PATCH] fix: `is [not] null/[not] in` can be overridden to extract ConstantBinary normally --- src/expression/simplify.rs | 49 +++++++++++- src/optimizer/rule/implementation/dql/scan.rs | 1 + .../rule/normalization/simplification.rs | 75 +++++++++++++++++++ src/types/value.rs | 8 ++ 4 files changed, 131 insertions(+), 2 deletions(-) diff --git a/src/expression/simplify.rs b/src/expression/simplify.rs index 07008ca7..601e2e45 100644 --- a/src/expression/simplify.rs +++ b/src/expression/simplify.rs @@ -2,7 +2,7 @@ use crate::catalog::ColumnRef; use crate::expression::value_compute::{binary_op, unary_op}; use crate::expression::{BinaryOperator, ScalarExpression, UnaryOperator}; use crate::types::errors::TypeError; -use crate::types::value::{DataValue, ValueRef}; +use crate::types::value::{DataValue, ValueRef, FALSE_VALUE, NULL_VALUE, TRUE_VALUE}; use crate::types::{ColumnId, LogicalType}; use ahash::RandomState; use itertools::Itertools; @@ -611,6 +611,39 @@ impl ScalarExpression { })); } } + ScalarExpression::In { + expr, + negated, + args, + } => { + let (op_1, op_2, value) = if *negated { + ( + BinaryOperator::NotEq, + BinaryOperator::And, + TRUE_VALUE.clone(), + ) + } else { + (BinaryOperator::Eq, BinaryOperator::Or, FALSE_VALUE.clone()) + }; + + let mut new_expr = ScalarExpression::Constant(value); + + for arg in args.drain(..) { + new_expr = ScalarExpression::Binary { + op: op_2.clone(), + left_expr: Box::new(ScalarExpression::Binary { + op: op_1.clone(), + left_expr: expr.clone(), + right_expr: Box::new(arg), + ty: LogicalType::Boolean, + }), + right_expr: Box::new(new_expr), + ty: LogicalType::Boolean, + } + } + + let _ = mem::replace(self, new_expr); + } _ => (), } @@ -834,9 +867,21 @@ impl ScalarExpression { } ScalarExpression::Alias { expr, .. } | ScalarExpression::TypeCast { expr, .. } - | ScalarExpression::IsNull { expr, .. } | ScalarExpression::Unary { expr, .. } | ScalarExpression::In { expr, .. } => expr.convert_binary(col_id), + ScalarExpression::IsNull { expr, .. } => match expr.as_ref() { + ScalarExpression::ColumnRef(column) => Ok( + (column.id() == column.id()).then(|| ConstantBinary::Eq(NULL_VALUE.clone())) + ), + ScalarExpression::Constant(_) + | ScalarExpression::Alias { .. } + | ScalarExpression::TypeCast { .. } + | ScalarExpression::IsNull { .. } + | ScalarExpression::Unary { .. } + | ScalarExpression::Binary { .. } + | ScalarExpression::AggCall { .. } + | ScalarExpression::In { .. } => expr.convert_binary(col_id), + }, ScalarExpression::Constant(_) | ScalarExpression::ColumnRef(_) | ScalarExpression::AggCall { .. } => Ok(None), diff --git a/src/optimizer/rule/implementation/dql/scan.rs b/src/optimizer/rule/implementation/dql/scan.rs index 00f1631d..7ab96b24 100644 --- a/src/optimizer/rule/implementation/dql/scan.rs +++ b/src/optimizer/rule/implementation/dql/scan.rs @@ -85,6 +85,7 @@ impl ImplementationRule for IndexScanImplementation { cost = Some(histogram.collect_count(binaries) * 2); } } + assert!(!matches!(cost, Some(0))); group_expr.append_expr(Expression { op: PhysicalOption::IndexScan(index_info.clone()), diff --git a/src/optimizer/rule/normalization/simplification.rs b/src/optimizer/rule/normalization/simplification.rs index 19e92f3e..8e008ad1 100644 --- a/src/optimizer/rule/normalization/simplification.rs +++ b/src/optimizer/rule/normalization/simplification.rs @@ -527,6 +527,81 @@ mod test { let cb_1_c1 = op_1.predicate.convert_binary(&0).unwrap(); println!("op_1 => c1: {:#?}", cb_1_c1); + assert_eq!(cb_1_c1, Some(ConstantBinary::Eq(Arc::new(DataValue::Null)))); + + Ok(()) + } + + #[tokio::test] + async fn test_simplify_filter_column_in() -> Result<(), DatabaseError> { + let plan_1 = select_sql_run("select * from t1 where c1 in (1, 2, 3)").await?; + + let op = |plan: LogicalPlan, expr: &str| -> Result, DatabaseError> { + let best_plan = HepOptimizer::new(plan.clone()) + .batch( + "test_simplify_filter".to_string(), + HepBatchStrategy::once_topdown(), + vec![NormalizationRuleImpl::SimplifyFilter], + ) + .find_best::(None)?; + if let Operator::Filter(filter_op) = best_plan.childrens[0].clone().operator { + println!("{expr}: {:#?}", filter_op); + + Ok(Some(filter_op)) + } else { + Ok(None) + } + }; + + let op_1 = op(plan_1, "c1 in (1, 2, 3)")?.unwrap(); + + let cb_1_c1 = op_1.predicate.convert_binary(&0).unwrap(); + println!("op_1 => c1: {:#?}", cb_1_c1); + assert_eq!( + cb_1_c1, + Some(ConstantBinary::Or(vec![ + ConstantBinary::Eq(Arc::new(DataValue::Int32(Some(2)))), + ConstantBinary::Eq(Arc::new(DataValue::Int32(Some(1)))), + ConstantBinary::Eq(Arc::new(DataValue::Int32(Some(3)))), + ])) + ); + + Ok(()) + } + + #[tokio::test] + async fn test_simplify_filter_column_not_in() -> Result<(), DatabaseError> { + let plan_1 = select_sql_run("select * from t1 where c1 not in (1, 2, 3)").await?; + + let op = |plan: LogicalPlan, expr: &str| -> Result, DatabaseError> { + let best_plan = HepOptimizer::new(plan.clone()) + .batch( + "test_simplify_filter".to_string(), + HepBatchStrategy::once_topdown(), + vec![NormalizationRuleImpl::SimplifyFilter], + ) + .find_best::(None)?; + if let Operator::Filter(filter_op) = best_plan.childrens[0].clone().operator { + println!("{expr}: {:#?}", filter_op); + + Ok(Some(filter_op)) + } else { + Ok(None) + } + }; + + let op_1 = op(plan_1, "c1 not in (1, 2, 3)")?.unwrap(); + + let cb_1_c1 = op_1.predicate.convert_binary(&0).unwrap(); + println!("op_1 => c1: {:#?}", cb_1_c1); + assert_eq!( + cb_1_c1, + Some(ConstantBinary::And(vec![ + ConstantBinary::NotEq(Arc::new(DataValue::Int32(Some(2)))), + ConstantBinary::NotEq(Arc::new(DataValue::Int32(Some(1)))), + ConstantBinary::NotEq(Arc::new(DataValue::Int32(Some(3)))), + ])) + ); Ok(()) } diff --git a/src/types/value.rs b/src/types/value.rs index 720bdaf7..0630b8b9 100644 --- a/src/types/value.rs +++ b/src/types/value.rs @@ -18,6 +18,9 @@ use serde::{Deserialize, Serialize}; use super::LogicalType; lazy_static! { + pub static ref NULL_VALUE: ValueRef = Arc::new(DataValue::Null); + pub static ref FALSE_VALUE: ValueRef = Arc::new(DataValue::Boolean(Some(false))); + pub static ref TRUE_VALUE: ValueRef = Arc::new(DataValue::Boolean(Some(true))); static ref UNIX_DATETIME: NaiveDateTime = NaiveDateTime::from_timestamp_opt(0, 0).unwrap(); } @@ -86,6 +89,11 @@ generate_get_option!(DataValue, impl PartialEq for DataValue { fn eq(&self, other: &Self) -> bool { use DataValue::*; + + if self.is_null() && other.is_null() { + return true; + } + match (self, other) { (Boolean(v1), Boolean(v2)) => v1.eq(v2), (Boolean(_), _) => false,