diff --git a/llmfoundry/registry.py b/llmfoundry/registry.py index e31840d3fb..3f0163ff01 100644 --- a/llmfoundry/registry.py +++ b/llmfoundry/registry.py @@ -155,6 +155,19 @@ description=_schedulers_description, ) +_tokenizers_description = ( + 'The tokenizers registry is used to register tokenizers that implement the transformers.PreTrainedTokenizerBase interface. ' + + + 'The tokenizer will be passed to the build_dataloader() and build_composer_model() methods in train.py.' +) +tokenizers = create_registry( + 'llmfoundry', + 'tokenizers', + generic_type=Type[PreTrainedTokenizerBase], + entry_points=True, + description=_tokenizers_description, +) + _models_description = ( """The models registry is used to register classes that implement the ComposerModel interface. @@ -383,6 +396,7 @@ 'optimizers', 'algorithms', 'schedulers', + 'tokenizers', 'models', 'dataset_replication_validators', 'collators', diff --git a/llmfoundry/tokenizers/__init__.py b/llmfoundry/tokenizers/__init__.py index 1703ed8862..d37c12a555 100644 --- a/llmfoundry/tokenizers/__init__.py +++ b/llmfoundry/tokenizers/__init__.py @@ -1,8 +1,11 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +from llmfoundry.registry import tokenizers from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper +tokenizers.register('tiktoken', func=TiktokenTokenizerWrapper) + __all__ = [ 'TiktokenTokenizerWrapper', ] diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index b889155be0..cf27e7660e 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -37,7 +37,6 @@ from llmfoundry.data.dataloader import build_dataloader from llmfoundry.eval.datasets.in_context_learning_evaluation import \ get_icl_task_dataloader -from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper from llmfoundry.utils.config_utils import to_dict_container, to_list_container from llmfoundry.utils.registry_utils import construct_from_registry @@ -506,8 +505,15 @@ def build_tokenizer( with dist.local_rank_zero_download_and_wait(signal_file_path): pass - if tokenizer_name.startswith('tiktoken'): - tokenizer = TiktokenTokenizerWrapper(**tokenizer_kwargs) + if tokenizer_name in registry.tokenizers: + tokenizer = construct_from_registry( + name=tokenizer_name, + registry=registry.tokenizers, + partial_function=True, + pre_validation_function=PreTrainedTokenizerBase, + post_validation_function=None, + kwargs=tokenizer_kwargs, + ) else: tokenizer = AutoTokenizer.from_pretrained( tokenizer_name, diff --git a/tests/test_registry.py b/tests/test_registry.py index aa0c93ee13..c4d1a1bcd5 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -24,6 +24,7 @@ def test_expected_registries_exist(): 'loggers', 'optimizers', 'schedulers', + 'tokenizers', 'callbacks', 'algorithms', 'callbacks_with_config', diff --git a/tests/tokenizers/test_registry.py b/tests/tokenizers/test_registry.py new file mode 100644 index 0000000000..920c207a64 --- /dev/null +++ b/tests/tokenizers/test_registry.py @@ -0,0 +1,35 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, Optional + +from transformers import PreTrainedTokenizer + +from llmfoundry.registry import tokenizers +from llmfoundry.utils import build_tokenizer + + +class DummyTokenizer(PreTrainedTokenizer): + """A dummy tokenizer that inherits from ``PreTrainedTokenizer``.""" + + def __init__( + self, + model_name: Optional[str] = 'dummy', + **kwargs: Optional[Dict[str, Any]], + ): + """Dummy constructor that has no real purpose.""" + super().__init__( + model_name=model_name, + eos_token='0', + pad_token='1', + **kwargs, + ) + + def get_vocab(self) -> Dict[str, int]: + return {} + + +def test_tokenizer_registry(): + tokenizers.register('dummy', func=DummyTokenizer) + tokenizer = build_tokenizer(tokenizer_name='dummy', tokenizer_kwargs={}) + assert type(tokenizer) == DummyTokenizer