Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Regularize select statement rendering #1461

Merged
merged 1 commit into from
Oct 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 79 additions & 78 deletions metricflow/sql/render/sql_plan_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from string import Template
from typing import List, Optional, Sequence, Tuple
from typing import List, Optional, Sequence

from metricflow_semantics.mf_logging.formatting import indent
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet
Expand All @@ -16,9 +16,11 @@
SqlExpressionRenderResult,
)
from metricflow.sql.render.rendering_constants import SqlRenderingConstants
from metricflow.sql.sql_exprs import SqlExpressionNode
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlJoinDescription,
SqlOrderByDescription,
SqlQueryPlan,
SqlQueryPlanNode,
SqlQueryPlanNodeVisitor,
Expand Down Expand Up @@ -71,12 +73,24 @@ class DefaultSqlQueryPlanRenderer(SqlQueryPlanRenderer):
# The renderer that is used to render the SQL expressions.
EXPR_RENDERER = DefaultSqlExpressionRenderer()

def _render_description_section(self, description: str) -> Optional[SqlPlanRenderResult]:
"""Render the description section as a comment.

e.g.
-- Description of the node.

"""
if len(description) == 0:
return None
description_lines = [f"-- {x}" for x in description.split("\n") if x]
return SqlPlanRenderResult("\n".join(description_lines), SqlBindParameterSet())

def _render_select_columns_section(
self,
select_columns: Sequence[SqlSelectColumn],
num_parents: int,
distinct: bool,
) -> Tuple[str, SqlBindParameterSet]:
) -> SqlPlanRenderResult:
"""Convert the select columns into a "SELECT" section.

e.g.
Expand Down Expand Up @@ -119,11 +133,9 @@ def _render_select_columns_section(
indent(", " + column_select_str, indent_prefix=SqlRenderingConstants.INDENT)
)

return "\n".join(select_section_lines), params
return SqlPlanRenderResult("\n".join(select_section_lines), params)

def _render_from_section(
self, from_source: SqlQueryPlanNode, from_source_alias: str
) -> Tuple[str, SqlBindParameterSet]:
def _render_from_section(self, from_source: SqlQueryPlanNode, from_source_alias: str) -> SqlPlanRenderResult:
"""Convert the node into a "FROM" section.

e.g.
Expand All @@ -146,9 +158,9 @@ def _render_from_section(
from_section_lines.append(f") {from_source_alias}")
from_section = "\n".join(from_section_lines)

return from_section, from_render_result.bind_parameter_set
return SqlPlanRenderResult(from_section, from_render_result.bind_parameter_set)

def _render_joins_section(self, join_descriptions: Sequence[SqlJoinDescription]) -> Tuple[str, SqlBindParameterSet]:
def _render_joins_section(self, join_descriptions: Sequence[SqlJoinDescription]) -> Optional[SqlPlanRenderResult]:
"""Convert the join descriptions into a "JOIN" section.

e.g.
Expand All @@ -160,6 +172,9 @@ def _render_joins_section(self, join_descriptions: Sequence[SqlJoinDescription])

Returns a tuple of the "JOIN" section as a string and the associated execution parameters.
"""
if len(join_descriptions) == 0:
return None

params = SqlBindParameterSet()
join_section_lines = []
for join_description in join_descriptions:
Expand Down Expand Up @@ -194,9 +209,17 @@ def _render_joins_section(self, join_descriptions: Sequence[SqlJoinDescription])
textwrap.indent(on_condition_rendered.sql, prefix=SqlRenderingConstants.INDENT)
)

return "\n".join(join_section_lines), params
return SqlPlanRenderResult("\n".join(join_section_lines), params)

def _render_where(self, where_expression: Optional[SqlExpressionNode]) -> Optional[SqlPlanRenderResult]:
if where_expression is None:
return None

where_expression_render_result = self.EXPR_RENDERER.render_sql_expr(where_expression)
where_section = f"WHERE {where_expression_render_result.sql}"
return SqlPlanRenderResult(where_section, where_expression_render_result.bind_parameter_set)

def _render_group_by_section(self, group_by_columns: Sequence[SqlSelectColumn]) -> Tuple[str, SqlBindParameterSet]:
def _render_group_by_section(self, group_by_columns: Sequence[SqlSelectColumn]) -> Optional[SqlPlanRenderResult]:
"""Convert the group by columns into a "GROUP BY" section.

e.g.
Expand All @@ -205,6 +228,9 @@ def _render_group_by_section(self, group_by_columns: Sequence[SqlSelectColumn])

Returns a tuple of the "GROUP BY" section as a string and the associated execution parameters.
"""
if len(group_by_columns) == 0:
return None

group_by_section_lines: List[str] = []
params = SqlBindParameterSet()
first = True
Expand All @@ -222,83 +248,58 @@ def _render_group_by_section(self, group_by_columns: Sequence[SqlSelectColumn])
textwrap.indent(f", {group_by_expr_rendered.sql}", prefix=SqlRenderingConstants.INDENT)
)

return "\n".join(group_by_section_lines), params

def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlPlanRenderResult: # noqa: D102
# Keep track of all execution parameters for all expressions
combined_params = SqlBindParameterSet()

# Render description section
description_section = "\n".join([f"-- {x}" for x in node.description.split("\n") if x])

# Render "SELECT" column section
select_section, select_params = self._render_select_columns_section(
node.select_columns, len(node.parent_nodes), node.distinct
)
combined_params = combined_params.merge(select_params)

# Render "FROM" section
from_section, from_params = self._render_from_section(node.from_source, node.from_source_alias)
combined_params = combined_params.merge(from_params)
return SqlPlanRenderResult("\n".join(group_by_section_lines), params)

# Render "JOIN" section
join_section, join_params = self._render_joins_section(node.join_descs)
combined_params = combined_params.merge(join_params)

# Render "GROUP BY" section
group_by_section, group_by_params = self._render_group_by_section(node.group_bys)
combined_params = combined_params.merge(group_by_params)

# Render "WHERE" section
where_section = None
if node.where:
where_render_result = self.EXPR_RENDERER.render_sql_expr(node.where)
combined_params = combined_params.merge(where_render_result.bind_parameter_set)
where_section = f"WHERE {where_render_result.sql}"

# Render "ORDER BY" section
order_by_section = None
if node.order_bys:
order_by_items: List[str] = []
for order_by in node.order_bys:
order_by_render_result = self.EXPR_RENDERER.render_sql_expr(order_by.expr)
order_by_items.append(order_by_render_result.sql + (" DESC" if order_by.desc else ""))
combined_params = combined_params.merge(order_by_render_result.bind_parameter_set)

order_by_section = "ORDER BY " + ", ".join(order_by_items)

# Render "LIMIT" section
limit_section = None
if node.limit:
limit_section = f"LIMIT {node.limit}"

# Combine the sections into a single string.
sections_to_render = []

if description_section:
sections_to_render.append(description_section)
def _render_order_by_section(self, order_bys: Sequence[SqlOrderByDescription]) -> Optional[SqlPlanRenderResult]:
"""Convert the group by columns into a "GROUP BY" section.

sections_to_render.append(select_section)
sections_to_render.append(from_section)
e.g.
ORDER BY
a.ds DESC
"""
if len(order_bys) == 0:
return None

if join_section:
sections_to_render.append(join_section)
order_by_items: List[str] = []
bind_parameters = []

if where_section:
sections_to_render.append(where_section)
for order_by in order_bys:
expression_render_result = self.EXPR_RENDERER.render_sql_expr(order_by.expr)
order_by_items.append(expression_render_result.sql + (" DESC" if order_by.desc else ""))
bind_parameters.append(expression_render_result.bind_parameter_set)

if group_by_section:
sections_to_render.append(group_by_section)
return SqlPlanRenderResult(
"ORDER BY " + ", ".join(order_by_items), SqlBindParameterSet.merge_iterable(bind_parameters)
)

if order_by_section:
sections_to_render.append(order_by_section)
def _render_limit_section(self, limit_value: Optional[int]) -> Optional[SqlPlanRenderResult]:
"""Convert the limit value into a LIMIT section.

if limit_section:
sections_to_render.append(limit_section)
e.g.
LIMIT 1
"""
if limit_value is None:
return None
return SqlPlanRenderResult(sql=f"LIMIT {limit_value}", bind_parameter_set=SqlBindParameterSet())

def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlPlanRenderResult: # noqa: D102
render_results = [
self._render_description_section(node.description),
self._render_select_columns_section(node.select_columns, len(node.parent_nodes), node.distinct),
self._render_from_section(node.from_source, node.from_source_alias),
self._render_joins_section(node.join_descs),
self._render_where(node.where),
self._render_group_by_section(node.group_bys),
self._render_order_by_section(node.order_bys),
self._render_limit_section(node.limit),
]

valid_render_results = tuple(render_result for render_result in render_results if render_result is not None)
return SqlPlanRenderResult(
sql="\n".join(sections_to_render),
bind_parameter_set=combined_params,
sql="\n".join(render_result.sql for render_result in valid_render_results),
bind_parameter_set=SqlBindParameterSet.merge_iterable(
[render_result.bind_parameter_set for render_result in valid_render_results]
),
)

def visit_table_node(self, node: SqlTableNode) -> SqlPlanRenderResult: # noqa: D102
Expand Down
Loading