From d3bf3ac532b93664e48a21192e021ad920f66877 Mon Sep 17 00:00:00 2001 From: Paul Yang Date: Mon, 14 Oct 2024 11:27:37 -0700 Subject: [PATCH] /* PR_START p--cte 03 */ Regularize select statement rendering. --- metricflow/sql/render/sql_plan_renderer.py | 157 +++++++++++---------- 1 file changed, 79 insertions(+), 78 deletions(-) diff --git a/metricflow/sql/render/sql_plan_renderer.py b/metricflow/sql/render/sql_plan_renderer.py index b493df346a..82cdb8d8d6 100644 --- a/metricflow/sql/render/sql_plan_renderer.py +++ b/metricflow/sql/render/sql_plan_renderer.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from string import Template -from typing import List, Optional, Sequence, Tuple +from typing import List, Optional, Sequence from metricflow_semantics.mf_logging.formatting import indent from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet @@ -16,9 +16,11 @@ SqlExpressionRenderResult, ) from metricflow.sql.render.rendering_constants import SqlRenderingConstants +from metricflow.sql.sql_exprs import SqlExpressionNode from metricflow.sql.sql_plan import ( SqlCreateTableAsNode, SqlJoinDescription, + SqlOrderByDescription, SqlQueryPlan, SqlQueryPlanNode, SqlQueryPlanNodeVisitor, @@ -71,12 +73,24 @@ class DefaultSqlQueryPlanRenderer(SqlQueryPlanRenderer): # The renderer that is used to render the SQL expressions. EXPR_RENDERER = DefaultSqlExpressionRenderer() + def _render_description_section(self, description: str) -> Optional[SqlPlanRenderResult]: + """Render the description section as a comment. + + e.g. + -- Description of the node. + + """ + if len(description) == 0: + return None + description_lines = [f"-- {x}" for x in description.split("\n") if x] + return SqlPlanRenderResult("\n".join(description_lines), SqlBindParameterSet()) + def _render_select_columns_section( self, select_columns: Sequence[SqlSelectColumn], num_parents: int, distinct: bool, - ) -> Tuple[str, SqlBindParameterSet]: + ) -> SqlPlanRenderResult: """Convert the select columns into a "SELECT" section. e.g. @@ -119,11 +133,9 @@ def _render_select_columns_section( indent(", " + column_select_str, indent_prefix=SqlRenderingConstants.INDENT) ) - return "\n".join(select_section_lines), params + return SqlPlanRenderResult("\n".join(select_section_lines), params) - def _render_from_section( - self, from_source: SqlQueryPlanNode, from_source_alias: str - ) -> Tuple[str, SqlBindParameterSet]: + def _render_from_section(self, from_source: SqlQueryPlanNode, from_source_alias: str) -> SqlPlanRenderResult: """Convert the node into a "FROM" section. e.g. @@ -146,9 +158,9 @@ def _render_from_section( from_section_lines.append(f") {from_source_alias}") from_section = "\n".join(from_section_lines) - return from_section, from_render_result.bind_parameter_set + return SqlPlanRenderResult(from_section, from_render_result.bind_parameter_set) - def _render_joins_section(self, join_descriptions: Sequence[SqlJoinDescription]) -> Tuple[str, SqlBindParameterSet]: + def _render_joins_section(self, join_descriptions: Sequence[SqlJoinDescription]) -> Optional[SqlPlanRenderResult]: """Convert the join descriptions into a "JOIN" section. e.g. @@ -160,6 +172,9 @@ def _render_joins_section(self, join_descriptions: Sequence[SqlJoinDescription]) Returns a tuple of the "JOIN" section as a string and the associated execution parameters. """ + if len(join_descriptions) == 0: + return None + params = SqlBindParameterSet() join_section_lines = [] for join_description in join_descriptions: @@ -194,9 +209,17 @@ def _render_joins_section(self, join_descriptions: Sequence[SqlJoinDescription]) textwrap.indent(on_condition_rendered.sql, prefix=SqlRenderingConstants.INDENT) ) - return "\n".join(join_section_lines), params + return SqlPlanRenderResult("\n".join(join_section_lines), params) + + def _render_where(self, where_expression: Optional[SqlExpressionNode]) -> Optional[SqlPlanRenderResult]: + if where_expression is None: + return None + + where_expression_render_result = self.EXPR_RENDERER.render_sql_expr(where_expression) + where_section = f"WHERE {where_expression_render_result.sql}" + return SqlPlanRenderResult(where_section, where_expression_render_result.bind_parameter_set) - def _render_group_by_section(self, group_by_columns: Sequence[SqlSelectColumn]) -> Tuple[str, SqlBindParameterSet]: + def _render_group_by_section(self, group_by_columns: Sequence[SqlSelectColumn]) -> Optional[SqlPlanRenderResult]: """Convert the group by columns into a "GROUP BY" section. e.g. @@ -205,6 +228,9 @@ def _render_group_by_section(self, group_by_columns: Sequence[SqlSelectColumn]) Returns a tuple of the "GROUP BY" section as a string and the associated execution parameters. """ + if len(group_by_columns) == 0: + return None + group_by_section_lines: List[str] = [] params = SqlBindParameterSet() first = True @@ -222,83 +248,58 @@ def _render_group_by_section(self, group_by_columns: Sequence[SqlSelectColumn]) textwrap.indent(f", {group_by_expr_rendered.sql}", prefix=SqlRenderingConstants.INDENT) ) - return "\n".join(group_by_section_lines), params - - def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlPlanRenderResult: # noqa: D102 - # Keep track of all execution parameters for all expressions - combined_params = SqlBindParameterSet() - - # Render description section - description_section = "\n".join([f"-- {x}" for x in node.description.split("\n") if x]) - - # Render "SELECT" column section - select_section, select_params = self._render_select_columns_section( - node.select_columns, len(node.parent_nodes), node.distinct - ) - combined_params = combined_params.merge(select_params) - - # Render "FROM" section - from_section, from_params = self._render_from_section(node.from_source, node.from_source_alias) - combined_params = combined_params.merge(from_params) + return SqlPlanRenderResult("\n".join(group_by_section_lines), params) - # Render "JOIN" section - join_section, join_params = self._render_joins_section(node.join_descs) - combined_params = combined_params.merge(join_params) - - # Render "GROUP BY" section - group_by_section, group_by_params = self._render_group_by_section(node.group_bys) - combined_params = combined_params.merge(group_by_params) - - # Render "WHERE" section - where_section = None - if node.where: - where_render_result = self.EXPR_RENDERER.render_sql_expr(node.where) - combined_params = combined_params.merge(where_render_result.bind_parameter_set) - where_section = f"WHERE {where_render_result.sql}" - - # Render "ORDER BY" section - order_by_section = None - if node.order_bys: - order_by_items: List[str] = [] - for order_by in node.order_bys: - order_by_render_result = self.EXPR_RENDERER.render_sql_expr(order_by.expr) - order_by_items.append(order_by_render_result.sql + (" DESC" if order_by.desc else "")) - combined_params = combined_params.merge(order_by_render_result.bind_parameter_set) - - order_by_section = "ORDER BY " + ", ".join(order_by_items) - - # Render "LIMIT" section - limit_section = None - if node.limit: - limit_section = f"LIMIT {node.limit}" - - # Combine the sections into a single string. - sections_to_render = [] - - if description_section: - sections_to_render.append(description_section) + def _render_order_by_section(self, order_bys: Sequence[SqlOrderByDescription]) -> Optional[SqlPlanRenderResult]: + """Convert the group by columns into a "GROUP BY" section. - sections_to_render.append(select_section) - sections_to_render.append(from_section) + e.g. + ORDER BY + a.ds DESC + """ + if len(order_bys) == 0: + return None - if join_section: - sections_to_render.append(join_section) + order_by_items: List[str] = [] + bind_parameters = [] - if where_section: - sections_to_render.append(where_section) + for order_by in order_bys: + expression_render_result = self.EXPR_RENDERER.render_sql_expr(order_by.expr) + order_by_items.append(expression_render_result.sql + (" DESC" if order_by.desc else "")) + bind_parameters.append(expression_render_result.bind_parameter_set) - if group_by_section: - sections_to_render.append(group_by_section) + return SqlPlanRenderResult( + "ORDER BY " + ", ".join(order_by_items), SqlBindParameterSet.merge_iterable(bind_parameters) + ) - if order_by_section: - sections_to_render.append(order_by_section) + def _render_limit_section(self, limit_value: Optional[int]) -> Optional[SqlPlanRenderResult]: + """Convert the limit value into a LIMIT section. - if limit_section: - sections_to_render.append(limit_section) + e.g. + LIMIT 1 + """ + if limit_value is None: + return None + return SqlPlanRenderResult(sql=f"LIMIT {limit_value}", bind_parameter_set=SqlBindParameterSet()) + def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlPlanRenderResult: # noqa: D102 + render_results = [ + self._render_description_section(node.description), + self._render_select_columns_section(node.select_columns, len(node.parent_nodes), node.distinct), + self._render_from_section(node.from_source, node.from_source_alias), + self._render_joins_section(node.join_descs), + self._render_where(node.where), + self._render_group_by_section(node.group_bys), + self._render_order_by_section(node.order_bys), + self._render_limit_section(node.limit), + ] + + valid_render_results = tuple(render_result for render_result in render_results if render_result is not None) return SqlPlanRenderResult( - sql="\n".join(sections_to_render), - bind_parameter_set=combined_params, + sql="\n".join(render_result.sql for render_result in valid_render_results), + bind_parameter_set=SqlBindParameterSet.merge_iterable( + [render_result.bind_parameter_set for render_result in valid_render_results] + ), ) def visit_table_node(self, node: SqlTableNode) -> SqlPlanRenderResult: # noqa: D102