Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug with MetricFlowQueryRequest.sql_optimization_level handling #1524

Merged
merged 3 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions metricflow/engine/metricflow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,11 +409,6 @@ def __init__(
column_association_resolver=self._column_association_resolver,
semantic_manifest_lookup=self._semantic_manifest_lookup,
)
self._to_execution_plan_converter = DataflowToExecutionPlanConverter(
sql_plan_converter=self._to_sql_query_plan_converter,
sql_plan_renderer=self._sql_client.sql_query_plan_renderer,
sql_client=sql_client,
)
self._executor = SequentialPlanExecutor()

self._query_parser = query_parser or MetricFlowQueryParser(
Expand Down Expand Up @@ -539,7 +534,13 @@ def _create_execution_plan(self, mf_query_request: MetricFlowQueryRequest) -> Me
)

logger.info(LazyFormat("Building execution plan"))
convert_to_execution_plan_result = self._to_execution_plan_converter.convert_to_execution_plan(dataflow_plan)
_to_execution_plan_converter = DataflowToExecutionPlanConverter(
sql_plan_converter=self._to_sql_query_plan_converter,
sql_plan_renderer=self._sql_client.sql_query_plan_renderer,
sql_client=self._sql_client,
sql_optimization_level=mf_query_request.sql_optimization_level,
)
convert_to_execution_plan_result = _to_execution_plan_converter.convert_to_execution_plan(dataflow_plan)
return MetricFlowExplainResult(
query_spec=query_spec,
dataflow_plan=dataflow_plan,
Expand Down
5 changes: 5 additions & 0 deletions metricflow/execution/dataflow_to_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from metricflow.plan_conversion.convert_to_sql_plan import ConvertToSqlPlanResult
from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlQueryPlanConverter
from metricflow.protocols.sql_client import SqlClient
from metricflow.sql.optimizer.optimization_levels import SqlQueryOptimizationLevel
from metricflow.sql.render.sql_plan_renderer import SqlPlanRenderResult, SqlQueryPlanRenderer

logger = logging.getLogger(__name__)
Expand All @@ -53,22 +54,26 @@ def __init__(
sql_plan_converter: DataflowToSqlQueryPlanConverter,
sql_plan_renderer: SqlQueryPlanRenderer,
sql_client: SqlClient,
sql_optimization_level: SqlQueryOptimizationLevel,
) -> None:
"""Constructor.

Args:
sql_plan_converter: Converts a dataflow plan node to a SQL query plan
sql_plan_renderer: Converts a SQL query plan to SQL text
sql_client: The client to use for running queries.
sql_optimization_level: The optimization level to use for generating the SQL.
"""
self._sql_plan_converter = sql_plan_converter
self._sql_plan_renderer = sql_plan_renderer
self._sql_client = sql_client
self._optimization_level = sql_optimization_level

def _convert_to_sql_plan(self, node: DataflowPlanNode) -> ConvertToSqlPlanResult:
logger.debug(LazyFormat(lambda: f"Generating SQL query plan from {node.node_id}"))
result = self._sql_plan_converter.convert_to_sql_query_plan(
sql_engine_type=self._sql_client.sql_engine_type,
optimization_level=self._optimization_level,
dataflow_plan_node=node,
)
logger.debug(LazyFormat(lambda: f"Generated SQL query plan is:\n{result.sql_plan.structure_text()}"))
Expand Down
4 changes: 4 additions & 0 deletions metricflow/sql/optimizer/optimization_levels.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ class SqlQueryOptimizationLevel(Enum):
O4 = "O4"
O5 = "O5"

@staticmethod
def default_level() -> SqlQueryOptimizationLevel: # noqa: D102
return SqlQueryOptimizationLevel.O4


@dataclass(frozen=True)
class SqlGenerationOptionSet:
Expand Down
25 changes: 25 additions & 0 deletions tests_metricflow/integration/test_mf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from _pytest.fixtures import FixtureRequest
from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration

from metricflow.engine.metricflow_engine import MetricFlowExplainResult, MetricFlowQueryRequest
from metricflow.sql.optimizer.optimization_levels import SqlQueryOptimizationLevel
from tests_metricflow.integration.conftest import IntegrationTestHelpers
from tests_metricflow.snapshot_utils import assert_object_snapshot_equal

Expand All @@ -16,3 +18,26 @@ def test_list_dimensions( # noqa: D103
obj_id="result0",
obj=sorted([dim.qualified_name for dim in it_helpers.mf_engine.list_dimensions()]),
)


def test_sql_optimization_level(it_helpers: IntegrationTestHelpers) -> None:
"""Check that different SQL optimization levels produce different SQL."""
assert (
SqlQueryOptimizationLevel.default_level() != SqlQueryOptimizationLevel.O0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a blocker for this PR, but I'm curious - do you know why the 00 enum option uses two different 0 characters? I've never seen that before!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Borrows from the syntax from gcc: https://www.rapidtables.com/code/linux/gcc/gcc-o.html

), "The default optimization level should be different from the lowest level."
explain_result_at_default_level: MetricFlowExplainResult = it_helpers.mf_engine.explain(
MetricFlowQueryRequest.create_with_random_request_id(
metric_names=("bookings",),
group_by_names=("metric_time",),
sql_optimization_level=SqlQueryOptimizationLevel.default_level(),
)
)
explain_result_at_level_0: MetricFlowExplainResult = it_helpers.mf_engine.explain(
MetricFlowQueryRequest.create_with_random_request_id(
metric_names=("bookings",),
group_by_names=("metric_time",),
sql_optimization_level=SqlQueryOptimizationLevel.O0,
)
)

assert explain_result_at_default_level.rendered_sql.sql_query != explain_result_at_level_0.rendered_sql.sql_query
13 changes: 4 additions & 9 deletions tests_metricflow/plan_conversion/test_dataflow_to_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from metricflow.execution.dataflow_to_execution import DataflowToExecutionPlanConverter
from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlQueryPlanConverter
from metricflow.protocols.sql_client import SqlClient
from metricflow.sql.optimizer.optimization_levels import SqlQueryOptimizationLevel
from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer
from tests_metricflow.snapshot_utils import assert_execution_plan_text_equal

Expand All @@ -30,6 +31,7 @@ def make_execution_plan_converter( # noqa: D103
),
sql_plan_renderer=DefaultSqlQueryPlanRenderer(),
sql_client=sql_client,
sql_optimization_level=SqlQueryOptimizationLevel.O4,
)


Expand Down Expand Up @@ -172,17 +174,10 @@ def test_multihop_joined_plan(
)
)

to_execution_plan_converter = DataflowToExecutionPlanConverter(
sql_plan_converter=DataflowToSqlQueryPlanConverter(
column_association_resolver=DunderColumnAssociationResolver(
partitioned_multi_hop_join_semantic_manifest_lookup
),
semantic_manifest_lookup=partitioned_multi_hop_join_semantic_manifest_lookup,
),
sql_plan_renderer=DefaultSqlQueryPlanRenderer(),
to_execution_plan_converter = make_execution_plan_converter(
semantic_manifest_lookup=partitioned_multi_hop_join_semantic_manifest_lookup,
sql_client=sql_client,
)

execution_plan = to_execution_plan_converter.convert_to_execution_plan(dataflow_plan).execution_plan

assert_execution_plan_text_equal(
Expand Down
Loading