Skip to content

Commit

Permalink
Add classes to represent CTEs (#1462)
Browse files Browse the repository at this point in the history
* This PR adds classes to represent CTEs to the SQL object model.
* CTEs will be used in later PRs to simplify generated queries.
* Functionality to render CTEs is included, and existing code should not
hit the CTE use cases so many methods were left unimplemented.
  • Loading branch information
plypaul authored Oct 24, 2024
1 parent c5eeb9f commit 85dac0f
Show file tree
Hide file tree
Showing 9 changed files with 254 additions and 2 deletions.
1 change: 1 addition & 0 deletions metricflow-semantics/metricflow_semantics/dag/id_prefix.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class StaticIdPrefix(IdPrefix, Enum, metaclass=EnumMetaClassHelper):
SQL_PLAN_TABLE_FROM_CLAUSE_ID_PREFIX = "tfc"
SQL_PLAN_QUERY_FROM_CLAUSE_ID_PREFIX = "qfc"
SQL_PLAN_CREATE_TABLE_AS_ID_PREFIX = "cta"
SQL_PLAN_COMMON_TABLE_EXPRESSION_ID_PREFIX = "cte"

EXEC_NODE_READ_SQL_QUERY = "rsq"
EXEC_NODE_NOOP = "noop"
Expand Down
7 changes: 7 additions & 0 deletions metricflow/sql/optimizer/column_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
from collections import defaultdict
from typing import Dict, List, Set, Tuple

from typing_extensions import override

from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlQueryPlanOptimizer
from metricflow.sql.sql_exprs import (
SqlExpressionTreeLineage,
)
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteNode,
SqlJoinDescription,
SqlQueryPlanNode,
SqlQueryPlanNodeVisitor,
Expand Down Expand Up @@ -111,6 +114,10 @@ def _prune_columns_from_grandparents(
distinct=node.distinct,
)

@override
def visit_cte_node(self, node: SqlCteNode) -> SqlQueryPlanNode:
raise NotImplementedError

def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D102
# Remove columns that are not needed from this SELECT statement because the parent SELECT statement doesn't
# need them. However, keep columns that are in group bys because that changes the meaning of the query.
Expand Down
10 changes: 10 additions & 0 deletions metricflow/sql/optimizer/rewriting_sub_query_reducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from metricflow_semantics.mf_logging.formatting import indent
from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat
from typing_extensions import override

from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlQueryPlanOptimizer
from metricflow.sql.sql_exprs import (
Expand All @@ -19,6 +20,7 @@
)
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteNode,
SqlJoinDescription,
SqlOrderByDescription,
SqlQueryPlanNode,
Expand Down Expand Up @@ -582,6 +584,10 @@ def _rewrite_node_with_join(node: SqlSelectStatementNode) -> SqlSelectStatementN
distinct=node.distinct,
)

@override
def visit_cte_node(self, node: SqlCteNode) -> SqlQueryPlanNode:
raise NotImplementedError

def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D102
node_with_reduced_parents = self._reduce_parents(node)

Expand Down Expand Up @@ -727,6 +733,10 @@ def _find_matching_select(
return select_column
return None

@override
def visit_cte_node(self, node: SqlCteNode) -> SqlQueryPlanNode:
raise NotImplementedError

def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D102
new_group_bys = []
for group_by in node.group_bys:
Expand Down
7 changes: 7 additions & 0 deletions metricflow/sql/optimizer/sub_query_reducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
import logging
from typing import List, Optional

from typing_extensions import override

from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlQueryPlanOptimizer
from metricflow.sql.sql_exprs import SqlColumnReference, SqlColumnReferenceExpression
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteNode,
SqlJoinDescription,
SqlOrderByDescription,
SqlQueryPlanNode,
Expand Down Expand Up @@ -121,6 +124,10 @@ def _find_matching_table_alias(node: SqlSelectStatementNode, column_alias: str)
return column_reference_expr.col_ref.table_alias
return None

@override
def visit_cte_node(self, node: SqlCteNode) -> SqlQueryPlanNode:
raise NotImplementedError

def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D102
node_with_reduced_parents = self._reduce_parents(node)

Expand Down
7 changes: 7 additions & 0 deletions metricflow/sql/optimizer/table_alias_simplifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

import logging

from typing_extensions import override

from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlQueryPlanOptimizer
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteNode,
SqlJoinDescription,
SqlOrderByDescription,
SqlQueryPlanNode,
Expand All @@ -21,6 +24,10 @@
class SqlTableAliasSimplifierVisitor(SqlQueryPlanNodeVisitor[SqlQueryPlanNode]):
"""Visits the SQL query plan to see if table aliases can be omitted when rendering column references."""

@override
def visit_cte_node(self, node: SqlCteNode) -> SqlQueryPlanNode:
raise NotImplementedError

def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D102
# If there is only a single parent, no table aliases are required since there's no ambiguity.
should_simplify_table_aliases = len(node.parent_nodes) <= 1
Expand Down
41 changes: 41 additions & 0 deletions metricflow/sql/render/sql_plan_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from metricflow_semantics.mf_logging.formatting import indent
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet
from typing_extensions import override

from metricflow.sql.render.expr_renderer import (
DefaultSqlExpressionRenderer,
Expand All @@ -19,6 +20,7 @@
from metricflow.sql.sql_exprs import SqlExpressionNode
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteNode,
SqlJoinDescription,
SqlOrderByDescription,
SqlQueryPlan,
Expand Down Expand Up @@ -85,6 +87,44 @@ def _render_description_section(self, description: str) -> Optional[SqlPlanRende
description_lines = [f"-- {x}" for x in description.split("\n") if x]
return SqlPlanRenderResult("\n".join(description_lines), SqlBindParameterSet())

@override
def visit_cte_node(self, node: SqlCteNode) -> SqlPlanRenderResult:
lines = []
collected_bind_parameters = []
lines.append(f"{node.cte_alias} AS (")
select_statement_render_result = node.select_statement.accept(self)
lines.append(indent(select_statement_render_result.sql, indent_prefix=SqlRenderingConstants.INDENT))
collected_bind_parameters.append(select_statement_render_result.bind_parameter_set)
lines.append(")")

return SqlPlanRenderResult(
sql="\n".join(lines), bind_parameter_set=SqlBindParameterSet.merge_iterable(collected_bind_parameters)
)

def _render_cte_sections(self, cte_nodes: Sequence[SqlCteNode]) -> Optional[SqlPlanRenderResult]:
"""Convert the CTEs into a series of `WITH` clauses.
e.g.
WITH cte_alias_0 AS (
...
)
cte_alias_1 AS (
...
)
...
"""
if len(cte_nodes) == 0:
return None

cte_render_results = tuple(self.visit_cte_node(cte_node) for cte_node in cte_nodes)

return SqlPlanRenderResult(
sql="WITH " + "\n, ".join(cte_render_result.sql + "\n" for cte_render_result in cte_render_results),
bind_parameter_set=SqlBindParameterSet.merge_iterable(
[cte_render_result.bind_parameter_set for cte_render_result in cte_render_results]
),
)

def _render_select_columns_section(
self,
select_columns: Sequence[SqlSelectColumn],
Expand Down Expand Up @@ -285,6 +325,7 @@ def _render_limit_section(self, limit_value: Optional[int]) -> Optional[SqlPlanR
def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlPlanRenderResult: # noqa: D102
render_results = [
self._render_description_section(node.description),
self._render_cte_sections(node.cte_sources),
self._render_select_columns_section(node.select_columns, len(node.parent_nodes), node.distinct),
self._render_from_section(node.from_source, node.from_source_alias),
self._render_joins_section(node.join_descs),
Expand Down
51 changes: 49 additions & 2 deletions metricflow/sql/sql_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> Vi
def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> VisitorOutputT: # noqa: D102
raise NotImplementedError

@abstractmethod
def visit_cte_node(self, node: SqlCteNode) -> VisitorOutputT: # noqa: D102
raise NotImplementedError


@dataclass(frozen=True)
class SqlSelectColumn:
Expand Down Expand Up @@ -121,6 +125,7 @@ class SqlSelectStatementNode(SqlQueryPlanNode):
select_columns: Tuple[SqlSelectColumn, ...]
from_source: SqlQueryPlanNode
from_source_alias: str
cte_sources: Tuple[SqlCteNode, ...]
join_descs: Tuple[SqlJoinDescription, ...]
group_bys: Tuple[SqlSelectColumn, ...]
order_bys: Tuple[SqlOrderByDescription, ...]
Expand All @@ -134,20 +139,22 @@ def create( # noqa: D102
select_columns: Tuple[SqlSelectColumn, ...],
from_source: SqlQueryPlanNode,
from_source_alias: str,
cte_sources: Tuple[SqlCteNode, ...] = (),
join_descs: Tuple[SqlJoinDescription, ...] = (),
group_bys: Tuple[SqlSelectColumn, ...] = (),
order_bys: Tuple[SqlOrderByDescription, ...] = (),
where: Optional[SqlExpressionNode] = None,
limit: Optional[int] = None,
distinct: bool = False,
) -> SqlSelectStatementNode:
parent_nodes = [from_source] + [x.right_source for x in join_descs]
parent_nodes = (from_source,) + tuple(x.right_source for x in join_descs) + cte_sources
return SqlSelectStatementNode(
parent_nodes=tuple(parent_nodes),
parent_nodes=parent_nodes,
_description=description,
select_columns=select_columns,
from_source=from_source,
from_source_alias=from_source_alias,
cte_sources=cte_sources,
join_descs=join_descs,
group_bys=group_bys,
order_bys=order_bys,
Expand Down Expand Up @@ -334,3 +341,43 @@ def __init__(self, render_node: SqlQueryPlanNode, plan_id: Optional[DagId] = Non
@property
def render_node(self) -> SqlQueryPlanNode: # noqa: D102
return self._render_node


@dataclass(frozen=True)
class SqlCteNode(SqlQueryPlanNode):
"""Represents a single common table expression."""

select_statement: SqlSelectStatementNode
cte_alias: str

@staticmethod
def create(select_statement: SqlSelectStatementNode, cte_alias: str) -> SqlCteNode: # noqa: D102
return SqlCteNode(
parent_nodes=(select_statement,),
select_statement=select_statement,
cte_alias=cte_alias,
)

@override
def accept(self, visitor: SqlQueryPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT:
return visitor.visit_cte_node(self)

@property
@override
def is_table(self) -> bool:
return False

@property
@override
def as_select_node(self) -> Optional[SqlSelectStatementNode]:
return None

@property
@override
def description(self) -> str:
return "CTE"

@classmethod
@override
def id_prefix(cls) -> IdPrefix:
return StaticIdPrefix.SQL_PLAN_COMMON_TABLE_EXPRESSION_ID_PREFIX
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
-- cte_test
WITH cte_0 AS (
-- cte_select_0
SELECT
cte_source_table_0.col_0
FROM demo.cte_source_table_0 cte_source_table_0
)

, cte_1 AS (
-- cte_select_1
SELECT
cte_source_table_1.col_1
FROM demo.cte_source_table_1 cte_source_table_1
)

SELECT
cte_0.col_0 AS col_0
, cte_1.col_1 AS col_1
FROM cte_0 cte_0
LEFT OUTER JOIN
cte_1 cte_1
ON
cte_0.col_0 = cte_1.col_1
Loading

0 comments on commit 85dac0f

Please sign in to comment.