From ded2c9510bf839702db804fa3c8b6e1fffeccba2 Mon Sep 17 00:00:00 2001 From: Kould <2435992353@qq.com> Date: Sun, 24 Sep 2023 19:13:22 +0800 Subject: [PATCH] fix: check or in `c1 > c2 or c1 > 1` --- src/expression/simplify.rs | 48 +++++++++++++++++----------- src/optimizer/rule/simplification.rs | 31 ++++++++++++++++++ 2 files changed, 60 insertions(+), 19 deletions(-) diff --git a/src/expression/simplify.rs b/src/expression/simplify.rs index c5efef39..a4dc9dd1 100644 --- a/src/expression/simplify.rs +++ b/src/expression/simplify.rs @@ -286,23 +286,11 @@ impl ScalarExpression { } } - fn unpack_col(&self, is_binary_then_return: bool) -> Option { + fn unpack_col(&self) -> Option { match self { ScalarExpression::ColumnRef(col) => Some(col.clone()), - ScalarExpression::Alias { expr, .. } => expr.unpack_col(is_binary_then_return), - ScalarExpression::Binary { left_expr, right_expr, .. } => { - if is_binary_then_return { - return None; - } - - match (left_expr.unpack_col(is_binary_then_return), - right_expr.unpack_col(is_binary_then_return)) - { - (Some(col), None) | (None, Some(col)) => Some(col), - _ => None - } - } - ScalarExpression::Unary { expr, .. } => expr.unpack_col(is_binary_then_return), + ScalarExpression::Alias { expr, .. } => expr.unpack_col(), + ScalarExpression::Unary { expr, .. } => expr.unpack_col(), _ => None } } @@ -323,7 +311,7 @@ impl ScalarExpression { if matches!(op, BinaryOperator::Plus | BinaryOperator::Divide | BinaryOperator::Minus | BinaryOperator::Multiply) { - match (left_expr.unpack_col(true), right_expr.unpack_col(true)) { + match (left_expr.unpack_col(), right_expr.unpack_col()) { (Some(_), Some(_)) => (), (Some(col), None) => { fix_option.replace(Replace::Binary(ReplaceBinary{ @@ -539,19 +527,20 @@ impl ScalarExpression { }, (None, None) => { if let (Some(col), Some(val)) = - (left_expr.unpack_col(false), right_expr.unpack_val()) + (left_expr.unpack_col(), right_expr.unpack_val()) { return Ok(Self::new_binary(col_id, *op, col, val, false)); } if let (Some(val), Some(col)) = - (left_expr.unpack_val(), right_expr.unpack_col(false)) + (left_expr.unpack_val(), right_expr.unpack_col()) { return Ok(Self::new_binary(col_id, *op, col, val, true)); } return Ok(None); } - (Some(binary), None) | (None, Some(binary)) => return Ok(Some(binary)), + (Some(binary), None) => Ok(Self::check_or(col_id, right_expr, op, binary)), + (None, Some(binary)) => Ok(Self::check_or(col_id, left_expr, op, binary)), } }, ScalarExpression::Alias { expr, .. } => expr.convert_binary(col_id), @@ -562,6 +551,27 @@ impl ScalarExpression { } } + /// check if: c1 > c2 or c1 > 1 + /// this case it makes no sense to just extract c1 > 1 + fn check_or( + col_id: &ColumnId, + right_expr: &Box, + op: &BinaryOperator, + binary: ConstantBinary + ) -> Option { + let check_func = |expr: &ScalarExpression, col_id: &ColumnId| { + expr.referenced_columns() + .iter() + .find(|col| col.id == Some(*col_id)) + .is_some() + }; + if matches!(op, BinaryOperator::Or) && check_func(right_expr, col_id) { + return None + } + + Some(binary) + } + fn new_binary(col_id: &ColumnId, mut op: BinaryOperator, col: ColumnRef, val: ValueRef, is_flip: bool) -> Option { if col.id.unwrap() != *col_id { return None; diff --git a/src/optimizer/rule/simplification.rs b/src/optimizer/rule/simplification.rs index edff14f0..796668b7 100644 --- a/src/optimizer/rule/simplification.rs +++ b/src/optimizer/rule/simplification.rs @@ -255,4 +255,35 @@ mod test { Ok(()) } + + #[tokio::test] + async fn test_simplify_filter_multiple_column_in_or() -> Result<(), DatabaseError> { + // c1 + 1 < -1 => c1 < -2 + let plan_1 = select_sql_run("select * from t1 where c1 > c2 or c1 > 1").await?; + + let op = |plan: LogicalPlan, expr: &str| -> Result, DatabaseError> { + let best_plan = HepOptimizer::new(plan.clone()) + .batch( + "test_simplify_filter".to_string(), + HepBatchStrategy::once_topdown(), + vec![RuleImpl::SimplifyFilter] + ) + .find_best()?; + if let Operator::Filter(filter_op) = best_plan.childrens[0].clone().operator { + println!("{expr}: {:#?}", filter_op); + + Ok(Some(filter_op)) + } else { + Ok(None) + } + }; + + let op_1 = op(plan_1, "c1 > c2 or c1 > 1")?.unwrap(); + + let cb_1_c1 = op_1.predicate.convert_binary(&0).unwrap(); + println!("op_1 => c1: {:#?}", cb_1_c1); + assert_eq!(cb_1_c1, None); + + Ok(()) + } } \ No newline at end of file