diff --git a/src/db.rs b/src/db.rs index 9c05b919..3616c7ee 100644 --- a/src/db.rs +++ b/src/db.rs @@ -54,10 +54,10 @@ impl Database { /// Limit(1) /// Project(a,b) let source_plan = binder.bind(&stmts[0])?; - // println!("source_plan plan: {:#?}", source_plan); + //println!("source_plan plan: {:#?}", source_plan); let best_plan = Self::default_optimizer(source_plan).find_best()?; - // println!("best_plan plan: {:#?}", best_plan); + //println!("best_plan plan: {:#?}", best_plan); let transaction = RefCell::new(transaction); let mut stream = build(best_plan, &transaction); @@ -78,10 +78,10 @@ impl Database { .batch( "Simplify Filter".to_string(), HepBatchStrategy::fix_point_topdown(10), - vec![RuleImpl::SimplifyFilter, RuleImpl::ConstantCalculation], + vec![RuleImpl::LikeRewrite, RuleImpl::SimplifyFilter, RuleImpl::ConstantCalculation], ) .batch( - "Predicate Pushdown".to_string(), + "Predicate Pushown".to_string(), HepBatchStrategy::fix_point_topdown(10), vec![ RuleImpl::PushPredicateThroughJoin, @@ -206,6 +206,12 @@ mod test { let _ = kipsql .run("insert into t3 (a, b) values (4, 4444), (5, 5222), (6, 1.00)") .await?; + let _ = kipsql + .run("create table t4 (a int primary key, b varchar(100))") + .await?; + let _ = kipsql + .run("insert into t4 (a, b) values (1, 'abc'), (2, 'abdc'), (3, 'abcd'), (4, 'ddabc')") + .await?; println!("show tables:"); let tuples_show_tables = kipsql.run("show tables").await?; @@ -371,6 +377,10 @@ mod test { let tuples_decimal = kipsql.run("select * from t3").await?; println!("{}", create_table(&tuples_decimal)); + println!("like rewrite:"); + let tuples_like_rewrite = kipsql.run("select * from t4 where b like 'abc%'").await?; + println!("{}", create_table(&tuples_like_rewrite)); + Ok(()) } } diff --git a/src/optimizer/rule/mod.rs b/src/optimizer/rule/mod.rs index 1c9bbbed..969b6b31 100644 --- a/src/optimizer/rule/mod.rs +++ b/src/optimizer/rule/mod.rs @@ -9,7 +9,7 @@ use crate::optimizer::rule::pushdown_limit::{ }; use crate::optimizer::rule::pushdown_predicates::PushPredicateIntoScan; use crate::optimizer::rule::pushdown_predicates::PushPredicateThroughJoin; -use crate::optimizer::rule::simplification::ConstantCalculation; +use crate::optimizer::rule::simplification::{ConstantCalculation, LikeRewrite}; use crate::optimizer::rule::simplification::SimplifyFilter; use crate::optimizer::OptimizerError; @@ -37,6 +37,7 @@ pub enum RuleImpl { // Simplification SimplifyFilter, ConstantCalculation, + LikeRewrite, } impl Rule for RuleImpl { @@ -53,6 +54,7 @@ impl Rule for RuleImpl { RuleImpl::PushPredicateIntoScan => PushPredicateIntoScan.pattern(), RuleImpl::SimplifyFilter => SimplifyFilter.pattern(), RuleImpl::ConstantCalculation => ConstantCalculation.pattern(), + RuleImpl::LikeRewrite =>LikeRewrite.pattern(), } } @@ -69,6 +71,7 @@ impl Rule for RuleImpl { RuleImpl::SimplifyFilter => SimplifyFilter.apply(node_id, graph), RuleImpl::PushPredicateIntoScan => PushPredicateIntoScan.apply(node_id, graph), RuleImpl::ConstantCalculation => ConstantCalculation.apply(node_id, graph), + RuleImpl::LikeRewrite => LikeRewrite.apply(node_id, graph), } } } diff --git a/src/optimizer/rule/simplification.rs b/src/optimizer/rule/simplification.rs index 3f004451..884d0cc7 100644 --- a/src/optimizer/rule/simplification.rs +++ b/src/optimizer/rule/simplification.rs @@ -5,7 +5,15 @@ use crate::optimizer::OptimizerError; use crate::planner::operator::join::JoinCondition; use crate::planner::operator::Operator; use lazy_static::lazy_static; +use crate::expression::{BinaryOperator, ScalarExpression}; +use crate::types::value::{DataValue, ValueRef}; lazy_static! { + static ref LIKE_REWRITE_RULE: Pattern = { + Pattern { + predicate: |op| matches!(op, Operator::Filter(_)), + children: PatternChildrenPredicate::None, + } + }; static ref CONSTANT_CALCULATION_RULE: Pattern = { Pattern { predicate: |_| true, @@ -109,6 +117,84 @@ impl Rule for SimplifyFilter { } } +pub struct LikeRewrite; + +impl Rule for LikeRewrite { + fn pattern(&self) -> &Pattern { + &LIKE_REWRITE_RULE + } + + fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), OptimizerError> { + if let Operator::Filter(mut filter_op) = graph.operator(node_id).clone() { + // if is like expression + if let ScalarExpression::Binary { + op: BinaryOperator::Like, + left_expr, + right_expr, + ty, + } = &mut filter_op.predicate + { + // if left is column and right is constant + if let ScalarExpression::ColumnRef(_) = left_expr.as_ref() { + if let ScalarExpression::Constant(value) = right_expr.as_ref() { + match value.as_ref() { + DataValue::Utf8(val_str) => { + let mut value = val_str.clone().unwrap_or_else(|| "".to_string()); + + if value.ends_with('%') { + value.pop(); // remove '%' + if let Some(last_char) = value.clone().pop() { + if let Some(next_char) = increment_char(last_char) { + let mut new_value = value.clone(); + new_value.pop(); + new_value.push(next_char); + + let new_expr = ScalarExpression::Binary { + op: BinaryOperator::And, + left_expr: Box::new(ScalarExpression::Binary { + op: BinaryOperator::GtEq, + left_expr: left_expr.clone(), + right_expr: Box::new(ScalarExpression::Constant(ValueRef::from(DataValue::Utf8(Some(value))))), + ty: ty.clone(), + }), + right_expr: Box::new(ScalarExpression::Binary { + op: BinaryOperator::Lt, + left_expr: left_expr.clone(), + right_expr: Box::new(ScalarExpression::Constant(ValueRef::from(DataValue::Utf8(Some(new_value))))), + ty: ty.clone(), + }), + ty: ty.clone(), + }; + filter_op.predicate = new_expr; + } + } + } + } + _ => { + graph.version += 1; + return Ok(()); + } + } + } + } + } + graph.replace_node(node_id, Operator::Filter(filter_op)) + } + // mark changed to skip this rule batch + graph.version += 1; + Ok(()) + } +} + +fn increment_char(v: char) -> Option { + match v { + 'z' => None, + 'Z' => None, + _ => std::char::from_u32(v as u32 + 1), + } +} + + #[cfg(test)] mod test { use crate::binder::test::select_sql_run; @@ -126,6 +212,15 @@ mod test { use crate::types::LogicalType; use std::collections::Bound; use std::sync::Arc; + use crate::optimizer::rule::simplification::increment_char; + + + #[test] + fn test_increment_char() { + assert_eq!(increment_char('a'), Some('b')); + assert_eq!(increment_char('z'), None); + assert_eq!(increment_char('A'), Some('B')); + } #[tokio::test] async fn test_constant_calculation_omitted() -> Result<(), DatabaseError> { @@ -302,6 +397,7 @@ mod test { Ok(()) } + #[tokio::test] async fn test_simplify_filter_multiple_column() -> Result<(), DatabaseError> { // c1 + 1 < -1 => c1 < -2 @@ -343,7 +439,7 @@ mod test { cb_1_c1, Some(ConstantBinary::Scope { min: Bound::Unbounded, - max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-2)))) + max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-2)))), }) ); @@ -353,7 +449,7 @@ mod test { cb_1_c2, Some(ConstantBinary::Scope { min: Bound::Excluded(Arc::new(DataValue::Int32(Some(2)))), - max: Bound::Unbounded + max: Bound::Unbounded, }) ); @@ -363,7 +459,7 @@ mod test { cb_2_c1, Some(ConstantBinary::Scope { min: Bound::Excluded(Arc::new(DataValue::Int32(Some(2)))), - max: Bound::Unbounded + max: Bound::Unbounded, }) ); @@ -373,7 +469,7 @@ mod test { cb_1_c1, Some(ConstantBinary::Scope { min: Bound::Unbounded, - max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-2)))) + max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-2)))), }) ); @@ -383,7 +479,7 @@ mod test { cb_3_c1, Some(ConstantBinary::Scope { min: Bound::Unbounded, - max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-1)))) + max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-1)))), }) ); @@ -393,7 +489,7 @@ mod test { cb_3_c2, Some(ConstantBinary::Scope { min: Bound::Excluded(Arc::new(DataValue::Int32(Some(0)))), - max: Bound::Unbounded + max: Bound::Unbounded, }) ); @@ -403,7 +499,7 @@ mod test { cb_4_c1, Some(ConstantBinary::Scope { min: Bound::Excluded(Arc::new(DataValue::Int32(Some(0)))), - max: Bound::Unbounded + max: Bound::Unbounded, }) ); @@ -413,7 +509,7 @@ mod test { cb_4_c2, Some(ConstantBinary::Scope { min: Bound::Unbounded, - max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-1)))) + max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-1)))), }) );