diff --git a/llmfoundry/models/utils/param_init_fns.py b/llmfoundry/models/utils/param_init_fns.py index 495bb14a69..bd409dee36 100644 --- a/llmfoundry/models/utils/param_init_fns.py +++ b/llmfoundry/models/utils/param_init_fns.py @@ -157,13 +157,6 @@ def generic_param_init_fn_( emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, **kwargs: Any, ) -> None: - - print('in init') - print(type(module)) - print(isinstance(module, GLU)) - print(isinstance(module, MLP)) - print(GLU) - print(MLP) del kwargs # unused, just to capture any extra args from the config # enable user to divide _is_residual weights by diff --git a/llmfoundry/utils/registry_utils.py b/llmfoundry/utils/registry_utils.py index 83348a7fd6..0eeefbae74 100644 --- a/llmfoundry/utils/registry_utils.py +++ b/llmfoundry/utils/registry_utils.py @@ -1,9 +1,11 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import copy import functools import importlib.util import os +from contextlib import contextmanager from pathlib import Path from types import ModuleType from typing import (Any, Callable, Dict, Generic, Optional, Sequence, Type, @@ -174,3 +176,13 @@ def import_file(loc: Union[str, Path]) -> ModuleType: except Exception as e: raise RuntimeError(f'Error executing {loc}') from e return module + + +@contextmanager +def save_registry(): + """Save the registry state and restore after the context manager exits.""" + saved_registry_state = copy.deepcopy(catalogue.REGISTRY) + + yield + + catalogue.REGISTRY = saved_registry_state diff --git a/tests/fixtures/autouse.py b/tests/fixtures/autouse.py index 04c0812aeb..16e3f8ad6f 100644 --- a/tests/fixtures/autouse.py +++ b/tests/fixtures/autouse.py @@ -1,7 +1,6 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -import copy import gc import os import sys @@ -10,21 +9,17 @@ import torch from composer.utils import dist, get_device, reproducibility +from llmfoundry.utils.registry_utils import save_registry + # Add llm-foundry repo root to path so we can import scripts in the tests REPO_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')) sys.path.append(REPO_DIR) @pytest.fixture(autouse=True) -def save_registry(): - from catalogue import REGISTRY - - # Save it - saved_registry = copy.deepcopy(REGISTRY) - # Yield - yield - # Restore it - REGISTRY.update(saved_registry) +def save_registry_fixture(): + with save_registry(): + yield @pytest.fixture(autouse=True)