From 70a5563ba4c8a3a8345f4d1374c5d2e20ec9579c Mon Sep 17 00:00:00 2001 From: Paul Yang Date: Sat, 9 Nov 2024 21:51:55 -0800 Subject: [PATCH] /* PR_START p--cte 17 */ Fix SQL optimization level not getting passed from the request. --- metricflow/engine/metricflow_engine.py | 13 +++++++------ metricflow/execution/dataflow_to_execution.py | 5 +++++ .../plan_conversion/test_dataflow_to_execution.py | 13 ++++--------- 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/metricflow/engine/metricflow_engine.py b/metricflow/engine/metricflow_engine.py index d5916a11d1..319aa1a2b6 100644 --- a/metricflow/engine/metricflow_engine.py +++ b/metricflow/engine/metricflow_engine.py @@ -409,11 +409,6 @@ def __init__( column_association_resolver=self._column_association_resolver, semantic_manifest_lookup=self._semantic_manifest_lookup, ) - self._to_execution_plan_converter = DataflowToExecutionPlanConverter( - sql_plan_converter=self._to_sql_query_plan_converter, - sql_plan_renderer=self._sql_client.sql_query_plan_renderer, - sql_client=sql_client, - ) self._executor = SequentialPlanExecutor() self._query_parser = query_parser or MetricFlowQueryParser( @@ -539,7 +534,13 @@ def _create_execution_plan(self, mf_query_request: MetricFlowQueryRequest) -> Me ) logger.info(LazyFormat("Building execution plan")) - convert_to_execution_plan_result = self._to_execution_plan_converter.convert_to_execution_plan(dataflow_plan) + _to_execution_plan_converter = DataflowToExecutionPlanConverter( + sql_plan_converter=self._to_sql_query_plan_converter, + sql_plan_renderer=self._sql_client.sql_query_plan_renderer, + sql_client=self._sql_client, + sql_optimization_level=mf_query_request.sql_optimization_level, + ) + convert_to_execution_plan_result = _to_execution_plan_converter.convert_to_execution_plan(dataflow_plan) return MetricFlowExplainResult( query_spec=query_spec, dataflow_plan=dataflow_plan, diff --git a/metricflow/execution/dataflow_to_execution.py b/metricflow/execution/dataflow_to_execution.py index 3a438a9f7c..fa0518e847 100644 --- a/metricflow/execution/dataflow_to_execution.py +++ b/metricflow/execution/dataflow_to_execution.py @@ -40,6 +40,7 @@ from metricflow.plan_conversion.convert_to_sql_plan import ConvertToSqlPlanResult from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlQueryPlanConverter from metricflow.protocols.sql_client import SqlClient +from metricflow.sql.optimizer.optimization_levels import SqlQueryOptimizationLevel from metricflow.sql.render.sql_plan_renderer import SqlPlanRenderResult, SqlQueryPlanRenderer logger = logging.getLogger(__name__) @@ -53,6 +54,7 @@ def __init__( sql_plan_converter: DataflowToSqlQueryPlanConverter, sql_plan_renderer: SqlQueryPlanRenderer, sql_client: SqlClient, + sql_optimization_level: SqlQueryOptimizationLevel, ) -> None: """Constructor. @@ -60,15 +62,18 @@ def __init__( sql_plan_converter: Converts a dataflow plan node to a SQL query plan sql_plan_renderer: Converts a SQL query plan to SQL text sql_client: The client to use for running queries. + sql_optimization_level: The optimization level to use for generating the SQL. """ self._sql_plan_converter = sql_plan_converter self._sql_plan_renderer = sql_plan_renderer self._sql_client = sql_client + self._optimization_level = sql_optimization_level def _convert_to_sql_plan(self, node: DataflowPlanNode) -> ConvertToSqlPlanResult: logger.debug(LazyFormat(lambda: f"Generating SQL query plan from {node.node_id}")) result = self._sql_plan_converter.convert_to_sql_query_plan( sql_engine_type=self._sql_client.sql_engine_type, + optimization_level=self._optimization_level, dataflow_plan_node=node, ) logger.debug(LazyFormat(lambda: f"Generated SQL query plan is:\n{result.sql_plan.structure_text()}")) diff --git a/tests_metricflow/plan_conversion/test_dataflow_to_execution.py b/tests_metricflow/plan_conversion/test_dataflow_to_execution.py index b92a26a2ec..c5b9503f3c 100644 --- a/tests_metricflow/plan_conversion/test_dataflow_to_execution.py +++ b/tests_metricflow/plan_conversion/test_dataflow_to_execution.py @@ -15,6 +15,7 @@ from metricflow.execution.dataflow_to_execution import DataflowToExecutionPlanConverter from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlQueryPlanConverter from metricflow.protocols.sql_client import SqlClient +from metricflow.sql.optimizer.optimization_levels import SqlQueryOptimizationLevel from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer from tests_metricflow.snapshot_utils import assert_execution_plan_text_equal @@ -30,6 +31,7 @@ def make_execution_plan_converter( # noqa: D103 ), sql_plan_renderer=DefaultSqlQueryPlanRenderer(), sql_client=sql_client, + sql_optimization_level=SqlQueryOptimizationLevel.O4, ) @@ -172,17 +174,10 @@ def test_multihop_joined_plan( ) ) - to_execution_plan_converter = DataflowToExecutionPlanConverter( - sql_plan_converter=DataflowToSqlQueryPlanConverter( - column_association_resolver=DunderColumnAssociationResolver( - partitioned_multi_hop_join_semantic_manifest_lookup - ), - semantic_manifest_lookup=partitioned_multi_hop_join_semantic_manifest_lookup, - ), - sql_plan_renderer=DefaultSqlQueryPlanRenderer(), + to_execution_plan_converter = make_execution_plan_converter( + semantic_manifest_lookup=partitioned_multi_hop_join_semantic_manifest_lookup, sql_client=sql_client, ) - execution_plan = to_execution_plan_converter.convert_to_execution_plan(dataflow_plan).execution_plan assert_execution_plan_text_equal(