Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add improved test for the SQL optimization level in the request #1566

Merged
merged 5 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def __init__(self, left_branch_node: DataflowPlanNode) -> None: # noqa: D107
self._current_left_node: DataflowPlanNode = left_branch_node

def _log_visit_node_type(self, node: DataflowPlanNode) -> None:
logger.debug(lambda: f"Visiting {node.node_id}")
logger.debug(LazyFormat(lambda: f"Visiting {node.node_id}"))

def _log_combine_failure(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,14 @@ def visit_combine_aggregated_outputs_node( # noqa: D102
for branch_combination_result in combination_results
]

logger.debug(lambda: f"Got {len(combined_parent_branches)} branches after combination")
logger.debug(
LazyFormat(
"Possible branches combined.",
count_of_branches_before_combination=len(optimized_parent_branches),
count_of_branches_after_combination=len(combined_parent_branches),
)
)

assert len(combined_parent_branches) > 0

# If we were able to reduce the parent branches of the CombineAggregatedOutputsNode into a single one, there's
Expand Down
8 changes: 8 additions & 0 deletions metricflow/sql/optimizer/optimization_levels.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import functools
from dataclasses import dataclass
from enum import Enum
from typing import Tuple
Expand All @@ -13,6 +14,7 @@
from metricflow.sql.optimizer.table_alias_simplifier import SqlTableAliasSimplifier


@functools.total_ordering
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow, didn't know about this

class SqlQueryOptimizationLevel(Enum):
"""Defines the level of query optimization and the associated optimizers to apply."""

Expand All @@ -27,6 +29,12 @@ class SqlQueryOptimizationLevel(Enum):
def default_level() -> SqlQueryOptimizationLevel: # noqa: D102
return SqlQueryOptimizationLevel.O5

def __lt__(self, other: SqlQueryOptimizationLevel) -> bool: # noqa: D105
if not isinstance(other, SqlQueryOptimizationLevel):
return NotImplemented

return self.name < other.name


@dataclass(frozen=True)
class SqlGenerationOptionSet:
Expand Down
46 changes: 46 additions & 0 deletions tests_metricflow/engine/test_explain.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
from __future__ import annotations

import logging
from concurrent.futures import Future, ThreadPoolExecutor
from typing import Mapping, Sequence

import pytest
from _pytest.fixtures import FixtureRequest
from metricflow_semantics.mf_logging.pretty_print import mf_pformat_dict
from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration

from metricflow.engine.metricflow_engine import MetricFlowEngine, MetricFlowExplainResult, MetricFlowQueryRequest
from metricflow.sql.optimizer.optimization_levels import SqlQueryOptimizationLevel
from tests_metricflow.fixtures.manifest_fixtures import MetricFlowEngineTestFixture, SemanticManifestSetup
from tests_metricflow.snapshot_utils import assert_str_snapshot_equal

logger = logging.getLogger(__name__)


def _explain_one_query(mf_engine: MetricFlowEngine) -> str:
Expand All @@ -29,3 +39,39 @@ def test_concurrent_explain_consistency(
results = [future.result() for future in futures]
for result in results:
assert result == results[0], "Expected only one unique result / results to be the same"


@pytest.mark.sql_engine_snapshot
@pytest.mark.duckdb_only
def test_optimization_level(
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
mf_engine_test_fixture_mapping: Mapping[SemanticManifestSetup, MetricFlowEngineTestFixture],
) -> None:
"""Tests that the results of explain reflect the SQL optimization level in the request."""
mf_engine = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].metricflow_engine

results = {}
for optimization_level in SqlQueryOptimizationLevel:
# Skip lower optimization levels as they are generally not used.
if optimization_level <= SqlQueryOptimizationLevel.O3:
continue

explain_result: MetricFlowExplainResult = mf_engine.explain(
MetricFlowQueryRequest.create_with_random_request_id(
metric_names=("bookings", "views"),
group_by_names=("metric_time", "listing__country_latest"),
sql_optimization_level=optimization_level,
)
)
results[optimization_level.value] = explain_result.rendered_sql_without_descriptions.sql_query

assert_str_snapshot_equal(
request=request,
mf_test_configuration=mf_test_configuration,
snapshot_id="result",
snapshot_str=mf_pformat_dict(
description=None, obj_dict=results, preserve_raw_strings=True, pad_items_with_newlines=True
),
expectation_description=f"The result for {SqlQueryOptimizationLevel.O5} should be SQL uses a CTE.",
)
25 changes: 0 additions & 25 deletions tests_metricflow/integration/test_mf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from _pytest.fixtures import FixtureRequest
from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration

from metricflow.engine.metricflow_engine import MetricFlowExplainResult, MetricFlowQueryRequest
from metricflow.sql.optimizer.optimization_levels import SqlQueryOptimizationLevel
from tests_metricflow.integration.conftest import IntegrationTestHelpers
from tests_metricflow.snapshot_utils import assert_object_snapshot_equal

Expand All @@ -18,26 +16,3 @@ def test_list_dimensions( # noqa: D103
obj_id="result0",
obj=sorted([dim.qualified_name for dim in it_helpers.mf_engine.list_dimensions()]),
)


def test_sql_optimization_level(it_helpers: IntegrationTestHelpers) -> None:
"""Check that different SQL optimization levels produce different SQL."""
assert (
SqlQueryOptimizationLevel.default_level() != SqlQueryOptimizationLevel.O0
), "The default optimization level should be different from the lowest level."
explain_result_at_default_level: MetricFlowExplainResult = it_helpers.mf_engine.explain(
MetricFlowQueryRequest.create_with_random_request_id(
metric_names=("bookings",),
group_by_names=("metric_time",),
sql_optimization_level=SqlQueryOptimizationLevel.default_level(),
)
)
explain_result_at_level_0: MetricFlowExplainResult = it_helpers.mf_engine.explain(
MetricFlowQueryRequest.create_with_random_request_id(
metric_names=("bookings",),
group_by_names=("metric_time",),
sql_optimization_level=SqlQueryOptimizationLevel.O0,
)
)

assert explain_result_at_default_level.rendered_sql.sql_query != explain_result_at_level_0.rendered_sql.sql_query
4 changes: 4 additions & 0 deletions tests_metricflow/snapshot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,8 @@ def assert_str_snapshot_equal( # type: ignore[misc]
snapshot_file_extension=".txt",
additional_sub_directories_for_snapshots=(sql_engine.value,) if sql_engine is not None else (),
expectation_description=expectation_description,
incomparable_strings_replacement_function=make_schema_replacement_function(
system_schema=mf_test_configuration.mf_system_schema,
source_schema=mf_test_configuration.mf_source_schema,
),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
test_name: test_optimization_level
test_filename: test_explain.py
docstring:
Tests that the results of explain reflect the SQL optimization level in the request.
expectation_description:
The result for SqlQueryOptimizationLevel.O5 should be SQL uses a CTE.
---
O4:
SELECT
COALESCE(subq_8.metric_time__day, subq_17.metric_time__day) AS metric_time__day
, COALESCE(subq_8.listing__country_latest, subq_17.listing__country_latest) AS listing__country_latest
, MAX(subq_8.bookings) AS bookings
, MAX(subq_17.views) AS views
FROM (
SELECT
subq_1.metric_time__day AS metric_time__day
, listings_latest_src_10000.country AS listing__country_latest
, SUM(subq_1.bookings) AS bookings
FROM (
SELECT
DATE_TRUNC('day', ds) AS metric_time__day
, listing_id AS listing
, 1 AS bookings
FROM ***************************.fct_bookings bookings_source_src_10000
) subq_1
LEFT OUTER JOIN
***************************.dim_listings_latest listings_latest_src_10000
ON
subq_1.listing = listings_latest_src_10000.listing_id
GROUP BY
subq_1.metric_time__day
, listings_latest_src_10000.country
) subq_8
FULL OUTER JOIN (
SELECT
subq_10.metric_time__day AS metric_time__day
, listings_latest_src_10000.country AS listing__country_latest
, SUM(subq_10.views) AS views
FROM (
SELECT
DATE_TRUNC('day', ds) AS metric_time__day
, listing_id AS listing
, 1 AS views
FROM ***************************.fct_views views_source_src_10000
) subq_10
LEFT OUTER JOIN
***************************.dim_listings_latest listings_latest_src_10000
ON
subq_10.listing = listings_latest_src_10000.listing_id
GROUP BY
subq_10.metric_time__day
, listings_latest_src_10000.country
) subq_17
ON
(
subq_8.listing__country_latest = subq_17.listing__country_latest
) AND (
subq_8.metric_time__day = subq_17.metric_time__day
)
GROUP BY
COALESCE(subq_8.metric_time__day, subq_17.metric_time__day)
, COALESCE(subq_8.listing__country_latest, subq_17.listing__country_latest)

O5:
WITH sma_10014_cte AS (
SELECT
listing_id AS listing
, country AS country_latest
FROM ***************************.dim_listings_latest listings_latest_src_10000
)

SELECT
COALESCE(subq_8.metric_time__day, subq_16.metric_time__day) AS metric_time__day
, COALESCE(subq_8.listing__country_latest, subq_16.listing__country_latest) AS listing__country_latest
, MAX(subq_8.bookings) AS bookings
, MAX(subq_16.views) AS views
FROM (
SELECT
subq_1.metric_time__day AS metric_time__day
, sma_10014_cte.country_latest AS listing__country_latest
, SUM(subq_1.bookings) AS bookings
FROM (
SELECT
DATE_TRUNC('day', ds) AS metric_time__day
, listing_id AS listing
, 1 AS bookings
FROM ***************************.fct_bookings bookings_source_src_10000
) subq_1
LEFT OUTER JOIN
sma_10014_cte sma_10014_cte
ON
subq_1.listing = sma_10014_cte.listing
GROUP BY
subq_1.metric_time__day
, sma_10014_cte.country_latest
) subq_8
FULL OUTER JOIN (
SELECT
subq_10.metric_time__day AS metric_time__day
, sma_10014_cte.country_latest AS listing__country_latest
, SUM(subq_10.views) AS views
FROM (
SELECT
DATE_TRUNC('day', ds) AS metric_time__day
, listing_id AS listing
, 1 AS views
FROM ***************************.fct_views views_source_src_10000
) subq_10
LEFT OUTER JOIN
sma_10014_cte sma_10014_cte
ON
subq_10.listing = sma_10014_cte.listing
GROUP BY
subq_10.metric_time__day
, sma_10014_cte.country_latest
) subq_16
ON
(
subq_8.listing__country_latest = subq_16.listing__country_latest
) AND (
subq_8.metric_time__day = subq_16.metric_time__day
)
GROUP BY
COALESCE(subq_8.metric_time__day, subq_16.metric_time__day)
, COALESCE(subq_8.listing__country_latest, subq_16.listing__country_latest)
Loading