From 161d81b52ce49ecf9d441bd4bc050af76185d095 Mon Sep 17 00:00:00 2001 From: tlento Date: Thu, 16 May 2024 18:41:02 -0700 Subject: [PATCH] Update pushdown_params -> pushdown_state --- .../dataflow/builder/dataflow_plan_builder.py | 98 +++++++++---------- metricflow/plan_conversion/node_processor.py | 8 +- .../builder/test_predicate_pushdown.py | 10 +- 3 files changed, 57 insertions(+), 59 deletions(-) diff --git a/metricflow/dataflow/builder/dataflow_plan_builder.py b/metricflow/dataflow/builder/dataflow_plan_builder.py index 0fa4ae9016..568ac9e81d 100644 --- a/metricflow/dataflow/builder/dataflow_plan_builder.py +++ b/metricflow/dataflow/builder/dataflow_plan_builder.py @@ -181,7 +181,7 @@ def _build_query_output_node( ) ) - predicate_pushdown_params = PredicatePushdownState(time_range_constraint=query_spec.time_range_constraint) + predicate_pushdown_state = PredicatePushdownState(time_range_constraint=query_spec.time_range_constraint) return self._build_metrics_output_node( metric_specs=tuple( @@ -193,7 +193,7 @@ def _build_query_output_node( ), queried_linkable_specs=query_spec.linkable_specs, filter_spec_factory=filter_spec_factory, - predicate_pushdown_params=predicate_pushdown_params, + predicate_pushdown_state=predicate_pushdown_state, for_group_by_source_node=for_group_by_source_node, ) @@ -237,7 +237,7 @@ def _build_aggregated_conversion_node( entity_spec: EntitySpec, window: Optional[MetricTimeWindow], queried_linkable_specs: LinkableSpecSet, - predicate_pushdown_params: PredicatePushdownState, + predicate_pushdown_state: PredicatePushdownState, constant_properties: Optional[Sequence[ConstantPropertyInput]] = None, ) -> DataflowPlanNode: """Builds a node that contains aggregated values of conversions and opportunities.""" @@ -247,7 +247,7 @@ def _build_aggregated_conversion_node( # implementation. disabled_pushdown_parameters = PredicatePushdownState.with_pushdown_disabled() time_range_only_pushdown_parameters = PredicatePushdownState( - time_range_constraint=predicate_pushdown_params.time_range_constraint, + time_range_constraint=predicate_pushdown_state.time_range_constraint, pushdown_enabled_types=frozenset([PredicateInputType.TIME_RANGE_CONSTRAINT]), ) @@ -258,13 +258,13 @@ def _build_aggregated_conversion_node( ) base_measure_recipe = self._find_dataflow_recipe( measure_spec_properties=self._build_measure_spec_properties([base_measure_spec.measure_spec]), - predicate_pushdown_params=time_range_only_pushdown_parameters, + predicate_pushdown_state=time_range_only_pushdown_parameters, linkable_spec_set=base_required_linkable_specs, ) logger.info(f"Recipe for base measure aggregation:\n{mf_pformat(base_measure_recipe)}") conversion_measure_recipe = self._find_dataflow_recipe( measure_spec_properties=self._build_measure_spec_properties([conversion_measure_spec.measure_spec]), - predicate_pushdown_params=disabled_pushdown_parameters, + predicate_pushdown_state=disabled_pushdown_parameters, linkable_spec_set=LinkableSpecSet(), ) logger.info(f"Recipe for conversion measure aggregation:\n{mf_pformat(conversion_measure_recipe)}") @@ -281,7 +281,7 @@ def _build_aggregated_conversion_node( aggregated_base_measure_node = self.build_aggregated_measure( metric_input_measure_spec=base_measure_spec, queried_linkable_specs=queried_linkable_specs, - predicate_pushdown_params=time_range_only_pushdown_parameters, + predicate_pushdown_state=time_range_only_pushdown_parameters, ) # Build unaggregated conversions source node @@ -357,7 +357,7 @@ def _build_aggregated_conversion_node( aggregated_conversions_node = self.build_aggregated_measure( metric_input_measure_spec=conversion_measure_spec, queried_linkable_specs=queried_linkable_specs, - predicate_pushdown_params=disabled_pushdown_parameters, + predicate_pushdown_state=disabled_pushdown_parameters, measure_recipe=recipe_with_join_conversion_source_node, ) @@ -369,7 +369,7 @@ def _build_conversion_metric_output_node( metric_spec: MetricSpec, queried_linkable_specs: LinkableSpecSet, filter_spec_factory: WhereSpecFactory, - predicate_pushdown_params: PredicatePushdownState, + predicate_pushdown_state: PredicatePushdownState, for_group_by_source_node: bool = False, ) -> ComputeMetricsNode: """Builds a compute metric node for a conversion metric.""" @@ -396,7 +396,7 @@ def _build_conversion_metric_output_node( base_measure_spec=base_measure, conversion_measure_spec=conversion_measure, queried_linkable_specs=queried_linkable_specs, - predicate_pushdown_params=predicate_pushdown_params, + predicate_pushdown_state=predicate_pushdown_state, entity_spec=entity_spec, window=conversion_type_params.window, constant_properties=conversion_type_params.constant_properties, @@ -414,7 +414,7 @@ def _build_base_metric_output_node( metric_spec: MetricSpec, queried_linkable_specs: LinkableSpecSet, filter_spec_factory: WhereSpecFactory, - predicate_pushdown_params: PredicatePushdownState, + predicate_pushdown_state: PredicatePushdownState, for_group_by_source_node: bool = False, ) -> ComputeMetricsNode: """Builds a node to compute a metric that is not defined from other metrics.""" @@ -459,7 +459,7 @@ def _build_base_metric_output_node( aggregated_measures_node = self.build_aggregated_measure( metric_input_measure_spec=metric_input_measure_spec, queried_linkable_specs=queried_linkable_specs, - predicate_pushdown_params=predicate_pushdown_params, + predicate_pushdown_state=predicate_pushdown_state, ) return self.build_computed_metrics_node( metric_spec=metric_spec, @@ -473,7 +473,7 @@ def _build_derived_metric_output_node( metric_spec: MetricSpec, queried_linkable_specs: LinkableSpecSet, filter_spec_factory: WhereSpecFactory, - predicate_pushdown_params: PredicatePushdownState, + predicate_pushdown_state: PredicatePushdownState, for_group_by_source_node: bool = False, ) -> DataflowPlanNode: """Builds a node to compute a metric defined from other metrics.""" @@ -509,8 +509,8 @@ def _build_derived_metric_output_node( if not metric_spec.has_time_offset: filter_specs.extend(metric_spec.filter_specs) - metric_pushdown_params = ( - predicate_pushdown_params + metric_pushdown_state = ( + predicate_pushdown_state if not metric_spec.has_time_offset else PredicatePushdownState.with_pushdown_disabled() ) @@ -528,7 +528,7 @@ def _build_derived_metric_output_node( queried_linkable_specs if not metric_spec.has_time_offset else required_linkable_specs ), filter_spec_factory=filter_spec_factory, - predicate_pushdown_params=metric_pushdown_params, + predicate_pushdown_state=metric_pushdown_state, ) ) @@ -554,7 +554,7 @@ def _build_derived_metric_output_node( parent_node=output_node, requested_agg_time_dimension_specs=queried_agg_time_dimension_specs, use_custom_agg_time_dimension=not queried_linkable_specs.contains_metric_time, - time_range_constraint=predicate_pushdown_params.time_range_constraint, + time_range_constraint=predicate_pushdown_state.time_range_constraint, offset_window=metric_spec.offset_window, offset_to_grain=metric_spec.offset_to_grain, join_type=SqlJoinType.INNER, @@ -577,7 +577,7 @@ def _build_any_metric_output_node( metric_spec: MetricSpec, queried_linkable_specs: LinkableSpecSet, filter_spec_factory: WhereSpecFactory, - predicate_pushdown_params: PredicatePushdownState, + predicate_pushdown_state: PredicatePushdownState, for_group_by_source_node: bool = False, ) -> DataflowPlanNode: """Builds a node to compute a metric of any type.""" @@ -588,7 +588,7 @@ def _build_any_metric_output_node( metric_spec=metric_spec, queried_linkable_specs=queried_linkable_specs, filter_spec_factory=filter_spec_factory, - predicate_pushdown_params=predicate_pushdown_params, + predicate_pushdown_state=predicate_pushdown_state, for_group_by_source_node=for_group_by_source_node, ) @@ -597,7 +597,7 @@ def _build_any_metric_output_node( metric_spec=metric_spec, queried_linkable_specs=queried_linkable_specs, filter_spec_factory=filter_spec_factory, - predicate_pushdown_params=predicate_pushdown_params, + predicate_pushdown_state=predicate_pushdown_state, for_group_by_source_node=for_group_by_source_node, ) elif metric.type is MetricType.CONVERSION: @@ -605,7 +605,7 @@ def _build_any_metric_output_node( metric_spec=metric_spec, queried_linkable_specs=queried_linkable_specs, filter_spec_factory=filter_spec_factory, - predicate_pushdown_params=predicate_pushdown_params, + predicate_pushdown_state=predicate_pushdown_state, for_group_by_source_node=for_group_by_source_node, ) @@ -616,7 +616,7 @@ def _build_metrics_output_node( metric_specs: Sequence[MetricSpec], queried_linkable_specs: LinkableSpecSet, filter_spec_factory: WhereSpecFactory, - predicate_pushdown_params: PredicatePushdownState, + predicate_pushdown_state: PredicatePushdownState, for_group_by_source_node: bool = False, ) -> DataflowPlanNode: """Builds a node that computes all requested metrics. @@ -626,7 +626,7 @@ def _build_metrics_output_node( include offsets and filters. queried_linkable_specs: Dimensions/entities that were queried. filter_spec_factory: Constructs WhereFilterSpecs with the resolved ambiguous group-by-items in the filter. - predicate_pushdown_params: Parameters for evaluating and applying filter predicate pushdown, e.g., for + predicate_pushdown_state: Parameters for evaluating and applying filter predicate pushdown, e.g., for applying time constraints prior to other dimension joins. """ output_nodes: List[DataflowPlanNode] = [] @@ -640,7 +640,7 @@ def _build_metrics_output_node( metric_spec=metric_spec, queried_linkable_specs=queried_linkable_specs, filter_spec_factory=filter_spec_factory, - predicate_pushdown_params=predicate_pushdown_params, + predicate_pushdown_state=predicate_pushdown_state, for_group_by_source_node=for_group_by_source_node, ) ) @@ -680,9 +680,9 @@ def _build_plan_for_distinct_values(self, query_spec: MetricFlowQuerySpec) -> Da required_linkable_specs, _ = self.__get_required_and_extraneous_linkable_specs( queried_linkable_specs=query_spec.linkable_specs, filter_specs=query_level_filter_specs ) - predicate_pushdown_params = PredicatePushdownState(time_range_constraint=query_spec.time_range_constraint) + predicate_pushdown_state = PredicatePushdownState(time_range_constraint=query_spec.time_range_constraint) dataflow_recipe = self._find_dataflow_recipe( - linkable_spec_set=required_linkable_specs, predicate_pushdown_params=predicate_pushdown_params + linkable_spec_set=required_linkable_specs, predicate_pushdown_state=predicate_pushdown_state ) if not dataflow_recipe: raise UnableToSatisfyQueryError(f"Unable to join all items in request: {required_linkable_specs}") @@ -847,7 +847,7 @@ def _build_measure_spec_properties(self, measure_specs: Sequence[MeasureSpec]) - def _find_dataflow_recipe( self, linkable_spec_set: LinkableSpecSet, - predicate_pushdown_params: PredicatePushdownState, + predicate_pushdown_state: PredicatePushdownState, measure_spec_properties: Optional[MeasureSpecProperties] = None, ) -> Optional[DataflowRecipe]: linkable_specs = linkable_spec_set.as_tuple @@ -884,14 +884,14 @@ def _find_dataflow_recipe( ) # TODO - Pushdown: Encapsulate this in the node processor if ( - predicate_pushdown_params.is_pushdown_enabled_for_time_range_constraint - and predicate_pushdown_params.time_range_constraint + predicate_pushdown_state.is_pushdown_enabled_for_time_range_constraint + and predicate_pushdown_state.time_range_constraint ): candidate_nodes_for_left_side_of_join = list( node_processor.add_time_range_constraint( source_nodes=candidate_nodes_for_left_side_of_join, metric_time_dimension_reference=self._metric_time_dimension_reference, - time_range_constraint=predicate_pushdown_params.time_range_constraint, + time_range_constraint=predicate_pushdown_state.time_range_constraint, ) ) @@ -1212,7 +1212,7 @@ def build_aggregated_measure( self, metric_input_measure_spec: MetricInputMeasureSpec, queried_linkable_specs: LinkableSpecSet, - predicate_pushdown_params: PredicatePushdownState, + predicate_pushdown_state: PredicatePushdownState, measure_recipe: Optional[DataflowRecipe] = None, ) -> DataflowPlanNode: """Returns a node where the measures are aggregated by the linkable specs and constrained appropriately. @@ -1232,7 +1232,7 @@ def build_aggregated_measure( return self._build_aggregated_measure_from_measure_source_node( metric_input_measure_spec=metric_input_measure_spec, queried_linkable_specs=queried_linkable_specs, - predicate_pushdown_params=predicate_pushdown_params, + predicate_pushdown_state=predicate_pushdown_state, measure_recipe=measure_recipe, ) @@ -1264,7 +1264,7 @@ def _build_aggregated_measure_from_measure_source_node( self, metric_input_measure_spec: MetricInputMeasureSpec, queried_linkable_specs: LinkableSpecSet, - predicate_pushdown_params: PredicatePushdownState, + predicate_pushdown_state: PredicatePushdownState, measure_recipe: Optional[DataflowRecipe] = None, ) -> DataflowPlanNode: measure_spec = metric_input_measure_spec.measure_spec @@ -1283,8 +1283,8 @@ def _build_aggregated_measure_from_measure_source_node( non_additive_dimension_spec = measure_properties.non_additive_dimension_spec cumulative_metric_adjusted_time_constraint: Optional[TimeRangeConstraint] = None - if cumulative and predicate_pushdown_params.time_range_constraint is not None: - logger.info(f"Time range constraint before adjustment is {predicate_pushdown_params.time_range_constraint}") + if cumulative and predicate_pushdown_state.time_range_constraint is not None: + logger.info(f"Time range constraint before adjustment is {predicate_pushdown_state.time_range_constraint}") granularity: Optional[TimeGranularity] = None count = 0 if cumulative_window is not None: @@ -1295,7 +1295,7 @@ def _build_aggregated_measure_from_measure_source_node( granularity = cumulative_grain_to_date cumulative_metric_adjusted_time_constraint = ( - predicate_pushdown_params.time_range_constraint.adjust_time_constraint_for_cumulative_metric( + predicate_pushdown_state.time_range_constraint.adjust_time_constraint_for_cumulative_metric( granularity, count ) ) @@ -1318,24 +1318,22 @@ def _build_aggregated_measure_from_measure_source_node( + indent(f"\nevaluation:\n{mf_pformat(required_linkable_specs)}") ) measure_time_constraint = ( - (cumulative_metric_adjusted_time_constraint or predicate_pushdown_params.time_range_constraint) + (cumulative_metric_adjusted_time_constraint or predicate_pushdown_state.time_range_constraint) # If joining to time spine for time offset, constraints will be applied after that join. if not before_aggregation_time_spine_join_description else None ) if measure_time_constraint is None: - measure_pushdown_params = PredicatePushdownState.without_time_range_constraint( - predicate_pushdown_params - ) + measure_pushdown_state = PredicatePushdownState.without_time_range_constraint(predicate_pushdown_state) else: - measure_pushdown_params = PredicatePushdownState.with_time_range_constraint( - predicate_pushdown_params, time_range_constraint=measure_time_constraint + measure_pushdown_state = PredicatePushdownState.with_time_range_constraint( + predicate_pushdown_state, time_range_constraint=measure_time_constraint ) find_recipe_start_time = time.time() measure_recipe = self._find_dataflow_recipe( measure_spec_properties=measure_properties, - predicate_pushdown_params=measure_pushdown_params, + predicate_pushdown_state=measure_pushdown_state, linkable_spec_set=required_linkable_specs, ) logger.info( @@ -1370,7 +1368,7 @@ def _build_aggregated_measure_from_measure_source_node( # Note: we use the original constraint here because the JoinOverTimeRangeNode will eventually get # rendered with an interval that expands the join window time_range_constraint=( - predicate_pushdown_params.time_range_constraint + predicate_pushdown_state.time_range_constraint if not before_aggregation_time_spine_join_description else None ), @@ -1393,7 +1391,7 @@ def _build_aggregated_measure_from_measure_source_node( parent_node=time_range_node or measure_recipe.source_node, requested_agg_time_dimension_specs=queried_agg_time_dimension_specs, use_custom_agg_time_dimension=not queried_linkable_specs.contains_metric_time, - time_range_constraint=predicate_pushdown_params.time_range_constraint, + time_range_constraint=predicate_pushdown_state.time_range_constraint, offset_window=before_aggregation_time_spine_join_description.offset_window, offset_to_grain=before_aggregation_time_spine_join_description.offset_to_grain, join_type=before_aggregation_time_spine_join_description.join_type, @@ -1432,13 +1430,13 @@ def _build_aggregated_measure_from_measure_source_node( # TODO - Pushdown: Encapsulate all of this window sliding bookkeeping in the pushdown params object if ( cumulative_metric_adjusted_time_constraint is not None - and predicate_pushdown_params.time_range_constraint is not None + and predicate_pushdown_state.time_range_constraint is not None ): assert ( queried_linkable_specs.contains_metric_time ), "Using time constraints currently requires querying with metric_time." cumulative_metric_constrained_node = ConstrainTimeRangeNode( - unaggregated_measure_node, predicate_pushdown_params.time_range_constraint + unaggregated_measure_node, predicate_pushdown_state.time_range_constraint ) pre_aggregate_node: DataflowPlanNode = cumulative_metric_constrained_node or unaggregated_measure_node @@ -1502,16 +1500,16 @@ def _build_aggregated_measure_from_measure_source_node( requested_agg_time_dimension_specs=queried_agg_time_dimension_specs, use_custom_agg_time_dimension=not queried_linkable_specs.contains_metric_time, join_type=after_aggregation_time_spine_join_description.join_type, - time_range_constraint=predicate_pushdown_params.time_range_constraint, + time_range_constraint=predicate_pushdown_state.time_range_constraint, offset_window=after_aggregation_time_spine_join_description.offset_window, offset_to_grain=after_aggregation_time_spine_join_description.offset_to_grain, ) # Since new rows might have been added due to time spine join, apply constraints again here. if len(metric_input_measure_spec.filter_specs) > 0: output_node = WhereConstraintNode(parent_node=output_node, where_constraint=merged_where_filter_spec) - if predicate_pushdown_params.time_range_constraint is not None: + if predicate_pushdown_state.time_range_constraint is not None: output_node = ConstrainTimeRangeNode( - parent_node=output_node, time_range_constraint=predicate_pushdown_params.time_range_constraint + parent_node=output_node, time_range_constraint=predicate_pushdown_state.time_range_constraint ) return output_node diff --git a/metricflow/plan_conversion/node_processor.py b/metricflow/plan_conversion/node_processor.py index 666553940c..6cecd50021 100644 --- a/metricflow/plan_conversion/node_processor.py +++ b/metricflow/plan_conversion/node_processor.py @@ -143,14 +143,14 @@ def is_pushdown_enabled_for_time_range_constraint(self) -> bool: @staticmethod def with_time_range_constraint( - original_pushdown_params: PredicatePushdownState, time_range_constraint: TimeRangeConstraint + original_pushdown_state: PredicatePushdownState, time_range_constraint: TimeRangeConstraint ) -> PredicatePushdownState: """Factory method for adding or updating a time range constraint input to a set of pushdown parameters. This allows for temporarily overriding a time range constraint with an adjusted one, or enabling a time range constraint filter if one becomes available mid-stream during dataflow plan construction. """ - pushdown_enabled_types = original_pushdown_params.pushdown_enabled_types.union( + pushdown_enabled_types = original_pushdown_state.pushdown_enabled_types.union( {PredicateInputType.TIME_RANGE_CONSTRAINT} ) return PredicatePushdownState( @@ -159,10 +159,10 @@ def with_time_range_constraint( @staticmethod def without_time_range_constraint( - original_pushdown_params: PredicatePushdownState, + original_pushdown_state: PredicatePushdownState, ) -> PredicatePushdownState: """Factory method for removing the time range constraint, if any, from the given set of pushdown parameters.""" - pushdown_enabled_types = original_pushdown_params.pushdown_enabled_types.difference( + pushdown_enabled_types = original_pushdown_state.pushdown_enabled_types.difference( {PredicateInputType.TIME_RANGE_CONSTRAINT} ) return PredicatePushdownState(time_range_constraint=None, pushdown_enabled_types=pushdown_enabled_types) diff --git a/tests_metricflow/dataflow/builder/test_predicate_pushdown.py b/tests_metricflow/dataflow/builder/test_predicate_pushdown.py index d82d2255b4..cf2a2cf991 100644 --- a/tests_metricflow/dataflow/builder/test_predicate_pushdown.py +++ b/tests_metricflow/dataflow/builder/test_predicate_pushdown.py @@ -7,7 +7,7 @@ @pytest.fixture -def all_pushdown_params() -> PredicatePushdownState: +def all_pushdown_state() -> PredicatePushdownState: """Tests a valid configuration with all predicate properties set and pushdown fully enabled.""" params = PredicatePushdownState( time_range_constraint=TimeRangeConstraint.all_time(), @@ -15,7 +15,7 @@ def all_pushdown_params() -> PredicatePushdownState: return params -def test_time_range_pushdown_enabled_states(all_pushdown_params: PredicatePushdownState) -> None: +def test_time_range_pushdown_enabled_states(all_pushdown_state: PredicatePushdownState) -> None: """Tests pushdown enabled check for time range pushdown operations.""" time_range_only_params = PredicatePushdownState( time_range_constraint=TimeRangeConstraint.all_time(), @@ -23,18 +23,18 @@ def test_time_range_pushdown_enabled_states(all_pushdown_params: PredicatePushdo ) enabled_states = { - "fully enabled": all_pushdown_params.is_pushdown_enabled_for_time_range_constraint, + "fully enabled": all_pushdown_state.is_pushdown_enabled_for_time_range_constraint, "enabled for time range only": time_range_only_params.is_pushdown_enabled_for_time_range_constraint, } assert all(list(enabled_states.values())), ( "Expected pushdown to be enabled for pushdown params with time range constraint and global pushdown enabled, " f"but some params returned False for is_pushdown_enabled.\nPushdown enabled states: {enabled_states}\n" - f"All params: {all_pushdown_params}\nTime range only params: {time_range_only_params}" + f"All params: {all_pushdown_state}\nTime range only params: {time_range_only_params}" ) -def test_invalid_disabled_pushdown_params() -> None: +def test_invalid_disabled_pushdown_state() -> None: """Tests checks for invalid param configuration on disabled pushdown parameters.""" with pytest.raises(AssertionError, match="Disabled pushdown parameters cannot have properties set"): PredicatePushdownState(time_range_constraint=TimeRangeConstraint.all_time(), pushdown_enabled_types=frozenset())