diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index a723d53bf..116930a20 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -104,8 +104,8 @@ ) from metricflow.protocols.sql_client import SqlEngine from metricflow.sql.optimizer.optimization_levels import ( + SqlQueryGenerationOptionSet, SqlQueryOptimizationLevel, - SqlQueryOptimizerConfiguration, ) from metricflow.sql.sql_exprs import ( SqlAggregateFunctionExpression, @@ -181,20 +181,23 @@ def convert_to_sql_query_plan( sql_query_plan_id: Optional[DagId] = None, ) -> ConvertToSqlPlanResult: """Create an SQL query plan that represents the computation up to the given dataflow plan node.""" - to_sql_visitor = DataflowNodeToSqlSubqueryVisitor( + # TODO: Handle generation with CTE. + to_sql_subquery_visitor = DataflowNodeToSqlSubqueryVisitor( column_association_resolver=self.column_association_resolver, semantic_manifest_lookup=self._semantic_manifest_lookup, ) - data_set = dataflow_plan_node.accept(to_sql_visitor) + data_set = dataflow_plan_node.accept(to_sql_subquery_visitor) sql_node: SqlQueryPlanNode = data_set.sql_node # TODO: Make this a more generally accessible attribute instead of checking against the # BigQuery-ness of the engine use_column_alias_in_group_by = sql_engine_type is SqlEngine.BIGQUERY - for optimizer in SqlQueryOptimizerConfiguration.optimizers_for_level( + option_set = SqlQueryGenerationOptionSet.options_for_level( optimization_level, use_column_alias_in_group_by=use_column_alias_in_group_by - ): + ) + + for optimizer in option_set.optimizers: logger.debug(LazyFormat(lambda: f"Applying optimizer: {optimizer.__class__.__name__}")) sql_node = optimizer.optimize(sql_node) logger.debug( diff --git a/metricflow/sql/optimizer/optimization_levels.py b/metricflow/sql/optimizer/optimization_levels.py index 2355ca53f..067019eb9 100644 --- a/metricflow/sql/optimizer/optimization_levels.py +++ b/metricflow/sql/optimizer/optimization_levels.py @@ -1,7 +1,10 @@ from __future__ import annotations +from dataclasses import dataclass from enum import Enum -from typing import Sequence +from typing import Tuple + +from dbt_semantic_interfaces.enum_extension import assert_values_exhausted from metricflow.sql.optimizer.column_pruner import SqlColumnPrunerOptimizer from metricflow.sql.optimizer.rewriting_sub_query_reducer import SqlRewritingSubQueryReducer @@ -18,27 +21,49 @@ class SqlQueryOptimizationLevel(Enum): O2 = "O2" O3 = "O3" O4 = "O4" + O5 = "O5" + +@dataclass(frozen=True) +class SqlQueryGenerationOptionSet: + """Defines the different SQL generation optimizers / options that should be used at each level.""" -class SqlQueryOptimizerConfiguration: - """Defines the different optimizers that should be used at each level.""" + optimizers: Tuple[SqlQueryPlanOptimizer, ...] + + # Specifies whether CTEs can be used to simplify generated SQL. + allow_cte: bool @staticmethod - def optimizers_for_level( + def options_for_level( # noqa: D102 level: SqlQueryOptimizationLevel, use_column_alias_in_group_by: bool - ) -> Sequence[SqlQueryPlanOptimizer]: - """Return the optimizers that should be applied (in order) for each level.""" + ) -> SqlQueryGenerationOptionSet: + optimizers: Tuple[SqlQueryPlanOptimizer, ...] = () + allow_cte = False if level is SqlQueryOptimizationLevel.O0: - return () + pass elif level is SqlQueryOptimizationLevel.O1: - return (SqlTableAliasSimplifier(),) + optimizers = (SqlTableAliasSimplifier(),) elif level is SqlQueryOptimizationLevel.O2: - return (SqlColumnPrunerOptimizer(), SqlTableAliasSimplifier()) + optimizers = (SqlColumnPrunerOptimizer(), SqlTableAliasSimplifier()) elif level is SqlQueryOptimizationLevel.O3: - return (SqlColumnPrunerOptimizer(), SqlSubQueryReducer(), SqlTableAliasSimplifier()) + optimizers = (SqlColumnPrunerOptimizer(), SqlSubQueryReducer(), SqlTableAliasSimplifier()) elif level is SqlQueryOptimizationLevel.O4: - return ( + optimizers = ( + SqlColumnPrunerOptimizer(), + SqlRewritingSubQueryReducer(use_column_alias_in_group_bys=use_column_alias_in_group_by), + SqlTableAliasSimplifier(), + ) + elif level is SqlQueryOptimizationLevel.O5: + optimizers = ( SqlColumnPrunerOptimizer(), SqlRewritingSubQueryReducer(use_column_alias_in_group_bys=use_column_alias_in_group_by), SqlTableAliasSimplifier(), ) + allow_cte = True + else: + assert_values_exhausted(level) + + return SqlQueryGenerationOptionSet( + optimizers=optimizers, + allow_cte=allow_cte, + )