Skip to content

Commit

Permalink
/* PR_START p--thread-local-id 02 */ Move `patch_id_generators_helper…
Browse files Browse the repository at this point in the history
…` to `SequentialIdGenerator`.
  • Loading branch information
plypaul committed Oct 21, 2024
1 parent 2a6f43c commit 87cdee6
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 26 deletions.
24 changes: 23 additions & 1 deletion metricflow-semantics/metricflow_semantics/dag/sequential_id.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import threading
from contextlib import ExitStack, contextmanager
from dataclasses import dataclass
from typing import Dict
from typing import Dict, Generator
from unittest.mock import patch

from typing_extensions import override

Expand Down Expand Up @@ -48,3 +50,23 @@ def reset(cls, default_start_value: int = 0) -> None:
with cls._state_lock:
cls._prefix_to_next_value = {}
cls._default_start_value = default_start_value

@contextmanager
@staticmethod
def patch_id_generators_helper(start_value: int) -> Generator[None, None, None]:
"""Replace ID generators in IdGeneratorRegistry with one that has the given start value.
TODO: This method will be modified in a later PR.
"""
# Create patch context managers for all ID generators in the registry with introspection magic.
patch_context_managers = [
patch.object(SequentialIdGenerator, "_prefix_to_next_value", {}),
patch.object(SequentialIdGenerator, "_default_start_value", start_value),
]

# Enter the patch context for the patches above.
with ExitStack() as stack:
for patch_context_manager in patch_context_managers:
stack.enter_context(patch_context_manager) # type: ignore
# This will un-patch when done with the test.
yield None
Original file line number Diff line number Diff line change
@@ -1,43 +1,21 @@
from __future__ import annotations

from contextlib import ExitStack, contextmanager
from dataclasses import dataclass
from typing import Generator
from unittest.mock import patch

import pytest

from metricflow_semantics.dag.sequential_id import SequentialIdGenerator


@contextmanager
def patch_id_generators_helper(start_value: int) -> Generator[None, None, None]:
"""Replace ID generators in IdGeneratorRegistry with one that has the given start value.
TODO: This method will be modified in a later PR.
"""
# Create patch context managers for all ID generators in the registry with introspection magic.
patch_context_managers = [
patch.object(SequentialIdGenerator, "_prefix_to_next_value", {}),
patch.object(SequentialIdGenerator, "_default_start_value", start_value),
]

# Enter the patch context for the patches above.
with ExitStack() as stack:
for patch_context_manager in patch_context_managers:
stack.enter_context(patch_context_manager) # type: ignore
# This will un-patch when done with the test.
yield None


@pytest.fixture(autouse=True, scope="function")
def patch_id_generators() -> Generator[None, None, None]:
"""Patch ID generators with a new one to get repeatability in plan outputs before every test.
Plan outputs contain IDs, so if the IDs are not consistent from run to run, there will be diffs in the actual vs.
expected outputs during a test.
"""
with patch_id_generators_helper(start_value=IdNumberSpace.for_test_start().start_value):
with SequentialIdGenerator.patch_id_generators_helper(start_value=IdNumberSpace.for_test_start().start_value):
yield None


Expand Down
5 changes: 3 additions & 2 deletions tests_metricflow/fixtures/manifest_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
from dbt_semantic_interfaces.implementations.semantic_manifest import PydanticSemanticManifest
from dbt_semantic_interfaces.protocols import SemanticModel
from dbt_semantic_interfaces.test_utils import as_datetime
from metricflow_semantics.dag.sequential_id import SequentialIdGenerator
from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat
from metricflow_semantics.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow_semantics.query.query_parser import MetricFlowQueryParser
from metricflow_semantics.specs.column_assoc import ColumnAssociationResolver
from metricflow_semantics.specs.dunder_column_association_resolver import DunderColumnAssociationResolver
from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration
from metricflow_semantics.test_helpers.id_helpers import IdNumberSpace, patch_id_generators_helper
from metricflow_semantics.test_helpers.id_helpers import IdNumberSpace
from metricflow_semantics.test_helpers.manifest_helpers import load_semantic_manifest
from metricflow_semantics.test_helpers.semantic_manifest_yamls.ambiguous_resolution_manifest import (
AMBIGUOUS_RESOLUTION_MANIFEST_ANCHOR,
Expand Down Expand Up @@ -264,7 +265,7 @@ def mf_engine_test_fixture_mapping(
"""Returns a mapping for all semantic manifests used in testing to the associated test fixture."""
fixture_mapping: Dict[SemanticManifestSetup, MetricFlowEngineTestFixture] = {}
for semantic_manifest_setup in SemanticManifestSetup:
with patch_id_generators_helper(semantic_manifest_setup.id_number_space.start_value):
with SequentialIdGenerator.patch_id_generators_helper(semantic_manifest_setup.id_number_space.start_value):
fixture_mapping[semantic_manifest_setup] = MetricFlowEngineTestFixture.from_parameters(
sql_client, load_semantic_manifest(semantic_manifest_setup.yaml_file_dir, template_mapping)
)
Expand Down

0 comments on commit 87cdee6

Please sign in to comment.