Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consolidate where constraint predicate pushdown management #1300

Merged
merged 1 commit into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 29 additions & 38 deletions metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,17 @@ def _push_down_where_filters(

for filter_spec in current_pushdown_state.where_filter_specs:
filter_spec_semantic_models = self._models_for_spec(filter_spec)
invalid_element_types = [
element
for element in filter_spec.linkable_elements
if element.element_type not in current_pushdown_state.pushdown_eligible_element_types
]
if len(filter_spec_semantic_models) != 1 or len(invalid_element_types) > 0:
continue

all_linkable_specs_match = all(spec in source_node_linkable_specs for spec in filter_spec.linkable_specs)
semantic_models_match = (
len(filter_spec_semantic_models) == 1 and filter_spec_semantic_models[0] == source_semantic_model
)
# TODO: Handle the case where entities can be defined in multiple models, only one of which need match
tlento marked this conversation as resolved.
Show resolved Hide resolved
semantic_models_match = filter_spec_semantic_models[0] == source_semantic_model
if all_linkable_specs_match and semantic_models_match:
filters_to_apply.append(filter_spec)
else:
Expand Down Expand Up @@ -277,32 +284,24 @@ def visit_constrain_time_range_node(self, node: ConstrainTimeRangeNode) -> Optim
def visit_where_constraint_node(self, node: WhereConstraintNode) -> OptimizeBranchResult:
"""Adds where filters from the input node to the current pushdown state.

The WhereConstraintNode carries the filter information in the form of WhereFilterSpecs. For any
filter specs that may be eligible for predicate pushdown this node will add them to the pushdown state.
The WhereConstraintNode carries the filter information in the form of WhereFilterSpecs, which may or may
not be eligible for pushdown. This processor simply propagates them forward so long as where filter
predicate pushdown is still enabled for this branch.

The fact that they have been added at this point does not mean they will be pushed down, as intervening
join nodes might remove them from consideration, so we retain them here as well in order to ensure all
filters are applied as specified.
join nodes might remove them from consideration, so we retain them ensure all filters are applied as specified
within this method.

TODO: Update to only apply filters that have not been pushed down
tlento marked this conversation as resolved.
Show resolved Hide resolved
"""
self._log_visit_node_type(node)
current_pushdown_state = self._predicate_pushdown_tracker.last_pushdown_state
if not current_pushdown_state.where_filter_pushdown_enabled:
return self._default_handler(node)

where_specs = node.input_where_specs
pushdown_eligible_specs: List[WhereFilterSpec] = []
for spec in where_specs:
semantic_models = self._models_for_spec(spec)
invalid_element_types = [
element
for element in spec.linkable_elements
if element.element_type not in current_pushdown_state.pushdown_eligible_element_types
]
if len(semantic_models) != 1 or len(invalid_element_types) > 0:
continue
pushdown_eligible_specs.append(spec)

updated_pushdown_state = PredicatePushdownState.with_additional_where_filter_specs(
original_pushdown_state=current_pushdown_state, additional_where_filter_specs=tuple(pushdown_eligible_specs)
updated_pushdown_state = PredicatePushdownState.with_where_filter_specs(
original_pushdown_state=current_pushdown_state,
where_filter_specs=tuple(current_pushdown_state.where_filter_specs) + tuple(node.input_where_specs),
)

return self._default_handler(node=node, pushdown_state=updated_pushdown_state)
Expand All @@ -323,10 +322,8 @@ def visit_combine_aggregated_outputs_node(self, node: CombineAggregatedOutputsNo
"""
self._log_visit_node_type(node)
# TODO: move this "remove where filters" logic into PredicatePushdownState
updated_pushdown_state = PredicatePushdownState(
time_range_constraint=self._predicate_pushdown_tracker.last_pushdown_state.time_range_constraint,
where_filter_specs=tuple(),
pushdown_enabled_types=self._predicate_pushdown_tracker.last_pushdown_state.pushdown_enabled_types,
updated_pushdown_state = PredicatePushdownState.without_where_filter_specs(
original_pushdown_state=self._predicate_pushdown_tracker.last_pushdown_state,
)
return self._default_handler(node=node, pushdown_state=updated_pushdown_state)

Expand All @@ -342,10 +339,8 @@ def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> O

"""
self._log_visit_node_type(node)
base_node_pushdown_state = PredicatePushdownState(
time_range_constraint=self._predicate_pushdown_tracker.last_pushdown_state.time_range_constraint,
where_filter_specs=tuple(),
pushdown_enabled_types=self._predicate_pushdown_tracker.last_pushdown_state.pushdown_enabled_types,
base_node_pushdown_state = PredicatePushdownState.without_where_filter_specs(
original_pushdown_state=self._predicate_pushdown_tracker.last_pushdown_state,
)
# The conversion metric branch silently removes all filters, so this is a redundant operation.
# TODO: Enable pushdown for the conversion metric branch when filters are supported
Expand Down Expand Up @@ -384,10 +379,8 @@ def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> OptimizeBranc
self._log_visit_node_type(node)
left_parent = node.left_node
if any(join_description.join_type is SqlJoinType.FULL_OUTER for join_description in node.join_targets):
left_branch_pushdown_state = PredicatePushdownState(
time_range_constraint=self._predicate_pushdown_tracker.last_pushdown_state.time_range_constraint,
where_filter_specs=tuple(),
pushdown_enabled_types=self._predicate_pushdown_tracker.last_pushdown_state.pushdown_enabled_types,
left_branch_pushdown_state = PredicatePushdownState.without_where_filter_specs(
original_pushdown_state=self._predicate_pushdown_tracker.last_pushdown_state,
)
else:
left_branch_pushdown_state = self._predicate_pushdown_tracker.last_pushdown_state
Expand All @@ -399,10 +392,8 @@ def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> OptimizeBranc
base_right_branch_pushdown_state = PredicatePushdownState.without_time_range_constraint(
self._predicate_pushdown_tracker.last_pushdown_state
)
outer_join_right_branch_pushdown_state = PredicatePushdownState(
time_range_constraint=None,
where_filter_specs=tuple(),
pushdown_enabled_types=base_right_branch_pushdown_state.pushdown_enabled_types,
outer_join_right_branch_pushdown_state = PredicatePushdownState.without_where_filter_specs(
original_pushdown_state=base_right_branch_pushdown_state
)
for join_description in node.join_targets:
if (
Expand Down
35 changes: 27 additions & 8 deletions metricflow/plan_conversion/node_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,14 @@ def with_time_range_constraint(
def without_time_range_constraint(
original_pushdown_state: PredicatePushdownState,
) -> PredicatePushdownState:
"""Factory method for updating pushdown state to bypass time range constraints."""
"""Factory method for updating pushdown state to bypass time range constraints.

This eliminates time range constraint pushdown as an option, since the only reason to remove
time range constraint metadata is to turn it off, so we avoid potential issues where
a second ConstrainTimeRange node might update the pushdown state.

TODO: replace or rename this method.
"""
pushdown_enabled_types = original_pushdown_state.pushdown_enabled_types.difference(
{PredicateInputType.TIME_RANGE_CONSTRAINT}
)
Expand All @@ -232,18 +239,30 @@ def without_time_range_constraint(
)

@staticmethod
def with_additional_where_filter_specs(
original_pushdown_state: PredicatePushdownState, additional_where_filter_specs: Sequence[WhereFilterSpec]
def without_where_filter_specs(
original_pushdown_state: PredicatePushdownState,
) -> PredicatePushdownState:
"""Factory method for updating pushdown state to remove existing where filter specs.

This simply blanks out the where filter specs without altering which types of pushdown are available.
"""
return PredicatePushdownState.with_where_filter_specs(
original_pushdown_state=original_pushdown_state,
where_filter_specs=tuple(),
)

@staticmethod
def with_where_filter_specs(
original_pushdown_state: PredicatePushdownState, where_filter_specs: Sequence[WhereFilterSpec]
) -> PredicatePushdownState:
"""Factory method for adding additional WhereFilterSpecs for pushdown operations.
"""Factory method for replacing WhereFilterSpecs in pushdown operations.

This requires that the PushdownState allow for where filters - time range only or disabled states will
raise an exception, and must be checked externally.
This requires that the PushdownState allow for where filters - time range only or disabled states will raise
an exception, and must be checked externally.
"""
updated_where_specs = tuple(original_pushdown_state.where_filter_specs) + tuple(additional_where_filter_specs)
return PredicatePushdownState(
time_range_constraint=original_pushdown_state.time_range_constraint,
where_filter_specs=updated_where_specs,
where_filter_specs=where_filter_specs,
pushdown_enabled_types=original_pushdown_state.pushdown_enabled_types,
)

Expand Down
Loading