Skip to content

Commit

Permalink
Add required alias case.
Browse files Browse the repository at this point in the history
  • Loading branch information
plypaul committed Nov 13, 2024
1 parent fb3b1d3 commit 8675e15
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 16 deletions.
2 changes: 1 addition & 1 deletion metricflow/sql/optimizer/table_alias_simplifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def visit_cte_node(self, node: SqlCteNode) -> SqlQueryPlanNode:
return node.with_new_select(node.select_statement.accept(self))

def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D102
# If there is only a single parent, no table aliases are required since there's no ambiguity.
# If there is only a single source in the SELECT, no table aliases are required since there's no ambiguity.
should_simplify_table_aliases = len(node.join_descs) == 0

if should_simplify_table_aliases:
Expand Down
82 changes: 67 additions & 15 deletions tests_metricflow/sql/optimizer/test_cte_table_alias_simplifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_table_alias_simplification(
mf_test_configuration: MetricFlowTestConfiguration,
sql_plan_renderer: DefaultSqlQueryPlanRenderer,
) -> None:
"""Tests that table aliases are removed when not needed in CTEs."""
"""Tests that table aliases in the SELECT statement of a CTE are removed when not needed."""
select_statement = SqlSelectStatementNode.create(
description="Top-level SELECT",
select_columns=(
Expand Down Expand Up @@ -69,20 +69,6 @@ def test_table_alias_simplification(
join_type=SqlJoinType.INNER,
),
),
group_bys=(
SqlSelectColumn(
expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias="cte_source_0_alias", column_name="cte_source_0__col_0")
),
column_alias="top_level__col_0",
),
SqlSelectColumn(
expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias="right_source_alias", column_name="right_source__col_1")
),
column_alias="top_level__col_1",
),
),
cte_sources=(
SqlCteNode.create(
cte_alias="cte_source_0",
Expand Down Expand Up @@ -181,3 +167,69 @@ def test_table_alias_simplification(
sql_plan_renderer=sql_plan_renderer,
select_statement=select_statement,
)


def test_table_alias_no_simplification(
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
sql_plan_renderer: DefaultSqlQueryPlanRenderer,
) -> None:
"""Tests that table aliases in the SELECT statement of a CTE are not removed when required."""
select_statement = SqlSelectStatementNode.create(
description="Top-level SELECT",
select_columns=(
SqlSelectColumn(
expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias="cte_source_0_alias", column_name="cte_source_0__col_0")
),
column_alias="top_level__col_0",
),
),
from_source=SqlTableNode.create(sql_table=SqlTable(schema_name=None, table_name="cte_source_0")),
from_source_alias="cte_source_0_alias",
cte_sources=(
SqlCteNode.create(
cte_alias="cte_source_0",
select_statement=SqlSelectStatementNode.create(
description="CTE source 0",
select_columns=(
SqlSelectColumn(
expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias="from_source_alias", column_name="col_0")
),
column_alias="cte_source_0__col_0",
),
),
from_source=SqlTableNode.create(
sql_table=SqlTable(schema_name="test_schema", table_name="test_table_0")
),
from_source_alias="from_source_alias",
join_descs=(
SqlJoinDescription(
right_source=SqlTableNode.create(
sql_table=SqlTable(schema_name="test_schema", table_name="test_table_1")
),
right_source_alias="right_source_alias",
on_condition=SqlComparisonExpression.create(
left_expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias="from_source_alias", column_name="col_0")
),
comparison=SqlComparison.EQUALS,
right_expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias="right_source_alias", column_name="col_0")
),
),
join_type=SqlJoinType.INNER,
),
),
),
),
),
)
assert_optimizer_result_snapshot_equal(
request=request,
mf_test_configuration=mf_test_configuration,
optimizer=SqlTableAliasSimplifier(),
sql_plan_renderer=sql_plan_renderer,
select_statement=select_statement,
)

0 comments on commit 8675e15

Please sign in to comment.