Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Apr 5, 2024
1 parent e7ffd66 commit 24b95b2
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 11 deletions.
12 changes: 9 additions & 3 deletions llmfoundry/layers_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,27 @@

from llmfoundry.utils.registry_utils import create_registry

_norm_description = """The norms registry is used to register classes that implement normalization layers."""
_norm_description = (
'The norms registry is used to register classes that implement normalization layers.'
)
norms = create_registry('llmfoundry',
'norms',
generic_type=Type[torch.nn.Module],
entry_points=True,
description=_norm_description)

_attention_class_description = """The attention_class registry is used to register classes that implement attention layers."""
_attention_class_description = (
'The attention_class registry is used to register classes that implement attention layers. See '
+ 'attention.py for expected constructor signature.')
attention_class = create_registry('llmfoundry',
'attention_class',
generic_type=Type[torch.nn.Module],
entry_points=True,
description=_attention_class_description)

_attention_implementation_description = """The attention_implementation registry is used to register functions that implement the attention operation."""
_attention_implementation_description = (
'The attention_implementation registry is used to register functions that implement the attention operation.'
+ 'See attention.py for expected function signature.')
attention_implementation = create_registry(
'llmfoundry',
'attention_implementation',
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/models/utils/act_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def get_act_ckpt_module(mod_name: str) -> Any:
"""Get the module type from the module name."""
if mod_name.lower() == 'mptblock':
mod_type = MPTBlock
elif mod_name in attention_class.get_all():
elif mod_name in attention_class:
mod_type = attention_class.get(mod_name)
elif mod_name in FFN_CLASS_REGISTRY:
mod_type = FFN_CLASS_REGISTRY[mod_name]
Expand Down
26 changes: 19 additions & 7 deletions llmfoundry/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from transformers import PreTrainedTokenizerBase

from llmfoundry.interfaces import CallbackWithConfig
from llmfoundry.layers_registry import norms
from llmfoundry.layers_registry import (attention_class,
attention_implementation, norms)
from llmfoundry.utils.registry_utils import create_registry

_loggers_description = (
Expand Down Expand Up @@ -85,25 +86,34 @@
entry_points=True,
description=_schedulers_description)

_models_description = """The models registry is used to register classes that implement the ComposerModel interface. The model
constructor should accept two arguments: an omegaconf DictConfig named `om_model_config` and a PreTrainedTokenizerBase named `tokenizer`.
Note: This will soon be updated to take in named kwargs instead of a config directly."""
_models_description = (
'The models registry is used to register classes that implement the ComposerModel interface. '
+
'The model constructor should accept two arguments: an omegaconf DictConfig named `om_model_config` and a PreTrainedTokenizerBase named `tokenizer`. '
+
'Note: This will soon be updated to take in named kwargs instead of a config directly.'
)
models = create_registry('llmfoundry',
'models',
generic_type=Type[ComposerModel],
entry_points=True,
description=_models_description)

_dataloaders_description = """The dataloaders registry is used to register functions that create a DataSpec. The function should take
a DictConfig, a PreTrainedTokenizerBase, and an int as arguments, and return a DataSpec."""
_dataloaders_description = (
'The dataloaders registry is used to register functions that create a DataSpec. The function should take '
+
'a DictConfig, a PreTrainedTokenizerBase, and an int as arguments, and return a DataSpec.'
)
dataloaders = create_registry(
'llmfoundry',
'dataloaders',
generic_type=Callable[[DictConfig, PreTrainedTokenizerBase, int], DataSpec],
entry_points=True,
description=_dataloaders_description)

_metrics_description = """The metrics registry is used to register classes that implement the torchmetrics.Metric interface."""
_metrics_description = (
'The metrics registry is used to register classes that implement the torchmetrics.Metric interface.'
)
metrics = create_registry('llmfoundry',
'metrics',
generic_type=Type[Metric],
Expand All @@ -121,4 +131,6 @@
'metrics',
'dataloaders',
'norms',
'attention_class',
'attention_implementation',
]
2 changes: 2 additions & 0 deletions tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def test_expected_registries_exist():
'metrics',
'models',
'norms',
'attention_class',
'attention_implementation',
}

assert existing_registries == expected_registry_names
Expand Down

0 comments on commit 24b95b2

Please sign in to comment.