From 05e574783c0d43af21b11bac0291ff0a3aa8c090 Mon Sep 17 00:00:00 2001 From: Paul Yang Date: Mon, 4 Nov 2024 08:45:04 -0800 Subject: [PATCH] Add tests for column pruning CTEs. --- .../str/test_multi_child_pruning__result.txt | 50 +++ .../str/test_nested_pruning__result.txt | 44 +++ .../str/test_no_pruning__result.txt | 28 ++ .../str/test_simple_pruning__result.txt | 29 ++ .../sql/optimizer/check_optimizer.py | 61 ++++ .../sql/optimizer/test_cte_column_pruner.py | 330 ++++++++++++++++++ 6 files changed, 542 insertions(+) create mode 100644 tests_metricflow/snapshots/test_cte_column_pruner.py/str/test_multi_child_pruning__result.txt create mode 100644 tests_metricflow/snapshots/test_cte_column_pruner.py/str/test_nested_pruning__result.txt create mode 100644 tests_metricflow/snapshots/test_cte_column_pruner.py/str/test_no_pruning__result.txt create mode 100644 tests_metricflow/snapshots/test_cte_column_pruner.py/str/test_simple_pruning__result.txt create mode 100644 tests_metricflow/sql/optimizer/check_optimizer.py create mode 100644 tests_metricflow/sql/optimizer/test_cte_column_pruner.py diff --git a/tests_metricflow/snapshots/test_cte_column_pruner.py/str/test_multi_child_pruning__result.txt b/tests_metricflow/snapshots/test_cte_column_pruner.py/str/test_multi_child_pruning__result.txt new file mode 100644 index 000000000..2f9356034 --- /dev/null +++ b/tests_metricflow/snapshots/test_cte_column_pruner.py/str/test_multi_child_pruning__result.txt @@ -0,0 +1,50 @@ +optimizer: + SqlColumnPrunerOptimizer + +sql_before_optimizing: + -- Top-level SELECT + WITH cte_source_0 AS ( + -- CTE source 0 + SELECT + test_table_alias.col_0 AS cte_source_0__col_0 + , test_table_alias.col_1 AS cte_source_0__col_1 + , test_table_alias.col_1 AS cte_source_0__col_2 + FROM test_schema.test_table test_table_alias + ) + + SELECT + cte_source_0_alias.cte_source_0__col_0 AS top_level__col_0 + , right_source_alias.right_source__col_1 AS top_level__col_1 + FROM cte_source_0 cte_source_0_alias + INNER JOIN ( + -- Joined sub-query + SELECT + cte_source_0_alias_in_right_source.cte_source_0__col_0 AS right_source__col_0 + , cte_source_0_alias_in_right_source.cte_source_0__col_1 AS right_source__col_1 + FROM cte_source_0 cte_source_0_alias_in_right_source + ) right_source_alias + ON + cte_source_0_alias.cte_source_0__col_1 = right_source_alias.right_source__col_1 + +sql_after_optimizing: + -- Top-level SELECT + WITH cte_source_0 AS ( + -- CTE source 0 + SELECT + test_table_alias.col_0 AS cte_source_0__col_0 + , test_table_alias.col_1 AS cte_source_0__col_1 + FROM test_schema.test_table test_table_alias + ) + + SELECT + cte_source_0_alias.cte_source_0__col_0 AS top_level__col_0 + , right_source_alias.right_source__col_1 AS top_level__col_1 + FROM cte_source_0 cte_source_0_alias + INNER JOIN ( + -- Joined sub-query + SELECT + cte_source_0_alias_in_right_source.cte_source_0__col_1 AS right_source__col_1 + FROM cte_source_0 cte_source_0_alias_in_right_source + ) right_source_alias + ON + cte_source_0_alias.cte_source_0__col_1 = right_source_alias.right_source__col_1 diff --git a/tests_metricflow/snapshots/test_cte_column_pruner.py/str/test_nested_pruning__result.txt b/tests_metricflow/snapshots/test_cte_column_pruner.py/str/test_nested_pruning__result.txt new file mode 100644 index 000000000..ec3325a15 --- /dev/null +++ b/tests_metricflow/snapshots/test_cte_column_pruner.py/str/test_nested_pruning__result.txt @@ -0,0 +1,44 @@ +optimizer: + SqlColumnPrunerOptimizer + +sql_before_optimizing: + -- Top-level SELECT + WITH cte_source_0 AS ( + -- CTE source 0 + SELECT + test_table_alias.col_0 AS cte_source_0__col_0 + , test_table_alias.col_1 AS cte_source_0__col_1 + FROM test_schema.test_table test_table_alias + ) + + , cte_source_1 AS ( + -- CTE source 1 + SELECT + cte_source_0_alias.cte_source_0__col_0 AS cte_source_1__col_0 + , cte_source_0_alias.cte_source_0__col_0 AS cte_source_1__col_1 + FROM cte_source_0 cte_source_0_alias + ) + + SELECT + cte_source_1_alias.cte_source_1__col_0 AS top_level__col_0 + FROM cte_source_1 cte_source_1_alias + +sql_after_optimizing: + -- Top-level SELECT + WITH cte_source_0 AS ( + -- CTE source 0 + SELECT + test_table_alias.col_0 AS cte_source_0__col_0 + FROM test_schema.test_table test_table_alias + ) + + , cte_source_1 AS ( + -- CTE source 1 + SELECT + cte_source_0_alias.cte_source_0__col_0 AS cte_source_1__col_0 + FROM cte_source_0 cte_source_0_alias + ) + + SELECT + cte_source_1_alias.cte_source_1__col_0 AS top_level__col_0 + FROM cte_source_1 cte_source_1_alias diff --git a/tests_metricflow/snapshots/test_cte_column_pruner.py/str/test_no_pruning__result.txt b/tests_metricflow/snapshots/test_cte_column_pruner.py/str/test_no_pruning__result.txt new file mode 100644 index 000000000..1d4c73378 --- /dev/null +++ b/tests_metricflow/snapshots/test_cte_column_pruner.py/str/test_no_pruning__result.txt @@ -0,0 +1,28 @@ +optimizer: + SqlColumnPrunerOptimizer + +sql_before_optimizing: + -- Top-level SELECT + WITH cte_source_0 AS ( + -- CTE source 0 + SELECT + test_table_alias.col_0 AS cte_source_0__col_0 + FROM test_schema.test_table test_table_alias + ) + + SELECT + cte_source_0_alias.cte_source_0__col_0 AS top_level__col_0 + FROM cte_source_0 cte_source_0_alias + +sql_after_optimizing: + -- Top-level SELECT + WITH cte_source_0 AS ( + -- CTE source 0 + SELECT + test_table_alias.col_0 AS cte_source_0__col_0 + FROM test_schema.test_table test_table_alias + ) + + SELECT + cte_source_0_alias.cte_source_0__col_0 AS top_level__col_0 + FROM cte_source_0 cte_source_0_alias diff --git a/tests_metricflow/snapshots/test_cte_column_pruner.py/str/test_simple_pruning__result.txt b/tests_metricflow/snapshots/test_cte_column_pruner.py/str/test_simple_pruning__result.txt new file mode 100644 index 000000000..a27c7df2b --- /dev/null +++ b/tests_metricflow/snapshots/test_cte_column_pruner.py/str/test_simple_pruning__result.txt @@ -0,0 +1,29 @@ +optimizer: + SqlColumnPrunerOptimizer + +sql_before_optimizing: + -- Top-level SELECT + WITH cte_source_0 AS ( + -- CTE source 0 + SELECT + test_table_alias.col_0 AS cte_source_0__col_0 + , test_table_alias.col_0 AS cte_source_0__col_1 + FROM test_schema.test_table test_table_alias + ) + + SELECT + cte_source_0_alias.cte_source_0__col_0 AS top_level__col_0 + FROM cte_source_0 cte_source_0_alias + +sql_after_optimizing: + -- Top-level SELECT + WITH cte_source_0 AS ( + -- CTE source 0 + SELECT + test_table_alias.col_0 AS cte_source_0__col_0 + FROM test_schema.test_table test_table_alias + ) + + SELECT + cte_source_0_alias.cte_source_0__col_0 AS top_level__col_0 + FROM cte_source_0 cte_source_0_alias diff --git a/tests_metricflow/sql/optimizer/check_optimizer.py b/tests_metricflow/sql/optimizer/check_optimizer.py new file mode 100644 index 000000000..7ba642f3e --- /dev/null +++ b/tests_metricflow/sql/optimizer/check_optimizer.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +import logging + +from _pytest.fixtures import FixtureRequest +from metricflow_semantics.mf_logging.formatting import indent +from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat +from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration +from metricflow_semantics.test_helpers.snapshot_helpers import assert_str_snapshot_equal + +from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlQueryPlanOptimizer +from metricflow.sql.render.sql_plan_renderer import SqlQueryPlanRenderer +from metricflow.sql.sql_plan import SqlQueryPlan, SqlSelectStatementNode + +logger = logging.getLogger(__name__) + + +def assert_optimizer_result_snapshot_equal( + request: FixtureRequest, + mf_test_configuration: MetricFlowTestConfiguration, + optimizer: SqlQueryPlanOptimizer, + sql_plan_renderer: SqlQueryPlanRenderer, + select_statement: SqlSelectStatementNode, +) -> None: + """Helper to assert that the SQL snapshot of the optimizer result is the same as the stored one.""" + sql_before_optimizing = sql_plan_renderer.render_sql_query_plan(SqlQueryPlan(select_statement)).sql + logger.debug( + LazyFormat( + "Optimizing SELECT statement", + select_statement=select_statement.structure_text(), + sql_before_optimizing=sql_before_optimizing, + ) + ) + + column_pruned_select_node = optimizer.optimize(select_statement) + sql_after_optimizing = sql_plan_renderer.render_sql_query_plan(SqlQueryPlan(column_pruned_select_node)).sql + logger.debug( + LazyFormat( + "Optimized SQL", + sql_before_optimizing=sql_before_optimizing, + sql_after_optimizing=sql_after_optimizing, + ) + ) + snapshot_str = "\n".join( + [ + "optimizer:", + indent(optimizer.__class__.__name__), + "", + "sql_before_optimizing:", + indent(sql_before_optimizing), + "", + "sql_after_optimizing:", + indent(sql_after_optimizing), + ] + ) + assert_str_snapshot_equal( + request=request, + mf_test_configuration=mf_test_configuration, + snapshot_id="result", + snapshot_str=snapshot_str, + ) diff --git a/tests_metricflow/sql/optimizer/test_cte_column_pruner.py b/tests_metricflow/sql/optimizer/test_cte_column_pruner.py new file mode 100644 index 000000000..4b19d46e6 --- /dev/null +++ b/tests_metricflow/sql/optimizer/test_cte_column_pruner.py @@ -0,0 +1,330 @@ +from __future__ import annotations + +import logging + +import pytest +from _pytest.fixtures import FixtureRequest +from metricflow_semantics.sql.sql_join_type import SqlJoinType +from metricflow_semantics.sql.sql_table import SqlTable +from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration + +from metricflow.sql.optimizer.column_pruner import SqlColumnPrunerOptimizer +from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer, SqlQueryPlanRenderer +from metricflow.sql.sql_exprs import ( + SqlColumnReference, + SqlColumnReferenceExpression, + SqlComparison, + SqlComparisonExpression, +) +from metricflow.sql.sql_plan import ( + SqlCteNode, + SqlJoinDescription, + SqlSelectColumn, + SqlSelectStatementNode, + SqlTableNode, +) +from tests_metricflow.sql.optimizer.check_optimizer import assert_optimizer_result_snapshot_equal + +logger = logging.getLogger(__name__) + + +@pytest.fixture +def column_pruner() -> SqlColumnPrunerOptimizer: # noqa: D103 + return SqlColumnPrunerOptimizer() + + +@pytest.fixture +def sql_plan_renderer() -> SqlQueryPlanRenderer: # noqa: D103 + return DefaultSqlQueryPlanRenderer() + + +def test_no_pruning( + request: FixtureRequest, + mf_test_configuration: MetricFlowTestConfiguration, + column_pruner: SqlColumnPrunerOptimizer, + sql_plan_renderer: DefaultSqlQueryPlanRenderer, +) -> None: + """Tests a case where no pruning should occur for a CTE.""" + select_statement = SqlSelectStatementNode.create( + description="Top-level SELECT", + select_columns=( + SqlSelectColumn( + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="cte_source_0_alias", column_name="cte_source_0__col_0") + ), + column_alias="top_level__col_0", + ), + ), + from_source=SqlTableNode.create(sql_table=SqlTable(schema_name=None, table_name="cte_source_0")), + from_source_alias="cte_source_0_alias", + cte_sources=( + SqlCteNode.create( + cte_alias="cte_source_0", + select_statement=SqlSelectStatementNode.create( + description="CTE source 0", + select_columns=( + SqlSelectColumn( + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_0") + ), + column_alias="cte_source_0__col_0", + ), + ), + from_source=SqlTableNode.create( + sql_table=SqlTable(schema_name="test_schema", table_name="test_table") + ), + from_source_alias="test_table_alias", + ), + ), + ), + ) + assert_optimizer_result_snapshot_equal( + request=request, + mf_test_configuration=mf_test_configuration, + optimizer=column_pruner, + sql_plan_renderer=sql_plan_renderer, + select_statement=select_statement, + ) + + +def test_simple_pruning( + request: FixtureRequest, + mf_test_configuration: MetricFlowTestConfiguration, + column_pruner: SqlColumnPrunerOptimizer, + sql_plan_renderer: DefaultSqlQueryPlanRenderer, +) -> None: + """Tests the simplest case of pruning a CTE where a query depends on a CTE, and that CTE is pruned.""" + select_statement = SqlSelectStatementNode.create( + description="Top-level SELECT", + select_columns=( + SqlSelectColumn( + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="cte_source_0_alias", column_name="cte_source_0__col_0") + ), + column_alias="top_level__col_0", + ), + ), + from_source=SqlTableNode.create(sql_table=SqlTable(schema_name=None, table_name="cte_source_0")), + from_source_alias="cte_source_0_alias", + cte_sources=( + SqlCteNode.create( + cte_alias="cte_source_0", + select_statement=SqlSelectStatementNode.create( + description="CTE source 0", + select_columns=( + SqlSelectColumn( + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_0") + ), + column_alias="cte_source_0__col_0", + ), + SqlSelectColumn( + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_0") + ), + column_alias="cte_source_0__col_1", + ), + ), + from_source=SqlTableNode.create( + sql_table=SqlTable(schema_name="test_schema", table_name="test_table") + ), + from_source_alias="test_table_alias", + ), + ), + ), + ) + assert_optimizer_result_snapshot_equal( + request=request, + mf_test_configuration=mf_test_configuration, + optimizer=column_pruner, + sql_plan_renderer=sql_plan_renderer, + select_statement=select_statement, + ) + + +def test_nested_pruning( + request: FixtureRequest, + mf_test_configuration: MetricFlowTestConfiguration, + column_pruner: SqlColumnPrunerOptimizer, + sql_plan_renderer: DefaultSqlQueryPlanRenderer, +) -> None: + """Tests the case of pruning a CTE where a query depends on a CTE, and that CTE depends on another CTE.""" + select_statement = SqlSelectStatementNode.create( + description="Top-level SELECT", + select_columns=( + SqlSelectColumn( + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="cte_source_1_alias", column_name="cte_source_1__col_0") + ), + column_alias="top_level__col_0", + ), + ), + from_source=SqlTableNode.create(sql_table=SqlTable(schema_name=None, table_name="cte_source_1")), + from_source_alias="cte_source_1_alias", + cte_sources=( + SqlCteNode.create( + cte_alias="cte_source_0", + select_statement=SqlSelectStatementNode.create( + description="CTE source 0", + select_columns=( + SqlSelectColumn( + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_0") + ), + column_alias="cte_source_0__col_0", + ), + SqlSelectColumn( + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_1") + ), + column_alias="cte_source_0__col_1", + ), + ), + from_source=SqlTableNode.create( + sql_table=SqlTable(schema_name="test_schema", table_name="test_table") + ), + from_source_alias="test_table_alias", + ), + ), + SqlCteNode.create( + cte_alias="cte_source_1", + select_statement=SqlSelectStatementNode.create( + description="CTE source 1", + select_columns=( + SqlSelectColumn( + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference( + table_alias="cte_source_0_alias", column_name="cte_source_0__col_0" + ) + ), + column_alias="cte_source_1__col_0", + ), + SqlSelectColumn( + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference( + table_alias="cte_source_0_alias", column_name="cte_source_0__col_0" + ) + ), + column_alias="cte_source_1__col_1", + ), + ), + from_source=SqlTableNode.create(sql_table=SqlTable(schema_name=None, table_name="cte_source_0")), + from_source_alias="cte_source_0_alias", + ), + ), + ), + ) + + assert_optimizer_result_snapshot_equal( + request=request, + mf_test_configuration=mf_test_configuration, + optimizer=column_pruner, + sql_plan_renderer=sql_plan_renderer, + select_statement=select_statement, + ) + + +def test_multi_child_pruning( + request: FixtureRequest, + mf_test_configuration: MetricFlowTestConfiguration, + column_pruner: SqlColumnPrunerOptimizer, + sql_plan_renderer: DefaultSqlQueryPlanRenderer, +) -> None: + """Tests the case of pruning a CTE where difference sources depend on the same CTE.""" + select_statement = SqlSelectStatementNode.create( + description="Top-level SELECT", + select_columns=( + SqlSelectColumn( + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="cte_source_0_alias", column_name="cte_source_0__col_0") + ), + column_alias="top_level__col_0", + ), + SqlSelectColumn( + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="right_source_alias", column_name="right_source__col_1") + ), + column_alias="top_level__col_1", + ), + ), + from_source=SqlTableNode.create(sql_table=SqlTable(schema_name=None, table_name="cte_source_0")), + from_source_alias="cte_source_0_alias", + join_descs=( + SqlJoinDescription( + right_source=SqlSelectStatementNode.create( + description="Joined sub-query", + select_columns=( + SqlSelectColumn( + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference( + table_alias="cte_source_0_alias_in_right_source", column_name="cte_source_0__col_0" + ) + ), + column_alias="right_source__col_0", + ), + SqlSelectColumn( + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference( + table_alias="cte_source_0_alias_in_right_source", column_name="cte_source_0__col_1" + ) + ), + column_alias="right_source__col_1", + ), + ), + from_source=SqlTableNode.create(sql_table=SqlTable(schema_name=None, table_name="cte_source_0")), + from_source_alias="cte_source_0_alias_in_right_source", + ), + right_source_alias="right_source_alias", + on_condition=SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="cte_source_0_alias", column_name="cte_source_0__col_1") + ), + comparison=SqlComparison.EQUALS, + right_expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="right_source_alias", column_name="right_source__col_1") + ), + ), + join_type=SqlJoinType.INNER, + ), + ), + cte_sources=( + SqlCteNode.create( + cte_alias="cte_source_0", + select_statement=SqlSelectStatementNode.create( + description="CTE source 0", + select_columns=( + SqlSelectColumn( + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_0") + ), + column_alias="cte_source_0__col_0", + ), + SqlSelectColumn( + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_1") + ), + column_alias="cte_source_0__col_1", + ), + SqlSelectColumn( + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_1") + ), + column_alias="cte_source_0__col_2", + ), + ), + from_source=SqlTableNode.create( + sql_table=SqlTable(schema_name="test_schema", table_name="test_table") + ), + from_source_alias="test_table_alias", + ), + ), + ), + ) + + assert_optimizer_result_snapshot_equal( + request=request, + mf_test_configuration=mf_test_configuration, + optimizer=column_pruner, + sql_plan_renderer=sql_plan_renderer, + select_statement=select_statement, + )