diff --git a/metricflow/dataflow/builder/dataflow_plan_builder.py b/metricflow/dataflow/builder/dataflow_plan_builder.py index 84067fd722..56af8825de 100644 --- a/metricflow/dataflow/builder/dataflow_plan_builder.py +++ b/metricflow/dataflow/builder/dataflow_plan_builder.py @@ -634,8 +634,7 @@ def _build_derived_metric_output_node( ) if len(metric_spec.filter_specs) > 0: - merged_where_filter = WhereFilterSpec.merge_iterable(metric_spec.filter_specs) - output_node = WhereConstraintNode(parent_node=output_node, where_constraint=merged_where_filter) + output_node = WhereConstraintNode(parent_node=output_node, where_specs=metric_spec.filter_specs) if not extraneous_linkable_specs.is_subset_of(queried_linkable_specs): output_node = FilterElementsNode( parent_node=output_node, @@ -776,9 +775,7 @@ def _build_plan_for_distinct_values(self, query_spec: MetricFlowQuerySpec) -> Da output_node = JoinOnEntitiesNode(left_node=output_node, join_targets=dataflow_recipe.join_targets) if len(query_level_filter_specs) > 0: - output_node = WhereConstraintNode( - parent_node=output_node, where_constraint=WhereFilterSpec.merge_iterable(query_level_filter_specs) - ) + output_node = WhereConstraintNode(parent_node=output_node, where_specs=query_level_filter_specs) if query_spec.time_range_constraint: output_node = ConstrainTimeRangeNode( parent_node=output_node, time_range_constraint=query_spec.time_range_constraint @@ -1524,12 +1521,11 @@ def _build_aggregated_measure_from_measure_source_node( ) pre_aggregate_node: DataflowPlanNode = cumulative_metric_constrained_node or unaggregated_measure_node - merged_where_filter_spec = WhereFilterSpec.merge_iterable(metric_input_measure_spec.filter_specs) if len(metric_input_measure_spec.filter_specs) > 0: # Apply where constraint on the node pre_aggregate_node = WhereConstraintNode( parent_node=pre_aggregate_node, - where_constraint=merged_where_filter_spec, + where_specs=metric_input_measure_spec.filter_specs, ) if non_additive_dimension_spec is not None: @@ -1598,9 +1594,7 @@ def _build_aggregated_measure_from_measure_source_node( if set(filter_spec.linkable_specs).issubset(set(queried_linkable_specs.as_tuple)) ] if len(queried_filter_specs) > 0: - output_node = WhereConstraintNode( - parent_node=output_node, where_constraint=WhereFilterSpec.merge_iterable(queried_filter_specs) - ) + output_node = WhereConstraintNode(parent_node=output_node, where_specs=queried_filter_specs) # TODO: this will break if you query by agg_time_dimension but apply a time constraint on metric_time. # To fix when enabling time range constraints for agg_time_dimension. diff --git a/metricflow/dataflow/nodes/where_filter.py b/metricflow/dataflow/nodes/where_filter.py index fbed72fa2e..08b09ee435 100644 --- a/metricflow/dataflow/nodes/where_filter.py +++ b/metricflow/dataflow/nodes/where_filter.py @@ -16,9 +16,9 @@ class WhereConstraintNode(DataflowPlanNode): def __init__( # noqa: D107 self, parent_node: DataflowPlanNode, - where_constraint: WhereFilterSpec, + where_specs: Sequence[WhereFilterSpec], ) -> None: - self._where = where_constraint + self._where_specs = where_specs self.parent_node = parent_node super().__init__(node_id=self.create_unique_id(), parent_nodes=(parent_node,)) @@ -29,7 +29,17 @@ def id_prefix(cls) -> IdPrefix: # noqa: D102 @property def where(self) -> WhereFilterSpec: """Returns the specs for the elements that it should pass.""" - return self._where + return WhereFilterSpec.merge_iterable(self._where_specs) + + @property + def input_where_specs(self) -> Sequence[WhereFilterSpec]: + """Returns the discrete set of input where filter specs for this node. + + This is useful for things like predicate pushdown, where we need to differentiate between individual specs + for pushdown operations on the filter spec level. We merge them for things like rendering and node comparisons, + but in some cases we may be able to push down a subset of the input specs for efficiency reasons. + """ + return self._where_specs def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 return visitor.visit_where_constraint_node(self) @@ -51,5 +61,5 @@ def with_new_parents(self, new_parent_nodes: Sequence[DataflowPlanNode]) -> Wher assert len(new_parent_nodes) == 1 return WhereConstraintNode( parent_node=new_parent_nodes[0], - where_constraint=self.where, + where_specs=self.input_where_specs, ) diff --git a/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py b/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py index bb71c8a809..12d361999f 100644 --- a/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py +++ b/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py @@ -255,8 +255,8 @@ def visit_where_constraint_node(self, node: WhereConstraintNode) -> OptimizeBran """ self._log_visit_node_type(node) current_pushdown_state = self._predicate_pushdown_tracker.last_pushdown_state - # TODO: update WhereConstraintNode to hold a list of specs instead of merging them all before initialization - where_specs = (node.where,) + # TODO: short-circuit cases where pushdown is disabled for where constraints + where_specs = node.input_where_specs pushdown_eligible_specs: List[WhereFilterSpec] = [] for spec in where_specs: semantic_models = self._models_for_spec(spec) diff --git a/metricflow/plan_conversion/node_processor.py b/metricflow/plan_conversion/node_processor.py index 0bf08942cf..28054d49bb 100644 --- a/metricflow/plan_conversion/node_processor.py +++ b/metricflow/plan_conversion/node_processor.py @@ -378,9 +378,8 @@ def _add_where_constraint( if len(matching_filter_specs) == 0: filtered_nodes.append(source_node) else: - where_constraint = WhereFilterSpec.merge_iterable(matching_filter_specs) filtered_nodes.append( - WhereConstraintNode(parent_node=source_node, where_constraint=where_constraint) + WhereConstraintNode(parent_node=source_node, where_specs=matching_filter_specs) ) else: filtered_nodes.append(source_node) diff --git a/tests_metricflow/plan_conversion/test_dataflow_to_sql_plan.py b/tests_metricflow/plan_conversion/test_dataflow_to_sql_plan.py index 2f981895a2..15ba3620d5 100644 --- a/tests_metricflow/plan_conversion/test_dataflow_to_sql_plan.py +++ b/tests_metricflow/plan_conversion/test_dataflow_to_sql_plan.py @@ -190,27 +190,29 @@ def test_filter_with_where_constraint_node( ) # need to include ds_spec because where constraint operates on ds where_constraint_node = WhereConstraintNode( parent_node=filter_node, - where_constraint=WhereFilterSpec( - where_sql="booking__ds__day = '2020-01-01'", - bind_parameters=SqlBindParameters(), - linkable_specs=( - TimeDimensionSpec( - element_name="ds", - entity_links=(EntityReference(element_name="booking"),), - time_granularity=TimeGranularity.DAY, + where_specs=( + WhereFilterSpec( + where_sql="booking__ds__day = '2020-01-01'", + bind_parameters=SqlBindParameters(), + linkable_specs=( + TimeDimensionSpec( + element_name="ds", + entity_links=(EntityReference(element_name="booking"),), + time_granularity=TimeGranularity.DAY, + ), ), - ), - linkable_elements=( - LinkableDimension( - defined_in_semantic_model=SemanticModelReference("bookings_source"), - element_name="ds", - dimension_type=DimensionType.TIME, - entity_links=(EntityReference(element_name="booking"),), - properties=frozenset(), - time_granularity=TimeGranularity.DAY, - date_part=None, - join_path=SemanticModelJoinPath( - left_semantic_model_reference=SemanticModelReference("bookings_source"), + linkable_elements=( + LinkableDimension( + defined_in_semantic_model=SemanticModelReference("bookings_source"), + element_name="ds", + dimension_type=DimensionType.TIME, + entity_links=(EntityReference(element_name="booking"),), + properties=frozenset(), + time_granularity=TimeGranularity.DAY, + date_part=None, + join_path=SemanticModelJoinPath( + left_semantic_model_reference=SemanticModelReference("bookings_source"), + ), ), ), ),