From aa7363a6d571ee32f58aaf92da2ec9ba89c3bcd7 Mon Sep 17 00:00:00 2001 From: Paul Yang Date: Mon, 30 Oct 2023 14:18:28 -0700 Subject: [PATCH] Enforce tests have the SQL-engine-test marker set if necessary. --- metricflow/test/fixtures/setup_fixtures.py | 10 ++++++++++ metricflow/test/snapshot_utils.py | 5 ++++- metricflow/test/sql/compare_sql_plan.py | 4 +++- 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/metricflow/test/fixtures/setup_fixtures.py b/metricflow/test/fixtures/setup_fixtures.py index 706732c6d7..8512802b01 100644 --- a/metricflow/test/fixtures/setup_fixtures.py +++ b/metricflow/test/fixtures/setup_fixtures.py @@ -80,6 +80,16 @@ def pytest_configure(config: _pytest.config.Config) -> None: ) +def check_sql_engine_snapshot_marker(request: FixtureRequest) -> None: + """Raises an error if the given test request does not have the sql-engine-test marker set.""" + mark_names = set(mark.name for mark in request.node.iter_markers(name=SQL_ENGINE_SNAPSHOT_MARKER_NAME)) + if SQL_ENGINE_SNAPSHOT_MARKER_NAME not in mark_names: + raise ValueError( + f"This test needs to be marked with '{SQL_ENGINE_SNAPSHOT_MARKER_NAME}' to keep track of all tests that " + f"generate SQL-engine specific snapshots." + ) + + @pytest.fixture(scope="session") def mf_test_session_state( # noqa: D request: FixtureRequest, diff --git a/metricflow/test/snapshot_utils.py b/metricflow/test/snapshot_utils.py index 935252d28c..5ad862960a 100644 --- a/metricflow/test/snapshot_utils.py +++ b/metricflow/test/snapshot_utils.py @@ -19,7 +19,7 @@ from metricflow.model.semantics.linkable_spec_resolver import LinkableElementSet from metricflow.protocols.sql_client import SqlClient from metricflow.specs.specs import InstanceSpecSet -from metricflow.test.fixtures.setup_fixtures import MetricFlowTestSessionState +from metricflow.test.fixtures.setup_fixtures import MetricFlowTestSessionState, check_sql_engine_snapshot_marker logger = logging.getLogger(__name__) @@ -261,6 +261,9 @@ def assert_object_snapshot_equal( # type: ignore[misc] sql_client: Optional[SqlClient] = None, ) -> None: """For tests to compare large objects, this can be used to snapshot a text representation of the object.""" + if sql_client is not None: + check_sql_engine_snapshot_marker(request) + assert_snapshot_text_equal( request=request, mf_test_session_state=mf_test_session_state, diff --git a/metricflow/test/sql/compare_sql_plan.py b/metricflow/test/sql/compare_sql_plan.py index 5808608db4..b029405cff 100644 --- a/metricflow/test/sql/compare_sql_plan.py +++ b/metricflow/test/sql/compare_sql_plan.py @@ -6,7 +6,7 @@ from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer from metricflow.sql.sql_plan import SqlQueryPlan, SqlQueryPlanNode, SqlSelectStatementNode from metricflow.sql.sql_plan_to_text import sql_query_plan_as_text -from metricflow.test.fixtures.setup_fixtures import MetricFlowTestSessionState +from metricflow.test.fixtures.setup_fixtures import MetricFlowTestSessionState, check_sql_engine_snapshot_marker from metricflow.test.snapshot_utils import ( assert_plan_snapshot_text_equal, make_schema_replacement_function, @@ -58,6 +58,8 @@ def assert_rendered_sql_from_plan_equal( sql_client: SqlClient, ) -> None: """Similar to assert_rendered_sql_equal, but takes in a SQL query plan.""" + check_sql_engine_snapshot_marker(request) + rendered_sql = sql_client.sql_query_plan_renderer.render_sql_query_plan(sql_query_plan).sql assert_plan_snapshot_text_equal(