-
Notifications
You must be signed in to change notification settings - Fork 97
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable predicate pushdown for categorical dimension filters #1227
Changes from all commits
b4c5472
ef6fc82
df90189
f93031e
4bc075e
e9dcde3
ca3043c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
kind: Features | ||
body: Enable predicate pushdown for categorical dimensions | ||
time: 2024-05-21T20:22:52.841802-07:00 | ||
custom: | ||
Author: tlento | ||
Issue: "1011" |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,15 +3,16 @@ | |
import dataclasses | ||
import logging | ||
from enum import Enum | ||
from typing import FrozenSet, List, Optional, Sequence, Set | ||
from typing import Dict, FrozenSet, List, Optional, Sequence, Set | ||
|
||
from dbt_semantic_interfaces.enum_extension import assert_values_exhausted | ||
from dbt_semantic_interfaces.references import EntityReference, TimeDimensionReference | ||
from dbt_semantic_interfaces.references import EntityReference, SemanticModelReference, TimeDimensionReference | ||
from metricflow_semantics.filters.time_constraint import TimeRangeConstraint | ||
from metricflow_semantics.mf_logging.pretty_print import mf_pformat | ||
from metricflow_semantics.model.semantics.linkable_element import LinkableElementType | ||
from metricflow_semantics.model.semantics.semantic_model_join_evaluator import MAX_JOIN_HOPS | ||
from metricflow_semantics.model.semantics.semantic_model_lookup import SemanticModelLookup | ||
from metricflow_semantics.specs.spec_classes import LinkableInstanceSpec, LinklessEntitySpec | ||
from metricflow_semantics.specs.spec_classes import LinkableInstanceSpec, LinklessEntitySpec, WhereFilterSpec | ||
from metricflow_semantics.specs.spec_set import group_specs_by_type | ||
from metricflow_semantics.specs.spec_set_transforms import ToElementNameSet | ||
from metricflow_semantics.sql.sql_join_type import SqlJoinType | ||
|
@@ -25,6 +26,7 @@ | |
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode | ||
from metricflow.dataflow.nodes.join_to_base import JoinDescription, JoinOnEntitiesNode | ||
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode | ||
from metricflow.dataflow.nodes.where_filter import WhereConstraintNode | ||
from metricflow.validation.dataflow_join_validator import JoinDataflowOutputValidator | ||
|
||
logger = logging.getLogger(__name__) | ||
|
@@ -95,7 +97,10 @@ class PredicatePushdownState: | |
""" | ||
|
||
time_range_constraint: Optional[TimeRangeConstraint] | ||
pushdown_enabled_types: FrozenSet[PredicateInputType] = frozenset([PredicateInputType.TIME_RANGE_CONSTRAINT]) | ||
where_filter_specs: Sequence[WhereFilterSpec] | ||
pushdown_enabled_types: FrozenSet[PredicateInputType] = frozenset( | ||
[PredicateInputType.TIME_RANGE_CONSTRAINT, PredicateInputType.CATEGORICAL_DIMENSION] | ||
) | ||
|
||
def __post_init__(self) -> None: | ||
"""Validation to ensure pushdown states are configured correctly. | ||
|
@@ -107,13 +112,12 @@ def __post_init__(self) -> None: | |
invalid_types: Set[PredicateInputType] = set() | ||
|
||
for input_type in self.pushdown_enabled_types: | ||
if ( | ||
if input_type is PredicateInputType.ENTITY or input_type is PredicateInputType.TIME_DIMENSION: | ||
invalid_types.add(input_type) | ||
elif ( | ||
input_type is PredicateInputType.CATEGORICAL_DIMENSION | ||
or input_type is PredicateInputType.ENTITY | ||
or input_type is PredicateInputType.TIME_DIMENSION | ||
or input_type is PredicateInputType.TIME_RANGE_CONSTRAINT | ||
): | ||
invalid_types.add(input_type) | ||
elif input_type is PredicateInputType.TIME_RANGE_CONSTRAINT: | ||
continue | ||
else: | ||
assert_values_exhausted(input_type) | ||
|
@@ -125,23 +129,24 @@ def __post_init__(self) -> None: | |
f"for {self.pushdown_enabled_types}, which includes the following invalid types: {invalid_types}." | ||
) | ||
|
||
# TODO: Include where filter specs when they are added to this class | ||
time_range_constraint_is_valid = ( | ||
self.time_range_constraint is None | ||
or PredicateInputType.TIME_RANGE_CONSTRAINT in self.pushdown_enabled_types | ||
) | ||
assert time_range_constraint_is_valid, ( | ||
where_filter_specs_are_valid = len(self.where_filter_specs) == 0 or self.where_filter_pushdown_enabled | ||
assert time_range_constraint_is_valid and where_filter_specs_are_valid, ( | ||
"Invalid pushdown state configuration! Disabled pushdown state objects cannot have properties " | ||
"set that may lead to improper access and use in other contexts, as that can lead to unintended " | ||
"filtering operations in cases where these properties are accessed without appropriate checks against " | ||
"pushdown configuration. The following properties should all have None values:\n" | ||
f"time_range_constraint: {self.time_range_constraint}" | ||
"pushdown configuration. The following properties should be None or empty:\n" | ||
f"time_range_constraint: {self.time_range_constraint}\n" | ||
f"where_filter_specs: {self.where_filter_specs}" | ||
) | ||
|
||
@property | ||
def has_pushdown_potential(self) -> bool: | ||
"""Returns whether or not pushdown is enabled for a type with predicate candidates in place.""" | ||
return self.has_time_range_constraint_to_push_down | ||
return self.has_time_range_constraint_to_push_down or self.has_where_filters_to_push_down | ||
|
||
@property | ||
def has_time_range_constraint_to_push_down(self) -> bool: | ||
|
@@ -156,6 +161,44 @@ def has_time_range_constraint_to_push_down(self) -> bool: | |
and self.time_range_constraint is not None | ||
) | ||
|
||
@property | ||
def has_where_filters_to_push_down(self) -> bool: | ||
"""Convenience accessor for checking if there are any where filters to push down.""" | ||
return self.where_filter_pushdown_enabled and len(self.where_filter_specs) > 0 | ||
|
||
@property | ||
def where_filter_pushdown_enabled(self) -> bool: | ||
"""Indicates whether or not pushdown is enabled for where filters.""" | ||
return ( | ||
PredicateInputType.CATEGORICAL_DIMENSION in self.pushdown_enabled_types | ||
or PredicateInputType.ENTITY in self.pushdown_enabled_types | ||
or PredicateInputType.TIME_DIMENSION in self.pushdown_enabled_types | ||
) | ||
|
||
@property | ||
def pushdown_eligible_element_types(self) -> FrozenSet[LinkableElementType]: | ||
"""Set of linkable element types eligible for predicate pushdown. | ||
|
||
This converts from enabled PushdownInputTypes for checking if linkable elements in where filter specs are | ||
eligible for pushdown. | ||
""" | ||
eligible_types: List[LinkableElementType] = [] | ||
for enabled_type in self.pushdown_enabled_types: | ||
if enabled_type is PredicateInputType.TIME_RANGE_CONSTRAINT: | ||
pass | ||
elif enabled_type is PredicateInputType.CATEGORICAL_DIMENSION: | ||
eligible_types.append(LinkableElementType.DIMENSION) | ||
elif enabled_type is PredicateInputType.TIME_DIMENSION or enabled_type is PredicateInputType.ENTITY: | ||
# TODO: Remove as support for time dimensions and entities becomes available | ||
raise NotImplementedError( | ||
"Predicate pushdown is not currently supported for where filter predicates with time dimension or " | ||
f"entity references, but this pushdown state is enabled for {enabled_type}." | ||
) | ||
else: | ||
assert_values_exhausted(enabled_type) | ||
|
||
return frozenset(eligible_types) | ||
|
||
@staticmethod | ||
def with_time_range_constraint( | ||
original_pushdown_state: PredicatePushdownState, time_range_constraint: TimeRangeConstraint | ||
|
@@ -169,7 +212,9 @@ def with_time_range_constraint( | |
{PredicateInputType.TIME_RANGE_CONSTRAINT} | ||
) | ||
return PredicatePushdownState( | ||
time_range_constraint=time_range_constraint, pushdown_enabled_types=pushdown_enabled_types | ||
time_range_constraint=time_range_constraint, | ||
pushdown_enabled_types=pushdown_enabled_types, | ||
where_filter_specs=original_pushdown_state.where_filter_specs, | ||
) | ||
|
||
@staticmethod | ||
|
@@ -180,7 +225,27 @@ def without_time_range_constraint( | |
pushdown_enabled_types = original_pushdown_state.pushdown_enabled_types.difference( | ||
{PredicateInputType.TIME_RANGE_CONSTRAINT} | ||
) | ||
return PredicatePushdownState(time_range_constraint=None, pushdown_enabled_types=pushdown_enabled_types) | ||
return PredicatePushdownState( | ||
time_range_constraint=None, | ||
pushdown_enabled_types=pushdown_enabled_types, | ||
where_filter_specs=original_pushdown_state.where_filter_specs, | ||
) | ||
|
||
@staticmethod | ||
def with_additional_where_filter_specs( | ||
original_pushdown_state: PredicatePushdownState, additional_where_filter_specs: Sequence[WhereFilterSpec] | ||
) -> PredicatePushdownState: | ||
"""Factory method for adding additional WhereFilterSpecs for pushdown operations. | ||
|
||
This requires that the PushdownState allow for where filters - time range only or disabled states will | ||
raise an exception, and must be checked externally. | ||
""" | ||
updated_where_specs = tuple(original_pushdown_state.where_filter_specs) + tuple(additional_where_filter_specs) | ||
return PredicatePushdownState( | ||
time_range_constraint=original_pushdown_state.time_range_constraint, | ||
where_filter_specs=updated_where_specs, | ||
pushdown_enabled_types=original_pushdown_state.pushdown_enabled_types, | ||
) | ||
|
||
@staticmethod | ||
def with_pushdown_disabled() -> PredicatePushdownState: | ||
|
@@ -194,6 +259,7 @@ def with_pushdown_disabled() -> PredicatePushdownState: | |
return PredicatePushdownState( | ||
time_range_constraint=None, | ||
pushdown_enabled_types=frozenset(), | ||
where_filter_specs=tuple(), | ||
) | ||
|
||
|
||
|
@@ -240,6 +306,13 @@ def apply_matching_filter_predicates( | |
time_range_constraint=predicate_pushdown_state.time_range_constraint, | ||
) | ||
|
||
if predicate_pushdown_state.has_where_filters_to_push_down: | ||
source_nodes = self._add_where_constraint( | ||
source_nodes=source_nodes, | ||
where_filter_specs=predicate_pushdown_state.where_filter_specs, | ||
enabled_element_types=predicate_pushdown_state.pushdown_eligible_element_types, | ||
) | ||
|
||
return source_nodes | ||
|
||
def _add_time_range_constraint( | ||
|
@@ -272,6 +345,48 @@ def _add_time_range_constraint( | |
processed_nodes.append(source_node) | ||
return processed_nodes | ||
|
||
def _add_where_constraint( | ||
self, | ||
source_nodes: Sequence[DataflowPlanNode], | ||
where_filter_specs: Sequence[WhereFilterSpec], | ||
enabled_element_types: FrozenSet[LinkableElementType], | ||
) -> Sequence[DataflowPlanNode]: | ||
"""Processes where filter specs and evaluates their fitness for pushdown against the provided node set.""" | ||
eligible_filter_specs_by_model: Dict[SemanticModelReference, Sequence[WhereFilterSpec]] = {} | ||
for spec in where_filter_specs: | ||
semantic_models = set(element.semantic_model_origin for element in spec.linkable_elements) | ||
invalid_element_types = [ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok this is the logic I was looking for 👍 |
||
element for element in spec.linkable_elements if element.element_type not in enabled_element_types | ||
] | ||
if len(semantic_models) == 1 and len(invalid_element_types) == 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, so we only push down filters if ALL the filtered elements are eligible element types. Is that because we don't know if this is an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct, we cannot push down a filter with any invalid element types. The AND vs OR nature of things isn't relevant, it's because we don't have a way to handle those element types. Right now that's just because we haven't implemented handling, but in future it could be due to a given query being too difficult to manage for a given element type. For example, agg time dimension filters against a mixture of cumulative and derived offset metric inputs could get very tricky. In those cases we may not be able to push down a where filter with a time dimension. My expectation is that this will be more refined than clobbering anything that has a time dimension of any kind in it, but for now this definitely works and we can use more finesse later. |
||
model = semantic_models.pop() | ||
eligible_filter_specs_by_model[model] = tuple(eligible_filter_specs_by_model.get(model, tuple())) + ( | ||
spec, | ||
) | ||
|
||
filtered_nodes: List[DataflowPlanNode] = [] | ||
for source_node in source_nodes: | ||
node_semantic_models = tuple(source_node.as_plan().source_semantic_models) | ||
if len(node_semantic_models) == 1 and node_semantic_models[0] in eligible_filter_specs_by_model: | ||
eligible_filter_specs = eligible_filter_specs_by_model[node_semantic_models[0]] | ||
source_node_specs = self._node_data_set_resolver.get_output_data_set(source_node).instance_set.spec_set | ||
matching_filter_specs = [ | ||
filter_spec | ||
for filter_spec in eligible_filter_specs | ||
if all([spec in source_node_specs.linkable_specs for spec in filter_spec.linkable_specs]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there ever a time where this won't be true, since we get the semantic model from the linkable element above? Or is this just an extra safety check in case something gets misconfigured? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At this time it's a safeguard against something weird happening where a given source node isn't configured correctly. However, I expect this filter to be relevant for entities, since they may be defined in multiple semantic models and we need to be able to explicitly allow or disallow pushdown in those cases. If we ever add a pre-joined source node, for example, we might encounter a scenario where the entity and dimension come from different semantic models and then we couldn't push down past this point (and maybe shouldn't push down to this point, either). |
||
] | ||
if len(matching_filter_specs) == 0: | ||
filtered_nodes.append(source_node) | ||
else: | ||
where_constraint = WhereFilterSpec.merge_iterable(matching_filter_specs) | ||
filtered_nodes.append( | ||
WhereConstraintNode(parent_node=source_node, where_constraint=where_constraint) | ||
) | ||
else: | ||
filtered_nodes.append(source_node) | ||
|
||
return filtered_nodes | ||
|
||
def _node_contains_entity( | ||
self, | ||
node: DataflowPlanNode, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where does this get narrowed down to only categorical dimensions? 🤔