From 11f049accabbdd435cdafa40879fe59f282b0bc3 Mon Sep 17 00:00:00 2001 From: tlento Date: Mon, 24 Jun 2024 12:10:16 -0700 Subject: [PATCH] Simplify predicate pushdown state tracking The original implementation of the state tracker for predicate pushdown was built around an assumption that node-level and node-path access would be useful for propagating the application of filter predicates back up the dependency chain. After working towards an implementation of that logic it's clear that this is not particularly helpful. Although this tracking is theoretically useful for debugging scenarios the node-visit-level logging appears adequate, so we simplify our tracking before adding additional logic. --- .../optimizer/predicate_pushdown_optimizer.py | 51 ++++++------------- 1 file changed, 15 insertions(+), 36 deletions(-) diff --git a/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py b/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py index ec202d0c0d..1e624a6496 100644 --- a/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py +++ b/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py @@ -2,8 +2,7 @@ import logging from contextlib import contextmanager -from dataclasses import dataclass -from typing import Iterator, List, Optional, Sequence, Tuple, Union +from typing import Iterator, List, Optional, Sequence, Union from dbt_semantic_interfaces.references import SemanticModelReference from metricflow_semantics.dag.id_prefix import StaticIdPrefix @@ -43,51 +42,31 @@ logger = logging.getLogger(__name__) -@dataclass(frozen=True) -class PredicatePushdownBranchState: - """State tracking class for managing predicate pushdown along a given branch. - - This class is meant to show the state as of a given moment in time, matched up with the history of - all nodes visited. - - TODO: streamline into a single sequence. - """ - - branch_pushdown_state: Tuple[PredicatePushdownState, ...] - node_path: Tuple[DataflowPlanNode, ...] - - class PredicatePushdownBranchStateTracker: """Tracking class for monitoring pushdown state at the node level during a visitor walk.""" def __init__(self, initial_state: PredicatePushdownState) -> None: # noqa: D107 self._initial_state = initial_state self._current_branch_state: List[PredicatePushdownState] = [] - self._current_node_path: List[DataflowPlanNode] = [] @contextmanager - def track_pushdown_state( - self, node: DataflowPlanNode, pushdown_state: PredicatePushdownState - ) -> Iterator[PredicatePushdownBranchState]: + def track_pushdown_state(self, pushdown_state: PredicatePushdownState) -> Iterator[None]: """Context manager used to track pushdown state along branches in a Dataflow Plan. - This updates the branch state and node path on entry, and then pops the last entry off on exit in order to - allow tracking of pushdown state at the level of each node without repeating state or leaking to sibling - branches. - - TODO: combine this with the DagTraversalPathTracker + This retains a sequence of pushdown state objects to allow for tracking pushdown opportunities along + the current branch. Each entry represents the predicate pushdown state as of that point, and as such + callers need only concern themselves with the value returned by the last_pushdown_state property. """ self._current_branch_state.append(pushdown_state) - self._current_node_path.append(node) - yield PredicatePushdownBranchState( - branch_pushdown_state=tuple(self._current_branch_state), node_path=tuple(self._current_node_path) - ) + yield self._current_branch_state.pop(-1) - self._current_node_path.pop(-1) @property def last_pushdown_state(self) -> PredicatePushdownState: - """Returns the last seen PredicatePushdownState.""" + """Returns the last seen PredicatePushdownState. + + This is nearly always the input state a given node processing method should be using for pushdown operations. + """ if len(self._current_branch_state) > 0: return self._current_branch_state[-1] return self._initial_state @@ -159,7 +138,7 @@ def _default_handler( if pushdown_state is None: pushdown_state = self._predicate_pushdown_tracker.last_pushdown_state - with self._predicate_pushdown_tracker.track_pushdown_state(node, pushdown_state): + with self._predicate_pushdown_tracker.track_pushdown_state(pushdown_state): optimized_parents: Sequence[OptimizeBranchResult] = tuple( parent_node.accept(self) for parent_node in node.parent_nodes ) @@ -347,10 +326,10 @@ def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> O conversion_node_pushdown_state = PredicatePushdownState.with_pushdown_disabled() optimized_parents: List[OptimizeBranchResult] = [] - with self._predicate_pushdown_tracker.track_pushdown_state(node, base_node_pushdown_state): + with self._predicate_pushdown_tracker.track_pushdown_state(base_node_pushdown_state): optimized_parents.append(node.base_node.accept(self)) - with self._predicate_pushdown_tracker.track_pushdown_state(node, conversion_node_pushdown_state): + with self._predicate_pushdown_tracker.track_pushdown_state(conversion_node_pushdown_state): optimized_parents.append(node.conversion_node.accept(self)) return OptimizeBranchResult( @@ -386,7 +365,7 @@ def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> OptimizeBranc left_branch_pushdown_state = self._predicate_pushdown_tracker.last_pushdown_state optimized_parents: List[OptimizeBranchResult] = [] - with self._predicate_pushdown_tracker.track_pushdown_state(node, left_branch_pushdown_state): + with self._predicate_pushdown_tracker.track_pushdown_state(left_branch_pushdown_state): optimized_parents.append(left_parent.accept(self)) base_right_branch_pushdown_state = PredicatePushdownState.without_time_range_constraint( @@ -403,7 +382,7 @@ def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> OptimizeBranc right_branch_pushdown_state = outer_join_right_branch_pushdown_state else: right_branch_pushdown_state = base_right_branch_pushdown_state - with self._predicate_pushdown_tracker.track_pushdown_state(node, right_branch_pushdown_state): + with self._predicate_pushdown_tracker.track_pushdown_state(right_branch_pushdown_state): optimized_parents.append(join_description.join_node.accept(self)) return OptimizeBranchResult(