Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(condition): add column_self_eq_eliminate for condition simplification #15901

Merged
merged 12 commits into from
Mar 28, 2024
16 changes: 8 additions & 8 deletions src/frontend/planner_test/tests/testdata/output/bushy_join.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,31 +18,31 @@
│ ├─StreamExchange { dist: HashShard(t.id) }
│ │ └─StreamHashJoin { type: Inner, predicate: t.id = t.id, output: [t.id, t.id, t._row_id, t._row_id] }
│ │ ├─StreamExchange { dist: HashShard(t.id) }
│ │ │ └─StreamFilter { predicate: (t.id = t.id) }
│ │ │ └─StreamFilter { predicate: IsNotNull(t.id) }
│ │ │ └─StreamTableScan { table: t, columns: [t.id, t._row_id], stream_scan_type: ArrangementBackfill, pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
│ │ └─StreamExchange { dist: HashShard(t.id) }
│ │ └─StreamFilter { predicate: (t.id = t.id) }
│ │ └─StreamFilter { predicate: IsNotNull(t.id) }
│ │ └─StreamTableScan { table: t, columns: [t.id, t._row_id], stream_scan_type: ArrangementBackfill, pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
│ └─StreamHashJoin { type: Inner, predicate: t.id = t.id, output: [t.id, t.id, t._row_id, t._row_id] }
│ ├─StreamExchange { dist: HashShard(t.id) }
│ │ └─StreamFilter { predicate: (t.id = t.id) }
│ │ └─StreamFilter { predicate: IsNotNull(t.id) }
│ │ └─StreamTableScan { table: t, columns: [t.id, t._row_id], stream_scan_type: ArrangementBackfill, pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
│ └─StreamExchange { dist: HashShard(t.id) }
│ └─StreamFilter { predicate: (t.id = t.id) }
│ └─StreamFilter { predicate: IsNotNull(t.id) }
│ └─StreamTableScan { table: t, columns: [t.id, t._row_id], stream_scan_type: ArrangementBackfill, pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
└─StreamHashJoin { type: Inner, predicate: t.id = t.id AND t.id = t.id AND t.id = t.id AND t.id = t.id, output: [t.id, t.id, t.id, t.id, t._row_id, t._row_id, t._row_id, t._row_id] }
├─StreamExchange { dist: HashShard(t.id) }
│ └─StreamHashJoin { type: Inner, predicate: t.id = t.id, output: [t.id, t.id, t._row_id, t._row_id] }
│ ├─StreamExchange { dist: HashShard(t.id) }
│ │ └─StreamFilter { predicate: (t.id = t.id) }
│ │ └─StreamFilter { predicate: IsNotNull(t.id) }
│ │ └─StreamTableScan { table: t, columns: [t.id, t._row_id], stream_scan_type: ArrangementBackfill, pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
│ └─StreamExchange { dist: HashShard(t.id) }
│ └─StreamFilter { predicate: (t.id = t.id) }
│ └─StreamFilter { predicate: IsNotNull(t.id) }
│ └─StreamTableScan { table: t, columns: [t.id, t._row_id], stream_scan_type: ArrangementBackfill, pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
└─StreamHashJoin { type: Inner, predicate: t.id = t.id, output: [t.id, t.id, t._row_id, t._row_id] }
├─StreamExchange { dist: HashShard(t.id) }
│ └─StreamFilter { predicate: (t.id = t.id) }
│ └─StreamFilter { predicate: IsNotNull(t.id) }
│ └─StreamTableScan { table: t, columns: [t.id, t._row_id], stream_scan_type: ArrangementBackfill, pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
└─StreamExchange { dist: HashShard(t.id) }
└─StreamFilter { predicate: (t.id = t.id) }
└─StreamFilter { predicate: IsNotNull(t.id) }
└─StreamTableScan { table: t, columns: [t.id, t._row_id], stream_scan_type: ArrangementBackfill, pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
48 changes: 24 additions & 24 deletions src/frontend/planner_test/tests/testdata/output/ch_benchmark.yaml

Large diffs are not rendered by default.

16 changes: 8 additions & 8 deletions src/frontend/planner_test/tests/testdata/output/tpch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@
│ └─StreamExchange { dist: HashShard(supplier.s_nationkey) }
│ └─StreamTableScan { table: supplier, columns: [supplier.s_suppkey, supplier.s_name, supplier.s_address, supplier.s_nationkey, supplier.s_phone, supplier.s_acctbal, supplier.s_comment], stream_scan_type: ArrangementBackfill, pk: [supplier.s_suppkey], dist: UpstreamHashShard(supplier.s_suppkey) }
└─StreamExchange { dist: HashShard(partsupp.ps_suppkey) }
└─StreamFilter { predicate: (partsupp.ps_partkey = partsupp.ps_partkey) }
└─StreamFilter { predicate: IsNotNull(partsupp.ps_partkey) }
└─StreamTableScan { table: partsupp, columns: [partsupp.ps_partkey, partsupp.ps_suppkey, partsupp.ps_supplycost], stream_scan_type: ArrangementBackfill, pk: [partsupp.ps_partkey, partsupp.ps_suppkey], dist: UpstreamHashShard(partsupp.ps_partkey, partsupp.ps_suppkey) }
stream_dist_plan: |+
Fragment 0
Expand Down Expand Up @@ -457,7 +457,7 @@
└── BatchPlanNode

Fragment 17
StreamFilter { predicate: (partsupp.ps_partkey = partsupp.ps_partkey) }
StreamFilter { predicate: IsNotNull(partsupp.ps_partkey) }
└── StreamTableScan { table: partsupp, columns: [partsupp.ps_partkey, partsupp.ps_suppkey, partsupp.ps_supplycost], stream_scan_type: ArrangementBackfill, pk: [partsupp.ps_partkey, partsupp.ps_suppkey], dist: UpstreamHashShard(partsupp.ps_partkey, partsupp.ps_suppkey) } { tables: [ StreamScan: 44 ] }
├── Upstream
└── BatchPlanNode
Expand Down Expand Up @@ -963,7 +963,7 @@
│ │ └─StreamFilter { predicate: (region.r_name = 'MIDDLE EAST':Varchar) }
│ │ └─StreamTableScan { table: region, columns: [region.r_regionkey, region.r_name], stream_scan_type: ArrangementBackfill, pk: [region.r_regionkey], dist: UpstreamHashShard(region.r_regionkey) }
│ └─StreamExchange { dist: HashShard(nation.n_regionkey) }
│ └─StreamFilter { predicate: (nation.n_nationkey = nation.n_nationkey) }
│ └─StreamFilter { predicate: IsNotNull(nation.n_nationkey) }
│ └─StreamTableScan { table: nation, columns: [nation.n_nationkey, nation.n_name, nation.n_regionkey], stream_scan_type: ArrangementBackfill, pk: [nation.n_nationkey], dist: UpstreamHashShard(nation.n_nationkey) }
└─StreamExchange { dist: HashShard(customer.c_nationkey, supplier.s_nationkey) }
└─StreamHashJoin { type: Inner, predicate: orders.o_orderkey = lineitem.l_orderkey AND customer.c_nationkey = supplier.s_nationkey, output: [customer.c_nationkey, lineitem.l_extendedprice, lineitem.l_discount, supplier.s_nationkey, orders.o_orderkey, orders.o_custkey, lineitem.l_orderkey, lineitem.l_linenumber, lineitem.l_suppkey] }
Expand Down Expand Up @@ -1017,7 +1017,7 @@
└── BatchPlanNode

Fragment 5
StreamFilter { predicate: (nation.n_nationkey = nation.n_nationkey) }
StreamFilter { predicate: IsNotNull(nation.n_nationkey) }
└── StreamTableScan { table: nation, columns: [nation.n_nationkey, nation.n_name, nation.n_regionkey], stream_scan_type: ArrangementBackfill, pk: [nation.n_nationkey], dist: UpstreamHashShard(nation.n_nationkey) } { tables: [ StreamScan: 12 ] }
├── Upstream
└── BatchPlanNode
Expand Down Expand Up @@ -1907,7 +1907,7 @@
│ │ └─StreamFilter { predicate: Like(part.p_name, '%yellow%':Varchar) }
│ │ └─StreamTableScan { table: part, columns: [part.p_partkey, part.p_name], stream_scan_type: ArrangementBackfill, pk: [part.p_partkey], dist: UpstreamHashShard(part.p_partkey) }
│ └─StreamExchange { dist: HashShard(partsupp.ps_partkey) }
│ └─StreamFilter { predicate: (partsupp.ps_suppkey = partsupp.ps_suppkey) }
│ └─StreamFilter { predicate: IsNotNull(partsupp.ps_suppkey) }
│ └─StreamTableScan { table: partsupp, columns: [partsupp.ps_partkey, partsupp.ps_suppkey, partsupp.ps_supplycost], stream_scan_type: ArrangementBackfill, pk: [partsupp.ps_partkey, partsupp.ps_suppkey], dist: UpstreamHashShard(partsupp.ps_partkey, partsupp.ps_suppkey) }
└─StreamHashJoin { type: Inner, predicate: supplier.s_suppkey = lineitem.l_suppkey, output: [nation.n_name, supplier.s_suppkey, orders.o_orderdate, lineitem.l_partkey, lineitem.l_suppkey, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount, nation.n_nationkey, orders.o_orderkey, lineitem.l_linenumber] }
├─StreamExchange { dist: HashShard(supplier.s_suppkey) }
Expand All @@ -1921,7 +1921,7 @@
├─StreamExchange { dist: HashShard(orders.o_orderkey) }
│ └─StreamTableScan { table: orders, columns: [orders.o_orderkey, orders.o_orderdate], stream_scan_type: ArrangementBackfill, pk: [orders.o_orderkey], dist: UpstreamHashShard(orders.o_orderkey) }
└─StreamExchange { dist: HashShard(lineitem.l_orderkey) }
└─StreamFilter { predicate: (lineitem.l_partkey = lineitem.l_partkey) }
└─StreamFilter { predicate: IsNotNull(lineitem.l_partkey) }
└─StreamTableScan { table: lineitem, columns: [lineitem.l_orderkey, lineitem.l_partkey, lineitem.l_suppkey, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount, lineitem.l_linenumber], stream_scan_type: ArrangementBackfill, pk: [lineitem.l_orderkey, lineitem.l_linenumber], dist: UpstreamHashShard(lineitem.l_orderkey, lineitem.l_linenumber) }
stream_dist_plan: |+
Fragment 0
Expand Down Expand Up @@ -1960,7 +1960,7 @@
└── BatchPlanNode

Fragment 5
StreamFilter { predicate: (partsupp.ps_suppkey = partsupp.ps_suppkey) }
StreamFilter { predicate: IsNotNull(partsupp.ps_suppkey) }
└── StreamTableScan { table: partsupp, columns: [partsupp.ps_partkey, partsupp.ps_suppkey, partsupp.ps_supplycost], stream_scan_type: ArrangementBackfill, pk: [partsupp.ps_partkey, partsupp.ps_suppkey], dist: UpstreamHashShard(partsupp.ps_partkey, partsupp.ps_suppkey) } { tables: [ StreamScan: 12 ] }
├── Upstream
└── BatchPlanNode
Expand Down Expand Up @@ -1991,7 +1991,7 @@
└── BatchPlanNode

Fragment 11
StreamFilter { predicate: (lineitem.l_partkey = lineitem.l_partkey) }
StreamFilter { predicate: IsNotNull(lineitem.l_partkey) }
└── StreamTableScan { table: lineitem, columns: [lineitem.l_orderkey, lineitem.l_partkey, lineitem.l_suppkey, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount, lineitem.l_linenumber], stream_scan_type: ArrangementBackfill, pk: [lineitem.l_orderkey, lineitem.l_linenumber], dist: UpstreamHashShard(lineitem.l_orderkey, lineitem.l_linenumber) } { tables: [ StreamScan: 28 ] }
├── Upstream
└── BatchPlanNode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@
│ └─StreamRowIdGen { row_id_index: 7 }
│ └─StreamSource { source: supplier, columns: [s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment, _row_id] }
└─StreamExchange { dist: HashShard(ps_suppkey) }
└─StreamFilter { predicate: (ps_partkey = ps_partkey) }
└─StreamFilter { predicate: IsNotNull(ps_partkey) }
└─StreamShare { id: 5 }
└─StreamProject { exprs: [ps_partkey, ps_suppkey, ps_supplycost, _row_id] }
└─StreamRowIdGen { row_id_index: 5 }
Expand Down Expand Up @@ -372,7 +372,7 @@
└── StreamExchange NoShuffle from 9

Fragment 21
StreamFilter { predicate: (ps_partkey = ps_partkey) }
StreamFilter { predicate: IsNotNull(ps_partkey) }
└── StreamExchange NoShuffle from 7

Table 0 { columns: [ p_partkey, p_mfgr, ps_partkey, min(ps_supplycost), _row_id ], primary key: [ $0 ASC, $3 ASC, $2 ASC, $4 ASC ], value indices: [ 0, 1, 2, 3, 4 ], distribution key: [ 2, 3 ], read pk prefix len hint: 3 }
Expand Down Expand Up @@ -548,7 +548,7 @@
│ │ └─StreamRowIdGen { row_id_index: 3 }
│ │ └─StreamSource { source: region, columns: [r_regionkey, r_name, r_comment, _row_id] }
│ └─StreamExchange { dist: HashShard(n_regionkey) }
│ └─StreamFilter { predicate: (n_nationkey = n_nationkey) }
│ └─StreamFilter { predicate: IsNotNull(n_nationkey) }
│ └─StreamRowIdGen { row_id_index: 4 }
│ └─StreamSource { source: nation, columns: [n_nationkey, n_name, n_regionkey, n_comment, _row_id] }
└─StreamExchange { dist: HashShard(c_nationkey, s_nationkey) }
Expand Down Expand Up @@ -594,7 +594,7 @@
└── StreamSource { source: region, columns: [r_regionkey, r_name, r_comment, _row_id] } { tables: [ Source: 9 ] }

Fragment 4
StreamFilter { predicate: (n_nationkey = n_nationkey) }
StreamFilter { predicate: IsNotNull(n_nationkey) }
└── StreamRowIdGen { row_id_index: 4 }
└── StreamSource { source: nation, columns: [n_nationkey, n_name, n_regionkey, n_comment, _row_id] } { tables: [ Source: 10 ] }

Expand Down Expand Up @@ -1421,7 +1421,7 @@
│ │ └─StreamRowIdGen { row_id_index: 9 }
│ │ └─StreamSource { source: part, columns: [p_partkey, p_name, p_mfgr, p_brand, p_type, p_size, p_container, p_retailprice, p_comment, _row_id] }
│ └─StreamExchange { dist: HashShard(ps_partkey) }
│ └─StreamFilter { predicate: (ps_suppkey = ps_suppkey) }
│ └─StreamFilter { predicate: IsNotNull(ps_suppkey) }
│ └─StreamRowIdGen { row_id_index: 5 }
│ └─StreamSource { source: partsupp, columns: [ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, ps_comment, _row_id] }
└─StreamHashJoin [append_only] { type: Inner, predicate: s_suppkey = l_suppkey, output: [n_name, s_suppkey, o_orderdate, l_partkey, l_suppkey, l_quantity, l_extendedprice, l_discount, _row_id, _row_id, n_nationkey, _row_id, _row_id, o_orderkey] }
Expand All @@ -1439,7 +1439,7 @@
│ └─StreamRowIdGen { row_id_index: 9 }
│ └─StreamSource { source: orders, columns: [o_orderkey, o_custkey, o_orderstatus, o_totalprice, o_orderdate, o_orderpriority, o_clerk, o_shippriority, o_comment, _row_id] }
└─StreamExchange { dist: HashShard(l_orderkey) }
└─StreamFilter { predicate: (l_partkey = l_partkey) }
└─StreamFilter { predicate: IsNotNull(l_partkey) }
└─StreamRowIdGen { row_id_index: 16 }
└─StreamSource { source: lineitem, columns: [l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment, _row_id] }
stream_dist_plan: |+
Expand Down Expand Up @@ -1469,7 +1469,7 @@
└── StreamSource { source: part, columns: [p_partkey, p_name, p_mfgr, p_brand, p_type, p_size, p_container, p_retailprice, p_comment, _row_id] } { tables: [ Source: 9 ] }

Fragment 4
StreamFilter { predicate: (ps_suppkey = ps_suppkey) }
StreamFilter { predicate: IsNotNull(ps_suppkey) }
└── StreamRowIdGen { row_id_index: 5 }
└── StreamSource { source: partsupp, columns: [ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, ps_comment, _row_id] } { tables: [ Source: 10 ] }

Expand All @@ -1496,7 +1496,7 @@
└── StreamSource { source: orders, columns: [o_orderkey, o_custkey, o_orderstatus, o_totalprice, o_orderdate, o_orderpriority, o_clerk, o_shippriority, o_comment, _row_id] } { tables: [ Source: 25 ] }

Fragment 10
StreamFilter { predicate: (l_partkey = l_partkey) }
StreamFilter { predicate: IsNotNull(l_partkey) }
└── StreamRowIdGen { row_id_index: 16 }
└── StreamSource { source: lineitem, columns: [l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment, _row_id] } { tables: [ Source: 26 ] }

Expand Down
79 changes: 79 additions & 0 deletions src/frontend/src/expr/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,85 @@ pub fn fold_boolean_constant(expr: ExprImpl) -> ExprImpl {
rewriter.rewrite_expr(expr)
}

/// check `ColumnSelfEqualRewriter`'s comment below.
pub fn column_self_eq_eliminate(expr: ExprImpl) -> ExprImpl {
ColumnSelfEqualRewriter::rewrite(expr)
}

/// for every `(col) == (col)`,
/// transform to `IsNotNull(col)`
/// since in the boolean context, `null = (...)` will always
/// be treated as false.
/// note: as always, only for *single column*.
pub struct ColumnSelfEqualRewriter {}

impl ColumnSelfEqualRewriter {
/// the exact copy from `logical_filter_expression_simplify_rule`
fn extract_column(expr: ExprImpl, columns: &mut Vec<ExprImpl>) {
match expr.clone() {
ExprImpl::FunctionCall(func_call) => {
// the functions that *never* return null will be ignored
if Self::is_not_null(func_call.func_type()) {
return;
}
for sub_expr in func_call.inputs() {
Self::extract_column(sub_expr.clone(), columns);
}
}
ExprImpl::InputRef(_) => {
if !columns.contains(&expr) {
// only add the column if not exists
columns.push(expr);
}
}
_ => (),
}
}

/// the exact copy from `logical_filter_expression_simplify_rule`
fn is_not_null(func_type: ExprType) -> bool {
func_type == ExprType::IsNull
|| func_type == ExprType::IsNotNull
|| func_type == ExprType::IsTrue
|| func_type == ExprType::IsFalse
|| func_type == ExprType::IsNotTrue
|| func_type == ExprType::IsNotFalse
}

pub fn rewrite(expr: ExprImpl) -> ExprImpl {
Copy link
Contributor Author

@xzhseh xzhseh Mar 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and why bother to separate the logic from ExpressionSimplifyRewriter in logical_filter_expression_simplify_rule?

  1. the former mainly aims to optimize pattern like, e.g., (e) [or / and] not(e).
  2. it's hard to maintain the module with two different logic interleaving. (and presumably more in the future)
  3. this one is especially to simplify Condition at constructing time, while the former will be used during logical_optimization.

let mut columns = vec![];
Self::extract_column(expr.clone(), &mut columns);
if columns.len() > 1 {
// leave it intact
return expr;
}

// extract the equal inputs with sanity check
let ExprImpl::FunctionCall(func_call) = expr.clone() else {
return expr;
};
if func_call.func_type() != ExprType::Equal || func_call.inputs().len() != 2 {
return expr;
}
assert_eq!(func_call.return_type(), DataType::Boolean);
let inputs = func_call.inputs();
let e1 = inputs[0].clone();
let e2 = inputs[1].clone();

if e1 == e2 {
if columns.is_empty() {
return ExprImpl::literal_bool(true);
}
let Ok(ret) = FunctionCall::new(ExprType::IsNotNull, vec![columns[0].clone()]) else {
return expr;
};
ret.into()
} else {
expr
}
}
}

/// Fold boolean constants in a expr
struct BooleanConstantFolding {}

Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/optimizer/plan_node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ impl PlanRef {

for c in merge_predicate.conjunctions {
let c = Condition::with_expr(expr_rewriter.rewrite_cond(c));

// rebuild the conjunctions
new_predicate = new_predicate.and(c);
}

Expand Down
Loading
Loading