From 87cdee67b99f9c4e557bbfaa84bcd7f2d6a04cd0 Mon Sep 17 00:00:00 2001 From: Paul Yang Date: Sun, 20 Oct 2024 13:32:38 -0700 Subject: [PATCH] /* PR_START p--thread-local-id 02 */ Move `patch_id_generators_helper` to `SequentialIdGenerator`. --- .../metricflow_semantics/dag/sequential_id.py | 24 ++++++++++++++++++- .../test_helpers/id_helpers.py | 24 +------------------ .../fixtures/manifest_fixtures.py | 5 ++-- 3 files changed, 27 insertions(+), 26 deletions(-) diff --git a/metricflow-semantics/metricflow_semantics/dag/sequential_id.py b/metricflow-semantics/metricflow_semantics/dag/sequential_id.py index 853109ff6d..99ab7100cc 100644 --- a/metricflow-semantics/metricflow_semantics/dag/sequential_id.py +++ b/metricflow-semantics/metricflow_semantics/dag/sequential_id.py @@ -1,8 +1,10 @@ from __future__ import annotations import threading +from contextlib import ExitStack, contextmanager from dataclasses import dataclass -from typing import Dict +from typing import Dict, Generator +from unittest.mock import patch from typing_extensions import override @@ -48,3 +50,23 @@ def reset(cls, default_start_value: int = 0) -> None: with cls._state_lock: cls._prefix_to_next_value = {} cls._default_start_value = default_start_value + + @contextmanager + @staticmethod + def patch_id_generators_helper(start_value: int) -> Generator[None, None, None]: + """Replace ID generators in IdGeneratorRegistry with one that has the given start value. + + TODO: This method will be modified in a later PR. + """ + # Create patch context managers for all ID generators in the registry with introspection magic. + patch_context_managers = [ + patch.object(SequentialIdGenerator, "_prefix_to_next_value", {}), + patch.object(SequentialIdGenerator, "_default_start_value", start_value), + ] + + # Enter the patch context for the patches above. + with ExitStack() as stack: + for patch_context_manager in patch_context_managers: + stack.enter_context(patch_context_manager) # type: ignore + # This will un-patch when done with the test. + yield None diff --git a/metricflow-semantics/metricflow_semantics/test_helpers/id_helpers.py b/metricflow-semantics/metricflow_semantics/test_helpers/id_helpers.py index 1a3a3c3a94..8621b4c305 100644 --- a/metricflow-semantics/metricflow_semantics/test_helpers/id_helpers.py +++ b/metricflow-semantics/metricflow_semantics/test_helpers/id_helpers.py @@ -1,35 +1,13 @@ from __future__ import annotations -from contextlib import ExitStack, contextmanager from dataclasses import dataclass from typing import Generator -from unittest.mock import patch import pytest from metricflow_semantics.dag.sequential_id import SequentialIdGenerator -@contextmanager -def patch_id_generators_helper(start_value: int) -> Generator[None, None, None]: - """Replace ID generators in IdGeneratorRegistry with one that has the given start value. - - TODO: This method will be modified in a later PR. - """ - # Create patch context managers for all ID generators in the registry with introspection magic. - patch_context_managers = [ - patch.object(SequentialIdGenerator, "_prefix_to_next_value", {}), - patch.object(SequentialIdGenerator, "_default_start_value", start_value), - ] - - # Enter the patch context for the patches above. - with ExitStack() as stack: - for patch_context_manager in patch_context_managers: - stack.enter_context(patch_context_manager) # type: ignore - # This will un-patch when done with the test. - yield None - - @pytest.fixture(autouse=True, scope="function") def patch_id_generators() -> Generator[None, None, None]: """Patch ID generators with a new one to get repeatability in plan outputs before every test. @@ -37,7 +15,7 @@ def patch_id_generators() -> Generator[None, None, None]: Plan outputs contain IDs, so if the IDs are not consistent from run to run, there will be diffs in the actual vs. expected outputs during a test. """ - with patch_id_generators_helper(start_value=IdNumberSpace.for_test_start().start_value): + with SequentialIdGenerator.patch_id_generators_helper(start_value=IdNumberSpace.for_test_start().start_value): yield None diff --git a/tests_metricflow/fixtures/manifest_fixtures.py b/tests_metricflow/fixtures/manifest_fixtures.py index 2efb7bca75..f50948ae7d 100644 --- a/tests_metricflow/fixtures/manifest_fixtures.py +++ b/tests_metricflow/fixtures/manifest_fixtures.py @@ -11,13 +11,14 @@ from dbt_semantic_interfaces.implementations.semantic_manifest import PydanticSemanticManifest from dbt_semantic_interfaces.protocols import SemanticModel from dbt_semantic_interfaces.test_utils import as_datetime +from metricflow_semantics.dag.sequential_id import SequentialIdGenerator from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat from metricflow_semantics.model.semantic_manifest_lookup import SemanticManifestLookup from metricflow_semantics.query.query_parser import MetricFlowQueryParser from metricflow_semantics.specs.column_assoc import ColumnAssociationResolver from metricflow_semantics.specs.dunder_column_association_resolver import DunderColumnAssociationResolver from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration -from metricflow_semantics.test_helpers.id_helpers import IdNumberSpace, patch_id_generators_helper +from metricflow_semantics.test_helpers.id_helpers import IdNumberSpace from metricflow_semantics.test_helpers.manifest_helpers import load_semantic_manifest from metricflow_semantics.test_helpers.semantic_manifest_yamls.ambiguous_resolution_manifest import ( AMBIGUOUS_RESOLUTION_MANIFEST_ANCHOR, @@ -264,7 +265,7 @@ def mf_engine_test_fixture_mapping( """Returns a mapping for all semantic manifests used in testing to the associated test fixture.""" fixture_mapping: Dict[SemanticManifestSetup, MetricFlowEngineTestFixture] = {} for semantic_manifest_setup in SemanticManifestSetup: - with patch_id_generators_helper(semantic_manifest_setup.id_number_space.start_value): + with SequentialIdGenerator.patch_id_generators_helper(semantic_manifest_setup.id_number_space.start_value): fixture_mapping[semantic_manifest_setup] = MetricFlowEngineTestFixture.from_parameters( sql_client, load_semantic_manifest(semantic_manifest_setup.yaml_file_dir, template_mapping) )