Skip to content

Commit

Permalink
Track and propagate applied where filter specs to outer plan nodes
Browse files Browse the repository at this point in the history
The PredicatePushdownOptimizer currently pushes predicates down
along the DataflowPlan DAG from the outermost WhereConstraintNode to
as close to the source node for that branch as possible. This results
in duplicate where filter application, because the WhereConstraintNode
does not have any way of evaluating whether or not a given set of
where filter specs could be applied downstream.

This change adds the tracking mechanism for propagating the filters
applied back up along the branch. As of now this is a tracking-change
only - the selective application of these filters will follow shortly.

In addition to the added test cases for the propagation mechanism, the
propagation mechanics were observed via testing several pushdown-enabled
rendering tests with the `log-cli-level=DEBUG` flag set in pytest.
  • Loading branch information
tlento committed Jun 25, 2024
1 parent eaf6079 commit 63efcda
Show file tree
Hide file tree
Showing 4 changed files with 255 additions and 66 deletions.
95 changes: 78 additions & 17 deletions metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ class PredicatePushdownBranchStateTracker:
"""Tracking class for monitoring pushdown state at the node level during a visitor walk."""

def __init__(self, initial_state: PredicatePushdownState) -> None: # noqa: D107
self._initial_state = initial_state
self._current_branch_state: List[PredicatePushdownState] = []
self._current_branch_state: List[PredicatePushdownState] = [initial_state]

@contextmanager
def track_pushdown_state(self, pushdown_state: PredicatePushdownState) -> Iterator[None]:
Expand All @@ -56,20 +55,45 @@ def track_pushdown_state(self, pushdown_state: PredicatePushdownState) -> Iterat
This retains a sequence of pushdown state objects to allow for tracking pushdown opportunities along
the current branch. Each entry represents the predicate pushdown state as of that point, and as such
callers need only concern themselves with the value returned by the last_pushdown_state property.
The back-propagation of pushdown_applied_where_filter_specs is necessary to ensure the outer query
node can evaluate which where filter specs needs to be applied. We capture the complete set because
we may have sequenced nodes where some where filters are applied (e.g., time dimension filters might
be applied to metric time nodes, etc.).
"""
self._current_branch_state.append(pushdown_state)
yield
self._current_branch_state.pop(-1)
last_visited_pushdown_state = self._current_branch_state.pop(-1)
if len(last_visited_pushdown_state.applied_where_filter_specs) > 0:
pushdown_applied_where_filter_specs = frozenset.union(
*[
last_visited_pushdown_state.applied_where_filter_specs,
self.last_pushdown_state.applied_where_filter_specs,
]
)
self.override_last_pushdown_state(
PredicatePushdownState.with_pushdown_applied_where_filter_specs(
original_pushdown_state=self.last_pushdown_state,
pushdown_applied_where_filter_specs=pushdown_applied_where_filter_specs,
)
)

@property
def last_pushdown_state(self) -> PredicatePushdownState:
"""Returns the last seen PredicatePushdownState.
This is nearly always the input state a given node processing method should be using for pushdown operations.
This is the input state a given node processing method should be using for pushdown operations.
"""
if len(self._current_branch_state) > 0:
return self._current_branch_state[-1]
return self._initial_state
return self._current_branch_state[-1]

def override_last_pushdown_state(self, pushdown_state: PredicatePushdownState) -> None:
"""Method for forcibly updating the last seen predicate pushdown state to a new value.
This is necessary only for cases where we wish to back-propagate some updated state attribute
for handling in the exit condition of the preceding node in the DAG. Since it is something of an
extraordinary circumstance we designate it as a special method rather than making it a property setter.
"""
self._current_branch_state[-1] = pushdown_state


class PredicatePushdownOptimizer(
Expand All @@ -81,8 +105,10 @@ class PredicatePushdownOptimizer(
This evaluates filter predicates to determine which, if any, can be directly to an input source node.
It operates by walking each branch in the DataflowPlan and collecting pushdown state information, then
evaluating that state at the input source node and applying the filter node (e.g., a WhereConstraintNode)
directly to the source. As the optimizer unrolls back through the branch it will remove the duplicated
constraint node if it is appropriate to do so.
directly to the source. As the optimizer unrolls back through the branch it will remove duplicated constraints.
This assumes that we never do predicate pushdown on a filter that needs to be re-applied, so every filter
we encounter gets applied exactly once per nested subquery branch encapsulated by a given constraint node.
"""

def __init__(self, node_data_set_resolver: DataflowPlanNodeOutputDataSetResolver) -> None:
Expand Down Expand Up @@ -128,8 +154,8 @@ def _default_handler(
) -> OptimizeBranchResult:
"""Encapsulates state updates and handling for most node types.
The dominant majority of nodes simply propagate the current predicate pushdown state along and return
whatever output the parent nodes produce. Of the nodes that do not do this, the most common deviation
The most common node-level operation is to simply propagate the current predicate pushdown state along and
return whatever output the parent nodes produce. Of the nodes that do not do this, the most common deviation
is a pushdown state update.
As such, this method defaults to propagating the last seen state, but allows an override for cases where
Expand Down Expand Up @@ -229,13 +255,16 @@ def _push_down_where_filters(
filters_left_over.append(filter_spec)

logger.log(level=self._log_level, msg=f"Filter specs to add:\n{filters_to_apply}")
applied_filters = frozenset.union(
*[frozenset(current_pushdown_state.applied_where_filter_specs), frozenset(filters_to_apply)]
)
updated_pushdown_state = PredicatePushdownState(
time_range_constraint=current_pushdown_state.time_range_constraint,
where_filter_specs=tuple(filters_left_over),
pushdown_enabled_types=current_pushdown_state.pushdown_enabled_types,
applied_where_filter_specs=applied_filters,
)
optimized_node = self._default_handler(node=node, pushdown_state=updated_pushdown_state)
# TODO: propagate filters applied back up the branch for removal
if len(filters_to_apply) > 0:
return OptimizeBranchResult(
optimized_branch=WhereConstraintNode(
Expand Down Expand Up @@ -267,11 +296,14 @@ def visit_where_constraint_node(self, node: WhereConstraintNode) -> OptimizeBran
not be eligible for pushdown. This processor simply propagates them forward so long as where filter
predicate pushdown is still enabled for this branch.
The fact that they have been added at this point does not mean they will be pushed down, as intervening
join nodes might remove them from consideration, so we retain them ensure all filters are applied as specified
within this method.
When the visitor returns to this node from its parents, it updates the pushdown state for this node in the
tracker. It does this within the scope of the context manager in order to keep the pushdown state updates
consistent - modifying only the state entry associated with this node, and allowing the tracker to do all
of the upstream state propagation.
TODO: Update to only apply filters that have not been pushed down
The fact that the filters have been added at this point does not mean they will be pushed down, as intervening
join nodes might remove them from consideration, so we must ensure all filters are applied as specified
within this method.
"""
self._log_visit_node_type(node)
current_pushdown_state = self._predicate_pushdown_tracker.last_pushdown_state
Expand All @@ -283,7 +315,35 @@ def visit_where_constraint_node(self, node: WhereConstraintNode) -> OptimizeBran
where_filter_specs=tuple(current_pushdown_state.where_filter_specs) + tuple(node.input_where_specs),
)

return self._default_handler(node=node, pushdown_state=updated_pushdown_state)
with self._predicate_pushdown_tracker.track_pushdown_state(updated_pushdown_state):
optimized_parent: OptimizeBranchResult = node.parent_node.accept(self)
# TODO: Update to only apply filters that have not been successfully pushed down
optimized_node = OptimizeBranchResult(
optimized_branch=node.with_new_parents((optimized_parent.optimized_branch,))
)

pushdown_state_updated_by_parent = self._predicate_pushdown_tracker.last_pushdown_state
# Override the pushdown state for this node and allow all upstream propagation to be handled by the tracker
if len(pushdown_state_updated_by_parent.applied_where_filter_specs) > 0:
updated_specs = frozenset.union(
frozenset(node.input_where_specs),
pushdown_state_updated_by_parent.applied_where_filter_specs,
)
self._predicate_pushdown_tracker.override_last_pushdown_state(
PredicatePushdownState.with_pushdown_applied_where_filter_specs(
original_pushdown_state=pushdown_state_updated_by_parent,
pushdown_applied_where_filter_specs=updated_specs,
)
)
logger.log(
level=self._log_level,
msg=(
f"Added applied specs to pushdown state. Added specs:\n\n{node.input_where_specs}\n\n"
+ f"Updated pushdown state:\n\n{self._predicate_pushdown_tracker.last_pushdown_state}"
),
)

return optimized_node

# Join nodes - these may affect pushdown state based on join type

Expand Down Expand Up @@ -422,6 +482,7 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> OptimizeBr
time_range_constraint=None,
where_filter_specs=tuple(),
pushdown_enabled_types=current_pushdown_state.pushdown_enabled_types,
applied_where_filter_specs=current_pushdown_state.applied_where_filter_specs,
)
else:
updated_pushdown_state = PredicatePushdownState.without_time_range_constraint(current_pushdown_state)
Expand Down
31 changes: 27 additions & 4 deletions metricflow/plan_conversion/node_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,20 +84,24 @@ class PredicateInputType(Enum):
class PredicatePushdownState:
"""Container class for maintaining state information relevant for predicate pushdown.
This broadly tracks two related items:
This broadly tracks three related items:
1. Filter predicates collected during the process of constructing a dataflow plan
2. Predicate types eligible for pushdown
3. Filters which have been applied already
The former may be updated as things like time constraints get altered or metric and measure filters are
The first may be updated as things like time constraints get altered or metric and measure filters are
added to the query filters.
The latter may be updated based on query configuration, like if a cumulative metric is added to the plan
The second may be updated based on query configuration, like if a cumulative metric is added to the plan
there may be changes to what sort of predicate pushdown operations are supported.
The last will be updated as filters are applied via pushdown or by the original WhereConstraintNode.
The time_range_constraint property holds the time window for setting up a time range filter expression.
Finally, the time_range_constraint property holds the time window for setting up a time range filter expression.
"""

time_range_constraint: Optional[TimeRangeConstraint]
# TODO: Deduplicate where_filter_specs
where_filter_specs: Sequence[WhereFilterSpec]
applied_where_filter_specs: FrozenSet[WhereFilterSpec] = frozenset()
pushdown_enabled_types: FrozenSet[PredicateInputType] = frozenset(
[PredicateInputType.TIME_RANGE_CONSTRAINT, PredicateInputType.CATEGORICAL_DIMENSION]
)
Expand Down Expand Up @@ -215,6 +219,7 @@ def with_time_range_constraint(
time_range_constraint=time_range_constraint,
pushdown_enabled_types=pushdown_enabled_types,
where_filter_specs=original_pushdown_state.where_filter_specs,
applied_where_filter_specs=original_pushdown_state.applied_where_filter_specs,
)

@staticmethod
Expand All @@ -236,6 +241,7 @@ def without_time_range_constraint(
time_range_constraint=None,
pushdown_enabled_types=pushdown_enabled_types,
where_filter_specs=original_pushdown_state.where_filter_specs,
applied_where_filter_specs=original_pushdown_state.applied_where_filter_specs,
)

@staticmethod
Expand Down Expand Up @@ -264,6 +270,23 @@ def with_where_filter_specs(
time_range_constraint=original_pushdown_state.time_range_constraint,
where_filter_specs=where_filter_specs,
pushdown_enabled_types=original_pushdown_state.pushdown_enabled_types,
applied_where_filter_specs=original_pushdown_state.applied_where_filter_specs,
)

@staticmethod
def with_pushdown_applied_where_filter_specs(
original_pushdown_state: PredicatePushdownState, pushdown_applied_where_filter_specs: FrozenSet[WhereFilterSpec]
) -> PredicatePushdownState:
"""Factory method for replacing pushdown applied where filter specs in pushdown operations.
This is useful for managing propagation - both forwards and backwards - of where filter specs that have been
applied via a pushdown operation.
"""
return PredicatePushdownState(
time_range_constraint=original_pushdown_state.time_range_constraint,
pushdown_enabled_types=original_pushdown_state.pushdown_enabled_types,
where_filter_specs=original_pushdown_state.where_filter_specs,
applied_where_filter_specs=pushdown_applied_where_filter_specs,
)

@staticmethod
Expand Down
45 changes: 0 additions & 45 deletions tests_metricflow/dataflow/builder/test_predicate_pushdown.py

This file was deleted.

Loading

0 comments on commit 63efcda

Please sign in to comment.