Skip to content

Commit

Permalink
maybe fix
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Apr 12, 2024
1 parent 3fd8086 commit f8d4c8f
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 17 deletions.
7 changes: 0 additions & 7 deletions llmfoundry/models/utils/param_init_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,6 @@ def generic_param_init_fn_(
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
**kwargs: Any,
) -> None:

print('in init')
print(type(module))
print(isinstance(module, GLU))
print(isinstance(module, MLP))
print(GLU)
print(MLP)
del kwargs # unused, just to capture any extra args from the config
# enable user to divide _is_residual weights by

Expand Down
12 changes: 12 additions & 0 deletions llmfoundry/utils/registry_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import copy
import functools
import importlib.util
import os
from contextlib import contextmanager
from pathlib import Path
from types import ModuleType
from typing import (Any, Callable, Dict, Generic, Optional, Sequence, Type,
Expand Down Expand Up @@ -174,3 +176,13 @@ def import_file(loc: Union[str, Path]) -> ModuleType:
except Exception as e:
raise RuntimeError(f'Error executing {loc}') from e
return module


@contextmanager
def save_registry():
"""Save the registry state and restore after the context manager exits."""
saved_registry_state = copy.deepcopy(catalogue.REGISTRY)

yield

catalogue.REGISTRY = saved_registry_state
15 changes: 5 additions & 10 deletions tests/fixtures/autouse.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import copy
import gc
import os
import sys
Expand All @@ -10,21 +9,17 @@
import torch
from composer.utils import dist, get_device, reproducibility

from llmfoundry.utils.registry_utils import save_registry

# Add llm-foundry repo root to path so we can import scripts in the tests
REPO_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))
sys.path.append(REPO_DIR)


@pytest.fixture(autouse=True)
def save_registry():
from catalogue import REGISTRY

# Save it
saved_registry = copy.deepcopy(REGISTRY)
# Yield
yield
# Restore it
REGISTRY.update(saved_registry)
def save_registry_fixture():
with save_registry():
yield


@pytest.fixture(autouse=True)
Expand Down

0 comments on commit f8d4c8f

Please sign in to comment.