diff --git a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/group_by.rs b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/group_by.rs new file mode 100644 index 000000000000..3ab423a6640d --- /dev/null +++ b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/group_by.rs @@ -0,0 +1,84 @@ +use super::*; + +#[allow(clippy::too_many_arguments)] +pub(super) fn process_group_by( + opt: &PredicatePushDown, + lp_arena: &mut Arena, + expr_arena: &mut Arena, + input: Node, + keys: Vec, + aggs: Vec, + schema: SchemaRef, + maintain_order: bool, + apply: Option>, + options: Arc, + acc_predicates: PlHashMap, Node>, +) -> PolarsResult { + use ALogicalPlan::*; + + #[cfg(feature = "dynamic_group_by")] + let no_push = { options.rolling.is_some() || options.dynamic.is_some() }; + + #[cfg(not(feature = "dynamic_group_by"))] + let no_push = false; + + // Don't pushdown predicates on these cases. + if apply.is_some() || no_push || options.slice.is_some() { + let lp = Aggregate { + input, + keys, + aggs, + schema, + apply, + maintain_order, + options, + }; + return opt.no_pushdown_restart_opt(lp, acc_predicates, lp_arena, expr_arena); + } + + // If the predicate only resolves to the keys we can push it down. + // When it filters the aggregations, the predicate should be done after aggregation. + let mut local_predicates = Vec::with_capacity(acc_predicates.len()); + let key_schema = aexprs_to_schema( + &keys, + lp_arena.get(input).schema(lp_arena).as_ref(), + Context::Default, + expr_arena, + ); + + let mut new_acc_predicates = PlHashMap::with_capacity(acc_predicates.len()); + + for (pred_name, predicate) in &acc_predicates { + // Counts change due to groupby's + // TODO! handle aliases, so that the predicate that is pushed down refers to the column before alias. + let mut push_down = !has_aexpr(*predicate, expr_arena, |ae| { + matches!(ae, AExpr::Count | AExpr::Alias(_, _)) + }); + + for name in aexpr_to_leaf_names_iter(*predicate, expr_arena) { + push_down &= key_schema.contains(name.as_ref()); + + if !push_down { + break; + } + } + if !push_down { + local_predicates.push(*predicate) + } else { + new_acc_predicates.insert(pred_name.clone(), *predicate); + } + } + + opt.pushdown_and_assign(input, new_acc_predicates, lp_arena, expr_arena)?; + + let lp = Aggregate { + input, + keys, + aggs, + schema, + apply, + maintain_order, + options, + }; + Ok(opt.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena)) +} diff --git a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs index 22569ffee7d0..eec0ddaff940 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs @@ -1,3 +1,4 @@ +mod group_by; mod join; mod keys; mod rename; @@ -11,6 +12,7 @@ use utils::*; use super::*; use crate::dsl::function_expr::FunctionExpr; use crate::logical_plan::optimizer; +use crate::prelude::optimizer::predicate_pushdown::group_by::process_group_by; use crate::prelude::optimizer::predicate_pushdown::join::process_join; use crate::prelude::optimizer::predicate_pushdown::rename::process_rename; use crate::utils::{check_input_node, has_aexpr}; @@ -421,6 +423,10 @@ impl<'a> PredicatePushDown<'a> { self.no_pushdown_restart_opt(lp, acc_predicates, lp_arena, expr_arena) } } + Aggregate {input, keys, aggs, schema, apply, maintain_order, options, } => { + process_group_by(self, lp_arena, expr_arena, input, keys, aggs, schema, maintain_order, apply, options, acc_predicates) + + }, lp @ Union {..} => { let mut local_predicates = vec![]; @@ -462,8 +468,7 @@ impl<'a> PredicatePushDown<'a> { lp @ Slice { .. } // caches will be different | lp @ Cache { .. } - // dont push down predicates. An aggregation needs all rows - | lp @ Aggregate {..} => { + => { self.no_pushdown_restart_opt(lp, acc_predicates, lp_arena, expr_arena) } #[cfg(feature = "python")] diff --git a/py-polars/tests/unit/test_predicates.py b/py-polars/tests/unit/test_predicates.py index 608cf627252e..2af6376d05ef 100644 --- a/py-polars/tests/unit/test_predicates.py +++ b/py-polars/tests/unit/test_predicates.py @@ -179,3 +179,16 @@ def test_is_in_join_blocked() -> None: assert df_all.filter(~pl.col("Groups").is_in(["A", "B", "F"])).collect().to_dict( False ) == {"values22": [None, 4, 5], "values20": [3, 4, 5], "Groups": ["C", "D", "E"]} + + +def test_predicate_pushdown_group_by_keys() -> None: + df = pl.LazyFrame( + {"str": ["A", "B", "A", "B", "C"], "group": [1, 1, 2, 1, 2]} + ).lazy() + assert ( + 'SELECTION: "None"' + not in df.group_by("group") + .agg([pl.count().alias("str_list")]) + .filter(pl.col("group") == 1) + .explain() + )