Skip to content

Commit

Permalink
/* PR_START p--cte 03 */ Regularize select statement rendering.
Browse files Browse the repository at this point in the history
  • Loading branch information
plypaul committed Oct 15, 2024
1 parent 44ac927 commit 4dfc9c8
Showing 1 changed file with 79 additions and 78 deletions.
157 changes: 79 additions & 78 deletions metricflow/sql/render/sql_plan_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from string import Template
from typing import List, Optional, Sequence, Tuple
from typing import List, Optional, Sequence

from metricflow_semantics.mf_logging.formatting import indent
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet
Expand All @@ -16,9 +16,11 @@
SqlExpressionRenderResult,
)
from metricflow.sql.render.rendering_constants import SqlRenderingConstants
from metricflow.sql.sql_exprs import SqlExpressionNode
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlJoinDescription,
SqlOrderByDescription,
SqlQueryPlan,
SqlQueryPlanNode,
SqlQueryPlanNodeVisitor,
Expand Down Expand Up @@ -71,12 +73,24 @@ class DefaultSqlQueryPlanRenderer(SqlQueryPlanRenderer):
# The renderer that is used to render the SQL expressions.
EXPR_RENDERER = DefaultSqlExpressionRenderer()

def _render_description_section(self, description: str) -> Optional[SqlPlanRenderResult]:
"""Render the description section as a comment.
e.g.
-- Description of the node.
"""
if len(description) == 0:
return None
description_lines = [f"-- {x}" for x in description.split("\n") if x]
return SqlPlanRenderResult("\n".join(description_lines), SqlBindParameterSet())

def _render_select_columns_section(
self,
select_columns: Sequence[SqlSelectColumn],
num_parents: int,
distinct: bool,
) -> Tuple[str, SqlBindParameterSet]:
) -> SqlPlanRenderResult:
"""Convert the select columns into a "SELECT" section.
e.g.
Expand Down Expand Up @@ -119,11 +133,9 @@ def _render_select_columns_section(
indent(", " + column_select_str, indent_prefix=SqlRenderingConstants.INDENT)
)

return "\n".join(select_section_lines), params
return SqlPlanRenderResult("\n".join(select_section_lines), params)

def _render_from_section(
self, from_source: SqlQueryPlanNode, from_source_alias: str
) -> Tuple[str, SqlBindParameterSet]:
def _render_from_section(self, from_source: SqlQueryPlanNode, from_source_alias: str) -> SqlPlanRenderResult:
"""Convert the node into a "FROM" section.
e.g.
Expand All @@ -146,9 +158,9 @@ def _render_from_section(
from_section_lines.append(f") {from_source_alias}")
from_section = "\n".join(from_section_lines)

return from_section, from_render_result.bind_parameter_set
return SqlPlanRenderResult(from_section, from_render_result.bind_parameter_set)

def _render_joins_section(self, join_descriptions: Sequence[SqlJoinDescription]) -> Tuple[str, SqlBindParameterSet]:
def _render_joins_section(self, join_descriptions: Sequence[SqlJoinDescription]) -> Optional[SqlPlanRenderResult]:
"""Convert the join descriptions into a "JOIN" section.
e.g.
Expand All @@ -160,6 +172,9 @@ def _render_joins_section(self, join_descriptions: Sequence[SqlJoinDescription])
Returns a tuple of the "JOIN" section as a string and the associated execution parameters.
"""
if len(join_descriptions) == 0:
return None

params = SqlBindParameterSet()
join_section_lines = []
for join_description in join_descriptions:
Expand Down Expand Up @@ -194,9 +209,17 @@ def _render_joins_section(self, join_descriptions: Sequence[SqlJoinDescription])
textwrap.indent(on_condition_rendered.sql, prefix=SqlRenderingConstants.INDENT)
)

return "\n".join(join_section_lines), params
return SqlPlanRenderResult("\n".join(join_section_lines), params)

def _render_where(self, where_expression: Optional[SqlExpressionNode]) -> Optional[SqlPlanRenderResult]:
if where_expression is None:
return None

where_expression_render_result = self.EXPR_RENDERER.render_sql_expr(where_expression)
where_section = f"WHERE {where_expression_render_result.sql}"
return SqlPlanRenderResult(where_section, where_expression_render_result.bind_parameter_set)

def _render_group_by_section(self, group_by_columns: Sequence[SqlSelectColumn]) -> Tuple[str, SqlBindParameterSet]:
def _render_group_by_section(self, group_by_columns: Sequence[SqlSelectColumn]) -> Optional[SqlPlanRenderResult]:
"""Convert the group by columns into a "GROUP BY" section.
e.g.
Expand All @@ -205,6 +228,9 @@ def _render_group_by_section(self, group_by_columns: Sequence[SqlSelectColumn])
Returns a tuple of the "GROUP BY" section as a string and the associated execution parameters.
"""
if len(group_by_columns) == 0:
return None

group_by_section_lines: List[str] = []
params = SqlBindParameterSet()
first = True
Expand All @@ -222,83 +248,58 @@ def _render_group_by_section(self, group_by_columns: Sequence[SqlSelectColumn])
textwrap.indent(f", {group_by_expr_rendered.sql}", prefix=SqlRenderingConstants.INDENT)
)

return "\n".join(group_by_section_lines), params

def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlPlanRenderResult: # noqa: D102
# Keep track of all execution parameters for all expressions
combined_params = SqlBindParameterSet()

# Render description section
description_section = "\n".join([f"-- {x}" for x in node.description.split("\n") if x])

# Render "SELECT" column section
select_section, select_params = self._render_select_columns_section(
node.select_columns, len(node.parent_nodes), node.distinct
)
combined_params = combined_params.merge(select_params)

# Render "FROM" section
from_section, from_params = self._render_from_section(node.from_source, node.from_source_alias)
combined_params = combined_params.merge(from_params)
return SqlPlanRenderResult("\n".join(group_by_section_lines), params)

# Render "JOIN" section
join_section, join_params = self._render_joins_section(node.join_descs)
combined_params = combined_params.merge(join_params)

# Render "GROUP BY" section
group_by_section, group_by_params = self._render_group_by_section(node.group_bys)
combined_params = combined_params.merge(group_by_params)

# Render "WHERE" section
where_section = None
if node.where:
where_render_result = self.EXPR_RENDERER.render_sql_expr(node.where)
combined_params = combined_params.merge(where_render_result.bind_parameter_set)
where_section = f"WHERE {where_render_result.sql}"

# Render "ORDER BY" section
order_by_section = None
if node.order_bys:
order_by_items: List[str] = []
for order_by in node.order_bys:
order_by_render_result = self.EXPR_RENDERER.render_sql_expr(order_by.expr)
order_by_items.append(order_by_render_result.sql + (" DESC" if order_by.desc else ""))
combined_params = combined_params.merge(order_by_render_result.bind_parameter_set)

order_by_section = "ORDER BY " + ", ".join(order_by_items)

# Render "LIMIT" section
limit_section = None
if node.limit:
limit_section = f"LIMIT {node.limit}"

# Combine the sections into a single string.
sections_to_render = []

if description_section:
sections_to_render.append(description_section)
def _render_order_by_section(self, order_bys: Sequence[SqlOrderByDescription]) -> Optional[SqlPlanRenderResult]:
"""Convert the group by columns into a "GROUP BY" section.
sections_to_render.append(select_section)
sections_to_render.append(from_section)
e.g.
ORDER BY
a.ds DESC
"""
if len(order_bys) == 0:
return None

if join_section:
sections_to_render.append(join_section)
order_by_items: List[str] = []
bind_parameters = []

if where_section:
sections_to_render.append(where_section)
for order_by in order_bys:
expression_render_result = self.EXPR_RENDERER.render_sql_expr(order_by.expr)
order_by_items.append(expression_render_result.sql + (" DESC" if order_by.desc else ""))
bind_parameters.append(expression_render_result.bind_parameter_set)

if group_by_section:
sections_to_render.append(group_by_section)
return SqlPlanRenderResult(
"ORDER BY " + ", ".join(order_by_items), SqlBindParameterSet.merge_iterable(bind_parameters)
)

if order_by_section:
sections_to_render.append(order_by_section)
def _render_limit_section(self, limit_value: Optional[int]) -> Optional[SqlPlanRenderResult]:
"""Convert the limit value into a LIMIT section.
if limit_section:
sections_to_render.append(limit_section)
e.g.
LIMIT 1
"""
if limit_value is None:
return None
return SqlPlanRenderResult(sql=f"LIMIT {limit_value}", bind_parameter_set=SqlBindParameterSet())

def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlPlanRenderResult: # noqa: D102
render_results = [
self._render_description_section(node.description),
self._render_select_columns_section(node.select_columns, len(node.parent_nodes), node.distinct),
self._render_from_section(node.from_source, node.from_source_alias),
self._render_joins_section(node.join_descs),
self._render_where(node.where),
self._render_group_by_section(node.group_bys),
self._render_order_by_section(node.order_bys),
self._render_limit_section(node.limit),
]

valid_render_results = tuple(render_result for render_result in render_results if render_result is not None)
return SqlPlanRenderResult(
sql="\n".join(sections_to_render),
bind_parameter_set=combined_params,
sql="\n".join(render_result.sql for render_result in valid_render_results),
bind_parameter_set=SqlBindParameterSet.merge_iterable(
[render_result.bind_parameter_set for render_result in valid_render_results]
),
)

def visit_table_node(self, node: SqlTableNode) -> SqlPlanRenderResult: # noqa: D102
Expand Down

0 comments on commit 4dfc9c8

Please sign in to comment.