From c4dd2fdf4b4a2024b31286d42f4fdfb797759499 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 12 Apr 2024 07:10:38 +0000 Subject: [PATCH] maybe fix --- llmfoundry/utils/registry_utils.py | 10 ++++++++++ tests/fixtures/autouse.py | 6 ++++++ 2 files changed, 16 insertions(+) diff --git a/llmfoundry/utils/registry_utils.py b/llmfoundry/utils/registry_utils.py index d9c23e6f26..d7861cc557 100644 --- a/llmfoundry/utils/registry_utils.py +++ b/llmfoundry/utils/registry_utils.py @@ -3,6 +3,7 @@ import functools import importlib.util +from contextlib import contextmanager import os from pathlib import Path from types import ModuleType @@ -174,3 +175,12 @@ def import_file(loc: Union[str, Path]) -> ModuleType: except Exception as e: raise RuntimeError(f'Error executing {loc}') from e return module + +@contextmanager +def save_registry(): + """Save the registry state and restore after the context manager exits.""" + saved_registry_state = copy.deepcopy(catalogue.REGISTRY) + + yield + + catalogue.REGISTRY = saved_registry_state \ No newline at end of file diff --git a/tests/fixtures/autouse.py b/tests/fixtures/autouse.py index ccbe1b69f7..6d6a5ad006 100644 --- a/tests/fixtures/autouse.py +++ b/tests/fixtures/autouse.py @@ -9,10 +9,16 @@ import torch from composer.utils import dist, get_device, reproducibility +from llmfoundry.utils.registry_utils import save_registry + # Add llm-foundry repo root to path so we can import scripts in the tests REPO_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')) sys.path.append(REPO_DIR) +@pytest.fixture(autouse=True) +def save_registry_fixture(): + with save_registry(): + yield @pytest.fixture(autouse=True) def initialize_dist(request: pytest.FixtureRequest):