diff --git a/metricflow-semantics/metricflow_semantics/dag/id_prefix.py b/metricflow-semantics/metricflow_semantics/dag/id_prefix.py index b6d1b2791b..e2be46af56 100644 --- a/metricflow-semantics/metricflow_semantics/dag/id_prefix.py +++ b/metricflow-semantics/metricflow_semantics/dag/id_prefix.py @@ -78,6 +78,7 @@ class StaticIdPrefix(IdPrefix, Enum, metaclass=EnumMetaClassHelper): SQL_PLAN_TABLE_FROM_CLAUSE_ID_PREFIX = "tfc" SQL_PLAN_QUERY_FROM_CLAUSE_ID_PREFIX = "qfc" SQL_PLAN_CREATE_TABLE_AS_ID_PREFIX = "cta" + SQL_PLAN_COMMON_TABLE_EXPRESSION_ID_PREFIX = "cta" EXEC_NODE_READ_SQL_QUERY = "rsq" EXEC_NODE_NOOP = "noop" diff --git a/metricflow/sql/optimizer/column_pruner.py b/metricflow/sql/optimizer/column_pruner.py index 2f44a1277f..79e6d8250a 100644 --- a/metricflow/sql/optimizer/column_pruner.py +++ b/metricflow/sql/optimizer/column_pruner.py @@ -4,12 +4,15 @@ from collections import defaultdict from typing import Dict, List, Set, Tuple +from typing_extensions import override + from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlQueryPlanOptimizer from metricflow.sql.sql_exprs import ( SqlExpressionTreeLineage, ) from metricflow.sql.sql_plan import ( SqlCreateTableAsNode, + SqlCteNode, SqlJoinDescription, SqlQueryPlanNode, SqlQueryPlanNodeVisitor, @@ -111,6 +114,10 @@ def _prune_columns_from_grandparents( distinct=node.distinct, ) + @override + def visit_cte_node(self, node: SqlCteNode) -> SqlQueryPlanNode: + raise NotImplementedError + def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D102 # Remove columns that are not needed from this SELECT statement because the parent SELECT statement doesn't # need them. However, keep columns that are in group bys because that changes the meaning of the query. diff --git a/metricflow/sql/optimizer/rewriting_sub_query_reducer.py b/metricflow/sql/optimizer/rewriting_sub_query_reducer.py index 66b66832d3..0a7f6de9d8 100644 --- a/metricflow/sql/optimizer/rewriting_sub_query_reducer.py +++ b/metricflow/sql/optimizer/rewriting_sub_query_reducer.py @@ -6,6 +6,7 @@ from metricflow_semantics.mf_logging.formatting import indent from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat +from typing_extensions import override from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlQueryPlanOptimizer from metricflow.sql.sql_exprs import ( @@ -19,6 +20,7 @@ ) from metricflow.sql.sql_plan import ( SqlCreateTableAsNode, + SqlCteNode, SqlJoinDescription, SqlOrderByDescription, SqlQueryPlanNode, @@ -582,6 +584,10 @@ def _rewrite_node_with_join(node: SqlSelectStatementNode) -> SqlSelectStatementN distinct=node.distinct, ) + @override + def visit_cte_node(self, node: SqlCteNode) -> SqlQueryPlanNode: + raise NotImplementedError + def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D102 node_with_reduced_parents = self._reduce_parents(node) @@ -727,6 +733,10 @@ def _find_matching_select( return select_column return None + @override + def visit_cte_node(self, node: SqlCteNode) -> SqlQueryPlanNode: + raise NotImplementedError + def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D102 new_group_bys = [] for group_by in node.group_bys: diff --git a/metricflow/sql/optimizer/sub_query_reducer.py b/metricflow/sql/optimizer/sub_query_reducer.py index f8f2098fca..c223d5d3f7 100644 --- a/metricflow/sql/optimizer/sub_query_reducer.py +++ b/metricflow/sql/optimizer/sub_query_reducer.py @@ -3,10 +3,13 @@ import logging from typing import List, Optional +from typing_extensions import override + from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlQueryPlanOptimizer from metricflow.sql.sql_exprs import SqlColumnReference, SqlColumnReferenceExpression from metricflow.sql.sql_plan import ( SqlCreateTableAsNode, + SqlCteNode, SqlJoinDescription, SqlOrderByDescription, SqlQueryPlanNode, @@ -121,6 +124,10 @@ def _find_matching_table_alias(node: SqlSelectStatementNode, column_alias: str) return column_reference_expr.col_ref.table_alias return None + @override + def visit_cte_node(self, node: SqlCteNode) -> SqlQueryPlanNode: + raise NotImplementedError + def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D102 node_with_reduced_parents = self._reduce_parents(node) diff --git a/metricflow/sql/optimizer/table_alias_simplifier.py b/metricflow/sql/optimizer/table_alias_simplifier.py index 6cc2906df7..646bdd0f05 100644 --- a/metricflow/sql/optimizer/table_alias_simplifier.py +++ b/metricflow/sql/optimizer/table_alias_simplifier.py @@ -2,9 +2,12 @@ import logging +from typing_extensions import override + from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlQueryPlanOptimizer from metricflow.sql.sql_plan import ( SqlCreateTableAsNode, + SqlCteNode, SqlJoinDescription, SqlOrderByDescription, SqlQueryPlanNode, @@ -21,6 +24,10 @@ class SqlTableAliasSimplifierVisitor(SqlQueryPlanNodeVisitor[SqlQueryPlanNode]): """Visits the SQL query plan to see if table aliases can be omitted when rendering column references.""" + @override + def visit_cte_node(self, node: SqlCteNode) -> SqlQueryPlanNode: + raise NotImplementedError + def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D102 # If there is only a single parent, no table aliases are required since there's no ambiguity. should_simplify_table_aliases = len(node.parent_nodes) <= 1 diff --git a/metricflow/sql/render/sql_plan_renderer.py b/metricflow/sql/render/sql_plan_renderer.py index 82cdb8d8d6..db4ccac9b6 100644 --- a/metricflow/sql/render/sql_plan_renderer.py +++ b/metricflow/sql/render/sql_plan_renderer.py @@ -9,6 +9,7 @@ from metricflow_semantics.mf_logging.formatting import indent from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet +from typing_extensions import override from metricflow.sql.render.expr_renderer import ( DefaultSqlExpressionRenderer, @@ -19,6 +20,7 @@ from metricflow.sql.sql_exprs import SqlExpressionNode from metricflow.sql.sql_plan import ( SqlCreateTableAsNode, + SqlCteNode, SqlJoinDescription, SqlOrderByDescription, SqlQueryPlan, @@ -85,6 +87,44 @@ def _render_description_section(self, description: str) -> Optional[SqlPlanRende description_lines = [f"-- {x}" for x in description.split("\n") if x] return SqlPlanRenderResult("\n".join(description_lines), SqlBindParameterSet()) + @override + def visit_cte_node(self, node: SqlCteNode) -> SqlPlanRenderResult: + lines = [] + collected_bind_parameters = [] + lines.append(f"{node.cte_alias} AS (") + select_statement_render_result = node.select_statement.accept(self) + lines.append(indent(select_statement_render_result.sql, indent_prefix=SqlRenderingConstants.INDENT)) + collected_bind_parameters.append(select_statement_render_result.bind_parameter_set) + lines.append(")") + + return SqlPlanRenderResult( + sql="\n".join(lines), bind_parameter_set=SqlBindParameterSet.merge_iterable(collected_bind_parameters) + ) + + def _render_cte_sections(self, cte_nodes: Sequence[SqlCteNode]) -> Optional[SqlPlanRenderResult]: + """Convert the CTEs into a series of `WITH` clauses. + + e.g. + WITH cte_alias_0 AS ( + ... + ) + cte_alias_1 AS ( + ... + ) + ... + """ + if len(cte_nodes) == 0: + return None + + cte_render_results = tuple(self.visit_cte_node(cte_node) for cte_node in cte_nodes) + + return SqlPlanRenderResult( + sql="WITH " + "\n, ".join(cte_render_result.sql + "\n" for cte_render_result in cte_render_results), + bind_parameter_set=SqlBindParameterSet.merge_iterable( + [cte_render_result.bind_parameter_set for cte_render_result in cte_render_results] + ), + ) + def _render_select_columns_section( self, select_columns: Sequence[SqlSelectColumn], @@ -285,6 +325,7 @@ def _render_limit_section(self, limit_value: Optional[int]) -> Optional[SqlPlanR def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlPlanRenderResult: # noqa: D102 render_results = [ self._render_description_section(node.description), + self._render_cte_sections(node.cte_sources), self._render_select_columns_section(node.select_columns, len(node.parent_nodes), node.distinct), self._render_from_section(node.from_source, node.from_source_alias), self._render_joins_section(node.join_descs), diff --git a/metricflow/sql/sql_plan.py b/metricflow/sql/sql_plan.py index 9fdabfc05d..ea4a61756b 100644 --- a/metricflow/sql/sql_plan.py +++ b/metricflow/sql/sql_plan.py @@ -74,6 +74,10 @@ def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> Vi def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> VisitorOutputT: # noqa: D102 raise NotImplementedError + @abstractmethod + def visit_cte_node(self, node: SqlCteNode) -> VisitorOutputT: # noqa: D102 + raise NotImplementedError + @dataclass(frozen=True) class SqlSelectColumn: @@ -121,6 +125,7 @@ class SqlSelectStatementNode(SqlQueryPlanNode): select_columns: Tuple[SqlSelectColumn, ...] from_source: SqlQueryPlanNode from_source_alias: str + cte_sources: Tuple[SqlCteNode, ...] join_descs: Tuple[SqlJoinDescription, ...] group_bys: Tuple[SqlSelectColumn, ...] order_bys: Tuple[SqlOrderByDescription, ...] @@ -134,6 +139,7 @@ def create( # noqa: D102 select_columns: Tuple[SqlSelectColumn, ...], from_source: SqlQueryPlanNode, from_source_alias: str, + cte_sources: Tuple[SqlCteNode, ...] = (), join_descs: Tuple[SqlJoinDescription, ...] = (), group_bys: Tuple[SqlSelectColumn, ...] = (), order_bys: Tuple[SqlOrderByDescription, ...] = (), @@ -141,13 +147,14 @@ def create( # noqa: D102 limit: Optional[int] = None, distinct: bool = False, ) -> SqlSelectStatementNode: - parent_nodes = [from_source] + [x.right_source for x in join_descs] + parent_nodes = (from_source,) + tuple(x.right_source for x in join_descs) + cte_sources return SqlSelectStatementNode( - parent_nodes=tuple(parent_nodes), + parent_nodes=parent_nodes, _description=description, select_columns=select_columns, from_source=from_source, from_source_alias=from_source_alias, + cte_sources=cte_sources, join_descs=join_descs, group_bys=group_bys, order_bys=order_bys, @@ -334,3 +341,43 @@ def __init__(self, render_node: SqlQueryPlanNode, plan_id: Optional[DagId] = Non @property def render_node(self) -> SqlQueryPlanNode: # noqa: D102 return self._render_node + + +@dataclass(frozen=True) +class SqlCteNode(SqlQueryPlanNode): + """Represents a single common table expression.""" + + select_statement: SqlSelectStatementNode + cte_alias: str + + @staticmethod + def create(select_statement: SqlSelectStatementNode, cte_alias: str) -> SqlCteNode: # noqa: D102 + return SqlCteNode( + parent_nodes=(select_statement,), + select_statement=select_statement, + cte_alias=cte_alias, + ) + + @override + def accept(self, visitor: SqlQueryPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: + return visitor.visit_cte_node(self) + + @property + @override + def is_table(self) -> bool: + return False + + @property + @override + def as_select_node(self) -> Optional[SqlSelectStatementNode]: + return None + + @property + @override + def description(self) -> str: + return "CTE" + + @classmethod + @override + def id_prefix(cls) -> IdPrefix: + return StaticIdPrefix.SQL_PLAN_COMMON_TABLE_EXPRESSION_ID_PREFIX