Skip to content

Commit

Permalink
Improve readability of PredicatePushdownStateTracker
Browse files Browse the repository at this point in the history
Added some assertions to make it more obvious what the current
branch state expectations are, and greatly expanded the documentation
of current behavior.

There is an update to the order of operations in the join handling nodes
as well, which aligns the processing with the expanded documentation in
the state tracking object.
  • Loading branch information
tlento committed Jun 26, 2024
1 parent 036cdb2 commit cac8451
Showing 1 changed file with 50 additions and 4 deletions.
54 changes: 50 additions & 4 deletions metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,39 @@ def track_pushdown_state(self, pushdown_state: PredicatePushdownState) -> Iterat
node can evaluate which where filter specs needs to be applied. We capture the complete set because
we may have sequenced nodes where some where filters are applied (e.g., time dimension filters might
be applied to metric time nodes, etc.).
The state tracking and propagation is built to work as follows:
For a simple DAG of where_node -> join_node -> source_nodes there will be two branches:
where_node -> join_node -> left_source_node
where_node -> join_node -> right_source_node
In this case the where_node receives the initial predicate pushdown state, and then adds its own
updated state object (with the where filters) via the context manager and propagates that to the join_node.
The join_node then receives the where_node's predicate pushdown state, and, for each branch, adds an
updated state object via the context manager and propagates the updated state objects to the next node.
The left_source_node gets the join node's left branch state and evaluates it. If it can apply any filters
it adds an updated state object to note that the filters are applied and propagates it along via the context
manager. In this case, the context manager exits immediately and returns to the left_source_node, which
finishes with applying the where constraints and returns back to the join_node.
At this point, the join node has a left branch context manager with the left_join_branch pushdown state. The
join_node does NOT have access to the left_source_node's pushdown state, but it needs to be able to notify its
parent, the where_node, that some filters have been applied at the left_source_node.
How does it do this? The left_source_node's state update included applied where filters. When the context
manager exits it propagates the left_source_node's applied where filters back up to the preceding state (in
this case, the join node's left branch state). The same thing happens when the branch states for the join
node exit the context manager, so the where_node then sees the union of all filters applied downstream.
The where_node, then, has access to the complete set of filters applied downstream.
This is complicated because of joins - we can't store a single set of applied filters, because there's no good
way to keep them organized in the case of multiple join branches. Instead, we track up and down a single
branch, and merge the events of parent branches at the join nodes that created them.
"""
self._current_branch_state.append(pushdown_state)
yield
Expand All @@ -84,6 +117,10 @@ def last_pushdown_state(self) -> PredicatePushdownState:
This is the input state a given node processing method should be using for pushdown operations.
"""
assert len(self._current_branch_state) > 0, (
"There should always be at least one element in current branch state! "
"This suggests an inappropriate removal or improper initialization."
)
return self._current_branch_state[-1]

def override_last_pushdown_state(self, pushdown_state: PredicatePushdownState) -> None:
Expand All @@ -97,6 +134,10 @@ def override_last_pushdown_state(self, pushdown_state: PredicatePushdownState) -
Since this is not something we want people doing by accident we designate it as a special method
rather than making it a property setter.
"""
assert len(self._current_branch_state) > 0, (
"There should always be at least one element in current branch state! "
"This suggests an inappropriate removal or improper initialization."
)
self._current_branch_state[-1] = pushdown_state


Expand Down Expand Up @@ -418,6 +459,9 @@ def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> OptimizeBranc
of the join. As such, time constraints are not propagated to the right side of the join. This restriction
may be relaxed at a later time, but for now it is largely irrelevant since we do not allow fanout joins and
do not yet have support for pre-filters based on time ranges for things like SCD joins.
Note we initialize branch state tracking objects prior to traversal to avoid back-propagation from
one branch affecting the predicate pushdown behavior along other branches.
"""
self._log_visit_node_type(node)
left_parent = node.left_node
Expand All @@ -428,16 +472,18 @@ def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> OptimizeBranc
else:
left_branch_pushdown_state = self._predicate_pushdown_tracker.last_pushdown_state

optimized_parents: List[OptimizeBranchResult] = []
with self._predicate_pushdown_tracker.track_pushdown_state(left_branch_pushdown_state):
optimized_parents.append(left_parent.accept(self))

base_right_branch_pushdown_state = PredicatePushdownState.without_time_range_constraint(
self._predicate_pushdown_tracker.last_pushdown_state
)
outer_join_right_branch_pushdown_state = PredicatePushdownState.without_where_filter_specs(
original_pushdown_state=base_right_branch_pushdown_state
)

optimized_parents: List[OptimizeBranchResult] = []

with self._predicate_pushdown_tracker.track_pushdown_state(left_branch_pushdown_state):
optimized_parents.append(left_parent.accept(self))

for join_description in node.join_targets:
if (
join_description.join_type is SqlJoinType.LEFT_OUTER
Expand Down

0 comments on commit cac8451

Please sign in to comment.