Skip to content

Commit

Permalink
maybe fix
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Apr 12, 2024
1 parent d110f74 commit c4dd2fd
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
10 changes: 10 additions & 0 deletions llmfoundry/utils/registry_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import functools
import importlib.util
from contextlib import contextmanager
import os
from pathlib import Path
from types import ModuleType
Expand Down Expand Up @@ -174,3 +175,12 @@ def import_file(loc: Union[str, Path]) -> ModuleType:
except Exception as e:
raise RuntimeError(f'Error executing {loc}') from e
return module

@contextmanager
def save_registry():
"""Save the registry state and restore after the context manager exits."""
saved_registry_state = copy.deepcopy(catalogue.REGISTRY)

yield

catalogue.REGISTRY = saved_registry_state
6 changes: 6 additions & 0 deletions tests/fixtures/autouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,16 @@
import torch
from composer.utils import dist, get_device, reproducibility

from llmfoundry.utils.registry_utils import save_registry

# Add llm-foundry repo root to path so we can import scripts in the tests
REPO_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))
sys.path.append(REPO_DIR)

@pytest.fixture(autouse=True)
def save_registry_fixture():
with save_registry():
yield

@pytest.fixture(autouse=True)
def initialize_dist(request: pytest.FixtureRequest):
Expand Down

0 comments on commit c4dd2fd

Please sign in to comment.