Skip to content

Commit

Permalink
Add tests for CTEs in table alias simplifier.
Browse files Browse the repository at this point in the history
  • Loading branch information
plypaul committed Nov 10, 2024
1 parent cb548a8 commit 60412e5
Showing 1 changed file with 142 additions and 5 deletions.
147 changes: 142 additions & 5 deletions tests_metricflow/sql/optimizer/test_cte_table_alias_simplifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,20 @@
from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration

from metricflow.sql.optimizer.table_alias_simplifier import SqlTableAliasSimplifier
from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer, SqlQueryPlanRenderer
from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer
from metricflow.sql.sql_exprs import (
SqlColumnReference,
SqlColumnReferenceExpression,
SqlComparison,
SqlComparisonExpression,
)
from metricflow.sql.sql_plan import (
SqlCteNode,
SqlJoinDescription,
SqlSelectColumn,
SqlSelectStatementNode,
SqlTableNode,
)
from tests_metricflow.sql.compare_sql_plan import assert_default_rendered_sql_equal
from tests_metricflow.sql.optimizer.check_optimizer import assert_optimizer_result_snapshot_equal


Expand All @@ -33,14 +33,151 @@ def test_table_alias_simplification(
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
sql_plan_renderer: DefaultSqlQueryPlanRenderer,
base_select_statement: SqlSelectStatementNode,
) -> None:
"""Tests that table aliases are removed when not needed in CTEs."""

select_statement = SqlSelectStatementNode.create(
description="Top-level SELECT",
select_columns=(
SqlSelectColumn(
expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias="cte_source_0_alias", column_name="cte_source_0__col_0")
),
column_alias="top_level__col_0",
),
SqlSelectColumn(
expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias="right_source_alias", column_name="right_source__col_1")
),
column_alias="top_level__col_1",
),
),
from_source=SqlTableNode.create(sql_table=SqlTable(schema_name=None, table_name="cte_source_0")),
from_source_alias="cte_source_0_alias",
join_descs=(
SqlJoinDescription(
right_source=SqlTableNode.create(sql_table=SqlTable(schema_name=None, table_name="cte_source_1")),
right_source_alias="right_source_alias",
on_condition=SqlComparisonExpression.create(
left_expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias="cte_source_0_alias", column_name="cte_source_0__col_1")
),
comparison=SqlComparison.EQUALS,
right_expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias="right_source_alias", column_name="right_source__col_1")
),
),
join_type=SqlJoinType.INNER,
),
),
group_bys=(
SqlSelectColumn(
expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias="cte_source_0_alias", column_name="cte_source_0__col_0")
),
column_alias="top_level__col_0",
),
SqlSelectColumn(
expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias="right_source_alias", column_name="right_source__col_1")
),
column_alias="top_level__col_1",
),
),
cte_sources=(
SqlCteNode.create(
cte_alias="cte_source_0",
select_statement=SqlSelectStatementNode.create(
description="CTE source 0",
select_columns=(
SqlSelectColumn(
expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_0")
),
column_alias="cte_source_0__col_0",
),
SqlSelectColumn(
expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_1")
),
column_alias="cte_source_0__col_1",
),
),
from_source=SqlSelectStatementNode.create(
description="CTE source 0 sub-query",
select_columns=(
SqlSelectColumn(
expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_0")
),
column_alias="cte_source_0_subquery__col_0",
),
SqlSelectColumn(
expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_0")
),
column_alias="cte_source_0_subquery__col_1",
),
),
from_source=SqlTableNode.create(
sql_table=SqlTable(schema_name="test_schema", table_name="test_table")
),
from_source_alias="test_table_alias",
),
from_source_alias="cte_source_0_subquery",
),
),
SqlCteNode.create(
cte_alias="cte_source_1",
select_statement=SqlSelectStatementNode.create(
description="CTE source 1",
select_columns=(
SqlSelectColumn(
expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_0")
),
column_alias="cte_source_1__col_0",
),
SqlSelectColumn(
expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_1")
),
column_alias="cte_source_1__col_1",
),
),
from_source=SqlSelectStatementNode.create(
description="CTE source 1 sub-query",
select_columns=(
SqlSelectColumn(
expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(
table_alias="cte_source_0_alias", column_name="cte_source_0__col_0"
)
),
column_alias="cte_source_1_subquery__col_0",
),
SqlSelectColumn(
expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(
table_alias="cte_source_0_alias", column_name="cte_source_0__col_0"
)
),
column_alias="cte_source_1_subquery__col_1",
),
),
from_source=SqlTableNode.create(
sql_table=SqlTable(schema_name=None, table_name="cte_source_0")
),
from_source_alias="cte_source_0_alias",
),
from_source_alias="cte_source_1_subquery",
),
),
),
)
assert_optimizer_result_snapshot_equal(
request=request,
mf_test_configuration=mf_test_configuration,
optimizer=SqlTableAliasSimplifier(),
sql_plan_renderer=sql_plan_renderer,
select_statement=base_select_statement,
select_statement=select_statement,
)

0 comments on commit 60412e5

Please sign in to comment.