From e1402509dd81081adf2ec1c3531d69e41a1eae92 Mon Sep 17 00:00:00 2001 From: Paul Yang Date: Sat, 16 Mar 2024 14:34:22 -0700 Subject: [PATCH] Update visitor for `SqlQueryPlanNode` and add tests. --- metricflow/sql/optimizer/column_pruner.py | 7 +++ .../optimizer/rewriting_sub_query_reducer.py | 13 +++++ metricflow/sql/optimizer/sub_query_reducer.py | 7 +++ .../sql/optimizer/table_alias_simplifier.py | 7 +++ metricflow/sql/render/sql_plan_renderer.py | 22 ++++++++ metricflow/sql/sql_plan.py | 6 +- metricflow/test/sql/test_sql_plan_render.py | 55 ++++++++++++++++++- 7 files changed, 115 insertions(+), 2 deletions(-) diff --git a/metricflow/sql/optimizer/column_pruner.py b/metricflow/sql/optimizer/column_pruner.py index f7e2778d92..9dc73c97c2 100644 --- a/metricflow/sql/optimizer/column_pruner.py +++ b/metricflow/sql/optimizer/column_pruner.py @@ -9,6 +9,7 @@ SqlExpressionTreeLineage, ) from metricflow.sql.sql_plan import ( + SqlCreateTableAsNode, SqlJoinDescription, SqlQueryPlanNode, SqlQueryPlanNodeVisitor, @@ -198,6 +199,12 @@ def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> Sq """Pruning cannot be done here since this is an arbitrary user-provided SQL query.""" return node + def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlQueryPlanNode: # noqa: D + return SqlCreateTableAsNode( + sql_table=node.sql_table, + parent_node=node.parent_node.accept(self), + ) + class SqlColumnPrunerOptimizer(SqlQueryPlanOptimizer): """Removes unnecessary columns in the SELECT clauses.""" diff --git a/metricflow/sql/optimizer/rewriting_sub_query_reducer.py b/metricflow/sql/optimizer/rewriting_sub_query_reducer.py index 0ec2cf40a7..c0527acb45 100644 --- a/metricflow/sql/optimizer/rewriting_sub_query_reducer.py +++ b/metricflow/sql/optimizer/rewriting_sub_query_reducer.py @@ -16,6 +16,7 @@ SqlLogicalOperator, ) from metricflow.sql.sql_plan import ( + SqlCreateTableAsNode, SqlJoinDescription, SqlOrderByDescription, SqlQueryPlanNode, @@ -657,6 +658,12 @@ def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQuery def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> SqlQueryPlanNode: # noqa: D return node + def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlQueryPlanNode: # noqa: D + return SqlCreateTableAsNode( + sql_table=node.sql_table, + parent_node=node.parent_node.accept(self), + ) + class SqlGroupByRewritingVisitor(SqlQueryPlanNodeVisitor[SqlQueryPlanNode]): """Re-writes the GROUP BY to use a SqlColumnAliasReferenceExpression.""" @@ -715,6 +722,12 @@ def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQuery def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> SqlQueryPlanNode: # noqa: D return node + def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlQueryPlanNode: # noqa: D + return SqlCreateTableAsNode( + sql_table=node.sql_table, + parent_node=node.parent_node.accept(self), + ) + class SqlRewritingSubQueryReducer(SqlQueryPlanOptimizer): """Simplify queries by eliminating sub-queries when possible by rewriting expressions. diff --git a/metricflow/sql/optimizer/sub_query_reducer.py b/metricflow/sql/optimizer/sub_query_reducer.py index c3ef22a6ff..981794c773 100644 --- a/metricflow/sql/optimizer/sub_query_reducer.py +++ b/metricflow/sql/optimizer/sub_query_reducer.py @@ -6,6 +6,7 @@ from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlQueryPlanOptimizer from metricflow.sql.sql_exprs import SqlColumnReference, SqlColumnReferenceExpression from metricflow.sql.sql_plan import ( + SqlCreateTableAsNode, SqlJoinDescription, SqlOrderByDescription, SqlQueryPlanNode, @@ -193,6 +194,12 @@ def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQuery def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> SqlQueryPlanNode: # noqa: D return node + def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlQueryPlanNode: # noqa: D + return SqlCreateTableAsNode( + sql_table=node.sql_table, + parent_node=node.parent_node.accept(self), + ) + class SqlSubQueryReducer(SqlQueryPlanOptimizer): """Simplify queries by eliminating sub-queries when possible. diff --git a/metricflow/sql/optimizer/table_alias_simplifier.py b/metricflow/sql/optimizer/table_alias_simplifier.py index 19a71e1936..2a8cf9a8f2 100644 --- a/metricflow/sql/optimizer/table_alias_simplifier.py +++ b/metricflow/sql/optimizer/table_alias_simplifier.py @@ -4,6 +4,7 @@ from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlQueryPlanOptimizer from metricflow.sql.sql_plan import ( + SqlCreateTableAsNode, SqlJoinDescription, SqlOrderByDescription, SqlQueryPlanNode, @@ -74,6 +75,12 @@ def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQuery def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> SqlQueryPlanNode: # noqa: D return node + def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlQueryPlanNode: # noqa: D + return SqlCreateTableAsNode( + sql_table=node.sql_table, + parent_node=node.parent_node.accept(self), + ) + class SqlTableAliasSimplifier(SqlQueryPlanOptimizer): """Simplify queries by eliminating table aliases when possible. diff --git a/metricflow/sql/render/sql_plan_renderer.py b/metricflow/sql/render/sql_plan_renderer.py index 7bdd8e09b2..c1ca17c46b 100644 --- a/metricflow/sql/render/sql_plan_renderer.py +++ b/metricflow/sql/render/sql_plan_renderer.py @@ -4,6 +4,7 @@ import textwrap from abc import ABC, abstractmethod from dataclasses import dataclass +from string import Template from typing import List, Optional, Sequence, Tuple from metricflow.mf_logging.formatting import indent @@ -15,6 +16,7 @@ from metricflow.sql.render.rendering_constants import SqlRenderingConstants from metricflow.sql.sql_bind_parameters import SqlBindParameters from metricflow.sql.sql_plan import ( + SqlCreateTableAsNode, SqlJoinDescription, SqlQueryPlan, SqlQueryPlanNode, @@ -310,6 +312,26 @@ def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> Sq bind_parameters=SqlBindParameters(), ) + def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlPlanRenderResult: # noqa: D + inner_sql_render_result = node.parent_node.accept(self) + inner_sql = inner_sql_render_result.sql + # Using a substitution since inner_sql can have multiple lines, and then dedent() wouldn't dent due to the + # short line. + sql = Template( + textwrap.dedent( + f"""\ + CREATE {node.sql_table.table_type.value.upper()} {node.sql_table.sql} AS ( + $inner_sql + ) + """ + ).rstrip() + ).substitute({"inner_sql": indent(inner_sql, indent_prefix=SqlRenderingConstants.INDENT)}) + + return SqlPlanRenderResult( + sql=sql, + bind_parameters=inner_sql_render_result.bind_parameters, + ) + @property def expr_renderer(self) -> SqlExpressionRenderer: # noqa :D return self.EXPR_RENDERER diff --git a/metricflow/sql/sql_plan.py b/metricflow/sql/sql_plan.py index 8e36c9403d..4eaa78c32f 100644 --- a/metricflow/sql/sql_plan.py +++ b/metricflow/sql/sql_plan.py @@ -77,6 +77,10 @@ def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> VisitorO def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> VisitorOutputT: # noqa: D raise NotImplementedError + @abstractmethod + def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> VisitorOutputT: # noqa: D + raise NotImplementedError + @dataclass(frozen=True) class SqlSelectColumn: @@ -301,7 +305,7 @@ def __init__(self, sql_table: SqlTable, parent_node: SqlQueryPlanNode) -> None: @override def accept(self, visitor: SqlQueryPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: - raise NotImplementedError + return visitor.visit_create_table_as_node(self) @property @override diff --git a/metricflow/test/sql/test_sql_plan_render.py b/metricflow/test/sql/test_sql_plan_render.py index d5ae779a1c..4db9743eb2 100644 --- a/metricflow/test/sql/test_sql_plan_render.py +++ b/metricflow/test/sql/test_sql_plan_render.py @@ -17,6 +17,7 @@ SqlStringExpression, ) from metricflow.sql.sql_plan import ( + SqlCreateTableAsNode, SqlJoinDescription, SqlJoinType, SqlOrderByDescription, @@ -24,7 +25,7 @@ SqlSelectStatementNode, SqlTableFromClauseNode, ) -from metricflow.sql.sql_table import SqlTable +from metricflow.sql.sql_table import SqlTable, SqlTableType from metricflow.test.fixtures.setup_fixtures import MetricFlowTestSessionState from metricflow.test.sql.compare_sql_plan import assert_rendered_sql_equal @@ -339,3 +340,55 @@ def test_render_limit( # noqa: D plan_id="plan0", sql_client=sql_client, ) + + +@pytest.mark.sql_engine_snapshot +def test_render_create_table_as( # noqa: D + request: FixtureRequest, + mf_test_session_state: MetricFlowTestSessionState, + sql_client: SqlClient, +) -> None: + select_node = SqlSelectStatementNode( + description="select_0", + select_columns=( + SqlSelectColumn( + expr=SqlColumnReferenceExpression(col_ref=SqlColumnReference(table_alias="a", column_name="bookings")), + column_alias="bookings", + ), + ), + from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")), + from_source_alias="a", + joins_descs=(), + where=None, + group_bys=(), + order_bys=(), + limit=1, + ) + assert_rendered_sql_equal( + request=request, + mf_test_session_state=mf_test_session_state, + sql_plan_node=SqlCreateTableAsNode( + sql_table=SqlTable( + schema_name="schema_name", + table_name="table_name", + table_type=SqlTableType.TABLE, + ), + parent_node=select_node, + ), + plan_id="create_table_as", + sql_client=sql_client, + ) + assert_rendered_sql_equal( + request=request, + mf_test_session_state=mf_test_session_state, + sql_plan_node=SqlCreateTableAsNode( + sql_table=SqlTable( + schema_name="schema_name", + table_name="table_name", + table_type=SqlTableType.VIEW, + ), + parent_node=select_node, + ), + plan_id="create_view_as", + sql_client=sql_client, + )