Skip to content

Commit

Permalink
Add optimization level test.
Browse files Browse the repository at this point in the history
  • Loading branch information
plypaul committed Dec 11, 2024
1 parent 505c876 commit aaf5de9
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 0 deletions.
45 changes: 45 additions & 0 deletions tests_metricflow/engine/test_explain.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
from __future__ import annotations

import logging
from concurrent.futures import Future, ThreadPoolExecutor
from typing import Mapping, Sequence

import pytest
from _pytest.fixtures import FixtureRequest
from metricflow_semantics.mf_logging.pretty_print import mf_pformat_dict
from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration

from metricflow.engine.metricflow_engine import MetricFlowEngine, MetricFlowExplainResult, MetricFlowQueryRequest
from metricflow.sql.optimizer.optimization_levels import SqlQueryOptimizationLevel
from tests_metricflow.fixtures.manifest_fixtures import MetricFlowEngineTestFixture, SemanticManifestSetup
from tests_metricflow.snapshot_utils import assert_str_snapshot_equal

logger = logging.getLogger(__name__)


def _explain_one_query(mf_engine: MetricFlowEngine) -> str:
Expand All @@ -29,3 +39,38 @@ def test_concurrent_explain_consistency(
results = [future.result() for future in futures]
for result in results:
assert result == results[0], "Expected only one unique result / results to be the same"


@pytest.mark.sql_engine_snapshot
@pytest.mark.duckdb_only
def test_optimization_level(
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
mf_engine_test_fixture_mapping: Mapping[SemanticManifestSetup, MetricFlowEngineTestFixture],
) -> None:
"""Tests that the results of explain reflect the SQL optimization level in the request."""
mf_engine = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].metricflow_engine

results = {}
for optimization_level in SqlQueryOptimizationLevel:
# Skip lower optimization levels as they are generally not used.
if optimization_level <= SqlQueryOptimizationLevel.O3:
continue

explain_result: MetricFlowExplainResult = mf_engine.explain(
MetricFlowQueryRequest.create_with_random_request_id(
metric_names=("bookings", "views"),
group_by_names=("metric_time", "listing__country_latest"),
)
)
results[optimization_level.value] = explain_result.rendered_sql_without_descriptions.sql_query

assert_str_snapshot_equal(
request=request,
mf_test_configuration=mf_test_configuration,
snapshot_id="result",
snapshot_str=mf_pformat_dict(
description=None, obj_dict=results, preserve_raw_strings=True, pad_items_with_newlines=True
),
expectation_description=f"The result for {SqlQueryOptimizationLevel.O5} should be SQL uses a CTE.",
)
4 changes: 4 additions & 0 deletions tests_metricflow/snapshot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,8 @@ def assert_str_snapshot_equal( # type: ignore[misc]
snapshot_file_extension=".txt",
additional_sub_directories_for_snapshots=(sql_engine.value,) if sql_engine is not None else (),
expectation_description=expectation_description,
incomparable_strings_replacement_function=make_schema_replacement_function(
system_schema=mf_test_configuration.mf_system_schema,
source_schema=mf_test_configuration.mf_source_schema,
),
)

0 comments on commit aaf5de9

Please sign in to comment.