Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add classes to represent CTEs #1462

Merged
merged 3 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions metricflow-semantics/metricflow_semantics/dag/id_prefix.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class StaticIdPrefix(IdPrefix, Enum, metaclass=EnumMetaClassHelper):
SQL_PLAN_TABLE_FROM_CLAUSE_ID_PREFIX = "tfc"
SQL_PLAN_QUERY_FROM_CLAUSE_ID_PREFIX = "qfc"
SQL_PLAN_CREATE_TABLE_AS_ID_PREFIX = "cta"
SQL_PLAN_COMMON_TABLE_EXPRESSION_ID_PREFIX = "cte"

EXEC_NODE_READ_SQL_QUERY = "rsq"
EXEC_NODE_NOOP = "noop"
Expand Down
7 changes: 7 additions & 0 deletions metricflow/sql/optimizer/column_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
from collections import defaultdict
from typing import Dict, List, Set, Tuple

from typing_extensions import override

from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlQueryPlanOptimizer
from metricflow.sql.sql_exprs import (
SqlExpressionTreeLineage,
)
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteNode,
SqlJoinDescription,
SqlQueryPlanNode,
SqlQueryPlanNodeVisitor,
Expand Down Expand Up @@ -111,6 +114,10 @@ def _prune_columns_from_grandparents(
distinct=node.distinct,
)

@override
def visit_cte_node(self, node: SqlCteNode) -> SqlQueryPlanNode:
raise NotImplementedError

def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D102
# Remove columns that are not needed from this SELECT statement because the parent SELECT statement doesn't
# need them. However, keep columns that are in group bys because that changes the meaning of the query.
Expand Down
10 changes: 10 additions & 0 deletions metricflow/sql/optimizer/rewriting_sub_query_reducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from metricflow_semantics.mf_logging.formatting import indent
from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat
from typing_extensions import override

from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlQueryPlanOptimizer
from metricflow.sql.sql_exprs import (
Expand All @@ -19,6 +20,7 @@
)
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteNode,
SqlJoinDescription,
SqlOrderByDescription,
SqlQueryPlanNode,
Expand Down Expand Up @@ -582,6 +584,10 @@ def _rewrite_node_with_join(node: SqlSelectStatementNode) -> SqlSelectStatementN
distinct=node.distinct,
)

@override
def visit_cte_node(self, node: SqlCteNode) -> SqlQueryPlanNode:
raise NotImplementedError

def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D102
node_with_reduced_parents = self._reduce_parents(node)

Expand Down Expand Up @@ -727,6 +733,10 @@ def _find_matching_select(
return select_column
return None

@override
def visit_cte_node(self, node: SqlCteNode) -> SqlQueryPlanNode:
raise NotImplementedError

def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D102
new_group_bys = []
for group_by in node.group_bys:
Expand Down
7 changes: 7 additions & 0 deletions metricflow/sql/optimizer/sub_query_reducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
import logging
from typing import List, Optional

from typing_extensions import override

from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlQueryPlanOptimizer
from metricflow.sql.sql_exprs import SqlColumnReference, SqlColumnReferenceExpression
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteNode,
SqlJoinDescription,
SqlOrderByDescription,
SqlQueryPlanNode,
Expand Down Expand Up @@ -121,6 +124,10 @@ def _find_matching_table_alias(node: SqlSelectStatementNode, column_alias: str)
return column_reference_expr.col_ref.table_alias
return None

@override
def visit_cte_node(self, node: SqlCteNode) -> SqlQueryPlanNode:
raise NotImplementedError

def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D102
node_with_reduced_parents = self._reduce_parents(node)

Expand Down
7 changes: 7 additions & 0 deletions metricflow/sql/optimizer/table_alias_simplifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

import logging

from typing_extensions import override

from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlQueryPlanOptimizer
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteNode,
SqlJoinDescription,
SqlOrderByDescription,
SqlQueryPlanNode,
Expand All @@ -21,6 +24,10 @@
class SqlTableAliasSimplifierVisitor(SqlQueryPlanNodeVisitor[SqlQueryPlanNode]):
"""Visits the SQL query plan to see if table aliases can be omitted when rendering column references."""

@override
def visit_cte_node(self, node: SqlCteNode) -> SqlQueryPlanNode:
raise NotImplementedError

def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D102
# If there is only a single parent, no table aliases are required since there's no ambiguity.
should_simplify_table_aliases = len(node.parent_nodes) <= 1
Expand Down
41 changes: 41 additions & 0 deletions metricflow/sql/render/sql_plan_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from metricflow_semantics.mf_logging.formatting import indent
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet
from typing_extensions import override

from metricflow.sql.render.expr_renderer import (
DefaultSqlExpressionRenderer,
Expand All @@ -19,6 +20,7 @@
from metricflow.sql.sql_exprs import SqlExpressionNode
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteNode,
SqlJoinDescription,
SqlOrderByDescription,
SqlQueryPlan,
Expand Down Expand Up @@ -85,6 +87,44 @@ def _render_description_section(self, description: str) -> Optional[SqlPlanRende
description_lines = [f"-- {x}" for x in description.split("\n") if x]
return SqlPlanRenderResult("\n".join(description_lines), SqlBindParameterSet())

@override
def visit_cte_node(self, node: SqlCteNode) -> SqlPlanRenderResult:
lines = []
collected_bind_parameters = []
lines.append(f"{node.cte_alias} AS (")
select_statement_render_result = node.select_statement.accept(self)
lines.append(indent(select_statement_render_result.sql, indent_prefix=SqlRenderingConstants.INDENT))
collected_bind_parameters.append(select_statement_render_result.bind_parameter_set)
lines.append(")")

return SqlPlanRenderResult(
sql="\n".join(lines), bind_parameter_set=SqlBindParameterSet.merge_iterable(collected_bind_parameters)
)

def _render_cte_sections(self, cte_nodes: Sequence[SqlCteNode]) -> Optional[SqlPlanRenderResult]:
"""Convert the CTEs into a series of `WITH` clauses.

e.g.
WITH cte_alias_0 AS (
...
)
cte_alias_1 AS (
...
)
...
"""
if len(cte_nodes) == 0:
return None

cte_render_results = tuple(self.visit_cte_node(cte_node) for cte_node in cte_nodes)

return SqlPlanRenderResult(
sql="WITH " + "\n, ".join(cte_render_result.sql + "\n" for cte_render_result in cte_render_results),
bind_parameter_set=SqlBindParameterSet.merge_iterable(
[cte_render_result.bind_parameter_set for cte_render_result in cte_render_results]
),
)

def _render_select_columns_section(
self,
select_columns: Sequence[SqlSelectColumn],
Expand Down Expand Up @@ -285,6 +325,7 @@ def _render_limit_section(self, limit_value: Optional[int]) -> Optional[SqlPlanR
def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlPlanRenderResult: # noqa: D102
render_results = [
self._render_description_section(node.description),
self._render_cte_sections(node.cte_sources),
self._render_select_columns_section(node.select_columns, len(node.parent_nodes), node.distinct),
self._render_from_section(node.from_source, node.from_source_alias),
self._render_joins_section(node.join_descs),
Expand Down
51 changes: 49 additions & 2 deletions metricflow/sql/sql_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> Vi
def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> VisitorOutputT: # noqa: D102
raise NotImplementedError

@abstractmethod
def visit_cte_node(self, node: SqlCteNode) -> VisitorOutputT: # noqa: D102
raise NotImplementedError


@dataclass(frozen=True)
class SqlSelectColumn:
Expand Down Expand Up @@ -121,6 +125,7 @@ class SqlSelectStatementNode(SqlQueryPlanNode):
select_columns: Tuple[SqlSelectColumn, ...]
from_source: SqlQueryPlanNode
from_source_alias: str
cte_sources: Tuple[SqlCteNode, ...]
join_descs: Tuple[SqlJoinDescription, ...]
group_bys: Tuple[SqlSelectColumn, ...]
order_bys: Tuple[SqlOrderByDescription, ...]
Expand All @@ -134,20 +139,22 @@ def create( # noqa: D102
select_columns: Tuple[SqlSelectColumn, ...],
from_source: SqlQueryPlanNode,
from_source_alias: str,
cte_sources: Tuple[SqlCteNode, ...] = (),
join_descs: Tuple[SqlJoinDescription, ...] = (),
group_bys: Tuple[SqlSelectColumn, ...] = (),
order_bys: Tuple[SqlOrderByDescription, ...] = (),
where: Optional[SqlExpressionNode] = None,
limit: Optional[int] = None,
distinct: bool = False,
) -> SqlSelectStatementNode:
parent_nodes = [from_source] + [x.right_source for x in join_descs]
parent_nodes = (from_source,) + tuple(x.right_source for x in join_descs) + cte_sources
return SqlSelectStatementNode(
parent_nodes=tuple(parent_nodes),
parent_nodes=parent_nodes,
_description=description,
select_columns=select_columns,
from_source=from_source,
from_source_alias=from_source_alias,
cte_sources=cte_sources,
join_descs=join_descs,
group_bys=group_bys,
order_bys=order_bys,
Expand Down Expand Up @@ -334,3 +341,43 @@ def __init__(self, render_node: SqlQueryPlanNode, plan_id: Optional[DagId] = Non
@property
def render_node(self) -> SqlQueryPlanNode: # noqa: D102
return self._render_node


@dataclass(frozen=True)
class SqlCteNode(SqlQueryPlanNode):
"""Represents a single common table expression."""

select_statement: SqlSelectStatementNode
cte_alias: str

@staticmethod
def create(select_statement: SqlSelectStatementNode, cte_alias: str) -> SqlCteNode: # noqa: D102
return SqlCteNode(
parent_nodes=(select_statement,),
select_statement=select_statement,
cte_alias=cte_alias,
)

@override
def accept(self, visitor: SqlQueryPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT:
return visitor.visit_cte_node(self)

@property
@override
def is_table(self) -> bool:
return False

@property
@override
def as_select_node(self) -> Optional[SqlSelectStatementNode]:
return None

@property
@override
def description(self) -> str:
return "CTE"

@classmethod
@override
def id_prefix(cls) -> IdPrefix:
return StaticIdPrefix.SQL_PLAN_COMMON_TABLE_EXPRESSION_ID_PREFIX
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
-- cte_test
WITH cte_0 AS (
-- cte_select_0
SELECT
cte_source_table_0.col_0
FROM demo.cte_source_table_0 cte_source_table_0
)

, cte_1 AS (
-- cte_select_1
SELECT
cte_source_table_1.col_1
FROM demo.cte_source_table_1 cte_source_table_1
)

SELECT
cte_0.col_0 AS col_0
, cte_1.col_1 AS col_1
FROM cte_0 cte_0
LEFT OUTER JOIN
cte_1 cte_1
ON
cte_0.col_0 = cte_1.col_1
Loading
Loading