From cac84519e7171b85f9dc477fdeb20d6e04aecf7f Mon Sep 17 00:00:00 2001 From: tlento Date: Tue, 25 Jun 2024 17:11:08 -0700 Subject: [PATCH] Improve readability of PredicatePushdownStateTracker Added some assertions to make it more obvious what the current branch state expectations are, and greatly expanded the documentation of current behavior. There is an update to the order of operations in the join handling nodes as well, which aligns the processing with the expanded documentation in the state tracking object. --- .../optimizer/predicate_pushdown_optimizer.py | 54 +++++++++++++++++-- 1 file changed, 50 insertions(+), 4 deletions(-) diff --git a/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py b/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py index 7c61cf1ba1..a4b4e0b0cb 100644 --- a/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py +++ b/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py @@ -60,6 +60,39 @@ def track_pushdown_state(self, pushdown_state: PredicatePushdownState) -> Iterat node can evaluate which where filter specs needs to be applied. We capture the complete set because we may have sequenced nodes where some where filters are applied (e.g., time dimension filters might be applied to metric time nodes, etc.). + + The state tracking and propagation is built to work as follows: + + For a simple DAG of where_node -> join_node -> source_nodes there will be two branches: + + where_node -> join_node -> left_source_node + where_node -> join_node -> right_source_node + + In this case the where_node receives the initial predicate pushdown state, and then adds its own + updated state object (with the where filters) via the context manager and propagates that to the join_node. + + The join_node then receives the where_node's predicate pushdown state, and, for each branch, adds an + updated state object via the context manager and propagates the updated state objects to the next node. + + The left_source_node gets the join node's left branch state and evaluates it. If it can apply any filters + it adds an updated state object to note that the filters are applied and propagates it along via the context + manager. In this case, the context manager exits immediately and returns to the left_source_node, which + finishes with applying the where constraints and returns back to the join_node. + + At this point, the join node has a left branch context manager with the left_join_branch pushdown state. The + join_node does NOT have access to the left_source_node's pushdown state, but it needs to be able to notify its + parent, the where_node, that some filters have been applied at the left_source_node. + + How does it do this? The left_source_node's state update included applied where filters. When the context + manager exits it propagates the left_source_node's applied where filters back up to the preceding state (in + this case, the join node's left branch state). The same thing happens when the branch states for the join + node exit the context manager, so the where_node then sees the union of all filters applied downstream. + + The where_node, then, has access to the complete set of filters applied downstream. + + This is complicated because of joins - we can't store a single set of applied filters, because there's no good + way to keep them organized in the case of multiple join branches. Instead, we track up and down a single + branch, and merge the events of parent branches at the join nodes that created them. """ self._current_branch_state.append(pushdown_state) yield @@ -84,6 +117,10 @@ def last_pushdown_state(self) -> PredicatePushdownState: This is the input state a given node processing method should be using for pushdown operations. """ + assert len(self._current_branch_state) > 0, ( + "There should always be at least one element in current branch state! " + "This suggests an inappropriate removal or improper initialization." + ) return self._current_branch_state[-1] def override_last_pushdown_state(self, pushdown_state: PredicatePushdownState) -> None: @@ -97,6 +134,10 @@ def override_last_pushdown_state(self, pushdown_state: PredicatePushdownState) - Since this is not something we want people doing by accident we designate it as a special method rather than making it a property setter. """ + assert len(self._current_branch_state) > 0, ( + "There should always be at least one element in current branch state! " + "This suggests an inappropriate removal or improper initialization." + ) self._current_branch_state[-1] = pushdown_state @@ -418,6 +459,9 @@ def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> OptimizeBranc of the join. As such, time constraints are not propagated to the right side of the join. This restriction may be relaxed at a later time, but for now it is largely irrelevant since we do not allow fanout joins and do not yet have support for pre-filters based on time ranges for things like SCD joins. + + Note we initialize branch state tracking objects prior to traversal to avoid back-propagation from + one branch affecting the predicate pushdown behavior along other branches. """ self._log_visit_node_type(node) left_parent = node.left_node @@ -428,16 +472,18 @@ def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> OptimizeBranc else: left_branch_pushdown_state = self._predicate_pushdown_tracker.last_pushdown_state - optimized_parents: List[OptimizeBranchResult] = [] - with self._predicate_pushdown_tracker.track_pushdown_state(left_branch_pushdown_state): - optimized_parents.append(left_parent.accept(self)) - base_right_branch_pushdown_state = PredicatePushdownState.without_time_range_constraint( self._predicate_pushdown_tracker.last_pushdown_state ) outer_join_right_branch_pushdown_state = PredicatePushdownState.without_where_filter_specs( original_pushdown_state=base_right_branch_pushdown_state ) + + optimized_parents: List[OptimizeBranchResult] = [] + + with self._predicate_pushdown_tracker.track_pushdown_state(left_branch_pushdown_state): + optimized_parents.append(left_parent.accept(self)) + for join_description in node.join_targets: if ( join_description.join_type is SqlJoinType.LEFT_OUTER