From 8fc896b0911715be229a66233df71e97b20724ed Mon Sep 17 00:00:00 2001 From: tlento Date: Tue, 25 Jun 2024 15:25:19 -0700 Subject: [PATCH] Enable PredicatePushdownOptimization for all MetricFlowEngine queries This effectively releases PredicatePushdownOptimization - the moment this change is deployed to cloud it will be enabled. In order to allow for a rapid mitigation of any unexpected issues this also parameterizes the query request object to allow callers to disable optimizations as needed. This means cloud services calling this method can override the optimizer behaviors without requiring an update to MetricFlow. --- .../dataflow/optimizer/dataflow_optimizer_factory.py | 5 +++++ metricflow/engine/metricflow_engine.py | 11 ++++++++--- .../query_rendering/compare_rendered_query.py | 10 ++++------ 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/metricflow/dataflow/optimizer/dataflow_optimizer_factory.py b/metricflow/dataflow/optimizer/dataflow_optimizer_factory.py index 8ccc8d8bba..c06d02bb27 100644 --- a/metricflow/dataflow/optimizer/dataflow_optimizer_factory.py +++ b/metricflow/dataflow/optimizer/dataflow_optimizer_factory.py @@ -24,6 +24,11 @@ class DataflowPlanOptimization(Enum): SOURCE_SCAN = 0 PREDICATE_PUSHDOWN = 1 + @staticmethod + def all_optimizations() -> FrozenSet[DataflowPlanOptimization]: + """Convenience method for getting a set of all available optimizations.""" + return frozenset((DataflowPlanOptimization.SOURCE_SCAN, DataflowPlanOptimization.PREDICATE_PUSHDOWN)) + class DataflowPlanOptimizerFactory: """Factory class for initializing an enumerated set of optimizers. diff --git a/metricflow/engine/metricflow_engine.py b/metricflow/engine/metricflow_engine.py index ec85ba9278..24e070f0e7 100644 --- a/metricflow/engine/metricflow_engine.py +++ b/metricflow/engine/metricflow_engine.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum -from typing import List, Optional, Sequence, Tuple +from typing import FrozenSet, List, Optional, Sequence, Tuple from dbt_semantic_interfaces.implementations.elements.dimension import PydanticDimensionTypeParams from dbt_semantic_interfaces.implementations.filters.where_filter import PydanticWhereFilter @@ -113,6 +113,7 @@ class MetricFlowQueryRequest: order_by: Optional[Sequence[OrderByQueryParameter]] = None min_max_only: bool = False sql_optimization_level: SqlQueryOptimizationLevel = SqlQueryOptimizationLevel.O4 + dataflow_plan_optimizations: FrozenSet[DataflowPlanOptimization] = DataflowPlanOptimization.all_optimizations() query_type: MetricFlowQueryType = MetricFlowQueryType.METRIC @staticmethod @@ -129,6 +130,7 @@ def create_with_random_request_id( # noqa: D102 order_by_names: Optional[Sequence[str]] = None, order_by: Optional[Sequence[OrderByQueryParameter]] = None, sql_optimization_level: SqlQueryOptimizationLevel = SqlQueryOptimizationLevel.O4, + dataflow_plan_optimizations: FrozenSet[DataflowPlanOptimization] = DataflowPlanOptimization.all_optimizations(), query_type: MetricFlowQueryType = MetricFlowQueryType.METRIC, min_max_only: bool = False, ) -> MetricFlowQueryRequest: @@ -146,6 +148,7 @@ def create_with_random_request_id( # noqa: D102 order_by_names=order_by_names, order_by=order_by, sql_optimization_level=sql_optimization_level, + dataflow_plan_optimizations=dataflow_plan_optimizations, query_type=query_type, min_max_only=min_max_only, ) @@ -500,10 +503,12 @@ def _create_execution_plan(self, mf_query_request: MetricFlowQueryRequest) -> Me dataflow_plan = self._dataflow_plan_builder.build_plan( query_spec=query_spec, output_selection_specs=output_selection_specs, - optimizations=frozenset({DataflowPlanOptimization.SOURCE_SCAN}), + optimizations=mf_query_request.dataflow_plan_optimizations, ) else: - dataflow_plan = self._dataflow_plan_builder.build_plan_for_distinct_values(query_spec=query_spec) + dataflow_plan = self._dataflow_plan_builder.build_plan_for_distinct_values( + query_spec=query_spec, optimizations=mf_query_request.dataflow_plan_optimizations + ) if len(dataflow_plan.sink_nodes) > 1: raise NotImplementedError( diff --git a/tests_metricflow/query_rendering/compare_rendered_query.py b/tests_metricflow/query_rendering/compare_rendered_query.py index d13085a928..b1651a9cac 100644 --- a/tests_metricflow/query_rendering/compare_rendered_query.py +++ b/tests_metricflow/query_rendering/compare_rendered_query.py @@ -52,16 +52,14 @@ def render_and_check( ) # Run dataflow -> sql conversion with all optimizers - optimizations = ( - DataflowPlanOptimization.SOURCE_SCAN, - DataflowPlanOptimization.PREDICATE_PUSHDOWN, - ) if is_distinct_values_plan: optimized_plan = dataflow_plan_builder.build_plan_for_distinct_values( - query_spec, optimizations=frozenset(optimizations) + query_spec, optimizations=DataflowPlanOptimization.all_optimizations() ) else: - optimized_plan = dataflow_plan_builder.build_plan(query_spec, optimizations=frozenset(optimizations)) + optimized_plan = dataflow_plan_builder.build_plan( + query_spec, optimizations=DataflowPlanOptimization.all_optimizations() + ) conversion_result = dataflow_to_sql_converter.convert_to_sql_query_plan( sql_engine_type=sql_client.sql_engine_type, dataflow_plan_node=optimized_plan.sink_node,