Skip to content

Commit

Permalink
Store a list of input WhereFilterSpecs in the WhereConstraintNode
Browse files Browse the repository at this point in the history
Up until now we've been merging sets of WhereFilterSpec instances
into a single instance and then storing that in the WhereConstraintNode.

This made sense for rendering, but when trying to do predicate pushdown
having the specs all merged together limits the space of filter pushdown
opporunities. In order to allow for the same breadth of predicate pushdown
opportunities we have at dataflow plan build time we keep the specs separate,
and encapsulate the merging of these specs into the WhereConstraintNode itself.
  • Loading branch information
tlento committed Jun 25, 2024
1 parent d18efed commit ad0298d
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 38 deletions.
14 changes: 4 additions & 10 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,8 +634,7 @@ def _build_derived_metric_output_node(
)

if len(metric_spec.filter_specs) > 0:
merged_where_filter = WhereFilterSpec.merge_iterable(metric_spec.filter_specs)
output_node = WhereConstraintNode(parent_node=output_node, where_constraint=merged_where_filter)
output_node = WhereConstraintNode(parent_node=output_node, where_specs=metric_spec.filter_specs)
if not extraneous_linkable_specs.is_subset_of(queried_linkable_specs):
output_node = FilterElementsNode(
parent_node=output_node,
Expand Down Expand Up @@ -776,9 +775,7 @@ def _build_plan_for_distinct_values(self, query_spec: MetricFlowQuerySpec) -> Da
output_node = JoinOnEntitiesNode(left_node=output_node, join_targets=dataflow_recipe.join_targets)

if len(query_level_filter_specs) > 0:
output_node = WhereConstraintNode(
parent_node=output_node, where_constraint=WhereFilterSpec.merge_iterable(query_level_filter_specs)
)
output_node = WhereConstraintNode(parent_node=output_node, where_specs=query_level_filter_specs)
if query_spec.time_range_constraint:
output_node = ConstrainTimeRangeNode(
parent_node=output_node, time_range_constraint=query_spec.time_range_constraint
Expand Down Expand Up @@ -1524,12 +1521,11 @@ def _build_aggregated_measure_from_measure_source_node(
)

pre_aggregate_node: DataflowPlanNode = cumulative_metric_constrained_node or unaggregated_measure_node
merged_where_filter_spec = WhereFilterSpec.merge_iterable(metric_input_measure_spec.filter_specs)
if len(metric_input_measure_spec.filter_specs) > 0:
# Apply where constraint on the node
pre_aggregate_node = WhereConstraintNode(
parent_node=pre_aggregate_node,
where_constraint=merged_where_filter_spec,
where_specs=metric_input_measure_spec.filter_specs,
)

if non_additive_dimension_spec is not None:
Expand Down Expand Up @@ -1598,9 +1594,7 @@ def _build_aggregated_measure_from_measure_source_node(
if set(filter_spec.linkable_specs).issubset(set(queried_linkable_specs.as_tuple))
]
if len(queried_filter_specs) > 0:
output_node = WhereConstraintNode(
parent_node=output_node, where_constraint=WhereFilterSpec.merge_iterable(queried_filter_specs)
)
output_node = WhereConstraintNode(parent_node=output_node, where_specs=queried_filter_specs)

# TODO: this will break if you query by agg_time_dimension but apply a time constraint on metric_time.
# To fix when enabling time range constraints for agg_time_dimension.
Expand Down
18 changes: 14 additions & 4 deletions metricflow/dataflow/nodes/where_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ class WhereConstraintNode(DataflowPlanNode):
def __init__( # noqa: D107
self,
parent_node: DataflowPlanNode,
where_constraint: WhereFilterSpec,
where_specs: Sequence[WhereFilterSpec],
) -> None:
self._where = where_constraint
self._where_specs = where_specs
self.parent_node = parent_node
super().__init__(node_id=self.create_unique_id(), parent_nodes=(parent_node,))

Expand All @@ -29,7 +29,17 @@ def id_prefix(cls) -> IdPrefix: # noqa: D102
@property
def where(self) -> WhereFilterSpec:
"""Returns the specs for the elements that it should pass."""
return self._where
return WhereFilterSpec.merge_iterable(self._where_specs)

@property
def input_where_specs(self) -> Sequence[WhereFilterSpec]:
"""Returns the discrete set of input where filter specs for this node.
This is useful for things like predicate pushdown, where we need to differentiate between individual specs
for pushdown operations on the filter spec level. We merge them for things like rendering and node comparisons,
but in some cases we may be able to push down a subset of the input specs for efficiency reasons.
"""
return self._where_specs

def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_where_constraint_node(self)
Expand All @@ -51,5 +61,5 @@ def with_new_parents(self, new_parent_nodes: Sequence[DataflowPlanNode]) -> Wher
assert len(new_parent_nodes) == 1
return WhereConstraintNode(
parent_node=new_parent_nodes[0],
where_constraint=self.where,
where_specs=self.input_where_specs,
)
4 changes: 2 additions & 2 deletions metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,8 @@ def visit_where_constraint_node(self, node: WhereConstraintNode) -> OptimizeBran
"""
self._log_visit_node_type(node)
current_pushdown_state = self._predicate_pushdown_tracker.last_pushdown_state
# TODO: update WhereConstraintNode to hold a list of specs instead of merging them all before initialization
where_specs = (node.where,)
# TODO: short-circuit cases where pushdown is disabled for where constraints
where_specs = node.input_where_specs
pushdown_eligible_specs: List[WhereFilterSpec] = []
for spec in where_specs:
semantic_models = self._models_for_spec(spec)
Expand Down
3 changes: 1 addition & 2 deletions metricflow/plan_conversion/node_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,8 @@ def _add_where_constraint(
if len(matching_filter_specs) == 0:
filtered_nodes.append(source_node)
else:
where_constraint = WhereFilterSpec.merge_iterable(matching_filter_specs)
filtered_nodes.append(
WhereConstraintNode(parent_node=source_node, where_constraint=where_constraint)
WhereConstraintNode(parent_node=source_node, where_specs=matching_filter_specs)
)
else:
filtered_nodes.append(source_node)
Expand Down
42 changes: 22 additions & 20 deletions tests_metricflow/plan_conversion/test_dataflow_to_sql_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,27 +190,29 @@ def test_filter_with_where_constraint_node(
) # need to include ds_spec because where constraint operates on ds
where_constraint_node = WhereConstraintNode(
parent_node=filter_node,
where_constraint=WhereFilterSpec(
where_sql="booking__ds__day = '2020-01-01'",
bind_parameters=SqlBindParameters(),
linkable_specs=(
TimeDimensionSpec(
element_name="ds",
entity_links=(EntityReference(element_name="booking"),),
time_granularity=TimeGranularity.DAY,
where_specs=(
WhereFilterSpec(
where_sql="booking__ds__day = '2020-01-01'",
bind_parameters=SqlBindParameters(),
linkable_specs=(
TimeDimensionSpec(
element_name="ds",
entity_links=(EntityReference(element_name="booking"),),
time_granularity=TimeGranularity.DAY,
),
),
),
linkable_elements=(
LinkableDimension(
defined_in_semantic_model=SemanticModelReference("bookings_source"),
element_name="ds",
dimension_type=DimensionType.TIME,
entity_links=(EntityReference(element_name="booking"),),
properties=frozenset(),
time_granularity=TimeGranularity.DAY,
date_part=None,
join_path=SemanticModelJoinPath(
left_semantic_model_reference=SemanticModelReference("bookings_source"),
linkable_elements=(
LinkableDimension(
defined_in_semantic_model=SemanticModelReference("bookings_source"),
element_name="ds",
dimension_type=DimensionType.TIME,
entity_links=(EntityReference(element_name="booking"),),
properties=frozenset(),
time_granularity=TimeGranularity.DAY,
date_part=None,
join_path=SemanticModelJoinPath(
left_semantic_model_reference=SemanticModelReference("bookings_source"),
),
),
),
),
Expand Down

0 comments on commit ad0298d

Please sign in to comment.