From ad0298d40a60fff45d72ca7171c06aaafca1bf5a Mon Sep 17 00:00:00 2001 From: tlento Date: Wed, 12 Jun 2024 15:25:56 -0700 Subject: [PATCH] Store a list of input WhereFilterSpecs in the WhereConstraintNode Up until now we've been merging sets of WhereFilterSpec instances into a single instance and then storing that in the WhereConstraintNode. This made sense for rendering, but when trying to do predicate pushdown having the specs all merged together limits the space of filter pushdown opporunities. In order to allow for the same breadth of predicate pushdown opportunities we have at dataflow plan build time we keep the specs separate, and encapsulate the merging of these specs into the WhereConstraintNode itself. --- .../dataflow/builder/dataflow_plan_builder.py | 14 ++----- metricflow/dataflow/nodes/where_filter.py | 18 ++++++-- .../optimizer/predicate_pushdown_optimizer.py | 4 +- metricflow/plan_conversion/node_processor.py | 3 +- .../test_dataflow_to_sql_plan.py | 42 ++++++++++--------- 5 files changed, 43 insertions(+), 38 deletions(-) diff --git a/metricflow/dataflow/builder/dataflow_plan_builder.py b/metricflow/dataflow/builder/dataflow_plan_builder.py index 84067fd722..56af8825de 100644 --- a/metricflow/dataflow/builder/dataflow_plan_builder.py +++ b/metricflow/dataflow/builder/dataflow_plan_builder.py @@ -634,8 +634,7 @@ def _build_derived_metric_output_node( ) if len(metric_spec.filter_specs) > 0: - merged_where_filter = WhereFilterSpec.merge_iterable(metric_spec.filter_specs) - output_node = WhereConstraintNode(parent_node=output_node, where_constraint=merged_where_filter) + output_node = WhereConstraintNode(parent_node=output_node, where_specs=metric_spec.filter_specs) if not extraneous_linkable_specs.is_subset_of(queried_linkable_specs): output_node = FilterElementsNode( parent_node=output_node, @@ -776,9 +775,7 @@ def _build_plan_for_distinct_values(self, query_spec: MetricFlowQuerySpec) -> Da output_node = JoinOnEntitiesNode(left_node=output_node, join_targets=dataflow_recipe.join_targets) if len(query_level_filter_specs) > 0: - output_node = WhereConstraintNode( - parent_node=output_node, where_constraint=WhereFilterSpec.merge_iterable(query_level_filter_specs) - ) + output_node = WhereConstraintNode(parent_node=output_node, where_specs=query_level_filter_specs) if query_spec.time_range_constraint: output_node = ConstrainTimeRangeNode( parent_node=output_node, time_range_constraint=query_spec.time_range_constraint @@ -1524,12 +1521,11 @@ def _build_aggregated_measure_from_measure_source_node( ) pre_aggregate_node: DataflowPlanNode = cumulative_metric_constrained_node or unaggregated_measure_node - merged_where_filter_spec = WhereFilterSpec.merge_iterable(metric_input_measure_spec.filter_specs) if len(metric_input_measure_spec.filter_specs) > 0: # Apply where constraint on the node pre_aggregate_node = WhereConstraintNode( parent_node=pre_aggregate_node, - where_constraint=merged_where_filter_spec, + where_specs=metric_input_measure_spec.filter_specs, ) if non_additive_dimension_spec is not None: @@ -1598,9 +1594,7 @@ def _build_aggregated_measure_from_measure_source_node( if set(filter_spec.linkable_specs).issubset(set(queried_linkable_specs.as_tuple)) ] if len(queried_filter_specs) > 0: - output_node = WhereConstraintNode( - parent_node=output_node, where_constraint=WhereFilterSpec.merge_iterable(queried_filter_specs) - ) + output_node = WhereConstraintNode(parent_node=output_node, where_specs=queried_filter_specs) # TODO: this will break if you query by agg_time_dimension but apply a time constraint on metric_time. # To fix when enabling time range constraints for agg_time_dimension. diff --git a/metricflow/dataflow/nodes/where_filter.py b/metricflow/dataflow/nodes/where_filter.py index fbed72fa2e..08b09ee435 100644 --- a/metricflow/dataflow/nodes/where_filter.py +++ b/metricflow/dataflow/nodes/where_filter.py @@ -16,9 +16,9 @@ class WhereConstraintNode(DataflowPlanNode): def __init__( # noqa: D107 self, parent_node: DataflowPlanNode, - where_constraint: WhereFilterSpec, + where_specs: Sequence[WhereFilterSpec], ) -> None: - self._where = where_constraint + self._where_specs = where_specs self.parent_node = parent_node super().__init__(node_id=self.create_unique_id(), parent_nodes=(parent_node,)) @@ -29,7 +29,17 @@ def id_prefix(cls) -> IdPrefix: # noqa: D102 @property def where(self) -> WhereFilterSpec: """Returns the specs for the elements that it should pass.""" - return self._where + return WhereFilterSpec.merge_iterable(self._where_specs) + + @property + def input_where_specs(self) -> Sequence[WhereFilterSpec]: + """Returns the discrete set of input where filter specs for this node. + + This is useful for things like predicate pushdown, where we need to differentiate between individual specs + for pushdown operations on the filter spec level. We merge them for things like rendering and node comparisons, + but in some cases we may be able to push down a subset of the input specs for efficiency reasons. + """ + return self._where_specs def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 return visitor.visit_where_constraint_node(self) @@ -51,5 +61,5 @@ def with_new_parents(self, new_parent_nodes: Sequence[DataflowPlanNode]) -> Wher assert len(new_parent_nodes) == 1 return WhereConstraintNode( parent_node=new_parent_nodes[0], - where_constraint=self.where, + where_specs=self.input_where_specs, ) diff --git a/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py b/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py index bb71c8a809..12d361999f 100644 --- a/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py +++ b/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py @@ -255,8 +255,8 @@ def visit_where_constraint_node(self, node: WhereConstraintNode) -> OptimizeBran """ self._log_visit_node_type(node) current_pushdown_state = self._predicate_pushdown_tracker.last_pushdown_state - # TODO: update WhereConstraintNode to hold a list of specs instead of merging them all before initialization - where_specs = (node.where,) + # TODO: short-circuit cases where pushdown is disabled for where constraints + where_specs = node.input_where_specs pushdown_eligible_specs: List[WhereFilterSpec] = [] for spec in where_specs: semantic_models = self._models_for_spec(spec) diff --git a/metricflow/plan_conversion/node_processor.py b/metricflow/plan_conversion/node_processor.py index 0bf08942cf..28054d49bb 100644 --- a/metricflow/plan_conversion/node_processor.py +++ b/metricflow/plan_conversion/node_processor.py @@ -378,9 +378,8 @@ def _add_where_constraint( if len(matching_filter_specs) == 0: filtered_nodes.append(source_node) else: - where_constraint = WhereFilterSpec.merge_iterable(matching_filter_specs) filtered_nodes.append( - WhereConstraintNode(parent_node=source_node, where_constraint=where_constraint) + WhereConstraintNode(parent_node=source_node, where_specs=matching_filter_specs) ) else: filtered_nodes.append(source_node) diff --git a/tests_metricflow/plan_conversion/test_dataflow_to_sql_plan.py b/tests_metricflow/plan_conversion/test_dataflow_to_sql_plan.py index 2f981895a2..15ba3620d5 100644 --- a/tests_metricflow/plan_conversion/test_dataflow_to_sql_plan.py +++ b/tests_metricflow/plan_conversion/test_dataflow_to_sql_plan.py @@ -190,27 +190,29 @@ def test_filter_with_where_constraint_node( ) # need to include ds_spec because where constraint operates on ds where_constraint_node = WhereConstraintNode( parent_node=filter_node, - where_constraint=WhereFilterSpec( - where_sql="booking__ds__day = '2020-01-01'", - bind_parameters=SqlBindParameters(), - linkable_specs=( - TimeDimensionSpec( - element_name="ds", - entity_links=(EntityReference(element_name="booking"),), - time_granularity=TimeGranularity.DAY, + where_specs=( + WhereFilterSpec( + where_sql="booking__ds__day = '2020-01-01'", + bind_parameters=SqlBindParameters(), + linkable_specs=( + TimeDimensionSpec( + element_name="ds", + entity_links=(EntityReference(element_name="booking"),), + time_granularity=TimeGranularity.DAY, + ), ), - ), - linkable_elements=( - LinkableDimension( - defined_in_semantic_model=SemanticModelReference("bookings_source"), - element_name="ds", - dimension_type=DimensionType.TIME, - entity_links=(EntityReference(element_name="booking"),), - properties=frozenset(), - time_granularity=TimeGranularity.DAY, - date_part=None, - join_path=SemanticModelJoinPath( - left_semantic_model_reference=SemanticModelReference("bookings_source"), + linkable_elements=( + LinkableDimension( + defined_in_semantic_model=SemanticModelReference("bookings_source"), + element_name="ds", + dimension_type=DimensionType.TIME, + entity_links=(EntityReference(element_name="booking"),), + properties=frozenset(), + time_granularity=TimeGranularity.DAY, + date_part=None, + join_path=SemanticModelJoinPath( + left_semantic_model_reference=SemanticModelReference("bookings_source"), + ), ), ), ),