Skip to content

Commit

Permalink
Consolidate where constraint predicate pushdown management
Browse files Browse the repository at this point in the history
Pushdown operations for where constraints were a bit scattered
around due to the transition from build-time to optimize-time
handling. This consolidates some of the mechanics.

In particular, the where constraint pushdown state updates have
been centralized into the PredicatePushdownState object, which
allows for more streamlined updates at the callsites, and the
where spec propagation has also been simplified to allow for
all filter specs to be evaluated at once instead of splitting
the evaluation between the WhereConstraintNode and input source
node handlers.
  • Loading branch information
tlento committed Jun 26, 2024
1 parent dc8bfad commit 46517d0
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 46 deletions.
67 changes: 29 additions & 38 deletions metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,17 @@ def _push_down_where_filters(

for filter_spec in current_pushdown_state.where_filter_specs:
filter_spec_semantic_models = self._models_for_spec(filter_spec)
invalid_element_types = [
element
for element in filter_spec.linkable_elements
if element.element_type not in current_pushdown_state.pushdown_eligible_element_types
]
if len(filter_spec_semantic_models) != 1 or len(invalid_element_types) > 0:
continue

all_linkable_specs_match = all(spec in source_node_linkable_specs for spec in filter_spec.linkable_specs)
semantic_models_match = (
len(filter_spec_semantic_models) == 1 and filter_spec_semantic_models[0] == source_semantic_model
)
# TODO: Handle the case where entities can be defined in multiple models, only one of which need match
semantic_models_match = filter_spec_semantic_models[0] == source_semantic_model
if all_linkable_specs_match and semantic_models_match:
filters_to_apply.append(filter_spec)
else:
Expand Down Expand Up @@ -277,32 +284,24 @@ def visit_constrain_time_range_node(self, node: ConstrainTimeRangeNode) -> Optim
def visit_where_constraint_node(self, node: WhereConstraintNode) -> OptimizeBranchResult:
"""Adds where filters from the input node to the current pushdown state.
The WhereConstraintNode carries the filter information in the form of WhereFilterSpecs. For any
filter specs that may be eligible for predicate pushdown this node will add them to the pushdown state.
The WhereConstraintNode carries the filter information in the form of WhereFilterSpecs, which may or may
not be eligible for pushdown. This processor simply propagates them forward so long as where filter
predicate pushdown is still enabled for this branch.
The fact that they have been added at this point does not mean they will be pushed down, as intervening
join nodes might remove them from consideration, so we retain them here as well in order to ensure all
filters are applied as specified.
join nodes might remove them from consideration, so we retain them ensure all filters are applied as specified
within this method.
TODO: Update to only apply filters that have not been pushed down
"""
self._log_visit_node_type(node)
current_pushdown_state = self._predicate_pushdown_tracker.last_pushdown_state
if not current_pushdown_state.where_filter_pushdown_enabled:
return self._default_handler(node)

where_specs = node.input_where_specs
pushdown_eligible_specs: List[WhereFilterSpec] = []
for spec in where_specs:
semantic_models = self._models_for_spec(spec)
invalid_element_types = [
element
for element in spec.linkable_elements
if element.element_type not in current_pushdown_state.pushdown_eligible_element_types
]
if len(semantic_models) != 1 or len(invalid_element_types) > 0:
continue
pushdown_eligible_specs.append(spec)

updated_pushdown_state = PredicatePushdownState.with_additional_where_filter_specs(
original_pushdown_state=current_pushdown_state, additional_where_filter_specs=tuple(pushdown_eligible_specs)
updated_pushdown_state = PredicatePushdownState.with_where_filter_specs(
original_pushdown_state=current_pushdown_state,
where_filter_specs=tuple(current_pushdown_state.where_filter_specs) + tuple(node.input_where_specs),
)

return self._default_handler(node=node, pushdown_state=updated_pushdown_state)
Expand All @@ -323,10 +322,8 @@ def visit_combine_aggregated_outputs_node(self, node: CombineAggregatedOutputsNo
"""
self._log_visit_node_type(node)
# TODO: move this "remove where filters" logic into PredicatePushdownState
updated_pushdown_state = PredicatePushdownState(
time_range_constraint=self._predicate_pushdown_tracker.last_pushdown_state.time_range_constraint,
where_filter_specs=tuple(),
pushdown_enabled_types=self._predicate_pushdown_tracker.last_pushdown_state.pushdown_enabled_types,
updated_pushdown_state = PredicatePushdownState.without_where_filter_specs(
original_pushdown_state=self._predicate_pushdown_tracker.last_pushdown_state,
)
return self._default_handler(node=node, pushdown_state=updated_pushdown_state)

Expand All @@ -342,10 +339,8 @@ def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> O
"""
self._log_visit_node_type(node)
base_node_pushdown_state = PredicatePushdownState(
time_range_constraint=self._predicate_pushdown_tracker.last_pushdown_state.time_range_constraint,
where_filter_specs=tuple(),
pushdown_enabled_types=self._predicate_pushdown_tracker.last_pushdown_state.pushdown_enabled_types,
base_node_pushdown_state = PredicatePushdownState.without_where_filter_specs(
original_pushdown_state=self._predicate_pushdown_tracker.last_pushdown_state,
)
# The conversion metric branch silently removes all filters, so this is a redundant operation.
# TODO: Enable pushdown for the conversion metric branch when filters are supported
Expand Down Expand Up @@ -384,10 +379,8 @@ def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> OptimizeBranc
self._log_visit_node_type(node)
left_parent = node.left_node
if any(join_description.join_type is SqlJoinType.FULL_OUTER for join_description in node.join_targets):
left_branch_pushdown_state = PredicatePushdownState(
time_range_constraint=self._predicate_pushdown_tracker.last_pushdown_state.time_range_constraint,
where_filter_specs=tuple(),
pushdown_enabled_types=self._predicate_pushdown_tracker.last_pushdown_state.pushdown_enabled_types,
left_branch_pushdown_state = PredicatePushdownState.without_where_filter_specs(
original_pushdown_state=self._predicate_pushdown_tracker.last_pushdown_state,
)
else:
left_branch_pushdown_state = self._predicate_pushdown_tracker.last_pushdown_state
Expand All @@ -399,10 +392,8 @@ def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> OptimizeBranc
base_right_branch_pushdown_state = PredicatePushdownState.without_time_range_constraint(
self._predicate_pushdown_tracker.last_pushdown_state
)
outer_join_right_branch_pushdown_state = PredicatePushdownState(
time_range_constraint=None,
where_filter_specs=tuple(),
pushdown_enabled_types=base_right_branch_pushdown_state.pushdown_enabled_types,
outer_join_right_branch_pushdown_state = PredicatePushdownState.without_where_filter_specs(
original_pushdown_state=base_right_branch_pushdown_state
)
for join_description in node.join_targets:
if (
Expand Down
35 changes: 27 additions & 8 deletions metricflow/plan_conversion/node_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,14 @@ def with_time_range_constraint(
def without_time_range_constraint(
original_pushdown_state: PredicatePushdownState,
) -> PredicatePushdownState:
"""Factory method for updating pushdown state to bypass time range constraints."""
"""Factory method for updating pushdown state to bypass time range constraints.
This eliminates time range constraint pushdown as an option, since the only reason to remove
time range constraint metadata is to turn it off, so we avoid potential issues where
a second ConstrainTimeRange node might update the pushdown state.
TODO: replace or rename this method.
"""
pushdown_enabled_types = original_pushdown_state.pushdown_enabled_types.difference(
{PredicateInputType.TIME_RANGE_CONSTRAINT}
)
Expand All @@ -232,18 +239,30 @@ def without_time_range_constraint(
)

@staticmethod
def with_additional_where_filter_specs(
original_pushdown_state: PredicatePushdownState, additional_where_filter_specs: Sequence[WhereFilterSpec]
def without_where_filter_specs(
original_pushdown_state: PredicatePushdownState,
) -> PredicatePushdownState:
"""Factory method for updating pushdown state to remove existing where filter specs.
This simply blanks out the where filter specs without altering which types of pushdown are available.
"""
return PredicatePushdownState.with_where_filter_specs(
original_pushdown_state=original_pushdown_state,
where_filter_specs=tuple(),
)

@staticmethod
def with_where_filter_specs(
original_pushdown_state: PredicatePushdownState, where_filter_specs: Sequence[WhereFilterSpec]
) -> PredicatePushdownState:
"""Factory method for adding additional WhereFilterSpecs for pushdown operations.
"""Factory method for replacing WhereFilterSpecs in pushdown operations.
This requires that the PushdownState allow for where filters - time range only or disabled states will
raise an exception, and must be checked externally.
This requires that the PushdownState allow for where filters - time range only or disabled states will raise
an exception, and must be checked externally.
"""
updated_where_specs = tuple(original_pushdown_state.where_filter_specs) + tuple(additional_where_filter_specs)
return PredicatePushdownState(
time_range_constraint=original_pushdown_state.time_range_constraint,
where_filter_specs=updated_where_specs,
where_filter_specs=where_filter_specs,
pushdown_enabled_types=original_pushdown_state.pushdown_enabled_types,
)

Expand Down

0 comments on commit 46517d0

Please sign in to comment.