Skip to content

Commit

Permalink
Simplify predicate pushdown state tracking
Browse files Browse the repository at this point in the history
The original implementation of the state tracker for predicate pushdown
was built around an assumption that node-level and node-path access would
be useful for propagating the application of filter predicates back
up the dependency chain.

After working towards an implementation of that logic it's clear that this
is not particularly helpful. Although this tracking is theoretically useful
for debugging scenarios the node-visit-level logging appears adequate, so we
simplify our tracking before adding additional logic.
  • Loading branch information
tlento committed Jun 26, 2024
1 parent 46517d0 commit 11f049a
Showing 1 changed file with 15 additions and 36 deletions.
51 changes: 15 additions & 36 deletions metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

import logging
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Iterator, List, Optional, Sequence, Tuple, Union
from typing import Iterator, List, Optional, Sequence, Union

from dbt_semantic_interfaces.references import SemanticModelReference
from metricflow_semantics.dag.id_prefix import StaticIdPrefix
Expand Down Expand Up @@ -43,51 +42,31 @@
logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class PredicatePushdownBranchState:
"""State tracking class for managing predicate pushdown along a given branch.
This class is meant to show the state as of a given moment in time, matched up with the history of
all nodes visited.
TODO: streamline into a single sequence.
"""

branch_pushdown_state: Tuple[PredicatePushdownState, ...]
node_path: Tuple[DataflowPlanNode, ...]


class PredicatePushdownBranchStateTracker:
"""Tracking class for monitoring pushdown state at the node level during a visitor walk."""

def __init__(self, initial_state: PredicatePushdownState) -> None: # noqa: D107
self._initial_state = initial_state
self._current_branch_state: List[PredicatePushdownState] = []
self._current_node_path: List[DataflowPlanNode] = []

@contextmanager
def track_pushdown_state(
self, node: DataflowPlanNode, pushdown_state: PredicatePushdownState
) -> Iterator[PredicatePushdownBranchState]:
def track_pushdown_state(self, pushdown_state: PredicatePushdownState) -> Iterator[None]:
"""Context manager used to track pushdown state along branches in a Dataflow Plan.
This updates the branch state and node path on entry, and then pops the last entry off on exit in order to
allow tracking of pushdown state at the level of each node without repeating state or leaking to sibling
branches.
TODO: combine this with the DagTraversalPathTracker
This retains a sequence of pushdown state objects to allow for tracking pushdown opportunities along
the current branch. Each entry represents the predicate pushdown state as of that point, and as such
callers need only concern themselves with the value returned by the last_pushdown_state property.
"""
self._current_branch_state.append(pushdown_state)
self._current_node_path.append(node)
yield PredicatePushdownBranchState(
branch_pushdown_state=tuple(self._current_branch_state), node_path=tuple(self._current_node_path)
)
yield
self._current_branch_state.pop(-1)
self._current_node_path.pop(-1)

@property
def last_pushdown_state(self) -> PredicatePushdownState:
"""Returns the last seen PredicatePushdownState."""
"""Returns the last seen PredicatePushdownState.
This is nearly always the input state a given node processing method should be using for pushdown operations.
"""
if len(self._current_branch_state) > 0:
return self._current_branch_state[-1]
return self._initial_state
Expand Down Expand Up @@ -159,7 +138,7 @@ def _default_handler(
if pushdown_state is None:
pushdown_state = self._predicate_pushdown_tracker.last_pushdown_state

with self._predicate_pushdown_tracker.track_pushdown_state(node, pushdown_state):
with self._predicate_pushdown_tracker.track_pushdown_state(pushdown_state):
optimized_parents: Sequence[OptimizeBranchResult] = tuple(
parent_node.accept(self) for parent_node in node.parent_nodes
)
Expand Down Expand Up @@ -347,10 +326,10 @@ def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> O
conversion_node_pushdown_state = PredicatePushdownState.with_pushdown_disabled()

optimized_parents: List[OptimizeBranchResult] = []
with self._predicate_pushdown_tracker.track_pushdown_state(node, base_node_pushdown_state):
with self._predicate_pushdown_tracker.track_pushdown_state(base_node_pushdown_state):
optimized_parents.append(node.base_node.accept(self))

with self._predicate_pushdown_tracker.track_pushdown_state(node, conversion_node_pushdown_state):
with self._predicate_pushdown_tracker.track_pushdown_state(conversion_node_pushdown_state):
optimized_parents.append(node.conversion_node.accept(self))

return OptimizeBranchResult(
Expand Down Expand Up @@ -386,7 +365,7 @@ def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> OptimizeBranc
left_branch_pushdown_state = self._predicate_pushdown_tracker.last_pushdown_state

optimized_parents: List[OptimizeBranchResult] = []
with self._predicate_pushdown_tracker.track_pushdown_state(node, left_branch_pushdown_state):
with self._predicate_pushdown_tracker.track_pushdown_state(left_branch_pushdown_state):
optimized_parents.append(left_parent.accept(self))

base_right_branch_pushdown_state = PredicatePushdownState.without_time_range_constraint(
Expand All @@ -403,7 +382,7 @@ def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> OptimizeBranc
right_branch_pushdown_state = outer_join_right_branch_pushdown_state
else:
right_branch_pushdown_state = base_right_branch_pushdown_state
with self._predicate_pushdown_tracker.track_pushdown_state(node, right_branch_pushdown_state):
with self._predicate_pushdown_tracker.track_pushdown_state(right_branch_pushdown_state):
optimized_parents.append(join_description.join_node.accept(self))

return OptimizeBranchResult(
Expand Down

0 comments on commit 11f049a

Please sign in to comment.