From 46517d01be160888932eab19b4e03aed2272361a Mon Sep 17 00:00:00 2001 From: tlento Date: Sun, 23 Jun 2024 13:16:47 -0700 Subject: [PATCH] Consolidate where constraint predicate pushdown management Pushdown operations for where constraints were a bit scattered around due to the transition from build-time to optimize-time handling. This consolidates some of the mechanics. In particular, the where constraint pushdown state updates have been centralized into the PredicatePushdownState object, which allows for more streamlined updates at the callsites, and the where spec propagation has also been simplified to allow for all filter specs to be evaluated at once instead of splitting the evaluation between the WhereConstraintNode and input source node handlers. --- .../optimizer/predicate_pushdown_optimizer.py | 67 ++++++++----------- metricflow/plan_conversion/node_processor.py | 35 +++++++--- 2 files changed, 56 insertions(+), 46 deletions(-) diff --git a/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py b/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py index 79b4790dd4..ec202d0c0d 100644 --- a/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py +++ b/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py @@ -233,10 +233,17 @@ def _push_down_where_filters( for filter_spec in current_pushdown_state.where_filter_specs: filter_spec_semantic_models = self._models_for_spec(filter_spec) + invalid_element_types = [ + element + for element in filter_spec.linkable_elements + if element.element_type not in current_pushdown_state.pushdown_eligible_element_types + ] + if len(filter_spec_semantic_models) != 1 or len(invalid_element_types) > 0: + continue + all_linkable_specs_match = all(spec in source_node_linkable_specs for spec in filter_spec.linkable_specs) - semantic_models_match = ( - len(filter_spec_semantic_models) == 1 and filter_spec_semantic_models[0] == source_semantic_model - ) + # TODO: Handle the case where entities can be defined in multiple models, only one of which need match + semantic_models_match = filter_spec_semantic_models[0] == source_semantic_model if all_linkable_specs_match and semantic_models_match: filters_to_apply.append(filter_spec) else: @@ -277,32 +284,24 @@ def visit_constrain_time_range_node(self, node: ConstrainTimeRangeNode) -> Optim def visit_where_constraint_node(self, node: WhereConstraintNode) -> OptimizeBranchResult: """Adds where filters from the input node to the current pushdown state. - The WhereConstraintNode carries the filter information in the form of WhereFilterSpecs. For any - filter specs that may be eligible for predicate pushdown this node will add them to the pushdown state. + The WhereConstraintNode carries the filter information in the form of WhereFilterSpecs, which may or may + not be eligible for pushdown. This processor simply propagates them forward so long as where filter + predicate pushdown is still enabled for this branch. + The fact that they have been added at this point does not mean they will be pushed down, as intervening - join nodes might remove them from consideration, so we retain them here as well in order to ensure all - filters are applied as specified. + join nodes might remove them from consideration, so we retain them ensure all filters are applied as specified + within this method. + + TODO: Update to only apply filters that have not been pushed down """ self._log_visit_node_type(node) current_pushdown_state = self._predicate_pushdown_tracker.last_pushdown_state if not current_pushdown_state.where_filter_pushdown_enabled: return self._default_handler(node) - where_specs = node.input_where_specs - pushdown_eligible_specs: List[WhereFilterSpec] = [] - for spec in where_specs: - semantic_models = self._models_for_spec(spec) - invalid_element_types = [ - element - for element in spec.linkable_elements - if element.element_type not in current_pushdown_state.pushdown_eligible_element_types - ] - if len(semantic_models) != 1 or len(invalid_element_types) > 0: - continue - pushdown_eligible_specs.append(spec) - - updated_pushdown_state = PredicatePushdownState.with_additional_where_filter_specs( - original_pushdown_state=current_pushdown_state, additional_where_filter_specs=tuple(pushdown_eligible_specs) + updated_pushdown_state = PredicatePushdownState.with_where_filter_specs( + original_pushdown_state=current_pushdown_state, + where_filter_specs=tuple(current_pushdown_state.where_filter_specs) + tuple(node.input_where_specs), ) return self._default_handler(node=node, pushdown_state=updated_pushdown_state) @@ -323,10 +322,8 @@ def visit_combine_aggregated_outputs_node(self, node: CombineAggregatedOutputsNo """ self._log_visit_node_type(node) # TODO: move this "remove where filters" logic into PredicatePushdownState - updated_pushdown_state = PredicatePushdownState( - time_range_constraint=self._predicate_pushdown_tracker.last_pushdown_state.time_range_constraint, - where_filter_specs=tuple(), - pushdown_enabled_types=self._predicate_pushdown_tracker.last_pushdown_state.pushdown_enabled_types, + updated_pushdown_state = PredicatePushdownState.without_where_filter_specs( + original_pushdown_state=self._predicate_pushdown_tracker.last_pushdown_state, ) return self._default_handler(node=node, pushdown_state=updated_pushdown_state) @@ -342,10 +339,8 @@ def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> O """ self._log_visit_node_type(node) - base_node_pushdown_state = PredicatePushdownState( - time_range_constraint=self._predicate_pushdown_tracker.last_pushdown_state.time_range_constraint, - where_filter_specs=tuple(), - pushdown_enabled_types=self._predicate_pushdown_tracker.last_pushdown_state.pushdown_enabled_types, + base_node_pushdown_state = PredicatePushdownState.without_where_filter_specs( + original_pushdown_state=self._predicate_pushdown_tracker.last_pushdown_state, ) # The conversion metric branch silently removes all filters, so this is a redundant operation. # TODO: Enable pushdown for the conversion metric branch when filters are supported @@ -384,10 +379,8 @@ def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> OptimizeBranc self._log_visit_node_type(node) left_parent = node.left_node if any(join_description.join_type is SqlJoinType.FULL_OUTER for join_description in node.join_targets): - left_branch_pushdown_state = PredicatePushdownState( - time_range_constraint=self._predicate_pushdown_tracker.last_pushdown_state.time_range_constraint, - where_filter_specs=tuple(), - pushdown_enabled_types=self._predicate_pushdown_tracker.last_pushdown_state.pushdown_enabled_types, + left_branch_pushdown_state = PredicatePushdownState.without_where_filter_specs( + original_pushdown_state=self._predicate_pushdown_tracker.last_pushdown_state, ) else: left_branch_pushdown_state = self._predicate_pushdown_tracker.last_pushdown_state @@ -399,10 +392,8 @@ def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> OptimizeBranc base_right_branch_pushdown_state = PredicatePushdownState.without_time_range_constraint( self._predicate_pushdown_tracker.last_pushdown_state ) - outer_join_right_branch_pushdown_state = PredicatePushdownState( - time_range_constraint=None, - where_filter_specs=tuple(), - pushdown_enabled_types=base_right_branch_pushdown_state.pushdown_enabled_types, + outer_join_right_branch_pushdown_state = PredicatePushdownState.without_where_filter_specs( + original_pushdown_state=base_right_branch_pushdown_state ) for join_description in node.join_targets: if ( diff --git a/metricflow/plan_conversion/node_processor.py b/metricflow/plan_conversion/node_processor.py index 28054d49bb..a2422fdd33 100644 --- a/metricflow/plan_conversion/node_processor.py +++ b/metricflow/plan_conversion/node_processor.py @@ -221,7 +221,14 @@ def with_time_range_constraint( def without_time_range_constraint( original_pushdown_state: PredicatePushdownState, ) -> PredicatePushdownState: - """Factory method for updating pushdown state to bypass time range constraints.""" + """Factory method for updating pushdown state to bypass time range constraints. + + This eliminates time range constraint pushdown as an option, since the only reason to remove + time range constraint metadata is to turn it off, so we avoid potential issues where + a second ConstrainTimeRange node might update the pushdown state. + + TODO: replace or rename this method. + """ pushdown_enabled_types = original_pushdown_state.pushdown_enabled_types.difference( {PredicateInputType.TIME_RANGE_CONSTRAINT} ) @@ -232,18 +239,30 @@ def without_time_range_constraint( ) @staticmethod - def with_additional_where_filter_specs( - original_pushdown_state: PredicatePushdownState, additional_where_filter_specs: Sequence[WhereFilterSpec] + def without_where_filter_specs( + original_pushdown_state: PredicatePushdownState, + ) -> PredicatePushdownState: + """Factory method for updating pushdown state to remove existing where filter specs. + + This simply blanks out the where filter specs without altering which types of pushdown are available. + """ + return PredicatePushdownState.with_where_filter_specs( + original_pushdown_state=original_pushdown_state, + where_filter_specs=tuple(), + ) + + @staticmethod + def with_where_filter_specs( + original_pushdown_state: PredicatePushdownState, where_filter_specs: Sequence[WhereFilterSpec] ) -> PredicatePushdownState: - """Factory method for adding additional WhereFilterSpecs for pushdown operations. + """Factory method for replacing WhereFilterSpecs in pushdown operations. - This requires that the PushdownState allow for where filters - time range only or disabled states will - raise an exception, and must be checked externally. + This requires that the PushdownState allow for where filters - time range only or disabled states will raise + an exception, and must be checked externally. """ - updated_where_specs = tuple(original_pushdown_state.where_filter_specs) + tuple(additional_where_filter_specs) return PredicatePushdownState( time_range_constraint=original_pushdown_state.time_range_constraint, - where_filter_specs=updated_where_specs, + where_filter_specs=where_filter_specs, pushdown_enabled_types=original_pushdown_state.pushdown_enabled_types, )