Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Track and propagate applied where filter specs to outer plan nodes #1302

Merged
merged 3 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 132 additions & 21 deletions metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ class PredicatePushdownBranchStateTracker:
"""Tracking class for monitoring pushdown state at the node level during a visitor walk."""

def __init__(self, initial_state: PredicatePushdownState) -> None: # noqa: D107
self._initial_state = initial_state
self._current_branch_state: List[PredicatePushdownState] = []
self._current_branch_state: List[PredicatePushdownState] = [initial_state]

@contextmanager
def track_pushdown_state(self, pushdown_state: PredicatePushdownState) -> Iterator[None]:
Expand All @@ -56,20 +55,90 @@ def track_pushdown_state(self, pushdown_state: PredicatePushdownState) -> Iterat
This retains a sequence of pushdown state objects to allow for tracking pushdown opportunities along
the current branch. Each entry represents the predicate pushdown state as of that point, and as such
callers need only concern themselves with the value returned by the last_pushdown_state property.

The back-propagation of pushdown_applied_where_filter_specs is necessary to ensure the outer query
node can evaluate which where filter specs needs to be applied. We capture the complete set because
we may have sequenced nodes where some where filters are applied (e.g., time dimension filters might
be applied to metric time nodes, etc.).

The state tracking and propagation is built to work as follows:

For a simple DAG of where_node -> join_node -> source_nodes there will be two branches:

where_node -> join_node -> left_source_node
where_node -> join_node -> right_source_node

In this case the where_node receives the initial predicate pushdown state, and then adds its own
updated state object (with the where filters) via the context manager and propagates that to the join_node.

The join_node then receives the where_node's predicate pushdown state, and, for each branch, adds an
updated state object via the context manager and propagates the updated state objects to the next node.

The left_source_node gets the join node's left branch state and evaluates it. If it can apply any filters
it adds an updated state object to note that the filters are applied and propagates it along via the context
manager. In this case, the context manager exits immediately and returns to the left_source_node, which
finishes with applying the where constraints and returns back to the join_node.

At this point, the join node has a left branch context manager with the left_join_branch pushdown state. The
join_node does NOT have access to the left_source_node's pushdown state, but it needs to be able to notify its
parent, the where_node, that some filters have been applied at the left_source_node.

How does it do this? The left_source_node's state update included applied where filters. When the context
manager exits it propagates the left_source_node's applied where filters back up to the preceding state (in
this case, the join node's left branch state). The same thing happens when the branch states for the join
node exit the context manager, so the where_node then sees the union of all filters applied downstream.

The where_node, then, has access to the complete set of filters applied downstream.

This is complicated because of joins - we can't store a single set of applied filters, because there's no good
way to keep them organized in the case of multiple join branches. Instead, we track up and down a single
branch, and merge the events of parent branches at the join nodes that created them.
"""
self._current_branch_state.append(pushdown_state)
yield
self._current_branch_state.pop(-1)
last_visited_pushdown_state = self._current_branch_state.pop(-1)
if len(last_visited_pushdown_state.applied_where_filter_specs) > 0:
pushdown_applied_where_filter_specs = frozenset.union(
*[
last_visited_pushdown_state.applied_where_filter_specs,
self.last_pushdown_state.applied_where_filter_specs,
]
)
self.override_last_pushdown_state(
PredicatePushdownState.with_pushdown_applied_where_filter_specs(
original_pushdown_state=self.last_pushdown_state,
pushdown_applied_where_filter_specs=pushdown_applied_where_filter_specs,
)
)

@property
def last_pushdown_state(self) -> PredicatePushdownState:
"""Returns the last seen PredicatePushdownState.

This is nearly always the input state a given node processing method should be using for pushdown operations.
This is the input state a given node processing method should be using for pushdown operations.
"""
assert len(self._current_branch_state) > 0, (
"There should always be at least one element in current branch state! "
"This suggests an inappropriate removal or improper initialization."
)
return self._current_branch_state[-1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any concern about KeyError handling here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, because there's always a value set since we initialize it non-empty, but I should put in an assertion guard so we get a useful error message if anybody changes that.


def override_last_pushdown_state(self, pushdown_state: PredicatePushdownState) -> None:
"""Method for forcibly updating the last seen predicate pushdown state to a new value.

This is necessary only for cases where we wish to back-propagate some updated state attribute
for handling in the exit condition of the preceding node in the DAG. The scenario where we use
this here is for indicating that a where filter has been applied elsewhere on the branch, and so
outer nodes can skip application as appropriate.

Since this is not something we want people doing by accident we designate it as a special method
rather than making it a property setter.
"""
if len(self._current_branch_state) > 0:
return self._current_branch_state[-1]
return self._initial_state
assert len(self._current_branch_state) > 0, (
"There should always be at least one element in current branch state! "
"This suggests an inappropriate removal or improper initialization."
)
self._current_branch_state[-1] = pushdown_state


class PredicatePushdownOptimizer(
Expand All @@ -81,8 +150,10 @@ class PredicatePushdownOptimizer(
This evaluates filter predicates to determine which, if any, can be directly to an input source node.
It operates by walking each branch in the DataflowPlan and collecting pushdown state information, then
evaluating that state at the input source node and applying the filter node (e.g., a WhereConstraintNode)
directly to the source. As the optimizer unrolls back through the branch it will remove the duplicated
constraint node if it is appropriate to do so.
directly to the source. As the optimizer unrolls back through the branch it will remove duplicated constraints.

This assumes that we never do predicate pushdown on a filter that needs to be re-applied, so every filter
we encounter gets applied exactly once per nested subquery branch encapsulated by a given constraint node.
"""

def __init__(self, node_data_set_resolver: DataflowPlanNodeOutputDataSetResolver) -> None:
Expand Down Expand Up @@ -128,8 +199,8 @@ def _default_handler(
) -> OptimizeBranchResult:
"""Encapsulates state updates and handling for most node types.

The dominant majority of nodes simply propagate the current predicate pushdown state along and return
whatever output the parent nodes produce. Of the nodes that do not do this, the most common deviation
The most common node-level operation is to simply propagate the current predicate pushdown state along and
return whatever output the parent nodes produce. Of the nodes that do not do this, the most common deviation
is a pushdown state update.

As such, this method defaults to propagating the last seen state, but allows an override for cases where
Expand Down Expand Up @@ -229,13 +300,16 @@ def _push_down_where_filters(
filters_left_over.append(filter_spec)

logger.log(level=self._log_level, msg=f"Filter specs to add:\n{filters_to_apply}")
applied_filters = frozenset.union(
*[frozenset(current_pushdown_state.applied_where_filter_specs), frozenset(filters_to_apply)]
)
updated_pushdown_state = PredicatePushdownState(
time_range_constraint=current_pushdown_state.time_range_constraint,
where_filter_specs=tuple(filters_left_over),
pushdown_enabled_types=current_pushdown_state.pushdown_enabled_types,
applied_where_filter_specs=applied_filters,
)
optimized_node = self._default_handler(node=node, pushdown_state=updated_pushdown_state)
# TODO: propagate filters applied back up the branch for removal
if len(filters_to_apply) > 0:
return OptimizeBranchResult(
optimized_branch=WhereConstraintNode(
Expand Down Expand Up @@ -267,11 +341,14 @@ def visit_where_constraint_node(self, node: WhereConstraintNode) -> OptimizeBran
not be eligible for pushdown. This processor simply propagates them forward so long as where filter
predicate pushdown is still enabled for this branch.

The fact that they have been added at this point does not mean they will be pushed down, as intervening
join nodes might remove them from consideration, so we retain them ensure all filters are applied as specified
within this method.
When the visitor returns to this node from its parents, it updates the pushdown state for this node in the
tracker. It does this within the scope of the context manager in order to keep the pushdown state updates
consistent - modifying only the state entry associated with this node, and allowing the tracker to do all
of the upstream state propagation.

TODO: Update to only apply filters that have not been pushed down
The fact that the filters have been added at this point does not mean they will be pushed down, as intervening
join nodes might remove them from consideration, so we must ensure all filters are applied as specified
within this method.
"""
self._log_visit_node_type(node)
current_pushdown_state = self._predicate_pushdown_tracker.last_pushdown_state
Expand All @@ -283,7 +360,35 @@ def visit_where_constraint_node(self, node: WhereConstraintNode) -> OptimizeBran
where_filter_specs=tuple(current_pushdown_state.where_filter_specs) + tuple(node.input_where_specs),
)

return self._default_handler(node=node, pushdown_state=updated_pushdown_state)
with self._predicate_pushdown_tracker.track_pushdown_state(updated_pushdown_state):
optimized_parent: OptimizeBranchResult = node.parent_node.accept(self)
# TODO: Update to only apply filters that have not been successfully pushed down
optimized_node = OptimizeBranchResult(
optimized_branch=node.with_new_parents((optimized_parent.optimized_branch,))
)

pushdown_state_updated_by_parent = self._predicate_pushdown_tracker.last_pushdown_state
# Override the pushdown state for this node and allow all upstream propagation to be handled by the tracker
if len(pushdown_state_updated_by_parent.applied_where_filter_specs) > 0:
updated_specs = frozenset.union(
frozenset(node.input_where_specs),
pushdown_state_updated_by_parent.applied_where_filter_specs,
)
self._predicate_pushdown_tracker.override_last_pushdown_state(
PredicatePushdownState.with_pushdown_applied_where_filter_specs(
original_pushdown_state=pushdown_state_updated_by_parent,
pushdown_applied_where_filter_specs=updated_specs,
)
)
logger.log(
level=self._log_level,
msg=(
f"Added applied specs to pushdown state. Added specs:\n\n{node.input_where_specs}\n\n"
+ f"Updated pushdown state:\n\n{self._predicate_pushdown_tracker.last_pushdown_state}"
),
)

return optimized_node

# Join nodes - these may affect pushdown state based on join type

Expand Down Expand Up @@ -354,6 +459,9 @@ def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> OptimizeBranc
of the join. As such, time constraints are not propagated to the right side of the join. This restriction
may be relaxed at a later time, but for now it is largely irrelevant since we do not allow fanout joins and
do not yet have support for pre-filters based on time ranges for things like SCD joins.

Note we initialize branch state tracking objects prior to traversal to avoid back-propagation from
one branch affecting the predicate pushdown behavior along other branches.
"""
self._log_visit_node_type(node)
left_parent = node.left_node
Expand All @@ -364,16 +472,18 @@ def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> OptimizeBranc
else:
left_branch_pushdown_state = self._predicate_pushdown_tracker.last_pushdown_state

optimized_parents: List[OptimizeBranchResult] = []
with self._predicate_pushdown_tracker.track_pushdown_state(left_branch_pushdown_state):
optimized_parents.append(left_parent.accept(self))

base_right_branch_pushdown_state = PredicatePushdownState.without_time_range_constraint(
self._predicate_pushdown_tracker.last_pushdown_state
)
outer_join_right_branch_pushdown_state = PredicatePushdownState.without_where_filter_specs(
original_pushdown_state=base_right_branch_pushdown_state
)

optimized_parents: List[OptimizeBranchResult] = []

with self._predicate_pushdown_tracker.track_pushdown_state(left_branch_pushdown_state):
optimized_parents.append(left_parent.accept(self))

for join_description in node.join_targets:
if (
join_description.join_type is SqlJoinType.LEFT_OUTER
Expand Down Expand Up @@ -422,6 +532,7 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> OptimizeBr
time_range_constraint=None,
where_filter_specs=tuple(),
pushdown_enabled_types=current_pushdown_state.pushdown_enabled_types,
applied_where_filter_specs=current_pushdown_state.applied_where_filter_specs,
)
else:
updated_pushdown_state = PredicatePushdownState.without_time_range_constraint(current_pushdown_state)
Expand Down
31 changes: 27 additions & 4 deletions metricflow/plan_conversion/node_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,20 +84,24 @@ class PredicateInputType(Enum):
class PredicatePushdownState:
"""Container class for maintaining state information relevant for predicate pushdown.

This broadly tracks two related items:
This broadly tracks three related items:
1. Filter predicates collected during the process of constructing a dataflow plan
2. Predicate types eligible for pushdown
3. Filters which have been applied already

The former may be updated as things like time constraints get altered or metric and measure filters are
The first may be updated as things like time constraints get altered or metric and measure filters are
added to the query filters.
The latter may be updated based on query configuration, like if a cumulative metric is added to the plan
The second may be updated based on query configuration, like if a cumulative metric is added to the plan
there may be changes to what sort of predicate pushdown operations are supported.
The last will be updated as filters are applied via pushdown or by the original WhereConstraintNode.

The time_range_constraint property holds the time window for setting up a time range filter expression.
Finally, the time_range_constraint property holds the time window for setting up a time range filter expression.
"""

time_range_constraint: Optional[TimeRangeConstraint]
# TODO: Deduplicate where_filter_specs
where_filter_specs: Sequence[WhereFilterSpec]
applied_where_filter_specs: FrozenSet[WhereFilterSpec] = frozenset()
pushdown_enabled_types: FrozenSet[PredicateInputType] = frozenset(
[PredicateInputType.TIME_RANGE_CONSTRAINT, PredicateInputType.CATEGORICAL_DIMENSION]
)
Expand Down Expand Up @@ -215,6 +219,7 @@ def with_time_range_constraint(
time_range_constraint=time_range_constraint,
pushdown_enabled_types=pushdown_enabled_types,
where_filter_specs=original_pushdown_state.where_filter_specs,
applied_where_filter_specs=original_pushdown_state.applied_where_filter_specs,
)

@staticmethod
Expand All @@ -236,6 +241,7 @@ def without_time_range_constraint(
time_range_constraint=None,
pushdown_enabled_types=pushdown_enabled_types,
where_filter_specs=original_pushdown_state.where_filter_specs,
applied_where_filter_specs=original_pushdown_state.applied_where_filter_specs,
)

@staticmethod
Expand Down Expand Up @@ -264,6 +270,23 @@ def with_where_filter_specs(
time_range_constraint=original_pushdown_state.time_range_constraint,
where_filter_specs=where_filter_specs,
pushdown_enabled_types=original_pushdown_state.pushdown_enabled_types,
applied_where_filter_specs=original_pushdown_state.applied_where_filter_specs,
)

@staticmethod
def with_pushdown_applied_where_filter_specs(
original_pushdown_state: PredicatePushdownState, pushdown_applied_where_filter_specs: FrozenSet[WhereFilterSpec]
) -> PredicatePushdownState:
"""Factory method for replacing pushdown applied where filter specs in pushdown operations.

This is useful for managing propagation - both forwards and backwards - of where filter specs that have been
applied via a pushdown operation.
"""
return PredicatePushdownState(
time_range_constraint=original_pushdown_state.time_range_constraint,
pushdown_enabled_types=original_pushdown_state.pushdown_enabled_types,
where_filter_specs=original_pushdown_state.where_filter_specs,
applied_where_filter_specs=pushdown_applied_where_filter_specs,
)

@staticmethod
Expand Down
Loading
Loading