diff --git a/metricflow/sql/optimizer/table_alias_simplifier.py b/metricflow/sql/optimizer/table_alias_simplifier.py index 19c43dc7f..b503060fd 100644 --- a/metricflow/sql/optimizer/table_alias_simplifier.py +++ b/metricflow/sql/optimizer/table_alias_simplifier.py @@ -29,7 +29,7 @@ def visit_cte_node(self, node: SqlCteNode) -> SqlQueryPlanNode: return node.with_new_select(node.select_statement.accept(self)) def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D102 - # If there is only a single parent, no table aliases are required since there's no ambiguity. + # If there is only a single source in the SELECT, no table aliases are required since there's no ambiguity. should_simplify_table_aliases = len(node.join_descs) == 0 if should_simplify_table_aliases: diff --git a/tests_metricflow/sql/optimizer/test_cte_table_alias_simplifier.py b/tests_metricflow/sql/optimizer/test_cte_table_alias_simplifier.py index e2f3da342..f2794f5ec 100644 --- a/tests_metricflow/sql/optimizer/test_cte_table_alias_simplifier.py +++ b/tests_metricflow/sql/optimizer/test_cte_table_alias_simplifier.py @@ -34,7 +34,7 @@ def test_table_alias_simplification( mf_test_configuration: MetricFlowTestConfiguration, sql_plan_renderer: DefaultSqlQueryPlanRenderer, ) -> None: - """Tests that table aliases are removed when not needed in CTEs.""" + """Tests that table aliases in the SELECT statement of a CTE are removed when not needed.""" select_statement = SqlSelectStatementNode.create( description="Top-level SELECT", select_columns=( @@ -69,20 +69,6 @@ def test_table_alias_simplification( join_type=SqlJoinType.INNER, ), ), - group_bys=( - SqlSelectColumn( - expr=SqlColumnReferenceExpression.create( - col_ref=SqlColumnReference(table_alias="cte_source_0_alias", column_name="cte_source_0__col_0") - ), - column_alias="top_level__col_0", - ), - SqlSelectColumn( - expr=SqlColumnReferenceExpression.create( - col_ref=SqlColumnReference(table_alias="right_source_alias", column_name="right_source__col_1") - ), - column_alias="top_level__col_1", - ), - ), cte_sources=( SqlCteNode.create( cte_alias="cte_source_0", @@ -181,3 +167,69 @@ def test_table_alias_simplification( sql_plan_renderer=sql_plan_renderer, select_statement=select_statement, ) + + +def test_table_alias_no_simplification( + request: FixtureRequest, + mf_test_configuration: MetricFlowTestConfiguration, + sql_plan_renderer: DefaultSqlQueryPlanRenderer, +) -> None: + """Tests that table aliases in the SELECT statement of a CTE are not removed when required.""" + select_statement = SqlSelectStatementNode.create( + description="Top-level SELECT", + select_columns=( + SqlSelectColumn( + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="cte_source_0_alias", column_name="cte_source_0__col_0") + ), + column_alias="top_level__col_0", + ), + ), + from_source=SqlTableNode.create(sql_table=SqlTable(schema_name=None, table_name="cte_source_0")), + from_source_alias="cte_source_0_alias", + cte_sources=( + SqlCteNode.create( + cte_alias="cte_source_0", + select_statement=SqlSelectStatementNode.create( + description="CTE source 0", + select_columns=( + SqlSelectColumn( + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="from_source_alias", column_name="col_0") + ), + column_alias="cte_source_0__col_0", + ), + ), + from_source=SqlTableNode.create( + sql_table=SqlTable(schema_name="test_schema", table_name="test_table_0") + ), + from_source_alias="from_source_alias", + join_descs=( + SqlJoinDescription( + right_source=SqlTableNode.create( + sql_table=SqlTable(schema_name="test_schema", table_name="test_table_1") + ), + right_source_alias="right_source_alias", + on_condition=SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="from_source_alias", column_name="col_0") + ), + comparison=SqlComparison.EQUALS, + right_expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="right_source_alias", column_name="col_0") + ), + ), + join_type=SqlJoinType.INNER, + ), + ), + ), + ), + ), + ) + assert_optimizer_result_snapshot_equal( + request=request, + mf_test_configuration=mf_test_configuration, + optimizer=SqlTableAliasSimplifier(), + sql_plan_renderer=sql_plan_renderer, + select_statement=select_statement, + )