diff --git a/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py b/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py index ec202d0c0d..1e624a6496 100644 --- a/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py +++ b/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py @@ -2,8 +2,7 @@ import logging from contextlib import contextmanager -from dataclasses import dataclass -from typing import Iterator, List, Optional, Sequence, Tuple, Union +from typing import Iterator, List, Optional, Sequence, Union from dbt_semantic_interfaces.references import SemanticModelReference from metricflow_semantics.dag.id_prefix import StaticIdPrefix @@ -43,51 +42,31 @@ logger = logging.getLogger(__name__) -@dataclass(frozen=True) -class PredicatePushdownBranchState: - """State tracking class for managing predicate pushdown along a given branch. - - This class is meant to show the state as of a given moment in time, matched up with the history of - all nodes visited. - - TODO: streamline into a single sequence. - """ - - branch_pushdown_state: Tuple[PredicatePushdownState, ...] - node_path: Tuple[DataflowPlanNode, ...] - - class PredicatePushdownBranchStateTracker: """Tracking class for monitoring pushdown state at the node level during a visitor walk.""" def __init__(self, initial_state: PredicatePushdownState) -> None: # noqa: D107 self._initial_state = initial_state self._current_branch_state: List[PredicatePushdownState] = [] - self._current_node_path: List[DataflowPlanNode] = [] @contextmanager - def track_pushdown_state( - self, node: DataflowPlanNode, pushdown_state: PredicatePushdownState - ) -> Iterator[PredicatePushdownBranchState]: + def track_pushdown_state(self, pushdown_state: PredicatePushdownState) -> Iterator[None]: """Context manager used to track pushdown state along branches in a Dataflow Plan. - This updates the branch state and node path on entry, and then pops the last entry off on exit in order to - allow tracking of pushdown state at the level of each node without repeating state or leaking to sibling - branches. - - TODO: combine this with the DagTraversalPathTracker + This retains a sequence of pushdown state objects to allow for tracking pushdown opportunities along + the current branch. Each entry represents the predicate pushdown state as of that point, and as such + callers need only concern themselves with the value returned by the last_pushdown_state property. """ self._current_branch_state.append(pushdown_state) - self._current_node_path.append(node) - yield PredicatePushdownBranchState( - branch_pushdown_state=tuple(self._current_branch_state), node_path=tuple(self._current_node_path) - ) + yield self._current_branch_state.pop(-1) - self._current_node_path.pop(-1) @property def last_pushdown_state(self) -> PredicatePushdownState: - """Returns the last seen PredicatePushdownState.""" + """Returns the last seen PredicatePushdownState. + + This is nearly always the input state a given node processing method should be using for pushdown operations. + """ if len(self._current_branch_state) > 0: return self._current_branch_state[-1] return self._initial_state @@ -159,7 +138,7 @@ def _default_handler( if pushdown_state is None: pushdown_state = self._predicate_pushdown_tracker.last_pushdown_state - with self._predicate_pushdown_tracker.track_pushdown_state(node, pushdown_state): + with self._predicate_pushdown_tracker.track_pushdown_state(pushdown_state): optimized_parents: Sequence[OptimizeBranchResult] = tuple( parent_node.accept(self) for parent_node in node.parent_nodes ) @@ -347,10 +326,10 @@ def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> O conversion_node_pushdown_state = PredicatePushdownState.with_pushdown_disabled() optimized_parents: List[OptimizeBranchResult] = [] - with self._predicate_pushdown_tracker.track_pushdown_state(node, base_node_pushdown_state): + with self._predicate_pushdown_tracker.track_pushdown_state(base_node_pushdown_state): optimized_parents.append(node.base_node.accept(self)) - with self._predicate_pushdown_tracker.track_pushdown_state(node, conversion_node_pushdown_state): + with self._predicate_pushdown_tracker.track_pushdown_state(conversion_node_pushdown_state): optimized_parents.append(node.conversion_node.accept(self)) return OptimizeBranchResult( @@ -386,7 +365,7 @@ def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> OptimizeBranc left_branch_pushdown_state = self._predicate_pushdown_tracker.last_pushdown_state optimized_parents: List[OptimizeBranchResult] = [] - with self._predicate_pushdown_tracker.track_pushdown_state(node, left_branch_pushdown_state): + with self._predicate_pushdown_tracker.track_pushdown_state(left_branch_pushdown_state): optimized_parents.append(left_parent.accept(self)) base_right_branch_pushdown_state = PredicatePushdownState.without_time_range_constraint( @@ -403,7 +382,7 @@ def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> OptimizeBranc right_branch_pushdown_state = outer_join_right_branch_pushdown_state else: right_branch_pushdown_state = base_right_branch_pushdown_state - with self._predicate_pushdown_tracker.track_pushdown_state(node, right_branch_pushdown_state): + with self._predicate_pushdown_tracker.track_pushdown_state(right_branch_pushdown_state): optimized_parents.append(join_description.join_node.accept(self)) return OptimizeBranchResult(