Skip to content

Commit

Permalink
fix: check or in c1 > c2 or c1 > 1
Browse files Browse the repository at this point in the history
  • Loading branch information
KKould committed Sep 24, 2023
1 parent 654e75e commit ded2c95
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 19 deletions.
48 changes: 29 additions & 19 deletions src/expression/simplify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,23 +286,11 @@ impl ScalarExpression {
}
}

fn unpack_col(&self, is_binary_then_return: bool) -> Option<ColumnRef> {
fn unpack_col(&self) -> Option<ColumnRef> {
match self {
ScalarExpression::ColumnRef(col) => Some(col.clone()),
ScalarExpression::Alias { expr, .. } => expr.unpack_col(is_binary_then_return),
ScalarExpression::Binary { left_expr, right_expr, .. } => {
if is_binary_then_return {
return None;
}

match (left_expr.unpack_col(is_binary_then_return),
right_expr.unpack_col(is_binary_then_return))
{
(Some(col), None) | (None, Some(col)) => Some(col),
_ => None
}
}
ScalarExpression::Unary { expr, .. } => expr.unpack_col(is_binary_then_return),
ScalarExpression::Alias { expr, .. } => expr.unpack_col(),
ScalarExpression::Unary { expr, .. } => expr.unpack_col(),
_ => None
}
}
Expand All @@ -323,7 +311,7 @@ impl ScalarExpression {
if matches!(op, BinaryOperator::Plus | BinaryOperator::Divide
| BinaryOperator::Minus | BinaryOperator::Multiply)
{
match (left_expr.unpack_col(true), right_expr.unpack_col(true)) {
match (left_expr.unpack_col(), right_expr.unpack_col()) {
(Some(_), Some(_)) => (),
(Some(col), None) => {
fix_option.replace(Replace::Binary(ReplaceBinary{
Expand Down Expand Up @@ -539,19 +527,20 @@ impl ScalarExpression {
},
(None, None) => {
if let (Some(col), Some(val)) =
(left_expr.unpack_col(false), right_expr.unpack_val())
(left_expr.unpack_col(), right_expr.unpack_val())
{
return Ok(Self::new_binary(col_id, *op, col, val, false));
}
if let (Some(val), Some(col)) =
(left_expr.unpack_val(), right_expr.unpack_col(false))
(left_expr.unpack_val(), right_expr.unpack_col())
{
return Ok(Self::new_binary(col_id, *op, col, val, true));
}

return Ok(None);
}
(Some(binary), None) | (None, Some(binary)) => return Ok(Some(binary)),
(Some(binary), None) => Ok(Self::check_or(col_id, right_expr, op, binary)),
(None, Some(binary)) => Ok(Self::check_or(col_id, left_expr, op, binary)),
}
},
ScalarExpression::Alias { expr, .. } => expr.convert_binary(col_id),
Expand All @@ -562,6 +551,27 @@ impl ScalarExpression {
}
}

/// check if: c1 > c2 or c1 > 1
/// this case it makes no sense to just extract c1 > 1
fn check_or(
col_id: &ColumnId,
right_expr: &Box<ScalarExpression>,
op: &BinaryOperator,
binary: ConstantBinary
) -> Option<ConstantBinary> {
let check_func = |expr: &ScalarExpression, col_id: &ColumnId| {
expr.referenced_columns()
.iter()
.find(|col| col.id == Some(*col_id))
.is_some()
};
if matches!(op, BinaryOperator::Or) && check_func(right_expr, col_id) {
return None
}

Some(binary)
}

fn new_binary(col_id: &ColumnId, mut op: BinaryOperator, col: ColumnRef, val: ValueRef, is_flip: bool) -> Option<ConstantBinary> {
if col.id.unwrap() != *col_id {
return None;
Expand Down
31 changes: 31 additions & 0 deletions src/optimizer/rule/simplification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,4 +255,35 @@ mod test {

Ok(())
}

#[tokio::test]
async fn test_simplify_filter_multiple_column_in_or() -> Result<(), DatabaseError> {
// c1 + 1 < -1 => c1 < -2
let plan_1 = select_sql_run("select * from t1 where c1 > c2 or c1 > 1").await?;

let op = |plan: LogicalPlan, expr: &str| -> Result<Option<FilterOperator>, DatabaseError> {
let best_plan = HepOptimizer::new(plan.clone())
.batch(
"test_simplify_filter".to_string(),
HepBatchStrategy::once_topdown(),
vec![RuleImpl::SimplifyFilter]
)
.find_best()?;
if let Operator::Filter(filter_op) = best_plan.childrens[0].clone().operator {
println!("{expr}: {:#?}", filter_op);

Ok(Some(filter_op))
} else {
Ok(None)
}
};

let op_1 = op(plan_1, "c1 > c2 or c1 > 1")?.unwrap();

let cb_1_c1 = op_1.predicate.convert_binary(&0).unwrap();
println!("op_1 => c1: {:#?}", cb_1_c1);
assert_eq!(cb_1_c1, None);

Ok(())
}
}

0 comments on commit ded2c95

Please sign in to comment.