Skip to content

Commit

Permalink
/* PR_START p--cte 16 */ Add option to control CTE generation.
Browse files Browse the repository at this point in the history
  • Loading branch information
plypaul committed Nov 13, 2024
1 parent 3ea7d29 commit 0bff032
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 16 deletions.
13 changes: 8 additions & 5 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@
)
from metricflow.protocols.sql_client import SqlEngine
from metricflow.sql.optimizer.optimization_levels import (
SqlQueryGenerationOptionSet,
SqlQueryOptimizationLevel,
SqlQueryOptimizerConfiguration,
)
from metricflow.sql.sql_exprs import (
SqlAggregateFunctionExpression,
Expand Down Expand Up @@ -181,20 +181,23 @@ def convert_to_sql_query_plan(
sql_query_plan_id: Optional[DagId] = None,
) -> ConvertToSqlPlanResult:
"""Create an SQL query plan that represents the computation up to the given dataflow plan node."""
to_sql_visitor = DataflowNodeToSqlSubqueryVisitor(
# TODO: Handle generation with CTE.
to_sql_subquery_visitor = DataflowNodeToSqlSubqueryVisitor(
column_association_resolver=self.column_association_resolver,
semantic_manifest_lookup=self._semantic_manifest_lookup,
)
data_set = dataflow_plan_node.accept(to_sql_visitor)
data_set = dataflow_plan_node.accept(to_sql_subquery_visitor)

sql_node: SqlQueryPlanNode = data_set.sql_node
# TODO: Make this a more generally accessible attribute instead of checking against the
# BigQuery-ness of the engine
use_column_alias_in_group_by = sql_engine_type is SqlEngine.BIGQUERY

for optimizer in SqlQueryOptimizerConfiguration.optimizers_for_level(
option_set = SqlQueryGenerationOptionSet.options_for_level(
optimization_level, use_column_alias_in_group_by=use_column_alias_in_group_by
):
)

for optimizer in option_set.optimizers:
logger.debug(LazyFormat(lambda: f"Applying optimizer: {optimizer.__class__.__name__}"))
sql_node = optimizer.optimize(sql_node)
logger.debug(
Expand Down
47 changes: 36 additions & 11 deletions metricflow/sql/optimizer/optimization_levels.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

from dataclasses import dataclass
from enum import Enum
from typing import Sequence
from typing import Tuple

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted

from metricflow.sql.optimizer.column_pruner import SqlColumnPrunerOptimizer
from metricflow.sql.optimizer.rewriting_sub_query_reducer import SqlRewritingSubQueryReducer
Expand All @@ -18,27 +21,49 @@ class SqlQueryOptimizationLevel(Enum):
O2 = "O2"
O3 = "O3"
O4 = "O4"
O5 = "O5"


@dataclass(frozen=True)
class SqlQueryGenerationOptionSet:
"""Defines the different SQL generation optimizers / options that should be used at each level."""

class SqlQueryOptimizerConfiguration:
"""Defines the different optimizers that should be used at each level."""
optimizers: Tuple[SqlQueryPlanOptimizer, ...]

# Specifies whether CTEs can be used to simplify generated SQL.
allow_cte: bool

@staticmethod
def optimizers_for_level(
def options_for_level( # noqa: D102
level: SqlQueryOptimizationLevel, use_column_alias_in_group_by: bool
) -> Sequence[SqlQueryPlanOptimizer]:
"""Return the optimizers that should be applied (in order) for each level."""
) -> SqlQueryGenerationOptionSet:
optimizers: Tuple[SqlQueryPlanOptimizer, ...] = ()
allow_cte = False
if level is SqlQueryOptimizationLevel.O0:
return ()
pass
elif level is SqlQueryOptimizationLevel.O1:
return (SqlTableAliasSimplifier(),)
optimizers = (SqlTableAliasSimplifier(),)
elif level is SqlQueryOptimizationLevel.O2:
return (SqlColumnPrunerOptimizer(), SqlTableAliasSimplifier())
optimizers = (SqlColumnPrunerOptimizer(), SqlTableAliasSimplifier())
elif level is SqlQueryOptimizationLevel.O3:
return (SqlColumnPrunerOptimizer(), SqlSubQueryReducer(), SqlTableAliasSimplifier())
optimizers = (SqlColumnPrunerOptimizer(), SqlSubQueryReducer(), SqlTableAliasSimplifier())
elif level is SqlQueryOptimizationLevel.O4:
return (
optimizers = (
SqlColumnPrunerOptimizer(),
SqlRewritingSubQueryReducer(use_column_alias_in_group_bys=use_column_alias_in_group_by),
SqlTableAliasSimplifier(),
)
elif level is SqlQueryOptimizationLevel.O5:
optimizers = (
SqlColumnPrunerOptimizer(),
SqlRewritingSubQueryReducer(use_column_alias_in_group_bys=use_column_alias_in_group_by),
SqlTableAliasSimplifier(),
)
allow_cte = True
else:
assert_values_exhausted(level)

return SqlQueryGenerationOptionSet(
optimizers=optimizers,
allow_cte=allow_cte,
)

0 comments on commit 0bff032

Please sign in to comment.