From aaf5de9772c39a1c6723091bd0b290eb96e5dc25 Mon Sep 17 00:00:00 2001 From: Paul Yang Date: Tue, 10 Dec 2024 16:07:50 -0800 Subject: [PATCH] Add optimization level test. --- tests_metricflow/engine/test_explain.py | 45 +++++++++++++++++++++++++ tests_metricflow/snapshot_utils.py | 4 +++ 2 files changed, 49 insertions(+) diff --git a/tests_metricflow/engine/test_explain.py b/tests_metricflow/engine/test_explain.py index a456d2430b..59f32dcdf1 100644 --- a/tests_metricflow/engine/test_explain.py +++ b/tests_metricflow/engine/test_explain.py @@ -1,10 +1,20 @@ from __future__ import annotations +import logging from concurrent.futures import Future, ThreadPoolExecutor from typing import Mapping, Sequence +import pytest +from _pytest.fixtures import FixtureRequest +from metricflow_semantics.mf_logging.pretty_print import mf_pformat_dict +from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration + from metricflow.engine.metricflow_engine import MetricFlowEngine, MetricFlowExplainResult, MetricFlowQueryRequest +from metricflow.sql.optimizer.optimization_levels import SqlQueryOptimizationLevel from tests_metricflow.fixtures.manifest_fixtures import MetricFlowEngineTestFixture, SemanticManifestSetup +from tests_metricflow.snapshot_utils import assert_str_snapshot_equal + +logger = logging.getLogger(__name__) def _explain_one_query(mf_engine: MetricFlowEngine) -> str: @@ -29,3 +39,38 @@ def test_concurrent_explain_consistency( results = [future.result() for future in futures] for result in results: assert result == results[0], "Expected only one unique result / results to be the same" + + +@pytest.mark.sql_engine_snapshot +@pytest.mark.duckdb_only +def test_optimization_level( + request: FixtureRequest, + mf_test_configuration: MetricFlowTestConfiguration, + mf_engine_test_fixture_mapping: Mapping[SemanticManifestSetup, MetricFlowEngineTestFixture], +) -> None: + """Tests that the results of explain reflect the SQL optimization level in the request.""" + mf_engine = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].metricflow_engine + + results = {} + for optimization_level in SqlQueryOptimizationLevel: + # Skip lower optimization levels as they are generally not used. + if optimization_level <= SqlQueryOptimizationLevel.O3: + continue + + explain_result: MetricFlowExplainResult = mf_engine.explain( + MetricFlowQueryRequest.create_with_random_request_id( + metric_names=("bookings", "views"), + group_by_names=("metric_time", "listing__country_latest"), + ) + ) + results[optimization_level.value] = explain_result.rendered_sql_without_descriptions.sql_query + + assert_str_snapshot_equal( + request=request, + mf_test_configuration=mf_test_configuration, + snapshot_id="result", + snapshot_str=mf_pformat_dict( + description=None, obj_dict=results, preserve_raw_strings=True, pad_items_with_newlines=True + ), + expectation_description=f"The result for {SqlQueryOptimizationLevel.O5} should be SQL uses a CTE.", + ) diff --git a/tests_metricflow/snapshot_utils.py b/tests_metricflow/snapshot_utils.py index 4f0f674e5d..5237fa8c94 100644 --- a/tests_metricflow/snapshot_utils.py +++ b/tests_metricflow/snapshot_utils.py @@ -141,4 +141,8 @@ def assert_str_snapshot_equal( # type: ignore[misc] snapshot_file_extension=".txt", additional_sub_directories_for_snapshots=(sql_engine.value,) if sql_engine is not None else (), expectation_description=expectation_description, + incomparable_strings_replacement_function=make_schema_replacement_function( + system_schema=mf_test_configuration.mf_system_schema, + source_schema=mf_test_configuration.mf_source_schema, + ), )