Skip to content

Commit

Permalink
/* PR_START p--cte 10 */ Support CTEs in table alias simplifier.
Browse files Browse the repository at this point in the history
  • Loading branch information
plypaul committed Nov 10, 2024
1 parent e870245 commit cb548a8
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 3 deletions.
11 changes: 9 additions & 2 deletions metricflow/sql/optimizer/table_alias_simplifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ class SqlTableAliasSimplifierVisitor(SqlQueryPlanNodeVisitor[SqlQueryPlanNode]):

@override
def visit_cte_node(self, node: SqlCteNode) -> SqlQueryPlanNode:
raise NotImplementedError
return node.with_new_select(node.select_statement.accept(self))

def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D102
# If there is only a single parent, no table aliases are required since there's no ambiguity.
should_simplify_table_aliases = len(node.parent_nodes) <= 1
should_simplify_table_aliases = len(node.join_descs) == 0

if should_simplify_table_aliases:
return SqlSelectStatementNode.create(
Expand All @@ -41,6 +41,10 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
),
from_source=node.from_source.accept(self),
from_source_alias=node.from_source_alias,
cte_sources=tuple(
cte_source.with_new_select(cte_source.select_statement.accept(self))
for cte_source in node.cte_sources
),
group_bys=tuple(
SqlSelectColumn(expr=x.expr.rewrite(should_render_table_alias=False), column_alias=x.column_alias)
for x in node.group_bys
Expand All @@ -59,6 +63,9 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
select_columns=node.select_columns,
from_source=node.from_source.accept(self),
from_source_alias=node.from_source_alias,
cte_sources=tuple(
cte_source.with_new_select(cte_source.select_statement.accept(self)) for cte_source in node.cte_sources
),
join_descs=tuple(
SqlJoinDescription(
right_source=x.right_source.accept(self),
Expand Down
46 changes: 46 additions & 0 deletions tests_metricflow/sql/optimizer/test_cte_table_alias_simplifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from __future__ import annotations

import pytest
from _pytest.fixtures import FixtureRequest
from metricflow_semantics.sql.sql_join_type import SqlJoinType
from metricflow_semantics.sql.sql_table import SqlTable
from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration

from metricflow.sql.optimizer.table_alias_simplifier import SqlTableAliasSimplifier
from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer, SqlQueryPlanRenderer
from metricflow.sql.sql_exprs import (
SqlColumnReference,
SqlColumnReferenceExpression,
SqlComparison,
SqlComparisonExpression,
)
from metricflow.sql.sql_plan import (
SqlJoinDescription,
SqlSelectColumn,
SqlSelectStatementNode,
SqlTableNode,
)
from tests_metricflow.sql.compare_sql_plan import assert_default_rendered_sql_equal
from tests_metricflow.sql.optimizer.check_optimizer import assert_optimizer_result_snapshot_equal


@pytest.fixture
def sql_plan_renderer() -> DefaultSqlQueryPlanRenderer: # noqa: D103
return DefaultSqlQueryPlanRenderer()


def test_table_alias_simplification(
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
sql_plan_renderer: DefaultSqlQueryPlanRenderer,
base_select_statement: SqlSelectStatementNode,
) -> None:
"""Tests that table aliases are removed when not needed in CTEs."""

assert_optimizer_result_snapshot_equal(
request=request,
mf_test_configuration=mf_test_configuration,
optimizer=SqlTableAliasSimplifier(),
sql_plan_renderer=sql_plan_renderer,
select_statement=base_select_statement,
)
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_table_alias_simplification(
mf_test_configuration: MetricFlowTestConfiguration,
base_select_statement: SqlSelectStatementNode,
) -> None:
"""Tests a case where no pruning should occur."""
"""Tests that table aliases are removed when not needed."""
assert_default_rendered_sql_equal(
request=request,
mf_test_configuration=mf_test_configuration,
Expand Down

0 comments on commit cb548a8

Please sign in to comment.