diff --git a/metricflow/dataflow/builder/dataflow_plan_builder.py b/metricflow/dataflow/builder/dataflow_plan_builder.py index 0d7cc19297..d26894d2c4 100644 --- a/metricflow/dataflow/builder/dataflow_plan_builder.py +++ b/metricflow/dataflow/builder/dataflow_plan_builder.py @@ -82,7 +82,7 @@ from metricflow.dataflow.nodes.write_to_table import WriteToResultTableNode from metricflow.dataflow.optimizer.dataflow_plan_optimizer import DataflowPlanOptimizer from metricflow.dataset.dataset_classes import DataSet -from metricflow.plan_conversion.node_processor import PreJoinNodeProcessor +from metricflow.plan_conversion.node_processor import PredicatePushdownParameters, PreJoinNodeProcessor from metricflow.sql.sql_table import SqlTable logger = logging.getLogger(__name__) @@ -177,6 +177,8 @@ def _build_query_output_node( ) ) + predicate_pushdown_params = PredicatePushdownParameters(time_range_constraint=query_spec.time_range_constraint) + return self._build_metrics_output_node( metric_specs=tuple( MetricSpec( @@ -187,7 +189,7 @@ def _build_query_output_node( ), queried_linkable_specs=query_spec.linkable_specs, filter_spec_factory=filter_spec_factory, - time_range_constraint=query_spec.time_range_constraint, + predicate_pushdown_params=predicate_pushdown_params, for_group_by_source_node=for_group_by_source_node, ) @@ -231,7 +233,7 @@ def _build_aggregated_conversion_node( entity_spec: EntitySpec, window: Optional[MetricTimeWindow], queried_linkable_specs: LinkableSpecSet, - time_range_constraint: Optional[TimeRangeConstraint] = None, + predicate_pushdown_params: PredicatePushdownParameters, constant_properties: Optional[Sequence[ConstantPropertyInput]] = None, ) -> DataflowPlanNode: """Builds a node that contains aggregated values of conversions and opportunities.""" @@ -242,12 +244,14 @@ def _build_aggregated_conversion_node( ) base_measure_recipe = self._find_dataflow_recipe( measure_spec_properties=self._build_measure_spec_properties([base_measure_spec.measure_spec]), - time_range_constraint=time_range_constraint, + predicate_pushdown_params=predicate_pushdown_params, linkable_spec_set=base_required_linkable_specs, ) logger.info(f"Recipe for base measure aggregation:\n{mf_pformat(base_measure_recipe)}") conversion_measure_recipe = self._find_dataflow_recipe( measure_spec_properties=self._build_measure_spec_properties([conversion_measure_spec.measure_spec]), + # TODO - Pushdown: Evaluate the potential for applying time constraints and other predicates for conversion + predicate_pushdown_params=PredicatePushdownParameters(time_range_constraint=None), linkable_spec_set=LinkableSpecSet(), ) logger.info(f"Recipe for conversion measure aggregation:\n{mf_pformat(conversion_measure_recipe)}") @@ -264,7 +268,7 @@ def _build_aggregated_conversion_node( aggregated_base_measure_node = self.build_aggregated_measure( metric_input_measure_spec=base_measure_spec, queried_linkable_specs=queried_linkable_specs, - time_range_constraint=time_range_constraint, + predicate_pushdown_params=predicate_pushdown_params, ) # Build unaggregated conversions source node @@ -336,7 +340,7 @@ def _build_aggregated_conversion_node( aggregated_conversions_node = self.build_aggregated_measure( metric_input_measure_spec=conversion_measure_spec, queried_linkable_specs=queried_linkable_specs, - time_range_constraint=time_range_constraint, + predicate_pushdown_params=predicate_pushdown_params, measure_recipe=recipe_with_join_conversion_source_node, ) @@ -348,7 +352,7 @@ def _build_conversion_metric_output_node( metric_spec: MetricSpec, queried_linkable_specs: LinkableSpecSet, filter_spec_factory: WhereSpecFactory, - time_range_constraint: Optional[TimeRangeConstraint] = None, + predicate_pushdown_params: PredicatePushdownParameters, for_group_by_source_node: bool = False, ) -> ComputeMetricsNode: """Builds a compute metric node for a conversion metric.""" @@ -375,7 +379,7 @@ def _build_conversion_metric_output_node( base_measure_spec=base_measure, conversion_measure_spec=conversion_measure, queried_linkable_specs=queried_linkable_specs, - time_range_constraint=time_range_constraint, + predicate_pushdown_params=predicate_pushdown_params, entity_spec=entity_spec, window=conversion_type_params.window, constant_properties=conversion_type_params.constant_properties, @@ -393,7 +397,7 @@ def _build_base_metric_output_node( metric_spec: MetricSpec, queried_linkable_specs: LinkableSpecSet, filter_spec_factory: WhereSpecFactory, - time_range_constraint: Optional[TimeRangeConstraint] = None, + predicate_pushdown_params: PredicatePushdownParameters, for_group_by_source_node: bool = False, ) -> ComputeMetricsNode: """Builds a node to compute a metric that is not defined from other metrics.""" @@ -438,7 +442,7 @@ def _build_base_metric_output_node( aggregated_measures_node = self.build_aggregated_measure( metric_input_measure_spec=metric_input_measure_spec, queried_linkable_specs=queried_linkable_specs, - time_range_constraint=time_range_constraint, + predicate_pushdown_params=predicate_pushdown_params, ) return self.build_computed_metrics_node( metric_spec=metric_spec, @@ -452,7 +456,7 @@ def _build_derived_metric_output_node( metric_spec: MetricSpec, queried_linkable_specs: LinkableSpecSet, filter_spec_factory: WhereSpecFactory, - time_range_constraint: Optional[TimeRangeConstraint] = None, + predicate_pushdown_params: PredicatePushdownParameters, for_group_by_source_node: bool = False, ) -> DataflowPlanNode: """Builds a node to compute a metric defined from other metrics.""" @@ -488,6 +492,13 @@ def _build_derived_metric_output_node( if not metric_spec.has_time_offset: filter_specs.extend(metric_spec.filter_specs) + # TODO - Pushdown: use parameters to disable pushdown operations instead of clobbering the constraints + metric_pushdown_params = ( + predicate_pushdown_params + if not metric_spec.has_time_offset + else PredicatePushdownParameters(time_range_constraint=None) + ) + parent_nodes.append( self._build_any_metric_output_node( metric_spec=MetricSpec( @@ -501,7 +512,7 @@ def _build_derived_metric_output_node( queried_linkable_specs if not metric_spec.has_time_offset else required_linkable_specs ), filter_spec_factory=filter_spec_factory, - time_range_constraint=time_range_constraint if not metric_spec.has_time_offset else None, + predicate_pushdown_params=metric_pushdown_params, ) ) @@ -527,7 +538,7 @@ def _build_derived_metric_output_node( parent_node=output_node, requested_agg_time_dimension_specs=queried_agg_time_dimension_specs, use_custom_agg_time_dimension=not queried_linkable_specs.contains_metric_time, - time_range_constraint=time_range_constraint, + time_range_constraint=predicate_pushdown_params.time_range_constraint, offset_window=metric_spec.offset_window, offset_to_grain=metric_spec.offset_to_grain, join_type=SqlJoinType.INNER, @@ -550,7 +561,7 @@ def _build_any_metric_output_node( metric_spec: MetricSpec, queried_linkable_specs: LinkableSpecSet, filter_spec_factory: WhereSpecFactory, - time_range_constraint: Optional[TimeRangeConstraint] = None, + predicate_pushdown_params: PredicatePushdownParameters, for_group_by_source_node: bool = False, ) -> DataflowPlanNode: """Builds a node to compute a metric of any type.""" @@ -561,7 +572,7 @@ def _build_any_metric_output_node( metric_spec=metric_spec, queried_linkable_specs=queried_linkable_specs, filter_spec_factory=filter_spec_factory, - time_range_constraint=time_range_constraint, + predicate_pushdown_params=predicate_pushdown_params, for_group_by_source_node=for_group_by_source_node, ) @@ -570,7 +581,7 @@ def _build_any_metric_output_node( metric_spec=metric_spec, queried_linkable_specs=queried_linkable_specs, filter_spec_factory=filter_spec_factory, - time_range_constraint=time_range_constraint, + predicate_pushdown_params=predicate_pushdown_params, for_group_by_source_node=for_group_by_source_node, ) elif metric.type is MetricType.CONVERSION: @@ -578,7 +589,7 @@ def _build_any_metric_output_node( metric_spec=metric_spec, queried_linkable_specs=queried_linkable_specs, filter_spec_factory=filter_spec_factory, - time_range_constraint=time_range_constraint, + predicate_pushdown_params=predicate_pushdown_params, for_group_by_source_node=for_group_by_source_node, ) @@ -589,7 +600,7 @@ def _build_metrics_output_node( metric_specs: Sequence[MetricSpec], queried_linkable_specs: LinkableSpecSet, filter_spec_factory: WhereSpecFactory, - time_range_constraint: Optional[TimeRangeConstraint] = None, + predicate_pushdown_params: PredicatePushdownParameters, for_group_by_source_node: bool = False, ) -> DataflowPlanNode: """Builds a node that computes all requested metrics. @@ -599,7 +610,8 @@ def _build_metrics_output_node( include offsets and filters. queried_linkable_specs: Dimensions/entities that were queried. filter_spec_factory: Constructs WhereFilterSpecs with the resolved ambiguous group-by-items in the filter. - time_range_constraint: Time range constraint used to compute the metric. + predicate_pushdown_params: Parameters for evaluating and applying filter predicate pushdown, e.g., for + applying time constraints prior to other dimension joins. """ output_nodes: List[DataflowPlanNode] = [] @@ -612,7 +624,7 @@ def _build_metrics_output_node( metric_spec=metric_spec, queried_linkable_specs=queried_linkable_specs, filter_spec_factory=filter_spec_factory, - time_range_constraint=time_range_constraint, + predicate_pushdown_params=predicate_pushdown_params, for_group_by_source_node=for_group_by_source_node, ) ) @@ -652,8 +664,9 @@ def _build_plan_for_distinct_values(self, query_spec: MetricFlowQuerySpec) -> Da required_linkable_specs, _ = self.__get_required_and_extraneous_linkable_specs( queried_linkable_specs=query_spec.linkable_specs, filter_specs=query_level_filter_specs ) + predicate_pushdown_params = PredicatePushdownParameters(time_range_constraint=query_spec.time_range_constraint) dataflow_recipe = self._find_dataflow_recipe( - linkable_spec_set=required_linkable_specs, time_range_constraint=query_spec.time_range_constraint + linkable_spec_set=required_linkable_specs, predicate_pushdown_params=predicate_pushdown_params ) if not dataflow_recipe: raise UnableToSatisfyQueryError(f"Unable to join all items in request: {required_linkable_specs}") @@ -818,8 +831,8 @@ def _build_measure_spec_properties(self, measure_specs: Sequence[MeasureSpec]) - def _find_dataflow_recipe( self, linkable_spec_set: LinkableSpecSet, + predicate_pushdown_params: PredicatePushdownParameters, measure_spec_properties: Optional[MeasureSpecProperties] = None, - time_range_constraint: Optional[TimeRangeConstraint] = None, ) -> Optional[DataflowRecipe]: linkable_specs = linkable_spec_set.as_tuple candidate_nodes_for_left_side_of_join: List[DataflowPlanNode] = [] @@ -853,12 +866,13 @@ def _find_dataflow_recipe( semantic_model_lookup=self._semantic_model_lookup, node_data_set_resolver=self._node_data_set_resolver, ) - if time_range_constraint: + # TODO - Pushdown: Encapsulate this in the node processor + if predicate_pushdown_params.time_range_constraint: candidate_nodes_for_left_side_of_join = list( node_processor.add_time_range_constraint( source_nodes=candidate_nodes_for_left_side_of_join, metric_time_dimension_reference=self._metric_time_dimension_reference, - time_range_constraint=time_range_constraint, + time_range_constraint=predicate_pushdown_params.time_range_constraint, ) ) @@ -1179,7 +1193,7 @@ def build_aggregated_measure( self, metric_input_measure_spec: MetricInputMeasureSpec, queried_linkable_specs: LinkableSpecSet, - time_range_constraint: Optional[TimeRangeConstraint] = None, + predicate_pushdown_params: PredicatePushdownParameters, measure_recipe: Optional[DataflowRecipe] = None, ) -> DataflowPlanNode: """Returns a node where the measures are aggregated by the linkable specs and constrained appropriately. @@ -1199,7 +1213,7 @@ def build_aggregated_measure( return self._build_aggregated_measure_from_measure_source_node( metric_input_measure_spec=metric_input_measure_spec, queried_linkable_specs=queried_linkable_specs, - time_range_constraint=time_range_constraint, + predicate_pushdown_params=predicate_pushdown_params, measure_recipe=measure_recipe, ) @@ -1231,7 +1245,7 @@ def _build_aggregated_measure_from_measure_source_node( self, metric_input_measure_spec: MetricInputMeasureSpec, queried_linkable_specs: LinkableSpecSet, - time_range_constraint: Optional[TimeRangeConstraint] = None, + predicate_pushdown_params: PredicatePushdownParameters, measure_recipe: Optional[DataflowRecipe] = None, ) -> DataflowPlanNode: measure_spec = metric_input_measure_spec.measure_spec @@ -1250,8 +1264,8 @@ def _build_aggregated_measure_from_measure_source_node( non_additive_dimension_spec = measure_properties.non_additive_dimension_spec cumulative_metric_adjusted_time_constraint: Optional[TimeRangeConstraint] = None - if cumulative and time_range_constraint is not None: - logger.info(f"Time range constraint before adjustment is {time_range_constraint}") + if cumulative and predicate_pushdown_params.time_range_constraint is not None: + logger.info(f"Time range constraint before adjustment is {predicate_pushdown_params.time_range_constraint}") granularity: Optional[TimeGranularity] = None count = 0 if cumulative_window is not None: @@ -1262,7 +1276,9 @@ def _build_aggregated_measure_from_measure_source_node( granularity = cumulative_grain_to_date cumulative_metric_adjusted_time_constraint = ( - time_range_constraint.adjust_time_constraint_for_cumulative_metric(granularity, count) + predicate_pushdown_params.time_range_constraint.adjust_time_constraint_for_cumulative_metric( + granularity, count + ) ) logger.info(f"Adjusted time range constraint {cumulative_metric_adjusted_time_constraint}") @@ -1283,15 +1299,18 @@ def _build_aggregated_measure_from_measure_source_node( + indent(f"\nevaluation:\n{mf_pformat(required_linkable_specs)}") ) + # TODO - Pushdown: Update this to be more robust to additional pushdown parameters + measure_time_constraint = ( + (cumulative_metric_adjusted_time_constraint or predicate_pushdown_params.time_range_constraint) + # If joining to time spine for time offset, constraints will be applied after that join. + if not before_aggregation_time_spine_join_description + else None + ) + measure_pushdown_params = PredicatePushdownParameters(time_range_constraint=measure_time_constraint) find_recipe_start_time = time.time() measure_recipe = self._find_dataflow_recipe( measure_spec_properties=measure_properties, - time_range_constraint=( - (cumulative_metric_adjusted_time_constraint or time_range_constraint) - # If joining to time spine for time offset, constraints will be applied after that join. - if not before_aggregation_time_spine_join_description - else None - ), + predicate_pushdown_params=measure_pushdown_params, linkable_spec_set=required_linkable_specs, ) logger.info( @@ -1323,8 +1342,12 @@ def _build_aggregated_measure_from_measure_source_node( time_dimension_spec_for_join=agg_time_dimension_spec_for_join, window=cumulative_window, grain_to_date=cumulative_grain_to_date, + # Note: we use the original constraint here because the JoinOverTimeRangeNode will eventually get + # rendered with an interval that expands the join window time_range_constraint=( - time_range_constraint if not before_aggregation_time_spine_join_description else None + predicate_pushdown_params.time_range_constraint + if not before_aggregation_time_spine_join_description + else None ), ) @@ -1339,11 +1362,13 @@ def _build_aggregated_measure_from_measure_source_node( f"Expected {SqlJoinType.INNER} for joining to time spine before aggregation. Remove this if there's a " f"new use case." ) + # This also uses the original time range constraint due to the application of the time window intervals + # in join rendering join_to_time_spine_node = JoinToTimeSpineNode( parent_node=time_range_node or measure_recipe.source_node, requested_agg_time_dimension_specs=queried_agg_time_dimension_specs, use_custom_agg_time_dimension=not queried_linkable_specs.contains_metric_time, - time_range_constraint=time_range_constraint, + time_range_constraint=predicate_pushdown_params.time_range_constraint, offset_window=before_aggregation_time_spine_join_description.offset_window, offset_to_grain=before_aggregation_time_spine_join_description.offset_to_grain, join_type=before_aggregation_time_spine_join_description.join_type, @@ -1379,12 +1404,16 @@ def _build_aggregated_measure_from_measure_source_node( # If time constraint was previously adjusted for cumulative window or grain, apply original time constraint # here. Can skip if metric is being aggregated over all time. cumulative_metric_constrained_node: Optional[ConstrainTimeRangeNode] = None - if cumulative_metric_adjusted_time_constraint is not None and time_range_constraint is not None: + # TODO - Pushdown: Encapsulate all of this window sliding bookkeeping in the pushdown params object + if ( + cumulative_metric_adjusted_time_constraint is not None + and predicate_pushdown_params.time_range_constraint is not None + ): assert ( queried_linkable_specs.contains_metric_time ), "Using time constraints currently requires querying with metric_time." cumulative_metric_constrained_node = ConstrainTimeRangeNode( - unaggregated_measure_node, time_range_constraint + unaggregated_measure_node, predicate_pushdown_params.time_range_constraint ) pre_aggregate_node: DataflowPlanNode = cumulative_metric_constrained_node or unaggregated_measure_node @@ -1448,16 +1477,16 @@ def _build_aggregated_measure_from_measure_source_node( requested_agg_time_dimension_specs=queried_agg_time_dimension_specs, use_custom_agg_time_dimension=not queried_linkable_specs.contains_metric_time, join_type=after_aggregation_time_spine_join_description.join_type, - time_range_constraint=time_range_constraint, + time_range_constraint=predicate_pushdown_params.time_range_constraint, offset_window=after_aggregation_time_spine_join_description.offset_window, offset_to_grain=after_aggregation_time_spine_join_description.offset_to_grain, ) # Since new rows might have been added due to time spine join, apply constraints again here. if len(metric_input_measure_spec.filter_specs) > 0: output_node = WhereConstraintNode(parent_node=output_node, where_constraint=merged_where_filter_spec) - if time_range_constraint is not None: + if predicate_pushdown_params.time_range_constraint is not None: output_node = ConstrainTimeRangeNode( - parent_node=output_node, time_range_constraint=time_range_constraint + parent_node=output_node, time_range_constraint=predicate_pushdown_params.time_range_constraint ) return output_node diff --git a/metricflow/plan_conversion/node_processor.py b/metricflow/plan_conversion/node_processor.py index 88904e3cbc..e21649e110 100644 --- a/metricflow/plan_conversion/node_processor.py +++ b/metricflow/plan_conversion/node_processor.py @@ -59,6 +59,16 @@ class MultiHopJoinCandidate: lineage: MultiHopJoinCandidateLineage +@dataclass(frozen=True) +class PredicatePushdownParameters: + """Container class for managing filter predicate pushdown. + + Stores time constraint information for applying pre-join time filters. + """ + + time_range_constraint: Optional[TimeRangeConstraint] + + class PreJoinNodeProcessor: """Processes source nodes before other nodes are joined.