Skip to content

Commit

Permalink
feat(condition): add column_self_eq_eliminate for condition simplif…
Browse files Browse the repository at this point in the history
…ication (#15901)
  • Loading branch information
xzhseh authored Mar 28, 2024
1 parent e627002 commit 2e197e1
Show file tree
Hide file tree
Showing 8 changed files with 140 additions and 53 deletions.
16 changes: 8 additions & 8 deletions src/frontend/planner_test/tests/testdata/output/bushy_join.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,31 +18,31 @@
│ ├─StreamExchange { dist: HashShard(t.id) }
│ │ └─StreamHashJoin { type: Inner, predicate: t.id = t.id, output: [t.id, t.id, t._row_id, t._row_id] }
│ │ ├─StreamExchange { dist: HashShard(t.id) }
│ │ │ └─StreamFilter { predicate: (t.id = t.id) }
│ │ │ └─StreamFilter { predicate: IsNotNull(t.id) }
│ │ │ └─StreamTableScan { table: t, columns: [t.id, t._row_id], stream_scan_type: ArrangementBackfill, pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
│ │ └─StreamExchange { dist: HashShard(t.id) }
│ │ └─StreamFilter { predicate: (t.id = t.id) }
│ │ └─StreamFilter { predicate: IsNotNull(t.id) }
│ │ └─StreamTableScan { table: t, columns: [t.id, t._row_id], stream_scan_type: ArrangementBackfill, pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
│ └─StreamHashJoin { type: Inner, predicate: t.id = t.id, output: [t.id, t.id, t._row_id, t._row_id] }
│ ├─StreamExchange { dist: HashShard(t.id) }
│ │ └─StreamFilter { predicate: (t.id = t.id) }
│ │ └─StreamFilter { predicate: IsNotNull(t.id) }
│ │ └─StreamTableScan { table: t, columns: [t.id, t._row_id], stream_scan_type: ArrangementBackfill, pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
│ └─StreamExchange { dist: HashShard(t.id) }
│ └─StreamFilter { predicate: (t.id = t.id) }
│ └─StreamFilter { predicate: IsNotNull(t.id) }
│ └─StreamTableScan { table: t, columns: [t.id, t._row_id], stream_scan_type: ArrangementBackfill, pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
└─StreamHashJoin { type: Inner, predicate: t.id = t.id AND t.id = t.id AND t.id = t.id AND t.id = t.id, output: [t.id, t.id, t.id, t.id, t._row_id, t._row_id, t._row_id, t._row_id] }
├─StreamExchange { dist: HashShard(t.id) }
│ └─StreamHashJoin { type: Inner, predicate: t.id = t.id, output: [t.id, t.id, t._row_id, t._row_id] }
│ ├─StreamExchange { dist: HashShard(t.id) }
│ │ └─StreamFilter { predicate: (t.id = t.id) }
│ │ └─StreamFilter { predicate: IsNotNull(t.id) }
│ │ └─StreamTableScan { table: t, columns: [t.id, t._row_id], stream_scan_type: ArrangementBackfill, pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
│ └─StreamExchange { dist: HashShard(t.id) }
│ └─StreamFilter { predicate: (t.id = t.id) }
│ └─StreamFilter { predicate: IsNotNull(t.id) }
│ └─StreamTableScan { table: t, columns: [t.id, t._row_id], stream_scan_type: ArrangementBackfill, pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
└─StreamHashJoin { type: Inner, predicate: t.id = t.id, output: [t.id, t.id, t._row_id, t._row_id] }
├─StreamExchange { dist: HashShard(t.id) }
│ └─StreamFilter { predicate: (t.id = t.id) }
│ └─StreamFilter { predicate: IsNotNull(t.id) }
│ └─StreamTableScan { table: t, columns: [t.id, t._row_id], stream_scan_type: ArrangementBackfill, pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
└─StreamExchange { dist: HashShard(t.id) }
└─StreamFilter { predicate: (t.id = t.id) }
└─StreamFilter { predicate: IsNotNull(t.id) }
└─StreamTableScan { table: t, columns: [t.id, t._row_id], stream_scan_type: ArrangementBackfill, pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
48 changes: 24 additions & 24 deletions src/frontend/planner_test/tests/testdata/output/ch_benchmark.yaml

Large diffs are not rendered by default.

16 changes: 8 additions & 8 deletions src/frontend/planner_test/tests/testdata/output/tpch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@
│ └─StreamExchange { dist: HashShard(supplier.s_nationkey) }
│ └─StreamTableScan { table: supplier, columns: [supplier.s_suppkey, supplier.s_name, supplier.s_address, supplier.s_nationkey, supplier.s_phone, supplier.s_acctbal, supplier.s_comment], stream_scan_type: ArrangementBackfill, pk: [supplier.s_suppkey], dist: UpstreamHashShard(supplier.s_suppkey) }
└─StreamExchange { dist: HashShard(partsupp.ps_suppkey) }
└─StreamFilter { predicate: (partsupp.ps_partkey = partsupp.ps_partkey) }
└─StreamFilter { predicate: IsNotNull(partsupp.ps_partkey) }
└─StreamTableScan { table: partsupp, columns: [partsupp.ps_partkey, partsupp.ps_suppkey, partsupp.ps_supplycost], stream_scan_type: ArrangementBackfill, pk: [partsupp.ps_partkey, partsupp.ps_suppkey], dist: UpstreamHashShard(partsupp.ps_partkey, partsupp.ps_suppkey) }
stream_dist_plan: |+
Fragment 0
Expand Down Expand Up @@ -457,7 +457,7 @@
└── BatchPlanNode
Fragment 17
StreamFilter { predicate: (partsupp.ps_partkey = partsupp.ps_partkey) }
StreamFilter { predicate: IsNotNull(partsupp.ps_partkey) }
└── StreamTableScan { table: partsupp, columns: [partsupp.ps_partkey, partsupp.ps_suppkey, partsupp.ps_supplycost], stream_scan_type: ArrangementBackfill, pk: [partsupp.ps_partkey, partsupp.ps_suppkey], dist: UpstreamHashShard(partsupp.ps_partkey, partsupp.ps_suppkey) } { tables: [ StreamScan: 44 ] }
├── Upstream
└── BatchPlanNode
Expand Down Expand Up @@ -963,7 +963,7 @@
│ │ └─StreamFilter { predicate: (region.r_name = 'MIDDLE EAST':Varchar) }
│ │ └─StreamTableScan { table: region, columns: [region.r_regionkey, region.r_name], stream_scan_type: ArrangementBackfill, pk: [region.r_regionkey], dist: UpstreamHashShard(region.r_regionkey) }
│ └─StreamExchange { dist: HashShard(nation.n_regionkey) }
│ └─StreamFilter { predicate: (nation.n_nationkey = nation.n_nationkey) }
│ └─StreamFilter { predicate: IsNotNull(nation.n_nationkey) }
│ └─StreamTableScan { table: nation, columns: [nation.n_nationkey, nation.n_name, nation.n_regionkey], stream_scan_type: ArrangementBackfill, pk: [nation.n_nationkey], dist: UpstreamHashShard(nation.n_nationkey) }
└─StreamExchange { dist: HashShard(customer.c_nationkey, supplier.s_nationkey) }
└─StreamHashJoin { type: Inner, predicate: orders.o_orderkey = lineitem.l_orderkey AND customer.c_nationkey = supplier.s_nationkey, output: [customer.c_nationkey, lineitem.l_extendedprice, lineitem.l_discount, supplier.s_nationkey, orders.o_orderkey, orders.o_custkey, lineitem.l_orderkey, lineitem.l_linenumber, lineitem.l_suppkey] }
Expand Down Expand Up @@ -1017,7 +1017,7 @@
└── BatchPlanNode
Fragment 5
StreamFilter { predicate: (nation.n_nationkey = nation.n_nationkey) }
StreamFilter { predicate: IsNotNull(nation.n_nationkey) }
└── StreamTableScan { table: nation, columns: [nation.n_nationkey, nation.n_name, nation.n_regionkey], stream_scan_type: ArrangementBackfill, pk: [nation.n_nationkey], dist: UpstreamHashShard(nation.n_nationkey) } { tables: [ StreamScan: 12 ] }
├── Upstream
└── BatchPlanNode
Expand Down Expand Up @@ -1907,7 +1907,7 @@
│ │ └─StreamFilter { predicate: Like(part.p_name, '%yellow%':Varchar) }
│ │ └─StreamTableScan { table: part, columns: [part.p_partkey, part.p_name], stream_scan_type: ArrangementBackfill, pk: [part.p_partkey], dist: UpstreamHashShard(part.p_partkey) }
│ └─StreamExchange { dist: HashShard(partsupp.ps_partkey) }
│ └─StreamFilter { predicate: (partsupp.ps_suppkey = partsupp.ps_suppkey) }
│ └─StreamFilter { predicate: IsNotNull(partsupp.ps_suppkey) }
│ └─StreamTableScan { table: partsupp, columns: [partsupp.ps_partkey, partsupp.ps_suppkey, partsupp.ps_supplycost], stream_scan_type: ArrangementBackfill, pk: [partsupp.ps_partkey, partsupp.ps_suppkey], dist: UpstreamHashShard(partsupp.ps_partkey, partsupp.ps_suppkey) }
└─StreamHashJoin { type: Inner, predicate: supplier.s_suppkey = lineitem.l_suppkey, output: [nation.n_name, supplier.s_suppkey, orders.o_orderdate, lineitem.l_partkey, lineitem.l_suppkey, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount, nation.n_nationkey, orders.o_orderkey, lineitem.l_linenumber] }
├─StreamExchange { dist: HashShard(supplier.s_suppkey) }
Expand All @@ -1921,7 +1921,7 @@
├─StreamExchange { dist: HashShard(orders.o_orderkey) }
│ └─StreamTableScan { table: orders, columns: [orders.o_orderkey, orders.o_orderdate], stream_scan_type: ArrangementBackfill, pk: [orders.o_orderkey], dist: UpstreamHashShard(orders.o_orderkey) }
└─StreamExchange { dist: HashShard(lineitem.l_orderkey) }
└─StreamFilter { predicate: (lineitem.l_partkey = lineitem.l_partkey) }
└─StreamFilter { predicate: IsNotNull(lineitem.l_partkey) }
└─StreamTableScan { table: lineitem, columns: [lineitem.l_orderkey, lineitem.l_partkey, lineitem.l_suppkey, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount, lineitem.l_linenumber], stream_scan_type: ArrangementBackfill, pk: [lineitem.l_orderkey, lineitem.l_linenumber], dist: UpstreamHashShard(lineitem.l_orderkey, lineitem.l_linenumber) }
stream_dist_plan: |+
Fragment 0
Expand Down Expand Up @@ -1960,7 +1960,7 @@
└── BatchPlanNode
Fragment 5
StreamFilter { predicate: (partsupp.ps_suppkey = partsupp.ps_suppkey) }
StreamFilter { predicate: IsNotNull(partsupp.ps_suppkey) }
└── StreamTableScan { table: partsupp, columns: [partsupp.ps_partkey, partsupp.ps_suppkey, partsupp.ps_supplycost], stream_scan_type: ArrangementBackfill, pk: [partsupp.ps_partkey, partsupp.ps_suppkey], dist: UpstreamHashShard(partsupp.ps_partkey, partsupp.ps_suppkey) } { tables: [ StreamScan: 12 ] }
├── Upstream
└── BatchPlanNode
Expand Down Expand Up @@ -1991,7 +1991,7 @@
└── BatchPlanNode
Fragment 11
StreamFilter { predicate: (lineitem.l_partkey = lineitem.l_partkey) }
StreamFilter { predicate: IsNotNull(lineitem.l_partkey) }
└── StreamTableScan { table: lineitem, columns: [lineitem.l_orderkey, lineitem.l_partkey, lineitem.l_suppkey, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount, lineitem.l_linenumber], stream_scan_type: ArrangementBackfill, pk: [lineitem.l_orderkey, lineitem.l_linenumber], dist: UpstreamHashShard(lineitem.l_orderkey, lineitem.l_linenumber) } { tables: [ StreamScan: 28 ] }
├── Upstream
└── BatchPlanNode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@
│ └─StreamRowIdGen { row_id_index: 7 }
│ └─StreamSource { source: supplier, columns: [s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment, _row_id] }
└─StreamExchange { dist: HashShard(ps_suppkey) }
└─StreamFilter { predicate: (ps_partkey = ps_partkey) }
└─StreamFilter { predicate: IsNotNull(ps_partkey) }
└─StreamShare { id: 5 }
└─StreamProject { exprs: [ps_partkey, ps_suppkey, ps_supplycost, _row_id] }
└─StreamRowIdGen { row_id_index: 5 }
Expand Down Expand Up @@ -372,7 +372,7 @@
└── StreamExchange NoShuffle from 9
Fragment 21
StreamFilter { predicate: (ps_partkey = ps_partkey) }
StreamFilter { predicate: IsNotNull(ps_partkey) }
└── StreamExchange NoShuffle from 7
Table 0 { columns: [ p_partkey, p_mfgr, ps_partkey, min(ps_supplycost), _row_id ], primary key: [ $0 ASC, $3 ASC, $2 ASC, $4 ASC ], value indices: [ 0, 1, 2, 3, 4 ], distribution key: [ 2, 3 ], read pk prefix len hint: 3 }
Expand Down Expand Up @@ -548,7 +548,7 @@
│ │ └─StreamRowIdGen { row_id_index: 3 }
│ │ └─StreamSource { source: region, columns: [r_regionkey, r_name, r_comment, _row_id] }
│ └─StreamExchange { dist: HashShard(n_regionkey) }
│ └─StreamFilter { predicate: (n_nationkey = n_nationkey) }
│ └─StreamFilter { predicate: IsNotNull(n_nationkey) }
│ └─StreamRowIdGen { row_id_index: 4 }
│ └─StreamSource { source: nation, columns: [n_nationkey, n_name, n_regionkey, n_comment, _row_id] }
└─StreamExchange { dist: HashShard(c_nationkey, s_nationkey) }
Expand Down Expand Up @@ -594,7 +594,7 @@
└── StreamSource { source: region, columns: [r_regionkey, r_name, r_comment, _row_id] } { tables: [ Source: 9 ] }
Fragment 4
StreamFilter { predicate: (n_nationkey = n_nationkey) }
StreamFilter { predicate: IsNotNull(n_nationkey) }
└── StreamRowIdGen { row_id_index: 4 }
└── StreamSource { source: nation, columns: [n_nationkey, n_name, n_regionkey, n_comment, _row_id] } { tables: [ Source: 10 ] }
Expand Down Expand Up @@ -1421,7 +1421,7 @@
│ │ └─StreamRowIdGen { row_id_index: 9 }
│ │ └─StreamSource { source: part, columns: [p_partkey, p_name, p_mfgr, p_brand, p_type, p_size, p_container, p_retailprice, p_comment, _row_id] }
│ └─StreamExchange { dist: HashShard(ps_partkey) }
│ └─StreamFilter { predicate: (ps_suppkey = ps_suppkey) }
│ └─StreamFilter { predicate: IsNotNull(ps_suppkey) }
│ └─StreamRowIdGen { row_id_index: 5 }
│ └─StreamSource { source: partsupp, columns: [ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, ps_comment, _row_id] }
└─StreamHashJoin [append_only] { type: Inner, predicate: s_suppkey = l_suppkey, output: [n_name, s_suppkey, o_orderdate, l_partkey, l_suppkey, l_quantity, l_extendedprice, l_discount, _row_id, _row_id, n_nationkey, _row_id, _row_id, o_orderkey] }
Expand All @@ -1439,7 +1439,7 @@
│ └─StreamRowIdGen { row_id_index: 9 }
│ └─StreamSource { source: orders, columns: [o_orderkey, o_custkey, o_orderstatus, o_totalprice, o_orderdate, o_orderpriority, o_clerk, o_shippriority, o_comment, _row_id] }
└─StreamExchange { dist: HashShard(l_orderkey) }
└─StreamFilter { predicate: (l_partkey = l_partkey) }
└─StreamFilter { predicate: IsNotNull(l_partkey) }
└─StreamRowIdGen { row_id_index: 16 }
└─StreamSource { source: lineitem, columns: [l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment, _row_id] }
stream_dist_plan: |+
Expand Down Expand Up @@ -1469,7 +1469,7 @@
└── StreamSource { source: part, columns: [p_partkey, p_name, p_mfgr, p_brand, p_type, p_size, p_container, p_retailprice, p_comment, _row_id] } { tables: [ Source: 9 ] }
Fragment 4
StreamFilter { predicate: (ps_suppkey = ps_suppkey) }
StreamFilter { predicate: IsNotNull(ps_suppkey) }
└── StreamRowIdGen { row_id_index: 5 }
└── StreamSource { source: partsupp, columns: [ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, ps_comment, _row_id] } { tables: [ Source: 10 ] }
Expand All @@ -1496,7 +1496,7 @@
└── StreamSource { source: orders, columns: [o_orderkey, o_custkey, o_orderstatus, o_totalprice, o_orderdate, o_orderpriority, o_clerk, o_shippriority, o_comment, _row_id] } { tables: [ Source: 25 ] }
Fragment 10
StreamFilter { predicate: (l_partkey = l_partkey) }
StreamFilter { predicate: IsNotNull(l_partkey) }
└── StreamRowIdGen { row_id_index: 16 }
└── StreamSource { source: lineitem, columns: [l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment, _row_id] } { tables: [ Source: 26 ] }
Expand Down
79 changes: 79 additions & 0 deletions src/frontend/src/expr/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,85 @@ pub fn fold_boolean_constant(expr: ExprImpl) -> ExprImpl {
rewriter.rewrite_expr(expr)
}

/// check `ColumnSelfEqualRewriter`'s comment below.
pub fn column_self_eq_eliminate(expr: ExprImpl) -> ExprImpl {
ColumnSelfEqualRewriter::rewrite(expr)
}

/// for every `(col) == (col)`,
/// transform to `IsNotNull(col)`
/// since in the boolean context, `null = (...)` will always
/// be treated as false.
/// note: as always, only for *single column*.
pub struct ColumnSelfEqualRewriter {}

impl ColumnSelfEqualRewriter {
/// the exact copy from `logical_filter_expression_simplify_rule`
fn extract_column(expr: ExprImpl, columns: &mut Vec<ExprImpl>) {
match expr.clone() {
ExprImpl::FunctionCall(func_call) => {
// the functions that *never* return null will be ignored
if Self::is_not_null(func_call.func_type()) {
return;
}
for sub_expr in func_call.inputs() {
Self::extract_column(sub_expr.clone(), columns);
}
}
ExprImpl::InputRef(_) => {
if !columns.contains(&expr) {
// only add the column if not exists
columns.push(expr);
}
}
_ => (),
}
}

/// the exact copy from `logical_filter_expression_simplify_rule`
fn is_not_null(func_type: ExprType) -> bool {
func_type == ExprType::IsNull
|| func_type == ExprType::IsNotNull
|| func_type == ExprType::IsTrue
|| func_type == ExprType::IsFalse
|| func_type == ExprType::IsNotTrue
|| func_type == ExprType::IsNotFalse
}

pub fn rewrite(expr: ExprImpl) -> ExprImpl {
let mut columns = vec![];
Self::extract_column(expr.clone(), &mut columns);
if columns.len() > 1 {
// leave it intact
return expr;
}

// extract the equal inputs with sanity check
let ExprImpl::FunctionCall(func_call) = expr.clone() else {
return expr;
};
if func_call.func_type() != ExprType::Equal || func_call.inputs().len() != 2 {
return expr;
}
assert_eq!(func_call.return_type(), DataType::Boolean);
let inputs = func_call.inputs();
let e1 = inputs[0].clone();
let e2 = inputs[1].clone();

if e1 == e2 {
if columns.is_empty() {
return ExprImpl::literal_bool(true);
}
let Ok(ret) = FunctionCall::new(ExprType::IsNotNull, vec![columns[0].clone()]) else {
return expr;
};
ret.into()
} else {
expr
}
}
}

/// Fold boolean constants in a expr
struct BooleanConstantFolding {}

Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/optimizer/plan_node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ impl PlanRef {

for c in merge_predicate.conjunctions {
let c = Condition::with_expr(expr_rewriter.rewrite_cond(c));

// rebuild the conjunctions
new_predicate = new_predicate.and(c);
}

Expand Down
Loading

0 comments on commit 2e197e1

Please sign in to comment.