Skip to content

Commit

Permalink
Enable predicate pushdown for categorical dimension filters (#1227)
Browse files Browse the repository at this point in the history
Enable predicate pushdown for categorical dimension filters

We now have the ability to push down filter predicates within
the DataflowPlan. We start with categorical dimension filters,
as they are the simplest.

This change simply tracks the where filters applied at the measure
node and pushes all of them down to the construction of the source
node for evaluation. At this time a filter is eligible to be applied
to the source node if it only contains references to categorical dimensions
that originate from the same, singular semantic model definition that
feeds into the source node in question.

We do not support time dimensions at this time, as they can cause strange
interactions with things like cumulative metrics, which could result in
inappropriate input filtering that produces non-obviously censored metric
results.

We also do not support entities at this time, as entities may be defined
in multiple semantic models and as such filters must be applied with more
care to ensure we are correctly accounting for the entity link paths to
the relevant source node, if any, when we apply the filter.

Finally, we are not able to safely push predicates down for the "null value"
side of an outer join, which, in practice, restricts us to only doing predicate
pushdown to the measure source nodes.

The snapshot test changes for existing snapshots highlight the new behavior,
while the added test snapshots demonstrate specific circumstances of
interest.
  • Loading branch information
tlento authored Jun 25, 2024
1 parent 3f3f862 commit 51e7622
Show file tree
Hide file tree
Showing 188 changed files with 45,628 additions and 15,605 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20240521-202252.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Enable predicate pushdown for categorical dimensions
time: 2024-05-21T20:22:52.841802-07:00
custom:
Author: tlento
Issue: "1011"
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,8 @@ metric:
name: instant_booking_fraction_of_max_value
description: |
Average instant booking value as a ratio of overall max booking value.
Tests constrained ratio measure.
Tests constrained ratio measure and predicate pushdown with different filters
on the same measure input.
type: ratio
type_params:
numerator:
Expand Down Expand Up @@ -331,7 +332,8 @@ metric:
name: regional_starting_balance_ratios
description: |
First day account balance ratio of western vs eastern region starting balance ratios,
used to test interaction between semi-additive measures and measure constraints
used to test interaction between semi-additive measures and measure constraints, and
behavior of predicate pushdown when there are multiple filters on the same categorical dimension
type: ratio
type_params:
numerator:
Expand Down
26 changes: 23 additions & 3 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ def _build_query_output_node(
)
)

predicate_pushdown_state = PredicatePushdownState(time_range_constraint=query_spec.time_range_constraint)
predicate_pushdown_state = PredicatePushdownState(
time_range_constraint=query_spec.time_range_constraint, where_filter_specs=query_level_filter_specs
)

return self._build_metrics_output_node(
metric_specs=tuple(
Expand Down Expand Up @@ -251,6 +253,7 @@ def _build_aggregated_conversion_node(
disabled_pushdown_state = PredicatePushdownState.with_pushdown_disabled()
time_range_only_pushdown_state = PredicatePushdownState(
time_range_constraint=predicate_pushdown_state.time_range_constraint,
where_filter_specs=tuple(),
pushdown_enabled_types=frozenset([PredicateInputType.TIME_RANGE_CONSTRAINT]),
)

Expand Down Expand Up @@ -511,6 +514,11 @@ def _build_base_metric_output_node(
),
descendent_filter_specs=metric_spec.filter_specs,
)
if predicate_pushdown_state.where_filter_pushdown_enabled:
predicate_pushdown_state = PredicatePushdownState.with_additional_where_filter_specs(
original_pushdown_state=predicate_pushdown_state,
additional_where_filter_specs=metric_input_measure_spec.filter_specs,
)

logger.info(
f"For\n{indent(mf_pformat(metric_spec))}"
Expand Down Expand Up @@ -568,6 +576,9 @@ def _build_derived_metric_output_node(

# If metric is offset, we'll apply where constraint after offset to avoid removing values
# unexpectedly. Time constraint will be applied by INNER JOINing to time spine.
# We may consider encapsulating this in pushdown state later, but as of this moment pushdown
# is about post-join to pre-join for dimension access, and relies on the builder to collect
# predicates from query and metric specs and make them available at measure level.
if not metric_spec.has_time_offset:
filter_specs.extend(metric_spec.filter_specs)

Expand Down Expand Up @@ -751,7 +762,9 @@ def _build_plan_for_distinct_values(self, query_spec: MetricFlowQuerySpec) -> Da
required_linkable_specs, _ = self.__get_required_and_extraneous_linkable_specs(
queried_linkable_specs=query_spec.linkable_specs, filter_specs=query_level_filter_specs
)
predicate_pushdown_state = PredicatePushdownState(time_range_constraint=query_spec.time_range_constraint)
predicate_pushdown_state = PredicatePushdownState(
time_range_constraint=query_spec.time_range_constraint, where_filter_specs=query_level_filter_specs
)
dataflow_recipe = self._find_dataflow_recipe(
linkable_spec_set=required_linkable_specs, predicate_pushdown_state=predicate_pushdown_state
)
Expand Down Expand Up @@ -954,7 +967,14 @@ def _find_dataflow_recipe(
node_data_set_resolver=self._node_data_set_resolver,
)

if predicate_pushdown_state.has_pushdown_potential:
if predicate_pushdown_state.has_pushdown_potential and default_join_type is not SqlJoinType.FULL_OUTER:
# TODO: encapsulate join type and distinct values state and eventually move this to a DataflowPlanOptimizer
# This works today because all of our subsequent join configuration operations preserve the join type
# as-is, or else switch it to a CROSS JOIN or INNER JOIN type, both of which are safe for predicate
# pushdown. However, there is currently no way to enforce that invariant, so we will need to move
# to a model where we evaluate the join nodes themselves and decide on whether or not to push down
# the predicate. This will be much more straightforward once we finish encapsulating our existing
# time range constraint pushdown controls into this mechanism.
candidate_nodes_for_left_side_of_join = list(
node_processor.apply_matching_filter_predicates(
source_nodes=candidate_nodes_for_left_side_of_join,
Expand Down
147 changes: 131 additions & 16 deletions metricflow/plan_conversion/node_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
import dataclasses
import logging
from enum import Enum
from typing import FrozenSet, List, Optional, Sequence, Set
from typing import Dict, FrozenSet, List, Optional, Sequence, Set

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.references import EntityReference, TimeDimensionReference
from dbt_semantic_interfaces.references import EntityReference, SemanticModelReference, TimeDimensionReference
from metricflow_semantics.filters.time_constraint import TimeRangeConstraint
from metricflow_semantics.mf_logging.pretty_print import mf_pformat
from metricflow_semantics.model.semantics.linkable_element import LinkableElementType
from metricflow_semantics.model.semantics.semantic_model_join_evaluator import MAX_JOIN_HOPS
from metricflow_semantics.model.semantics.semantic_model_lookup import SemanticModelLookup
from metricflow_semantics.specs.spec_classes import LinkableInstanceSpec, LinklessEntitySpec
from metricflow_semantics.specs.spec_classes import LinkableInstanceSpec, LinklessEntitySpec, WhereFilterSpec
from metricflow_semantics.specs.spec_set import group_specs_by_type
from metricflow_semantics.specs.spec_set_transforms import ToElementNameSet
from metricflow_semantics.sql.sql_join_type import SqlJoinType
Expand All @@ -25,6 +26,7 @@
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode
from metricflow.dataflow.nodes.join_to_base import JoinDescription, JoinOnEntitiesNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.where_filter import WhereConstraintNode
from metricflow.validation.dataflow_join_validator import JoinDataflowOutputValidator

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -95,7 +97,10 @@ class PredicatePushdownState:
"""

time_range_constraint: Optional[TimeRangeConstraint]
pushdown_enabled_types: FrozenSet[PredicateInputType] = frozenset([PredicateInputType.TIME_RANGE_CONSTRAINT])
where_filter_specs: Sequence[WhereFilterSpec]
pushdown_enabled_types: FrozenSet[PredicateInputType] = frozenset(
[PredicateInputType.TIME_RANGE_CONSTRAINT, PredicateInputType.CATEGORICAL_DIMENSION]
)

def __post_init__(self) -> None:
"""Validation to ensure pushdown states are configured correctly.
Expand All @@ -107,13 +112,12 @@ def __post_init__(self) -> None:
invalid_types: Set[PredicateInputType] = set()

for input_type in self.pushdown_enabled_types:
if (
if input_type is PredicateInputType.ENTITY or input_type is PredicateInputType.TIME_DIMENSION:
invalid_types.add(input_type)
elif (
input_type is PredicateInputType.CATEGORICAL_DIMENSION
or input_type is PredicateInputType.ENTITY
or input_type is PredicateInputType.TIME_DIMENSION
or input_type is PredicateInputType.TIME_RANGE_CONSTRAINT
):
invalid_types.add(input_type)
elif input_type is PredicateInputType.TIME_RANGE_CONSTRAINT:
continue
else:
assert_values_exhausted(input_type)
Expand All @@ -125,23 +129,24 @@ def __post_init__(self) -> None:
f"for {self.pushdown_enabled_types}, which includes the following invalid types: {invalid_types}."
)

# TODO: Include where filter specs when they are added to this class
time_range_constraint_is_valid = (
self.time_range_constraint is None
or PredicateInputType.TIME_RANGE_CONSTRAINT in self.pushdown_enabled_types
)
assert time_range_constraint_is_valid, (
where_filter_specs_are_valid = len(self.where_filter_specs) == 0 or self.where_filter_pushdown_enabled
assert time_range_constraint_is_valid and where_filter_specs_are_valid, (
"Invalid pushdown state configuration! Disabled pushdown state objects cannot have properties "
"set that may lead to improper access and use in other contexts, as that can lead to unintended "
"filtering operations in cases where these properties are accessed without appropriate checks against "
"pushdown configuration. The following properties should all have None values:\n"
f"time_range_constraint: {self.time_range_constraint}"
"pushdown configuration. The following properties should be None or empty:\n"
f"time_range_constraint: {self.time_range_constraint}\n"
f"where_filter_specs: {self.where_filter_specs}"
)

@property
def has_pushdown_potential(self) -> bool:
"""Returns whether or not pushdown is enabled for a type with predicate candidates in place."""
return self.has_time_range_constraint_to_push_down
return self.has_time_range_constraint_to_push_down or self.has_where_filters_to_push_down

@property
def has_time_range_constraint_to_push_down(self) -> bool:
Expand All @@ -156,6 +161,44 @@ def has_time_range_constraint_to_push_down(self) -> bool:
and self.time_range_constraint is not None
)

@property
def has_where_filters_to_push_down(self) -> bool:
"""Convenience accessor for checking if there are any where filters to push down."""
return self.where_filter_pushdown_enabled and len(self.where_filter_specs) > 0

@property
def where_filter_pushdown_enabled(self) -> bool:
"""Indicates whether or not pushdown is enabled for where filters."""
return (
PredicateInputType.CATEGORICAL_DIMENSION in self.pushdown_enabled_types
or PredicateInputType.ENTITY in self.pushdown_enabled_types
or PredicateInputType.TIME_DIMENSION in self.pushdown_enabled_types
)

@property
def pushdown_eligible_element_types(self) -> FrozenSet[LinkableElementType]:
"""Set of linkable element types eligible for predicate pushdown.
This converts from enabled PushdownInputTypes for checking if linkable elements in where filter specs are
eligible for pushdown.
"""
eligible_types: List[LinkableElementType] = []
for enabled_type in self.pushdown_enabled_types:
if enabled_type is PredicateInputType.TIME_RANGE_CONSTRAINT:
pass
elif enabled_type is PredicateInputType.CATEGORICAL_DIMENSION:
eligible_types.append(LinkableElementType.DIMENSION)
elif enabled_type is PredicateInputType.TIME_DIMENSION or enabled_type is PredicateInputType.ENTITY:
# TODO: Remove as support for time dimensions and entities becomes available
raise NotImplementedError(
"Predicate pushdown is not currently supported for where filter predicates with time dimension or "
f"entity references, but this pushdown state is enabled for {enabled_type}."
)
else:
assert_values_exhausted(enabled_type)

return frozenset(eligible_types)

@staticmethod
def with_time_range_constraint(
original_pushdown_state: PredicatePushdownState, time_range_constraint: TimeRangeConstraint
Expand All @@ -169,7 +212,9 @@ def with_time_range_constraint(
{PredicateInputType.TIME_RANGE_CONSTRAINT}
)
return PredicatePushdownState(
time_range_constraint=time_range_constraint, pushdown_enabled_types=pushdown_enabled_types
time_range_constraint=time_range_constraint,
pushdown_enabled_types=pushdown_enabled_types,
where_filter_specs=original_pushdown_state.where_filter_specs,
)

@staticmethod
Expand All @@ -180,7 +225,27 @@ def without_time_range_constraint(
pushdown_enabled_types = original_pushdown_state.pushdown_enabled_types.difference(
{PredicateInputType.TIME_RANGE_CONSTRAINT}
)
return PredicatePushdownState(time_range_constraint=None, pushdown_enabled_types=pushdown_enabled_types)
return PredicatePushdownState(
time_range_constraint=None,
pushdown_enabled_types=pushdown_enabled_types,
where_filter_specs=original_pushdown_state.where_filter_specs,
)

@staticmethod
def with_additional_where_filter_specs(
original_pushdown_state: PredicatePushdownState, additional_where_filter_specs: Sequence[WhereFilterSpec]
) -> PredicatePushdownState:
"""Factory method for adding additional WhereFilterSpecs for pushdown operations.
This requires that the PushdownState allow for where filters - time range only or disabled states will
raise an exception, and must be checked externally.
"""
updated_where_specs = tuple(original_pushdown_state.where_filter_specs) + tuple(additional_where_filter_specs)
return PredicatePushdownState(
time_range_constraint=original_pushdown_state.time_range_constraint,
where_filter_specs=updated_where_specs,
pushdown_enabled_types=original_pushdown_state.pushdown_enabled_types,
)

@staticmethod
def with_pushdown_disabled() -> PredicatePushdownState:
Expand All @@ -194,6 +259,7 @@ def with_pushdown_disabled() -> PredicatePushdownState:
return PredicatePushdownState(
time_range_constraint=None,
pushdown_enabled_types=frozenset(),
where_filter_specs=tuple(),
)


Expand Down Expand Up @@ -240,6 +306,13 @@ def apply_matching_filter_predicates(
time_range_constraint=predicate_pushdown_state.time_range_constraint,
)

if predicate_pushdown_state.has_where_filters_to_push_down:
source_nodes = self._add_where_constraint(
source_nodes=source_nodes,
where_filter_specs=predicate_pushdown_state.where_filter_specs,
enabled_element_types=predicate_pushdown_state.pushdown_eligible_element_types,
)

return source_nodes

def _add_time_range_constraint(
Expand Down Expand Up @@ -272,6 +345,48 @@ def _add_time_range_constraint(
processed_nodes.append(source_node)
return processed_nodes

def _add_where_constraint(
self,
source_nodes: Sequence[DataflowPlanNode],
where_filter_specs: Sequence[WhereFilterSpec],
enabled_element_types: FrozenSet[LinkableElementType],
) -> Sequence[DataflowPlanNode]:
"""Processes where filter specs and evaluates their fitness for pushdown against the provided node set."""
eligible_filter_specs_by_model: Dict[SemanticModelReference, Sequence[WhereFilterSpec]] = {}
for spec in where_filter_specs:
semantic_models = set(element.semantic_model_origin for element in spec.linkable_elements)
invalid_element_types = [
element for element in spec.linkable_elements if element.element_type not in enabled_element_types
]
if len(semantic_models) == 1 and len(invalid_element_types) == 0:
model = semantic_models.pop()
eligible_filter_specs_by_model[model] = tuple(eligible_filter_specs_by_model.get(model, tuple())) + (
spec,
)

filtered_nodes: List[DataflowPlanNode] = []
for source_node in source_nodes:
node_semantic_models = tuple(source_node.as_plan().source_semantic_models)
if len(node_semantic_models) == 1 and node_semantic_models[0] in eligible_filter_specs_by_model:
eligible_filter_specs = eligible_filter_specs_by_model[node_semantic_models[0]]
source_node_specs = self._node_data_set_resolver.get_output_data_set(source_node).instance_set.spec_set
matching_filter_specs = [
filter_spec
for filter_spec in eligible_filter_specs
if all([spec in source_node_specs.linkable_specs for spec in filter_spec.linkable_specs])
]
if len(matching_filter_specs) == 0:
filtered_nodes.append(source_node)
else:
where_constraint = WhereFilterSpec.merge_iterable(matching_filter_specs)
filtered_nodes.append(
WhereConstraintNode(parent_node=source_node, where_constraint=where_constraint)
)
else:
filtered_nodes.append(source_node)

return filtered_nodes

def _node_contains_entity(
self,
node: DataflowPlanNode,
Expand Down
11 changes: 7 additions & 4 deletions tests_metricflow/dataflow/builder/test_predicate_pushdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
@pytest.fixture
def fully_enabled_pushdown_state() -> PredicatePushdownState:
"""Tests a valid configuration with all predicate properties set and pushdown fully enabled."""
params = PredicatePushdownState(
time_range_constraint=TimeRangeConstraint.all_time(),
)
params = PredicatePushdownState(time_range_constraint=TimeRangeConstraint.all_time(), where_filter_specs=tuple())
return params


Expand All @@ -20,6 +18,7 @@ def test_time_range_pushdown_enabled_states(fully_enabled_pushdown_state: Predic
time_range_only_state = PredicatePushdownState(
time_range_constraint=TimeRangeConstraint.all_time(),
pushdown_enabled_types=frozenset([PredicateInputType.TIME_RANGE_CONSTRAINT]),
where_filter_specs=tuple(),
)

enabled_states = {
Expand All @@ -39,4 +38,8 @@ def test_time_range_pushdown_enabled_states(fully_enabled_pushdown_state: Predic
def test_invalid_disabled_pushdown_state() -> None:
"""Tests checks for invalid param configuration on disabled pushdown parameters."""
with pytest.raises(AssertionError, match="Disabled pushdown state objects cannot have properties set"):
PredicatePushdownState(time_range_constraint=TimeRangeConstraint.all_time(), pushdown_enabled_types=frozenset())
PredicatePushdownState(
time_range_constraint=TimeRangeConstraint.all_time(),
pushdown_enabled_types=frozenset(),
where_filter_specs=tuple(),
)
Loading

0 comments on commit 51e7622

Please sign in to comment.