From c0507185037e4be6d5c164c5e55f85fd04c44ce3 Mon Sep 17 00:00:00 2001 From: Paul Yang Date: Mon, 21 Oct 2024 12:40:10 -0700 Subject: [PATCH] Move `patch_id_generators_helper` to `SequentialIdGenerator` (#1469) This moves `patch_id_generators_helper` to be a method of `SequentialIdGenerator` so that it's more obvious that the patching is happening. Note that in a later PR, the patching is removed and replaced with a stack, but the placement is still helpful. --- .../metricflow_semantics/dag/sequential_id.py | 24 ++++++++++++++++++- .../test_helpers/id_helpers.py | 24 +------------------ .../fixtures/manifest_fixtures.py | 5 ++-- 3 files changed, 27 insertions(+), 26 deletions(-) diff --git a/metricflow-semantics/metricflow_semantics/dag/sequential_id.py b/metricflow-semantics/metricflow_semantics/dag/sequential_id.py index 853109ff6d..55c48530f1 100644 --- a/metricflow-semantics/metricflow_semantics/dag/sequential_id.py +++ b/metricflow-semantics/metricflow_semantics/dag/sequential_id.py @@ -1,8 +1,10 @@ from __future__ import annotations import threading +from contextlib import ExitStack, contextmanager from dataclasses import dataclass -from typing import Dict +from typing import Dict, Generator +from unittest.mock import patch from typing_extensions import override @@ -48,3 +50,23 @@ def reset(cls, default_start_value: int = 0) -> None: with cls._state_lock: cls._prefix_to_next_value = {} cls._default_start_value = default_start_value + + @classmethod + @contextmanager + def patch_id_generators_helper(cls, start_value: int) -> Generator[None, None, None]: + """Replace ID generators in IdGeneratorRegistry with one that has the given start value. + + TODO: This method will be modified in a later PR. + """ + # Create patch context managers for all ID generators in the registry with introspection magic. + patch_context_managers = [ + patch.object(SequentialIdGenerator, "_prefix_to_next_value", {}), + patch.object(SequentialIdGenerator, "_default_start_value", start_value), + ] + + # Enter the patch context for the patches above. + with ExitStack() as stack: + for patch_context_manager in patch_context_managers: + stack.enter_context(patch_context_manager) # type: ignore + # This will un-patch when done with the test. + yield None diff --git a/metricflow-semantics/metricflow_semantics/test_helpers/id_helpers.py b/metricflow-semantics/metricflow_semantics/test_helpers/id_helpers.py index 1a3a3c3a94..8621b4c305 100644 --- a/metricflow-semantics/metricflow_semantics/test_helpers/id_helpers.py +++ b/metricflow-semantics/metricflow_semantics/test_helpers/id_helpers.py @@ -1,35 +1,13 @@ from __future__ import annotations -from contextlib import ExitStack, contextmanager from dataclasses import dataclass from typing import Generator -from unittest.mock import patch import pytest from metricflow_semantics.dag.sequential_id import SequentialIdGenerator -@contextmanager -def patch_id_generators_helper(start_value: int) -> Generator[None, None, None]: - """Replace ID generators in IdGeneratorRegistry with one that has the given start value. - - TODO: This method will be modified in a later PR. - """ - # Create patch context managers for all ID generators in the registry with introspection magic. - patch_context_managers = [ - patch.object(SequentialIdGenerator, "_prefix_to_next_value", {}), - patch.object(SequentialIdGenerator, "_default_start_value", start_value), - ] - - # Enter the patch context for the patches above. - with ExitStack() as stack: - for patch_context_manager in patch_context_managers: - stack.enter_context(patch_context_manager) # type: ignore - # This will un-patch when done with the test. - yield None - - @pytest.fixture(autouse=True, scope="function") def patch_id_generators() -> Generator[None, None, None]: """Patch ID generators with a new one to get repeatability in plan outputs before every test. @@ -37,7 +15,7 @@ def patch_id_generators() -> Generator[None, None, None]: Plan outputs contain IDs, so if the IDs are not consistent from run to run, there will be diffs in the actual vs. expected outputs during a test. """ - with patch_id_generators_helper(start_value=IdNumberSpace.for_test_start().start_value): + with SequentialIdGenerator.patch_id_generators_helper(start_value=IdNumberSpace.for_test_start().start_value): yield None diff --git a/tests_metricflow/fixtures/manifest_fixtures.py b/tests_metricflow/fixtures/manifest_fixtures.py index 2efb7bca75..f50948ae7d 100644 --- a/tests_metricflow/fixtures/manifest_fixtures.py +++ b/tests_metricflow/fixtures/manifest_fixtures.py @@ -11,13 +11,14 @@ from dbt_semantic_interfaces.implementations.semantic_manifest import PydanticSemanticManifest from dbt_semantic_interfaces.protocols import SemanticModel from dbt_semantic_interfaces.test_utils import as_datetime +from metricflow_semantics.dag.sequential_id import SequentialIdGenerator from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat from metricflow_semantics.model.semantic_manifest_lookup import SemanticManifestLookup from metricflow_semantics.query.query_parser import MetricFlowQueryParser from metricflow_semantics.specs.column_assoc import ColumnAssociationResolver from metricflow_semantics.specs.dunder_column_association_resolver import DunderColumnAssociationResolver from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration -from metricflow_semantics.test_helpers.id_helpers import IdNumberSpace, patch_id_generators_helper +from metricflow_semantics.test_helpers.id_helpers import IdNumberSpace from metricflow_semantics.test_helpers.manifest_helpers import load_semantic_manifest from metricflow_semantics.test_helpers.semantic_manifest_yamls.ambiguous_resolution_manifest import ( AMBIGUOUS_RESOLUTION_MANIFEST_ANCHOR, @@ -264,7 +265,7 @@ def mf_engine_test_fixture_mapping( """Returns a mapping for all semantic manifests used in testing to the associated test fixture.""" fixture_mapping: Dict[SemanticManifestSetup, MetricFlowEngineTestFixture] = {} for semantic_manifest_setup in SemanticManifestSetup: - with patch_id_generators_helper(semantic_manifest_setup.id_number_space.start_value): + with SequentialIdGenerator.patch_id_generators_helper(semantic_manifest_setup.id_number_space.start_value): fixture_mapping[semantic_manifest_setup] = MetricFlowEngineTestFixture.from_parameters( sql_client, load_semantic_manifest(semantic_manifest_setup.yaml_file_dir, template_mapping) )