Skip to content

Commit

Permalink
fix: is [not] null/[not] in can be overridden to extract ConstantBi…
Browse files Browse the repository at this point in the history
…nary normally (#122)

* fix: `is [not] null/[not] in` can be overridden to extract ConstantBinary normally

* fix: `is not null`
  • Loading branch information
KKould authored Jan 30, 2024
1 parent 756b939 commit 2550932
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 2 deletions.
55 changes: 53 additions & 2 deletions src/expression/simplify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::catalog::ColumnRef;
use crate::expression::value_compute::{binary_op, unary_op};
use crate::expression::{BinaryOperator, ScalarExpression, UnaryOperator};
use crate::types::errors::TypeError;
use crate::types::value::{DataValue, ValueRef};
use crate::types::value::{DataValue, ValueRef, FALSE_VALUE, NULL_VALUE, TRUE_VALUE};
use crate::types::{ColumnId, LogicalType};
use ahash::RandomState;
use itertools::Itertools;
Expand Down Expand Up @@ -611,6 +611,39 @@ impl ScalarExpression {
}));
}
}
ScalarExpression::In {
expr,
negated,
args,
} => {
let (op_1, op_2, value) = if *negated {
(
BinaryOperator::NotEq,
BinaryOperator::And,
TRUE_VALUE.clone(),
)
} else {
(BinaryOperator::Eq, BinaryOperator::Or, FALSE_VALUE.clone())
};

let mut new_expr = ScalarExpression::Constant(value);

for arg in args.drain(..) {
new_expr = ScalarExpression::Binary {
op: op_2.clone(),
left_expr: Box::new(ScalarExpression::Binary {
op: op_1.clone(),
left_expr: expr.clone(),
right_expr: Box::new(arg),
ty: LogicalType::Boolean,
}),
right_expr: Box::new(new_expr),
ty: LogicalType::Boolean,
}
}

let _ = mem::replace(self, new_expr);
}
_ => (),
}

Expand Down Expand Up @@ -834,9 +867,27 @@ impl ScalarExpression {
}
ScalarExpression::Alias { expr, .. }
| ScalarExpression::TypeCast { expr, .. }
| ScalarExpression::IsNull { expr, .. }
| ScalarExpression::Unary { expr, .. }
| ScalarExpression::In { expr, .. } => expr.convert_binary(col_id),
ScalarExpression::IsNull { expr, negated, .. } => match expr.as_ref() {
ScalarExpression::ColumnRef(column) => {
Ok((column.id() == column.id()).then(|| {
if *negated {
ConstantBinary::NotEq(NULL_VALUE.clone())
} else {
ConstantBinary::Eq(NULL_VALUE.clone())
}
}))
}
ScalarExpression::Constant(_)
| ScalarExpression::Alias { .. }
| ScalarExpression::TypeCast { .. }
| ScalarExpression::IsNull { .. }
| ScalarExpression::Unary { .. }
| ScalarExpression::Binary { .. }
| ScalarExpression::AggCall { .. }
| ScalarExpression::In { .. } => expr.convert_binary(col_id),
},
ScalarExpression::Constant(_)
| ScalarExpression::ColumnRef(_)
| ScalarExpression::AggCall { .. } => Ok(None),
Expand Down
1 change: 1 addition & 0 deletions src/optimizer/rule/implementation/dql/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ impl<T: Transaction> ImplementationRule<T> for IndexScanImplementation {
cost = Some(histogram.collect_count(binaries) * 2);
}
}
assert!(!matches!(cost, Some(0)));

group_expr.append_expr(Expression {
op: PhysicalOption::IndexScan(index_info.clone()),
Expand Down
108 changes: 108 additions & 0 deletions src/optimizer/rule/normalization/simplification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,114 @@ mod test {

let cb_1_c1 = op_1.predicate.convert_binary(&0).unwrap();
println!("op_1 => c1: {:#?}", cb_1_c1);
assert_eq!(cb_1_c1, Some(ConstantBinary::Eq(Arc::new(DataValue::Null))));

Ok(())
}

#[tokio::test]
async fn test_simplify_filter_column_is_not_null() -> Result<(), DatabaseError> {
let plan_1 = select_sql_run("select * from t1 where c1 is not null").await?;

let op = |plan: LogicalPlan, expr: &str| -> Result<Option<FilterOperator>, DatabaseError> {
let best_plan = HepOptimizer::new(plan.clone())
.batch(
"test_simplify_filter".to_string(),
HepBatchStrategy::once_topdown(),
vec![NormalizationRuleImpl::SimplifyFilter],
)
.find_best::<KipTransaction>(None)?;
if let Operator::Filter(filter_op) = best_plan.childrens[0].clone().operator {
println!("{expr}: {:#?}", filter_op);

Ok(Some(filter_op))
} else {
Ok(None)
}
};

let op_1 = op(plan_1, "c1 is not null")?.unwrap();

let cb_1_c1 = op_1.predicate.convert_binary(&0).unwrap();
println!("op_1 => c1: {:#?}", cb_1_c1);
assert_eq!(
cb_1_c1,
Some(ConstantBinary::NotEq(Arc::new(DataValue::Null)))
);

Ok(())
}

#[tokio::test]
async fn test_simplify_filter_column_in() -> Result<(), DatabaseError> {
let plan_1 = select_sql_run("select * from t1 where c1 in (1, 2, 3)").await?;

let op = |plan: LogicalPlan, expr: &str| -> Result<Option<FilterOperator>, DatabaseError> {
let best_plan = HepOptimizer::new(plan.clone())
.batch(
"test_simplify_filter".to_string(),
HepBatchStrategy::once_topdown(),
vec![NormalizationRuleImpl::SimplifyFilter],
)
.find_best::<KipTransaction>(None)?;
if let Operator::Filter(filter_op) = best_plan.childrens[0].clone().operator {
println!("{expr}: {:#?}", filter_op);

Ok(Some(filter_op))
} else {
Ok(None)
}
};

let op_1 = op(plan_1, "c1 in (1, 2, 3)")?.unwrap();

let cb_1_c1 = op_1.predicate.convert_binary(&0).unwrap();
println!("op_1 => c1: {:#?}", cb_1_c1);
assert_eq!(
cb_1_c1,
Some(ConstantBinary::Or(vec![
ConstantBinary::Eq(Arc::new(DataValue::Int32(Some(2)))),
ConstantBinary::Eq(Arc::new(DataValue::Int32(Some(1)))),
ConstantBinary::Eq(Arc::new(DataValue::Int32(Some(3)))),
]))
);

Ok(())
}

#[tokio::test]
async fn test_simplify_filter_column_not_in() -> Result<(), DatabaseError> {
let plan_1 = select_sql_run("select * from t1 where c1 not in (1, 2, 3)").await?;

let op = |plan: LogicalPlan, expr: &str| -> Result<Option<FilterOperator>, DatabaseError> {
let best_plan = HepOptimizer::new(plan.clone())
.batch(
"test_simplify_filter".to_string(),
HepBatchStrategy::once_topdown(),
vec![NormalizationRuleImpl::SimplifyFilter],
)
.find_best::<KipTransaction>(None)?;
if let Operator::Filter(filter_op) = best_plan.childrens[0].clone().operator {
println!("{expr}: {:#?}", filter_op);

Ok(Some(filter_op))
} else {
Ok(None)
}
};

let op_1 = op(plan_1, "c1 not in (1, 2, 3)")?.unwrap();

let cb_1_c1 = op_1.predicate.convert_binary(&0).unwrap();
println!("op_1 => c1: {:#?}", cb_1_c1);
assert_eq!(
cb_1_c1,
Some(ConstantBinary::And(vec![
ConstantBinary::NotEq(Arc::new(DataValue::Int32(Some(2)))),
ConstantBinary::NotEq(Arc::new(DataValue::Int32(Some(1)))),
ConstantBinary::NotEq(Arc::new(DataValue::Int32(Some(3)))),
]))
);

Ok(())
}
Expand Down
8 changes: 8 additions & 0 deletions src/types/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ use serde::{Deserialize, Serialize};
use super::LogicalType;

lazy_static! {
pub static ref NULL_VALUE: ValueRef = Arc::new(DataValue::Null);
pub static ref FALSE_VALUE: ValueRef = Arc::new(DataValue::Boolean(Some(false)));
pub static ref TRUE_VALUE: ValueRef = Arc::new(DataValue::Boolean(Some(true)));
static ref UNIX_DATETIME: NaiveDateTime = NaiveDateTime::from_timestamp_opt(0, 0).unwrap();
}

Expand Down Expand Up @@ -86,6 +89,11 @@ generate_get_option!(DataValue,
impl PartialEq for DataValue {
fn eq(&self, other: &Self) -> bool {
use DataValue::*;

if self.is_null() && other.is_null() {
return true;
}

match (self, other) {
(Boolean(v1), Boolean(v2)) => v1.eq(v2),
(Boolean(_), _) => false,
Expand Down

0 comments on commit 2550932

Please sign in to comment.