diff --git a/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py b/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py index 1e624a6496..0eec2f06ac 100644 --- a/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py +++ b/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py @@ -46,8 +46,7 @@ class PredicatePushdownBranchStateTracker: """Tracking class for monitoring pushdown state at the node level during a visitor walk.""" def __init__(self, initial_state: PredicatePushdownState) -> None: # noqa: D107 - self._initial_state = initial_state - self._current_branch_state: List[PredicatePushdownState] = [] + self._current_branch_state: List[PredicatePushdownState] = [initial_state] @contextmanager def track_pushdown_state(self, pushdown_state: PredicatePushdownState) -> Iterator[None]: @@ -56,20 +55,45 @@ def track_pushdown_state(self, pushdown_state: PredicatePushdownState) -> Iterat This retains a sequence of pushdown state objects to allow for tracking pushdown opportunities along the current branch. Each entry represents the predicate pushdown state as of that point, and as such callers need only concern themselves with the value returned by the last_pushdown_state property. + + The back-propagation of pushdown_applied_where_filter_specs is necessary to ensure the outer query + node can evaluate which where filter specs needs to be applied. We capture the complete set because + we may have sequenced nodes where some where filters are applied (e.g., time dimension filters might + be applied to metric time nodes, etc.). """ self._current_branch_state.append(pushdown_state) yield - self._current_branch_state.pop(-1) + last_visited_pushdown_state = self._current_branch_state.pop(-1) + if len(last_visited_pushdown_state.applied_where_filter_specs) > 0: + pushdown_applied_where_filter_specs = frozenset.union( + *[ + last_visited_pushdown_state.applied_where_filter_specs, + self.last_pushdown_state.applied_where_filter_specs, + ] + ) + self.override_last_pushdown_state( + PredicatePushdownState.with_pushdown_applied_where_filter_specs( + original_pushdown_state=self.last_pushdown_state, + pushdown_applied_where_filter_specs=pushdown_applied_where_filter_specs, + ) + ) @property def last_pushdown_state(self) -> PredicatePushdownState: """Returns the last seen PredicatePushdownState. - This is nearly always the input state a given node processing method should be using for pushdown operations. + This is the input state a given node processing method should be using for pushdown operations. """ - if len(self._current_branch_state) > 0: - return self._current_branch_state[-1] - return self._initial_state + return self._current_branch_state[-1] + + def override_last_pushdown_state(self, pushdown_state: PredicatePushdownState) -> None: + """Method for forcibly updating the last seen predicate pushdown state to a new value. + + This is necessary only for cases where we wish to back-propagate some updated state attribute + for handling in the exit condition of the preceding node in the DAG. Since it is something of an + extraordinary circumstance we designate it as a special method rather than making it a property setter. + """ + self._current_branch_state[-1] = pushdown_state class PredicatePushdownOptimizer( @@ -81,8 +105,10 @@ class PredicatePushdownOptimizer( This evaluates filter predicates to determine which, if any, can be directly to an input source node. It operates by walking each branch in the DataflowPlan and collecting pushdown state information, then evaluating that state at the input source node and applying the filter node (e.g., a WhereConstraintNode) - directly to the source. As the optimizer unrolls back through the branch it will remove the duplicated - constraint node if it is appropriate to do so. + directly to the source. As the optimizer unrolls back through the branch it will remove duplicated constraints. + + This assumes that we never do predicate pushdown on a filter that needs to be re-applied, so every filter + we encounter gets applied exactly once per nested subquery branch encapsulated by a given constraint node. """ def __init__(self, node_data_set_resolver: DataflowPlanNodeOutputDataSetResolver) -> None: @@ -128,8 +154,8 @@ def _default_handler( ) -> OptimizeBranchResult: """Encapsulates state updates and handling for most node types. - The dominant majority of nodes simply propagate the current predicate pushdown state along and return - whatever output the parent nodes produce. Of the nodes that do not do this, the most common deviation + The most common node-level operation is to simply propagate the current predicate pushdown state along and + return whatever output the parent nodes produce. Of the nodes that do not do this, the most common deviation is a pushdown state update. As such, this method defaults to propagating the last seen state, but allows an override for cases where @@ -229,13 +255,16 @@ def _push_down_where_filters( filters_left_over.append(filter_spec) logger.log(level=self._log_level, msg=f"Filter specs to add:\n{filters_to_apply}") + applied_filters = frozenset.union( + *[frozenset(current_pushdown_state.applied_where_filter_specs), frozenset(filters_to_apply)] + ) updated_pushdown_state = PredicatePushdownState( time_range_constraint=current_pushdown_state.time_range_constraint, where_filter_specs=tuple(filters_left_over), pushdown_enabled_types=current_pushdown_state.pushdown_enabled_types, + applied_where_filter_specs=applied_filters, ) optimized_node = self._default_handler(node=node, pushdown_state=updated_pushdown_state) - # TODO: propagate filters applied back up the branch for removal if len(filters_to_apply) > 0: return OptimizeBranchResult( optimized_branch=WhereConstraintNode( @@ -267,11 +296,14 @@ def visit_where_constraint_node(self, node: WhereConstraintNode) -> OptimizeBran not be eligible for pushdown. This processor simply propagates them forward so long as where filter predicate pushdown is still enabled for this branch. - The fact that they have been added at this point does not mean they will be pushed down, as intervening - join nodes might remove them from consideration, so we retain them ensure all filters are applied as specified - within this method. + When the visitor returns to this node from its parents, it updates the pushdown state for this node in the + tracker. It does this within the scope of the context manager in order to keep the pushdown state updates + consistent - modifying only the state entry associated with this node, and allowing the tracker to do all + of the upstream state propagation. - TODO: Update to only apply filters that have not been pushed down + The fact that the filters have been added at this point does not mean they will be pushed down, as intervening + join nodes might remove them from consideration, so we must ensure all filters are applied as specified + within this method. """ self._log_visit_node_type(node) current_pushdown_state = self._predicate_pushdown_tracker.last_pushdown_state @@ -283,7 +315,35 @@ def visit_where_constraint_node(self, node: WhereConstraintNode) -> OptimizeBran where_filter_specs=tuple(current_pushdown_state.where_filter_specs) + tuple(node.input_where_specs), ) - return self._default_handler(node=node, pushdown_state=updated_pushdown_state) + with self._predicate_pushdown_tracker.track_pushdown_state(updated_pushdown_state): + optimized_parent: OptimizeBranchResult = node.parent_node.accept(self) + # TODO: Update to only apply filters that have not been successfully pushed down + optimized_node = OptimizeBranchResult( + optimized_branch=node.with_new_parents((optimized_parent.optimized_branch,)) + ) + + pushdown_state_updated_by_parent = self._predicate_pushdown_tracker.last_pushdown_state + # Override the pushdown state for this node and allow all upstream propagation to be handled by the tracker + if len(pushdown_state_updated_by_parent.applied_where_filter_specs) > 0: + updated_specs = frozenset.union( + frozenset(node.input_where_specs), + pushdown_state_updated_by_parent.applied_where_filter_specs, + ) + self._predicate_pushdown_tracker.override_last_pushdown_state( + PredicatePushdownState.with_pushdown_applied_where_filter_specs( + original_pushdown_state=pushdown_state_updated_by_parent, + pushdown_applied_where_filter_specs=updated_specs, + ) + ) + logger.log( + level=self._log_level, + msg=( + f"Added applied specs to pushdown state. Added specs:\n\n{node.input_where_specs}\n\n" + + f"Updated pushdown state:\n\n{self._predicate_pushdown_tracker.last_pushdown_state}" + ), + ) + + return optimized_node # Join nodes - these may affect pushdown state based on join type @@ -422,6 +482,7 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> OptimizeBr time_range_constraint=None, where_filter_specs=tuple(), pushdown_enabled_types=current_pushdown_state.pushdown_enabled_types, + applied_where_filter_specs=current_pushdown_state.applied_where_filter_specs, ) else: updated_pushdown_state = PredicatePushdownState.without_time_range_constraint(current_pushdown_state) diff --git a/metricflow/plan_conversion/node_processor.py b/metricflow/plan_conversion/node_processor.py index a2422fdd33..8b68e22705 100644 --- a/metricflow/plan_conversion/node_processor.py +++ b/metricflow/plan_conversion/node_processor.py @@ -84,20 +84,24 @@ class PredicateInputType(Enum): class PredicatePushdownState: """Container class for maintaining state information relevant for predicate pushdown. - This broadly tracks two related items: + This broadly tracks three related items: 1. Filter predicates collected during the process of constructing a dataflow plan 2. Predicate types eligible for pushdown + 3. Filters which have been applied already - The former may be updated as things like time constraints get altered or metric and measure filters are + The first may be updated as things like time constraints get altered or metric and measure filters are added to the query filters. - The latter may be updated based on query configuration, like if a cumulative metric is added to the plan + The second may be updated based on query configuration, like if a cumulative metric is added to the plan there may be changes to what sort of predicate pushdown operations are supported. + The last will be updated as filters are applied via pushdown or by the original WhereConstraintNode. - The time_range_constraint property holds the time window for setting up a time range filter expression. + Finally, the time_range_constraint property holds the time window for setting up a time range filter expression. """ time_range_constraint: Optional[TimeRangeConstraint] + # TODO: Deduplicate where_filter_specs where_filter_specs: Sequence[WhereFilterSpec] + applied_where_filter_specs: FrozenSet[WhereFilterSpec] = frozenset() pushdown_enabled_types: FrozenSet[PredicateInputType] = frozenset( [PredicateInputType.TIME_RANGE_CONSTRAINT, PredicateInputType.CATEGORICAL_DIMENSION] ) @@ -215,6 +219,7 @@ def with_time_range_constraint( time_range_constraint=time_range_constraint, pushdown_enabled_types=pushdown_enabled_types, where_filter_specs=original_pushdown_state.where_filter_specs, + applied_where_filter_specs=original_pushdown_state.applied_where_filter_specs, ) @staticmethod @@ -236,6 +241,7 @@ def without_time_range_constraint( time_range_constraint=None, pushdown_enabled_types=pushdown_enabled_types, where_filter_specs=original_pushdown_state.where_filter_specs, + applied_where_filter_specs=original_pushdown_state.applied_where_filter_specs, ) @staticmethod @@ -264,6 +270,23 @@ def with_where_filter_specs( time_range_constraint=original_pushdown_state.time_range_constraint, where_filter_specs=where_filter_specs, pushdown_enabled_types=original_pushdown_state.pushdown_enabled_types, + applied_where_filter_specs=original_pushdown_state.applied_where_filter_specs, + ) + + @staticmethod + def with_pushdown_applied_where_filter_specs( + original_pushdown_state: PredicatePushdownState, pushdown_applied_where_filter_specs: FrozenSet[WhereFilterSpec] + ) -> PredicatePushdownState: + """Factory method for replacing pushdown applied where filter specs in pushdown operations. + + This is useful for managing propagation - both forwards and backwards - of where filter specs that have been + applied via a pushdown operation. + """ + return PredicatePushdownState( + time_range_constraint=original_pushdown_state.time_range_constraint, + pushdown_enabled_types=original_pushdown_state.pushdown_enabled_types, + where_filter_specs=original_pushdown_state.where_filter_specs, + applied_where_filter_specs=pushdown_applied_where_filter_specs, ) @staticmethod diff --git a/tests_metricflow/dataflow/builder/test_predicate_pushdown.py b/tests_metricflow/dataflow/builder/test_predicate_pushdown.py deleted file mode 100644 index a3cc4e3e4b..0000000000 --- a/tests_metricflow/dataflow/builder/test_predicate_pushdown.py +++ /dev/null @@ -1,45 +0,0 @@ -from __future__ import annotations - -import pytest -from metricflow_semantics.filters.time_constraint import TimeRangeConstraint - -from metricflow.plan_conversion.node_processor import PredicateInputType, PredicatePushdownState - - -@pytest.fixture -def fully_enabled_pushdown_state() -> PredicatePushdownState: - """Tests a valid configuration with all predicate properties set and pushdown fully enabled.""" - params = PredicatePushdownState(time_range_constraint=TimeRangeConstraint.all_time(), where_filter_specs=tuple()) - return params - - -def test_time_range_pushdown_enabled_states(fully_enabled_pushdown_state: PredicatePushdownState) -> None: - """Tests pushdown enabled check for time range pushdown operations.""" - time_range_only_state = PredicatePushdownState( - time_range_constraint=TimeRangeConstraint.all_time(), - pushdown_enabled_types=frozenset([PredicateInputType.TIME_RANGE_CONSTRAINT]), - where_filter_specs=tuple(), - ) - - enabled_states = { - "fully enabled": fully_enabled_pushdown_state.has_time_range_constraint_to_push_down, - "enabled for time range only": time_range_only_state.has_time_range_constraint_to_push_down, - } - - assert all(list(enabled_states.values())), ( - "Expected pushdown to be enabled for pushdown state with time range constraint and global pushdown enabled, " - "but some states returned False for has_time_range_constraint_to_push_down.\n" - f"Pushdown enabled states: {enabled_states}\n" - f"Fully enabled state: {fully_enabled_pushdown_state}\n" - f"Time range only state: {time_range_only_state}" - ) - - -def test_invalid_disabled_pushdown_state() -> None: - """Tests checks for invalid param configuration on disabled pushdown parameters.""" - with pytest.raises(AssertionError, match="Disabled pushdown state objects cannot have properties set"): - PredicatePushdownState( - time_range_constraint=TimeRangeConstraint.all_time(), - pushdown_enabled_types=frozenset(), - where_filter_specs=tuple(), - ) diff --git a/tests_metricflow/dataflow/optimizer/test_predicate_pushdown_optimizer.py b/tests_metricflow/dataflow/optimizer/test_predicate_pushdown_optimizer.py new file mode 100644 index 0000000000..73ab080a3f --- /dev/null +++ b/tests_metricflow/dataflow/optimizer/test_predicate_pushdown_optimizer.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +from datetime import datetime + +import pytest +from metricflow_semantics.filters.time_constraint import TimeRangeConstraint +from metricflow_semantics.specs.spec_classes import WhereFilterSpec +from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters + +from metricflow.dataflow.optimizer.predicate_pushdown_optimizer import PredicatePushdownBranchStateTracker +from metricflow.plan_conversion.node_processor import PredicateInputType, PredicatePushdownState + + +@pytest.fixture +def fully_enabled_pushdown_state() -> PredicatePushdownState: + """Provides a valid configuration with all predicate properties set and pushdown fully enabled.""" + params = PredicatePushdownState(time_range_constraint=TimeRangeConstraint.all_time(), where_filter_specs=tuple()) + return params + + +@pytest.fixture +def branch_state_tracker(fully_enabled_pushdown_state: PredicatePushdownState) -> PredicatePushdownBranchStateTracker: + """Provides a branch state tracker for direct testing of update mechanics.""" + return PredicatePushdownBranchStateTracker(initial_state=fully_enabled_pushdown_state) + + +def test_time_range_pushdown_enabled_states(fully_enabled_pushdown_state: PredicatePushdownState) -> None: + """Tests pushdown enabled check for time range pushdown operations.""" + time_range_only_state = PredicatePushdownState( + time_range_constraint=TimeRangeConstraint.all_time(), + pushdown_enabled_types=frozenset([PredicateInputType.TIME_RANGE_CONSTRAINT]), + where_filter_specs=tuple(), + ) + + enabled_states = { + "fully enabled": fully_enabled_pushdown_state.has_time_range_constraint_to_push_down, + "enabled for time range only": time_range_only_state.has_time_range_constraint_to_push_down, + } + + assert all(list(enabled_states.values())), ( + "Expected pushdown to be enabled for pushdown state with time range constraint and global pushdown enabled, " + "but some states returned False for has_time_range_constraint_to_push_down.\n" + f"Pushdown enabled states: {enabled_states}\n" + f"Fully enabled state: {fully_enabled_pushdown_state}\n" + f"Time range only state: {time_range_only_state}" + ) + + +def test_invalid_disabled_pushdown_state() -> None: + """Tests checks for invalid param configuration on disabled pushdown parameters.""" + with pytest.raises(AssertionError, match="Disabled pushdown state objects cannot have properties set"): + PredicatePushdownState( + time_range_constraint=TimeRangeConstraint.all_time(), + pushdown_enabled_types=frozenset(), + where_filter_specs=tuple(), + ) + + +def test_branch_state_propagation(branch_state_tracker: PredicatePushdownBranchStateTracker) -> None: + """Tests forward propagation of predicate pushdown branch state. + + This asserts against expected results on entry and exit of a three-hop nested propagation. + """ + base_state = branch_state_tracker.last_pushdown_state + where_state = PredicatePushdownState.with_where_filter_specs( + original_pushdown_state=base_state, + where_filter_specs=( + WhereFilterSpec( + where_sql="x is true", bind_parameters=SqlBindParameters(), linkable_elements=(), linkable_specs=() + ), + ), + ) + time_state = PredicatePushdownState.with_time_range_constraint( + original_pushdown_state=base_state, + time_range_constraint=TimeRangeConstraint(datetime(2024, 1, 1), datetime(2024, 1, 1)), + ) + state_updates = (time_state, where_state, time_state) + with branch_state_tracker.track_pushdown_state(state_updates[0]): + assert branch_state_tracker.last_pushdown_state == state_updates[0], "Failed to track first state update!" + with branch_state_tracker.track_pushdown_state(state_updates[1]): + assert branch_state_tracker.last_pushdown_state == state_updates[1], "Failed to track second state update!" + with branch_state_tracker.track_pushdown_state(state_updates[2]): + assert ( + branch_state_tracker.last_pushdown_state == state_updates[2] + ), "Failed to track third state update!" + + assert branch_state_tracker.last_pushdown_state == state_updates[1], "Failed to remove third state update!" + + assert branch_state_tracker.last_pushdown_state == state_updates[0], "Failed to remove second state update!" + + assert branch_state_tracker.last_pushdown_state == base_state, "Failed to remove first state update!" + + +def test_applied_filter_back_propagation(branch_state_tracker: PredicatePushdownBranchStateTracker) -> None: + """Tests backwards propagation of applied where filter annotations. + + This asserts that propagation on entry remains unaffected while the applied where filter annotations are + back-propagated as expected after exit, both for cases where an update was applied on entry to the + context manager and where the value was overridden just prior to exit from the context manager. + """ + base_state = branch_state_tracker.last_pushdown_state + where_spec_x_is_true = WhereFilterSpec( + where_sql="x is true", bind_parameters=SqlBindParameters(), linkable_elements=(), linkable_specs=() + ) + where_spec_y_is_null = WhereFilterSpec( + where_sql="y is null", bind_parameters=SqlBindParameters(), linkable_elements=(), linkable_specs=() + ) + + where_state = PredicatePushdownState.with_where_filter_specs( + original_pushdown_state=base_state, where_filter_specs=(where_spec_x_is_true, where_spec_y_is_null) + ) + x_applied_state = PredicatePushdownState.with_pushdown_applied_where_filter_specs( + original_pushdown_state=where_state, pushdown_applied_where_filter_specs=frozenset((where_spec_x_is_true,)) + ) + both_applied_state = PredicatePushdownState.with_pushdown_applied_where_filter_specs( + original_pushdown_state=base_state, + pushdown_applied_where_filter_specs=frozenset((where_spec_x_is_true, where_spec_y_is_null)), + ) + + with branch_state_tracker.track_pushdown_state(base_state): + assert ( + branch_state_tracker.last_pushdown_state == base_state + ), "Initial condition AND initial tracking mis-configured!" + with branch_state_tracker.track_pushdown_state(where_state): + assert branch_state_tracker.last_pushdown_state == where_state, "Failed to track where state!" + with branch_state_tracker.track_pushdown_state(x_applied_state): + assert ( + branch_state_tracker.last_pushdown_state == x_applied_state + ), "Failed to track applied filter state!" + + assert ( + branch_state_tracker.last_pushdown_state.applied_where_filter_specs + == x_applied_state.applied_where_filter_specs + ), "Failed to back-propagate applied filter state!" + # Update internally from where state + branch_state_tracker.override_last_pushdown_state( + PredicatePushdownState.with_pushdown_applied_where_filter_specs( + original_pushdown_state=branch_state_tracker.last_pushdown_state, + pushdown_applied_where_filter_specs=frozenset((where_spec_x_is_true, where_spec_y_is_null)), + ) + ) + + assert not branch_state_tracker.last_pushdown_state.has_where_filters_to_push_down, ( + f"Failed to remove where filter state update! Should be {base_state} but got " + f"{branch_state_tracker.last_pushdown_state}!" + ) + assert branch_state_tracker.last_pushdown_state == both_applied_state + + # We expect to propagate back to the initial entry since we only ever want to apply a filter once within a branch + assert branch_state_tracker.last_pushdown_state == both_applied_state