diff --git a/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py b/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py index 1e624a6496..a4b4e0b0cb 100644 --- a/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py +++ b/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py @@ -46,8 +46,7 @@ class PredicatePushdownBranchStateTracker: """Tracking class for monitoring pushdown state at the node level during a visitor walk.""" def __init__(self, initial_state: PredicatePushdownState) -> None: # noqa: D107 - self._initial_state = initial_state - self._current_branch_state: List[PredicatePushdownState] = [] + self._current_branch_state: List[PredicatePushdownState] = [initial_state] @contextmanager def track_pushdown_state(self, pushdown_state: PredicatePushdownState) -> Iterator[None]: @@ -56,20 +55,90 @@ def track_pushdown_state(self, pushdown_state: PredicatePushdownState) -> Iterat This retains a sequence of pushdown state objects to allow for tracking pushdown opportunities along the current branch. Each entry represents the predicate pushdown state as of that point, and as such callers need only concern themselves with the value returned by the last_pushdown_state property. + + The back-propagation of pushdown_applied_where_filter_specs is necessary to ensure the outer query + node can evaluate which where filter specs needs to be applied. We capture the complete set because + we may have sequenced nodes where some where filters are applied (e.g., time dimension filters might + be applied to metric time nodes, etc.). + + The state tracking and propagation is built to work as follows: + + For a simple DAG of where_node -> join_node -> source_nodes there will be two branches: + + where_node -> join_node -> left_source_node + where_node -> join_node -> right_source_node + + In this case the where_node receives the initial predicate pushdown state, and then adds its own + updated state object (with the where filters) via the context manager and propagates that to the join_node. + + The join_node then receives the where_node's predicate pushdown state, and, for each branch, adds an + updated state object via the context manager and propagates the updated state objects to the next node. + + The left_source_node gets the join node's left branch state and evaluates it. If it can apply any filters + it adds an updated state object to note that the filters are applied and propagates it along via the context + manager. In this case, the context manager exits immediately and returns to the left_source_node, which + finishes with applying the where constraints and returns back to the join_node. + + At this point, the join node has a left branch context manager with the left_join_branch pushdown state. The + join_node does NOT have access to the left_source_node's pushdown state, but it needs to be able to notify its + parent, the where_node, that some filters have been applied at the left_source_node. + + How does it do this? The left_source_node's state update included applied where filters. When the context + manager exits it propagates the left_source_node's applied where filters back up to the preceding state (in + this case, the join node's left branch state). The same thing happens when the branch states for the join + node exit the context manager, so the where_node then sees the union of all filters applied downstream. + + The where_node, then, has access to the complete set of filters applied downstream. + + This is complicated because of joins - we can't store a single set of applied filters, because there's no good + way to keep them organized in the case of multiple join branches. Instead, we track up and down a single + branch, and merge the events of parent branches at the join nodes that created them. """ self._current_branch_state.append(pushdown_state) yield - self._current_branch_state.pop(-1) + last_visited_pushdown_state = self._current_branch_state.pop(-1) + if len(last_visited_pushdown_state.applied_where_filter_specs) > 0: + pushdown_applied_where_filter_specs = frozenset.union( + *[ + last_visited_pushdown_state.applied_where_filter_specs, + self.last_pushdown_state.applied_where_filter_specs, + ] + ) + self.override_last_pushdown_state( + PredicatePushdownState.with_pushdown_applied_where_filter_specs( + original_pushdown_state=self.last_pushdown_state, + pushdown_applied_where_filter_specs=pushdown_applied_where_filter_specs, + ) + ) @property def last_pushdown_state(self) -> PredicatePushdownState: """Returns the last seen PredicatePushdownState. - This is nearly always the input state a given node processing method should be using for pushdown operations. + This is the input state a given node processing method should be using for pushdown operations. + """ + assert len(self._current_branch_state) > 0, ( + "There should always be at least one element in current branch state! " + "This suggests an inappropriate removal or improper initialization." + ) + return self._current_branch_state[-1] + + def override_last_pushdown_state(self, pushdown_state: PredicatePushdownState) -> None: + """Method for forcibly updating the last seen predicate pushdown state to a new value. + + This is necessary only for cases where we wish to back-propagate some updated state attribute + for handling in the exit condition of the preceding node in the DAG. The scenario where we use + this here is for indicating that a where filter has been applied elsewhere on the branch, and so + outer nodes can skip application as appropriate. + + Since this is not something we want people doing by accident we designate it as a special method + rather than making it a property setter. """ - if len(self._current_branch_state) > 0: - return self._current_branch_state[-1] - return self._initial_state + assert len(self._current_branch_state) > 0, ( + "There should always be at least one element in current branch state! " + "This suggests an inappropriate removal or improper initialization." + ) + self._current_branch_state[-1] = pushdown_state class PredicatePushdownOptimizer( @@ -81,8 +150,10 @@ class PredicatePushdownOptimizer( This evaluates filter predicates to determine which, if any, can be directly to an input source node. It operates by walking each branch in the DataflowPlan and collecting pushdown state information, then evaluating that state at the input source node and applying the filter node (e.g., a WhereConstraintNode) - directly to the source. As the optimizer unrolls back through the branch it will remove the duplicated - constraint node if it is appropriate to do so. + directly to the source. As the optimizer unrolls back through the branch it will remove duplicated constraints. + + This assumes that we never do predicate pushdown on a filter that needs to be re-applied, so every filter + we encounter gets applied exactly once per nested subquery branch encapsulated by a given constraint node. """ def __init__(self, node_data_set_resolver: DataflowPlanNodeOutputDataSetResolver) -> None: @@ -128,8 +199,8 @@ def _default_handler( ) -> OptimizeBranchResult: """Encapsulates state updates and handling for most node types. - The dominant majority of nodes simply propagate the current predicate pushdown state along and return - whatever output the parent nodes produce. Of the nodes that do not do this, the most common deviation + The most common node-level operation is to simply propagate the current predicate pushdown state along and + return whatever output the parent nodes produce. Of the nodes that do not do this, the most common deviation is a pushdown state update. As such, this method defaults to propagating the last seen state, but allows an override for cases where @@ -229,13 +300,16 @@ def _push_down_where_filters( filters_left_over.append(filter_spec) logger.log(level=self._log_level, msg=f"Filter specs to add:\n{filters_to_apply}") + applied_filters = frozenset.union( + *[frozenset(current_pushdown_state.applied_where_filter_specs), frozenset(filters_to_apply)] + ) updated_pushdown_state = PredicatePushdownState( time_range_constraint=current_pushdown_state.time_range_constraint, where_filter_specs=tuple(filters_left_over), pushdown_enabled_types=current_pushdown_state.pushdown_enabled_types, + applied_where_filter_specs=applied_filters, ) optimized_node = self._default_handler(node=node, pushdown_state=updated_pushdown_state) - # TODO: propagate filters applied back up the branch for removal if len(filters_to_apply) > 0: return OptimizeBranchResult( optimized_branch=WhereConstraintNode( @@ -267,11 +341,14 @@ def visit_where_constraint_node(self, node: WhereConstraintNode) -> OptimizeBran not be eligible for pushdown. This processor simply propagates them forward so long as where filter predicate pushdown is still enabled for this branch. - The fact that they have been added at this point does not mean they will be pushed down, as intervening - join nodes might remove them from consideration, so we retain them ensure all filters are applied as specified - within this method. + When the visitor returns to this node from its parents, it updates the pushdown state for this node in the + tracker. It does this within the scope of the context manager in order to keep the pushdown state updates + consistent - modifying only the state entry associated with this node, and allowing the tracker to do all + of the upstream state propagation. - TODO: Update to only apply filters that have not been pushed down + The fact that the filters have been added at this point does not mean they will be pushed down, as intervening + join nodes might remove them from consideration, so we must ensure all filters are applied as specified + within this method. """ self._log_visit_node_type(node) current_pushdown_state = self._predicate_pushdown_tracker.last_pushdown_state @@ -283,7 +360,35 @@ def visit_where_constraint_node(self, node: WhereConstraintNode) -> OptimizeBran where_filter_specs=tuple(current_pushdown_state.where_filter_specs) + tuple(node.input_where_specs), ) - return self._default_handler(node=node, pushdown_state=updated_pushdown_state) + with self._predicate_pushdown_tracker.track_pushdown_state(updated_pushdown_state): + optimized_parent: OptimizeBranchResult = node.parent_node.accept(self) + # TODO: Update to only apply filters that have not been successfully pushed down + optimized_node = OptimizeBranchResult( + optimized_branch=node.with_new_parents((optimized_parent.optimized_branch,)) + ) + + pushdown_state_updated_by_parent = self._predicate_pushdown_tracker.last_pushdown_state + # Override the pushdown state for this node and allow all upstream propagation to be handled by the tracker + if len(pushdown_state_updated_by_parent.applied_where_filter_specs) > 0: + updated_specs = frozenset.union( + frozenset(node.input_where_specs), + pushdown_state_updated_by_parent.applied_where_filter_specs, + ) + self._predicate_pushdown_tracker.override_last_pushdown_state( + PredicatePushdownState.with_pushdown_applied_where_filter_specs( + original_pushdown_state=pushdown_state_updated_by_parent, + pushdown_applied_where_filter_specs=updated_specs, + ) + ) + logger.log( + level=self._log_level, + msg=( + f"Added applied specs to pushdown state. Added specs:\n\n{node.input_where_specs}\n\n" + + f"Updated pushdown state:\n\n{self._predicate_pushdown_tracker.last_pushdown_state}" + ), + ) + + return optimized_node # Join nodes - these may affect pushdown state based on join type @@ -354,6 +459,9 @@ def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> OptimizeBranc of the join. As such, time constraints are not propagated to the right side of the join. This restriction may be relaxed at a later time, but for now it is largely irrelevant since we do not allow fanout joins and do not yet have support for pre-filters based on time ranges for things like SCD joins. + + Note we initialize branch state tracking objects prior to traversal to avoid back-propagation from + one branch affecting the predicate pushdown behavior along other branches. """ self._log_visit_node_type(node) left_parent = node.left_node @@ -364,16 +472,18 @@ def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> OptimizeBranc else: left_branch_pushdown_state = self._predicate_pushdown_tracker.last_pushdown_state - optimized_parents: List[OptimizeBranchResult] = [] - with self._predicate_pushdown_tracker.track_pushdown_state(left_branch_pushdown_state): - optimized_parents.append(left_parent.accept(self)) - base_right_branch_pushdown_state = PredicatePushdownState.without_time_range_constraint( self._predicate_pushdown_tracker.last_pushdown_state ) outer_join_right_branch_pushdown_state = PredicatePushdownState.without_where_filter_specs( original_pushdown_state=base_right_branch_pushdown_state ) + + optimized_parents: List[OptimizeBranchResult] = [] + + with self._predicate_pushdown_tracker.track_pushdown_state(left_branch_pushdown_state): + optimized_parents.append(left_parent.accept(self)) + for join_description in node.join_targets: if ( join_description.join_type is SqlJoinType.LEFT_OUTER @@ -422,6 +532,7 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> OptimizeBr time_range_constraint=None, where_filter_specs=tuple(), pushdown_enabled_types=current_pushdown_state.pushdown_enabled_types, + applied_where_filter_specs=current_pushdown_state.applied_where_filter_specs, ) else: updated_pushdown_state = PredicatePushdownState.without_time_range_constraint(current_pushdown_state) diff --git a/metricflow/plan_conversion/node_processor.py b/metricflow/plan_conversion/node_processor.py index a2422fdd33..8b68e22705 100644 --- a/metricflow/plan_conversion/node_processor.py +++ b/metricflow/plan_conversion/node_processor.py @@ -84,20 +84,24 @@ class PredicateInputType(Enum): class PredicatePushdownState: """Container class for maintaining state information relevant for predicate pushdown. - This broadly tracks two related items: + This broadly tracks three related items: 1. Filter predicates collected during the process of constructing a dataflow plan 2. Predicate types eligible for pushdown + 3. Filters which have been applied already - The former may be updated as things like time constraints get altered or metric and measure filters are + The first may be updated as things like time constraints get altered or metric and measure filters are added to the query filters. - The latter may be updated based on query configuration, like if a cumulative metric is added to the plan + The second may be updated based on query configuration, like if a cumulative metric is added to the plan there may be changes to what sort of predicate pushdown operations are supported. + The last will be updated as filters are applied via pushdown or by the original WhereConstraintNode. - The time_range_constraint property holds the time window for setting up a time range filter expression. + Finally, the time_range_constraint property holds the time window for setting up a time range filter expression. """ time_range_constraint: Optional[TimeRangeConstraint] + # TODO: Deduplicate where_filter_specs where_filter_specs: Sequence[WhereFilterSpec] + applied_where_filter_specs: FrozenSet[WhereFilterSpec] = frozenset() pushdown_enabled_types: FrozenSet[PredicateInputType] = frozenset( [PredicateInputType.TIME_RANGE_CONSTRAINT, PredicateInputType.CATEGORICAL_DIMENSION] ) @@ -215,6 +219,7 @@ def with_time_range_constraint( time_range_constraint=time_range_constraint, pushdown_enabled_types=pushdown_enabled_types, where_filter_specs=original_pushdown_state.where_filter_specs, + applied_where_filter_specs=original_pushdown_state.applied_where_filter_specs, ) @staticmethod @@ -236,6 +241,7 @@ def without_time_range_constraint( time_range_constraint=None, pushdown_enabled_types=pushdown_enabled_types, where_filter_specs=original_pushdown_state.where_filter_specs, + applied_where_filter_specs=original_pushdown_state.applied_where_filter_specs, ) @staticmethod @@ -264,6 +270,23 @@ def with_where_filter_specs( time_range_constraint=original_pushdown_state.time_range_constraint, where_filter_specs=where_filter_specs, pushdown_enabled_types=original_pushdown_state.pushdown_enabled_types, + applied_where_filter_specs=original_pushdown_state.applied_where_filter_specs, + ) + + @staticmethod + def with_pushdown_applied_where_filter_specs( + original_pushdown_state: PredicatePushdownState, pushdown_applied_where_filter_specs: FrozenSet[WhereFilterSpec] + ) -> PredicatePushdownState: + """Factory method for replacing pushdown applied where filter specs in pushdown operations. + + This is useful for managing propagation - both forwards and backwards - of where filter specs that have been + applied via a pushdown operation. + """ + return PredicatePushdownState( + time_range_constraint=original_pushdown_state.time_range_constraint, + pushdown_enabled_types=original_pushdown_state.pushdown_enabled_types, + where_filter_specs=original_pushdown_state.where_filter_specs, + applied_where_filter_specs=pushdown_applied_where_filter_specs, ) @staticmethod diff --git a/tests_metricflow/dataflow/builder/test_predicate_pushdown.py b/tests_metricflow/dataflow/builder/test_predicate_pushdown.py deleted file mode 100644 index a3cc4e3e4b..0000000000 --- a/tests_metricflow/dataflow/builder/test_predicate_pushdown.py +++ /dev/null @@ -1,45 +0,0 @@ -from __future__ import annotations - -import pytest -from metricflow_semantics.filters.time_constraint import TimeRangeConstraint - -from metricflow.plan_conversion.node_processor import PredicateInputType, PredicatePushdownState - - -@pytest.fixture -def fully_enabled_pushdown_state() -> PredicatePushdownState: - """Tests a valid configuration with all predicate properties set and pushdown fully enabled.""" - params = PredicatePushdownState(time_range_constraint=TimeRangeConstraint.all_time(), where_filter_specs=tuple()) - return params - - -def test_time_range_pushdown_enabled_states(fully_enabled_pushdown_state: PredicatePushdownState) -> None: - """Tests pushdown enabled check for time range pushdown operations.""" - time_range_only_state = PredicatePushdownState( - time_range_constraint=TimeRangeConstraint.all_time(), - pushdown_enabled_types=frozenset([PredicateInputType.TIME_RANGE_CONSTRAINT]), - where_filter_specs=tuple(), - ) - - enabled_states = { - "fully enabled": fully_enabled_pushdown_state.has_time_range_constraint_to_push_down, - "enabled for time range only": time_range_only_state.has_time_range_constraint_to_push_down, - } - - assert all(list(enabled_states.values())), ( - "Expected pushdown to be enabled for pushdown state with time range constraint and global pushdown enabled, " - "but some states returned False for has_time_range_constraint_to_push_down.\n" - f"Pushdown enabled states: {enabled_states}\n" - f"Fully enabled state: {fully_enabled_pushdown_state}\n" - f"Time range only state: {time_range_only_state}" - ) - - -def test_invalid_disabled_pushdown_state() -> None: - """Tests checks for invalid param configuration on disabled pushdown parameters.""" - with pytest.raises(AssertionError, match="Disabled pushdown state objects cannot have properties set"): - PredicatePushdownState( - time_range_constraint=TimeRangeConstraint.all_time(), - pushdown_enabled_types=frozenset(), - where_filter_specs=tuple(), - ) diff --git a/tests_metricflow/dataflow/optimizer/test_predicate_pushdown_optimizer.py b/tests_metricflow/dataflow/optimizer/test_predicate_pushdown_optimizer.py new file mode 100644 index 0000000000..73ab080a3f --- /dev/null +++ b/tests_metricflow/dataflow/optimizer/test_predicate_pushdown_optimizer.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +from datetime import datetime + +import pytest +from metricflow_semantics.filters.time_constraint import TimeRangeConstraint +from metricflow_semantics.specs.spec_classes import WhereFilterSpec +from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters + +from metricflow.dataflow.optimizer.predicate_pushdown_optimizer import PredicatePushdownBranchStateTracker +from metricflow.plan_conversion.node_processor import PredicateInputType, PredicatePushdownState + + +@pytest.fixture +def fully_enabled_pushdown_state() -> PredicatePushdownState: + """Provides a valid configuration with all predicate properties set and pushdown fully enabled.""" + params = PredicatePushdownState(time_range_constraint=TimeRangeConstraint.all_time(), where_filter_specs=tuple()) + return params + + +@pytest.fixture +def branch_state_tracker(fully_enabled_pushdown_state: PredicatePushdownState) -> PredicatePushdownBranchStateTracker: + """Provides a branch state tracker for direct testing of update mechanics.""" + return PredicatePushdownBranchStateTracker(initial_state=fully_enabled_pushdown_state) + + +def test_time_range_pushdown_enabled_states(fully_enabled_pushdown_state: PredicatePushdownState) -> None: + """Tests pushdown enabled check for time range pushdown operations.""" + time_range_only_state = PredicatePushdownState( + time_range_constraint=TimeRangeConstraint.all_time(), + pushdown_enabled_types=frozenset([PredicateInputType.TIME_RANGE_CONSTRAINT]), + where_filter_specs=tuple(), + ) + + enabled_states = { + "fully enabled": fully_enabled_pushdown_state.has_time_range_constraint_to_push_down, + "enabled for time range only": time_range_only_state.has_time_range_constraint_to_push_down, + } + + assert all(list(enabled_states.values())), ( + "Expected pushdown to be enabled for pushdown state with time range constraint and global pushdown enabled, " + "but some states returned False for has_time_range_constraint_to_push_down.\n" + f"Pushdown enabled states: {enabled_states}\n" + f"Fully enabled state: {fully_enabled_pushdown_state}\n" + f"Time range only state: {time_range_only_state}" + ) + + +def test_invalid_disabled_pushdown_state() -> None: + """Tests checks for invalid param configuration on disabled pushdown parameters.""" + with pytest.raises(AssertionError, match="Disabled pushdown state objects cannot have properties set"): + PredicatePushdownState( + time_range_constraint=TimeRangeConstraint.all_time(), + pushdown_enabled_types=frozenset(), + where_filter_specs=tuple(), + ) + + +def test_branch_state_propagation(branch_state_tracker: PredicatePushdownBranchStateTracker) -> None: + """Tests forward propagation of predicate pushdown branch state. + + This asserts against expected results on entry and exit of a three-hop nested propagation. + """ + base_state = branch_state_tracker.last_pushdown_state + where_state = PredicatePushdownState.with_where_filter_specs( + original_pushdown_state=base_state, + where_filter_specs=( + WhereFilterSpec( + where_sql="x is true", bind_parameters=SqlBindParameters(), linkable_elements=(), linkable_specs=() + ), + ), + ) + time_state = PredicatePushdownState.with_time_range_constraint( + original_pushdown_state=base_state, + time_range_constraint=TimeRangeConstraint(datetime(2024, 1, 1), datetime(2024, 1, 1)), + ) + state_updates = (time_state, where_state, time_state) + with branch_state_tracker.track_pushdown_state(state_updates[0]): + assert branch_state_tracker.last_pushdown_state == state_updates[0], "Failed to track first state update!" + with branch_state_tracker.track_pushdown_state(state_updates[1]): + assert branch_state_tracker.last_pushdown_state == state_updates[1], "Failed to track second state update!" + with branch_state_tracker.track_pushdown_state(state_updates[2]): + assert ( + branch_state_tracker.last_pushdown_state == state_updates[2] + ), "Failed to track third state update!" + + assert branch_state_tracker.last_pushdown_state == state_updates[1], "Failed to remove third state update!" + + assert branch_state_tracker.last_pushdown_state == state_updates[0], "Failed to remove second state update!" + + assert branch_state_tracker.last_pushdown_state == base_state, "Failed to remove first state update!" + + +def test_applied_filter_back_propagation(branch_state_tracker: PredicatePushdownBranchStateTracker) -> None: + """Tests backwards propagation of applied where filter annotations. + + This asserts that propagation on entry remains unaffected while the applied where filter annotations are + back-propagated as expected after exit, both for cases where an update was applied on entry to the + context manager and where the value was overridden just prior to exit from the context manager. + """ + base_state = branch_state_tracker.last_pushdown_state + where_spec_x_is_true = WhereFilterSpec( + where_sql="x is true", bind_parameters=SqlBindParameters(), linkable_elements=(), linkable_specs=() + ) + where_spec_y_is_null = WhereFilterSpec( + where_sql="y is null", bind_parameters=SqlBindParameters(), linkable_elements=(), linkable_specs=() + ) + + where_state = PredicatePushdownState.with_where_filter_specs( + original_pushdown_state=base_state, where_filter_specs=(where_spec_x_is_true, where_spec_y_is_null) + ) + x_applied_state = PredicatePushdownState.with_pushdown_applied_where_filter_specs( + original_pushdown_state=where_state, pushdown_applied_where_filter_specs=frozenset((where_spec_x_is_true,)) + ) + both_applied_state = PredicatePushdownState.with_pushdown_applied_where_filter_specs( + original_pushdown_state=base_state, + pushdown_applied_where_filter_specs=frozenset((where_spec_x_is_true, where_spec_y_is_null)), + ) + + with branch_state_tracker.track_pushdown_state(base_state): + assert ( + branch_state_tracker.last_pushdown_state == base_state + ), "Initial condition AND initial tracking mis-configured!" + with branch_state_tracker.track_pushdown_state(where_state): + assert branch_state_tracker.last_pushdown_state == where_state, "Failed to track where state!" + with branch_state_tracker.track_pushdown_state(x_applied_state): + assert ( + branch_state_tracker.last_pushdown_state == x_applied_state + ), "Failed to track applied filter state!" + + assert ( + branch_state_tracker.last_pushdown_state.applied_where_filter_specs + == x_applied_state.applied_where_filter_specs + ), "Failed to back-propagate applied filter state!" + # Update internally from where state + branch_state_tracker.override_last_pushdown_state( + PredicatePushdownState.with_pushdown_applied_where_filter_specs( + original_pushdown_state=branch_state_tracker.last_pushdown_state, + pushdown_applied_where_filter_specs=frozenset((where_spec_x_is_true, where_spec_y_is_null)), + ) + ) + + assert not branch_state_tracker.last_pushdown_state.has_where_filters_to_push_down, ( + f"Failed to remove where filter state update! Should be {base_state} but got " + f"{branch_state_tracker.last_pushdown_state}!" + ) + assert branch_state_tracker.last_pushdown_state == both_applied_state + + # We expect to propagate back to the initial entry since we only ever want to apply a filter once within a branch + assert branch_state_tracker.last_pushdown_state == both_applied_state