diff --git a/llmfoundry/layers_registry.py b/llmfoundry/layers_registry.py index f683c245da..7e21fc69b0 100644 --- a/llmfoundry/layers_registry.py +++ b/llmfoundry/layers_registry.py @@ -7,21 +7,27 @@ from llmfoundry.utils.registry_utils import create_registry -_norm_description = """The norms registry is used to register classes that implement normalization layers.""" +_norm_description = ( + 'The norms registry is used to register classes that implement normalization layers.' +) norms = create_registry('llmfoundry', 'norms', generic_type=Type[torch.nn.Module], entry_points=True, description=_norm_description) -_attention_class_description = """The attention_class registry is used to register classes that implement attention layers.""" +_attention_class_description = ( + 'The attention_class registry is used to register classes that implement attention layers. See ' + + 'attention.py for expected constructor signature.') attention_class = create_registry('llmfoundry', 'attention_class', generic_type=Type[torch.nn.Module], entry_points=True, description=_attention_class_description) -_attention_implementation_description = """The attention_implementation registry is used to register functions that implement the attention operation.""" +_attention_implementation_description = ( + 'The attention_implementation registry is used to register functions that implement the attention operation.' + + 'See attention.py for expected function signature.') attention_implementation = create_registry( 'llmfoundry', 'attention_implementation', diff --git a/llmfoundry/models/utils/act_ckpt.py b/llmfoundry/models/utils/act_ckpt.py index 3399e70e03..ed25c2cfa1 100644 --- a/llmfoundry/models/utils/act_ckpt.py +++ b/llmfoundry/models/utils/act_ckpt.py @@ -24,7 +24,7 @@ def get_act_ckpt_module(mod_name: str) -> Any: """Get the module type from the module name.""" if mod_name.lower() == 'mptblock': mod_type = MPTBlock - elif mod_name in attention_class.get_all(): + elif mod_name in attention_class: mod_type = attention_class.get(mod_name) elif mod_name in FFN_CLASS_REGISTRY: mod_type = FFN_CLASS_REGISTRY[mod_name] diff --git a/llmfoundry/registry.py b/llmfoundry/registry.py index 424075da3b..2729f2fca0 100644 --- a/llmfoundry/registry.py +++ b/llmfoundry/registry.py @@ -12,7 +12,8 @@ from transformers import PreTrainedTokenizerBase from llmfoundry.interfaces import CallbackWithConfig -from llmfoundry.layers_registry import norms +from llmfoundry.layers_registry import (attention_class, + attention_implementation, norms) from llmfoundry.utils.registry_utils import create_registry _loggers_description = ( @@ -85,17 +86,24 @@ entry_points=True, description=_schedulers_description) -_models_description = """The models registry is used to register classes that implement the ComposerModel interface. The model -constructor should accept two arguments: an omegaconf DictConfig named `om_model_config` and a PreTrainedTokenizerBase named `tokenizer`. -Note: This will soon be updated to take in named kwargs instead of a config directly.""" +_models_description = ( + 'The models registry is used to register classes that implement the ComposerModel interface. ' + + + 'The model constructor should accept two arguments: an omegaconf DictConfig named `om_model_config` and a PreTrainedTokenizerBase named `tokenizer`. ' + + + 'Note: This will soon be updated to take in named kwargs instead of a config directly.' +) models = create_registry('llmfoundry', 'models', generic_type=Type[ComposerModel], entry_points=True, description=_models_description) -_dataloaders_description = """The dataloaders registry is used to register functions that create a DataSpec. The function should take -a DictConfig, a PreTrainedTokenizerBase, and an int as arguments, and return a DataSpec.""" +_dataloaders_description = ( + 'The dataloaders registry is used to register functions that create a DataSpec. The function should take ' + + + 'a DictConfig, a PreTrainedTokenizerBase, and an int as arguments, and return a DataSpec.' +) dataloaders = create_registry( 'llmfoundry', 'dataloaders', @@ -103,7 +111,9 @@ entry_points=True, description=_dataloaders_description) -_metrics_description = """The metrics registry is used to register classes that implement the torchmetrics.Metric interface.""" +_metrics_description = ( + 'The metrics registry is used to register classes that implement the torchmetrics.Metric interface.' +) metrics = create_registry('llmfoundry', 'metrics', generic_type=Type[Metric], @@ -121,4 +131,6 @@ 'metrics', 'dataloaders', 'norms', + 'attention_class', + 'attention_implementation', ] diff --git a/tests/test_registry.py b/tests/test_registry.py index c93c7c9749..17cdfa1457 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -31,6 +31,8 @@ def test_expected_registries_exist(): 'metrics', 'models', 'norms', + 'attention_class', + 'attention_implementation', } assert existing_registries == expected_registry_names