From 88723e2906b0d704a058f22d86812d0e69a9b81d Mon Sep 17 00:00:00 2001 From: tlento Date: Thu, 9 May 2024 16:33:24 -0700 Subject: [PATCH] Encapsulate time range constraints for predicate pushdown MetricFlow currently allows for a limited scope form of filter predicate pushdown that is particular to time range constraints where the querying user has provided an explicit time window for us to query against. The application of the pushdown operation is managed by threading the time range constraint through the entire dataflow plan builder and applying the filter operation as appropriate. This is exactly what we need to do for robust predicate pushdown evaluation for our expanded set of pushdown operations. Rather than wire a whole new set of parameters through, we simply encapsulate the time range constraints inside of a new object that is more readily extensible to other predicate pushdown handling. This specific change is as mechanical as possible in order to minimize confusion. Places that stood out for improvement via the encapsulating object have been marked for later updates, which will follow shortly. --- .../dataflow/builder/dataflow_plan_builder.py | 115 +++++++++++------- metricflow/plan_conversion/node_processor.py | 10 ++ 2 files changed, 82 insertions(+), 43 deletions(-) diff --git a/metricflow/dataflow/builder/dataflow_plan_builder.py b/metricflow/dataflow/builder/dataflow_plan_builder.py index 0d7cc19297..d26894d2c4 100644 --- a/metricflow/dataflow/builder/dataflow_plan_builder.py +++ b/metricflow/dataflow/builder/dataflow_plan_builder.py @@ -82,7 +82,7 @@ from metricflow.dataflow.nodes.write_to_table import WriteToResultTableNode from metricflow.dataflow.optimizer.dataflow_plan_optimizer import DataflowPlanOptimizer from metricflow.dataset.dataset_classes import DataSet -from metricflow.plan_conversion.node_processor import PreJoinNodeProcessor +from metricflow.plan_conversion.node_processor import PredicatePushdownParameters, PreJoinNodeProcessor from metricflow.sql.sql_table import SqlTable logger = logging.getLogger(__name__) @@ -177,6 +177,8 @@ def _build_query_output_node( ) ) + predicate_pushdown_params = PredicatePushdownParameters(time_range_constraint=query_spec.time_range_constraint) + return self._build_metrics_output_node( metric_specs=tuple( MetricSpec( @@ -187,7 +189,7 @@ def _build_query_output_node( ), queried_linkable_specs=query_spec.linkable_specs, filter_spec_factory=filter_spec_factory, - time_range_constraint=query_spec.time_range_constraint, + predicate_pushdown_params=predicate_pushdown_params, for_group_by_source_node=for_group_by_source_node, ) @@ -231,7 +233,7 @@ def _build_aggregated_conversion_node( entity_spec: EntitySpec, window: Optional[MetricTimeWindow], queried_linkable_specs: LinkableSpecSet, - time_range_constraint: Optional[TimeRangeConstraint] = None, + predicate_pushdown_params: PredicatePushdownParameters, constant_properties: Optional[Sequence[ConstantPropertyInput]] = None, ) -> DataflowPlanNode: """Builds a node that contains aggregated values of conversions and opportunities.""" @@ -242,12 +244,14 @@ def _build_aggregated_conversion_node( ) base_measure_recipe = self._find_dataflow_recipe( measure_spec_properties=self._build_measure_spec_properties([base_measure_spec.measure_spec]), - time_range_constraint=time_range_constraint, + predicate_pushdown_params=predicate_pushdown_params, linkable_spec_set=base_required_linkable_specs, ) logger.info(f"Recipe for base measure aggregation:\n{mf_pformat(base_measure_recipe)}") conversion_measure_recipe = self._find_dataflow_recipe( measure_spec_properties=self._build_measure_spec_properties([conversion_measure_spec.measure_spec]), + # TODO - Pushdown: Evaluate the potential for applying time constraints and other predicates for conversion + predicate_pushdown_params=PredicatePushdownParameters(time_range_constraint=None), linkable_spec_set=LinkableSpecSet(), ) logger.info(f"Recipe for conversion measure aggregation:\n{mf_pformat(conversion_measure_recipe)}") @@ -264,7 +268,7 @@ def _build_aggregated_conversion_node( aggregated_base_measure_node = self.build_aggregated_measure( metric_input_measure_spec=base_measure_spec, queried_linkable_specs=queried_linkable_specs, - time_range_constraint=time_range_constraint, + predicate_pushdown_params=predicate_pushdown_params, ) # Build unaggregated conversions source node @@ -336,7 +340,7 @@ def _build_aggregated_conversion_node( aggregated_conversions_node = self.build_aggregated_measure( metric_input_measure_spec=conversion_measure_spec, queried_linkable_specs=queried_linkable_specs, - time_range_constraint=time_range_constraint, + predicate_pushdown_params=predicate_pushdown_params, measure_recipe=recipe_with_join_conversion_source_node, ) @@ -348,7 +352,7 @@ def _build_conversion_metric_output_node( metric_spec: MetricSpec, queried_linkable_specs: LinkableSpecSet, filter_spec_factory: WhereSpecFactory, - time_range_constraint: Optional[TimeRangeConstraint] = None, + predicate_pushdown_params: PredicatePushdownParameters, for_group_by_source_node: bool = False, ) -> ComputeMetricsNode: """Builds a compute metric node for a conversion metric.""" @@ -375,7 +379,7 @@ def _build_conversion_metric_output_node( base_measure_spec=base_measure, conversion_measure_spec=conversion_measure, queried_linkable_specs=queried_linkable_specs, - time_range_constraint=time_range_constraint, + predicate_pushdown_params=predicate_pushdown_params, entity_spec=entity_spec, window=conversion_type_params.window, constant_properties=conversion_type_params.constant_properties, @@ -393,7 +397,7 @@ def _build_base_metric_output_node( metric_spec: MetricSpec, queried_linkable_specs: LinkableSpecSet, filter_spec_factory: WhereSpecFactory, - time_range_constraint: Optional[TimeRangeConstraint] = None, + predicate_pushdown_params: PredicatePushdownParameters, for_group_by_source_node: bool = False, ) -> ComputeMetricsNode: """Builds a node to compute a metric that is not defined from other metrics.""" @@ -438,7 +442,7 @@ def _build_base_metric_output_node( aggregated_measures_node = self.build_aggregated_measure( metric_input_measure_spec=metric_input_measure_spec, queried_linkable_specs=queried_linkable_specs, - time_range_constraint=time_range_constraint, + predicate_pushdown_params=predicate_pushdown_params, ) return self.build_computed_metrics_node( metric_spec=metric_spec, @@ -452,7 +456,7 @@ def _build_derived_metric_output_node( metric_spec: MetricSpec, queried_linkable_specs: LinkableSpecSet, filter_spec_factory: WhereSpecFactory, - time_range_constraint: Optional[TimeRangeConstraint] = None, + predicate_pushdown_params: PredicatePushdownParameters, for_group_by_source_node: bool = False, ) -> DataflowPlanNode: """Builds a node to compute a metric defined from other metrics.""" @@ -488,6 +492,13 @@ def _build_derived_metric_output_node( if not metric_spec.has_time_offset: filter_specs.extend(metric_spec.filter_specs) + # TODO - Pushdown: use parameters to disable pushdown operations instead of clobbering the constraints + metric_pushdown_params = ( + predicate_pushdown_params + if not metric_spec.has_time_offset + else PredicatePushdownParameters(time_range_constraint=None) + ) + parent_nodes.append( self._build_any_metric_output_node( metric_spec=MetricSpec( @@ -501,7 +512,7 @@ def _build_derived_metric_output_node( queried_linkable_specs if not metric_spec.has_time_offset else required_linkable_specs ), filter_spec_factory=filter_spec_factory, - time_range_constraint=time_range_constraint if not metric_spec.has_time_offset else None, + predicate_pushdown_params=metric_pushdown_params, ) ) @@ -527,7 +538,7 @@ def _build_derived_metric_output_node( parent_node=output_node, requested_agg_time_dimension_specs=queried_agg_time_dimension_specs, use_custom_agg_time_dimension=not queried_linkable_specs.contains_metric_time, - time_range_constraint=time_range_constraint, + time_range_constraint=predicate_pushdown_params.time_range_constraint, offset_window=metric_spec.offset_window, offset_to_grain=metric_spec.offset_to_grain, join_type=SqlJoinType.INNER, @@ -550,7 +561,7 @@ def _build_any_metric_output_node( metric_spec: MetricSpec, queried_linkable_specs: LinkableSpecSet, filter_spec_factory: WhereSpecFactory, - time_range_constraint: Optional[TimeRangeConstraint] = None, + predicate_pushdown_params: PredicatePushdownParameters, for_group_by_source_node: bool = False, ) -> DataflowPlanNode: """Builds a node to compute a metric of any type.""" @@ -561,7 +572,7 @@ def _build_any_metric_output_node( metric_spec=metric_spec, queried_linkable_specs=queried_linkable_specs, filter_spec_factory=filter_spec_factory, - time_range_constraint=time_range_constraint, + predicate_pushdown_params=predicate_pushdown_params, for_group_by_source_node=for_group_by_source_node, ) @@ -570,7 +581,7 @@ def _build_any_metric_output_node( metric_spec=metric_spec, queried_linkable_specs=queried_linkable_specs, filter_spec_factory=filter_spec_factory, - time_range_constraint=time_range_constraint, + predicate_pushdown_params=predicate_pushdown_params, for_group_by_source_node=for_group_by_source_node, ) elif metric.type is MetricType.CONVERSION: @@ -578,7 +589,7 @@ def _build_any_metric_output_node( metric_spec=metric_spec, queried_linkable_specs=queried_linkable_specs, filter_spec_factory=filter_spec_factory, - time_range_constraint=time_range_constraint, + predicate_pushdown_params=predicate_pushdown_params, for_group_by_source_node=for_group_by_source_node, ) @@ -589,7 +600,7 @@ def _build_metrics_output_node( metric_specs: Sequence[MetricSpec], queried_linkable_specs: LinkableSpecSet, filter_spec_factory: WhereSpecFactory, - time_range_constraint: Optional[TimeRangeConstraint] = None, + predicate_pushdown_params: PredicatePushdownParameters, for_group_by_source_node: bool = False, ) -> DataflowPlanNode: """Builds a node that computes all requested metrics. @@ -599,7 +610,8 @@ def _build_metrics_output_node( include offsets and filters. queried_linkable_specs: Dimensions/entities that were queried. filter_spec_factory: Constructs WhereFilterSpecs with the resolved ambiguous group-by-items in the filter. - time_range_constraint: Time range constraint used to compute the metric. + predicate_pushdown_params: Parameters for evaluating and applying filter predicate pushdown, e.g., for + applying time constraints prior to other dimension joins. """ output_nodes: List[DataflowPlanNode] = [] @@ -612,7 +624,7 @@ def _build_metrics_output_node( metric_spec=metric_spec, queried_linkable_specs=queried_linkable_specs, filter_spec_factory=filter_spec_factory, - time_range_constraint=time_range_constraint, + predicate_pushdown_params=predicate_pushdown_params, for_group_by_source_node=for_group_by_source_node, ) ) @@ -652,8 +664,9 @@ def _build_plan_for_distinct_values(self, query_spec: MetricFlowQuerySpec) -> Da required_linkable_specs, _ = self.__get_required_and_extraneous_linkable_specs( queried_linkable_specs=query_spec.linkable_specs, filter_specs=query_level_filter_specs ) + predicate_pushdown_params = PredicatePushdownParameters(time_range_constraint=query_spec.time_range_constraint) dataflow_recipe = self._find_dataflow_recipe( - linkable_spec_set=required_linkable_specs, time_range_constraint=query_spec.time_range_constraint + linkable_spec_set=required_linkable_specs, predicate_pushdown_params=predicate_pushdown_params ) if not dataflow_recipe: raise UnableToSatisfyQueryError(f"Unable to join all items in request: {required_linkable_specs}") @@ -818,8 +831,8 @@ def _build_measure_spec_properties(self, measure_specs: Sequence[MeasureSpec]) - def _find_dataflow_recipe( self, linkable_spec_set: LinkableSpecSet, + predicate_pushdown_params: PredicatePushdownParameters, measure_spec_properties: Optional[MeasureSpecProperties] = None, - time_range_constraint: Optional[TimeRangeConstraint] = None, ) -> Optional[DataflowRecipe]: linkable_specs = linkable_spec_set.as_tuple candidate_nodes_for_left_side_of_join: List[DataflowPlanNode] = [] @@ -853,12 +866,13 @@ def _find_dataflow_recipe( semantic_model_lookup=self._semantic_model_lookup, node_data_set_resolver=self._node_data_set_resolver, ) - if time_range_constraint: + # TODO - Pushdown: Encapsulate this in the node processor + if predicate_pushdown_params.time_range_constraint: candidate_nodes_for_left_side_of_join = list( node_processor.add_time_range_constraint( source_nodes=candidate_nodes_for_left_side_of_join, metric_time_dimension_reference=self._metric_time_dimension_reference, - time_range_constraint=time_range_constraint, + time_range_constraint=predicate_pushdown_params.time_range_constraint, ) ) @@ -1179,7 +1193,7 @@ def build_aggregated_measure( self, metric_input_measure_spec: MetricInputMeasureSpec, queried_linkable_specs: LinkableSpecSet, - time_range_constraint: Optional[TimeRangeConstraint] = None, + predicate_pushdown_params: PredicatePushdownParameters, measure_recipe: Optional[DataflowRecipe] = None, ) -> DataflowPlanNode: """Returns a node where the measures are aggregated by the linkable specs and constrained appropriately. @@ -1199,7 +1213,7 @@ def build_aggregated_measure( return self._build_aggregated_measure_from_measure_source_node( metric_input_measure_spec=metric_input_measure_spec, queried_linkable_specs=queried_linkable_specs, - time_range_constraint=time_range_constraint, + predicate_pushdown_params=predicate_pushdown_params, measure_recipe=measure_recipe, ) @@ -1231,7 +1245,7 @@ def _build_aggregated_measure_from_measure_source_node( self, metric_input_measure_spec: MetricInputMeasureSpec, queried_linkable_specs: LinkableSpecSet, - time_range_constraint: Optional[TimeRangeConstraint] = None, + predicate_pushdown_params: PredicatePushdownParameters, measure_recipe: Optional[DataflowRecipe] = None, ) -> DataflowPlanNode: measure_spec = metric_input_measure_spec.measure_spec @@ -1250,8 +1264,8 @@ def _build_aggregated_measure_from_measure_source_node( non_additive_dimension_spec = measure_properties.non_additive_dimension_spec cumulative_metric_adjusted_time_constraint: Optional[TimeRangeConstraint] = None - if cumulative and time_range_constraint is not None: - logger.info(f"Time range constraint before adjustment is {time_range_constraint}") + if cumulative and predicate_pushdown_params.time_range_constraint is not None: + logger.info(f"Time range constraint before adjustment is {predicate_pushdown_params.time_range_constraint}") granularity: Optional[TimeGranularity] = None count = 0 if cumulative_window is not None: @@ -1262,7 +1276,9 @@ def _build_aggregated_measure_from_measure_source_node( granularity = cumulative_grain_to_date cumulative_metric_adjusted_time_constraint = ( - time_range_constraint.adjust_time_constraint_for_cumulative_metric(granularity, count) + predicate_pushdown_params.time_range_constraint.adjust_time_constraint_for_cumulative_metric( + granularity, count + ) ) logger.info(f"Adjusted time range constraint {cumulative_metric_adjusted_time_constraint}") @@ -1283,15 +1299,18 @@ def _build_aggregated_measure_from_measure_source_node( + indent(f"\nevaluation:\n{mf_pformat(required_linkable_specs)}") ) + # TODO - Pushdown: Update this to be more robust to additional pushdown parameters + measure_time_constraint = ( + (cumulative_metric_adjusted_time_constraint or predicate_pushdown_params.time_range_constraint) + # If joining to time spine for time offset, constraints will be applied after that join. + if not before_aggregation_time_spine_join_description + else None + ) + measure_pushdown_params = PredicatePushdownParameters(time_range_constraint=measure_time_constraint) find_recipe_start_time = time.time() measure_recipe = self._find_dataflow_recipe( measure_spec_properties=measure_properties, - time_range_constraint=( - (cumulative_metric_adjusted_time_constraint or time_range_constraint) - # If joining to time spine for time offset, constraints will be applied after that join. - if not before_aggregation_time_spine_join_description - else None - ), + predicate_pushdown_params=measure_pushdown_params, linkable_spec_set=required_linkable_specs, ) logger.info( @@ -1323,8 +1342,12 @@ def _build_aggregated_measure_from_measure_source_node( time_dimension_spec_for_join=agg_time_dimension_spec_for_join, window=cumulative_window, grain_to_date=cumulative_grain_to_date, + # Note: we use the original constraint here because the JoinOverTimeRangeNode will eventually get + # rendered with an interval that expands the join window time_range_constraint=( - time_range_constraint if not before_aggregation_time_spine_join_description else None + predicate_pushdown_params.time_range_constraint + if not before_aggregation_time_spine_join_description + else None ), ) @@ -1339,11 +1362,13 @@ def _build_aggregated_measure_from_measure_source_node( f"Expected {SqlJoinType.INNER} for joining to time spine before aggregation. Remove this if there's a " f"new use case." ) + # This also uses the original time range constraint due to the application of the time window intervals + # in join rendering join_to_time_spine_node = JoinToTimeSpineNode( parent_node=time_range_node or measure_recipe.source_node, requested_agg_time_dimension_specs=queried_agg_time_dimension_specs, use_custom_agg_time_dimension=not queried_linkable_specs.contains_metric_time, - time_range_constraint=time_range_constraint, + time_range_constraint=predicate_pushdown_params.time_range_constraint, offset_window=before_aggregation_time_spine_join_description.offset_window, offset_to_grain=before_aggregation_time_spine_join_description.offset_to_grain, join_type=before_aggregation_time_spine_join_description.join_type, @@ -1379,12 +1404,16 @@ def _build_aggregated_measure_from_measure_source_node( # If time constraint was previously adjusted for cumulative window or grain, apply original time constraint # here. Can skip if metric is being aggregated over all time. cumulative_metric_constrained_node: Optional[ConstrainTimeRangeNode] = None - if cumulative_metric_adjusted_time_constraint is not None and time_range_constraint is not None: + # TODO - Pushdown: Encapsulate all of this window sliding bookkeeping in the pushdown params object + if ( + cumulative_metric_adjusted_time_constraint is not None + and predicate_pushdown_params.time_range_constraint is not None + ): assert ( queried_linkable_specs.contains_metric_time ), "Using time constraints currently requires querying with metric_time." cumulative_metric_constrained_node = ConstrainTimeRangeNode( - unaggregated_measure_node, time_range_constraint + unaggregated_measure_node, predicate_pushdown_params.time_range_constraint ) pre_aggregate_node: DataflowPlanNode = cumulative_metric_constrained_node or unaggregated_measure_node @@ -1448,16 +1477,16 @@ def _build_aggregated_measure_from_measure_source_node( requested_agg_time_dimension_specs=queried_agg_time_dimension_specs, use_custom_agg_time_dimension=not queried_linkable_specs.contains_metric_time, join_type=after_aggregation_time_spine_join_description.join_type, - time_range_constraint=time_range_constraint, + time_range_constraint=predicate_pushdown_params.time_range_constraint, offset_window=after_aggregation_time_spine_join_description.offset_window, offset_to_grain=after_aggregation_time_spine_join_description.offset_to_grain, ) # Since new rows might have been added due to time spine join, apply constraints again here. if len(metric_input_measure_spec.filter_specs) > 0: output_node = WhereConstraintNode(parent_node=output_node, where_constraint=merged_where_filter_spec) - if time_range_constraint is not None: + if predicate_pushdown_params.time_range_constraint is not None: output_node = ConstrainTimeRangeNode( - parent_node=output_node, time_range_constraint=time_range_constraint + parent_node=output_node, time_range_constraint=predicate_pushdown_params.time_range_constraint ) return output_node diff --git a/metricflow/plan_conversion/node_processor.py b/metricflow/plan_conversion/node_processor.py index 88904e3cbc..e21649e110 100644 --- a/metricflow/plan_conversion/node_processor.py +++ b/metricflow/plan_conversion/node_processor.py @@ -59,6 +59,16 @@ class MultiHopJoinCandidate: lineage: MultiHopJoinCandidateLineage +@dataclass(frozen=True) +class PredicatePushdownParameters: + """Container class for managing filter predicate pushdown. + + Stores time constraint information for applying pre-join time filters. + """ + + time_range_constraint: Optional[TimeRangeConstraint] + + class PreJoinNodeProcessor: """Processes source nodes before other nodes are joined.