Skip to content

Commit

Permalink
Update visitor for SqlQueryPlanNode and add tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
plypaul committed Mar 19, 2024
1 parent 2e2d2a8 commit e140250
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 2 deletions.
7 changes: 7 additions & 0 deletions metricflow/sql/optimizer/column_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
SqlExpressionTreeLineage,
)
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlJoinDescription,
SqlQueryPlanNode,
SqlQueryPlanNodeVisitor,
Expand Down Expand Up @@ -198,6 +199,12 @@ def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> Sq
"""Pruning cannot be done here since this is an arbitrary user-provided SQL query."""
return node

def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlQueryPlanNode: # noqa: D
return SqlCreateTableAsNode(
sql_table=node.sql_table,
parent_node=node.parent_node.accept(self),
)


class SqlColumnPrunerOptimizer(SqlQueryPlanOptimizer):
"""Removes unnecessary columns in the SELECT clauses."""
Expand Down
13 changes: 13 additions & 0 deletions metricflow/sql/optimizer/rewriting_sub_query_reducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
SqlLogicalOperator,
)
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlJoinDescription,
SqlOrderByDescription,
SqlQueryPlanNode,
Expand Down Expand Up @@ -657,6 +658,12 @@ def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQuery
def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> SqlQueryPlanNode: # noqa: D
return node

def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlQueryPlanNode: # noqa: D
return SqlCreateTableAsNode(
sql_table=node.sql_table,
parent_node=node.parent_node.accept(self),
)


class SqlGroupByRewritingVisitor(SqlQueryPlanNodeVisitor[SqlQueryPlanNode]):
"""Re-writes the GROUP BY to use a SqlColumnAliasReferenceExpression."""
Expand Down Expand Up @@ -715,6 +722,12 @@ def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQuery
def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> SqlQueryPlanNode: # noqa: D
return node

def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlQueryPlanNode: # noqa: D
return SqlCreateTableAsNode(
sql_table=node.sql_table,
parent_node=node.parent_node.accept(self),
)


class SqlRewritingSubQueryReducer(SqlQueryPlanOptimizer):
"""Simplify queries by eliminating sub-queries when possible by rewriting expressions.
Expand Down
7 changes: 7 additions & 0 deletions metricflow/sql/optimizer/sub_query_reducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlQueryPlanOptimizer
from metricflow.sql.sql_exprs import SqlColumnReference, SqlColumnReferenceExpression
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlJoinDescription,
SqlOrderByDescription,
SqlQueryPlanNode,
Expand Down Expand Up @@ -193,6 +194,12 @@ def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQuery
def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> SqlQueryPlanNode: # noqa: D
return node

def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlQueryPlanNode: # noqa: D
return SqlCreateTableAsNode(
sql_table=node.sql_table,
parent_node=node.parent_node.accept(self),
)


class SqlSubQueryReducer(SqlQueryPlanOptimizer):
"""Simplify queries by eliminating sub-queries when possible.
Expand Down
7 changes: 7 additions & 0 deletions metricflow/sql/optimizer/table_alias_simplifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlQueryPlanOptimizer
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlJoinDescription,
SqlOrderByDescription,
SqlQueryPlanNode,
Expand Down Expand Up @@ -74,6 +75,12 @@ def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQuery
def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> SqlQueryPlanNode: # noqa: D
return node

def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlQueryPlanNode: # noqa: D
return SqlCreateTableAsNode(
sql_table=node.sql_table,
parent_node=node.parent_node.accept(self),
)


class SqlTableAliasSimplifier(SqlQueryPlanOptimizer):
"""Simplify queries by eliminating table aliases when possible.
Expand Down
22 changes: 22 additions & 0 deletions metricflow/sql/render/sql_plan_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import textwrap
from abc import ABC, abstractmethod
from dataclasses import dataclass
from string import Template
from typing import List, Optional, Sequence, Tuple

from metricflow.mf_logging.formatting import indent
Expand All @@ -15,6 +16,7 @@
from metricflow.sql.render.rendering_constants import SqlRenderingConstants
from metricflow.sql.sql_bind_parameters import SqlBindParameters
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlJoinDescription,
SqlQueryPlan,
SqlQueryPlanNode,
Expand Down Expand Up @@ -310,6 +312,26 @@ def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> Sq
bind_parameters=SqlBindParameters(),
)

def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlPlanRenderResult: # noqa: D
inner_sql_render_result = node.parent_node.accept(self)
inner_sql = inner_sql_render_result.sql
# Using a substitution since inner_sql can have multiple lines, and then dedent() wouldn't dent due to the
# short line.
sql = Template(
textwrap.dedent(
f"""\
CREATE {node.sql_table.table_type.value.upper()} {node.sql_table.sql} AS (
$inner_sql
)
"""
).rstrip()
).substitute({"inner_sql": indent(inner_sql, indent_prefix=SqlRenderingConstants.INDENT)})

return SqlPlanRenderResult(
sql=sql,
bind_parameters=inner_sql_render_result.bind_parameters,
)

@property
def expr_renderer(self) -> SqlExpressionRenderer: # noqa :D
return self.EXPR_RENDERER
6 changes: 5 additions & 1 deletion metricflow/sql/sql_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> VisitorO
def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> VisitorOutputT: # noqa: D
raise NotImplementedError

@abstractmethod
def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> VisitorOutputT: # noqa: D
raise NotImplementedError


@dataclass(frozen=True)
class SqlSelectColumn:
Expand Down Expand Up @@ -301,7 +305,7 @@ def __init__(self, sql_table: SqlTable, parent_node: SqlQueryPlanNode) -> None:

@override
def accept(self, visitor: SqlQueryPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT:
raise NotImplementedError
return visitor.visit_create_table_as_node(self)

@property
@override
Expand Down
55 changes: 54 additions & 1 deletion metricflow/test/sql/test_sql_plan_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@
SqlStringExpression,
)
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlJoinDescription,
SqlJoinType,
SqlOrderByDescription,
SqlSelectColumn,
SqlSelectStatementNode,
SqlTableFromClauseNode,
)
from metricflow.sql.sql_table import SqlTable
from metricflow.sql.sql_table import SqlTable, SqlTableType
from metricflow.test.fixtures.setup_fixtures import MetricFlowTestSessionState
from metricflow.test.sql.compare_sql_plan import assert_rendered_sql_equal

Expand Down Expand Up @@ -339,3 +340,55 @@ def test_render_limit( # noqa: D
plan_id="plan0",
sql_client=sql_client,
)


@pytest.mark.sql_engine_snapshot
def test_render_create_table_as( # noqa: D
request: FixtureRequest,
mf_test_session_state: MetricFlowTestSessionState,
sql_client: SqlClient,
) -> None:
select_node = SqlSelectStatementNode(
description="select_0",
select_columns=(
SqlSelectColumn(
expr=SqlColumnReferenceExpression(col_ref=SqlColumnReference(table_alias="a", column_name="bookings")),
column_alias="bookings",
),
),
from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")),
from_source_alias="a",
joins_descs=(),
where=None,
group_bys=(),
order_bys=(),
limit=1,
)
assert_rendered_sql_equal(
request=request,
mf_test_session_state=mf_test_session_state,
sql_plan_node=SqlCreateTableAsNode(
sql_table=SqlTable(
schema_name="schema_name",
table_name="table_name",
table_type=SqlTableType.TABLE,
),
parent_node=select_node,
),
plan_id="create_table_as",
sql_client=sql_client,
)
assert_rendered_sql_equal(
request=request,
mf_test_session_state=mf_test_session_state,
sql_plan_node=SqlCreateTableAsNode(
sql_table=SqlTable(
schema_name="schema_name",
table_name="table_name",
table_type=SqlTableType.VIEW,
),
parent_node=select_node,
),
plan_id="create_view_as",
sql_client=sql_client,
)

0 comments on commit e140250

Please sign in to comment.