Skip to content

Commit

Permalink
[kushalkodnad/tokenizer-registry] Introduce new registry for tokenize…
Browse files Browse the repository at this point in the history
…rs (#1386)
  • Loading branch information
kushalkodn-db authored Jul 23, 2024
1 parent cefd616 commit 51949c4
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 3 deletions.
14 changes: 14 additions & 0 deletions llmfoundry/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,19 @@
description=_schedulers_description,
)

_tokenizers_description = (
'The tokenizers registry is used to register tokenizers that implement the transformers.PreTrainedTokenizerBase interface. '
+
'The tokenizer will be passed to the build_dataloader() and build_composer_model() methods in train.py.'
)
tokenizers = create_registry(
'llmfoundry',
'tokenizers',
generic_type=Type[PreTrainedTokenizerBase],
entry_points=True,
description=_tokenizers_description,
)

_models_description = (
"""The models registry is used to register classes that implement the ComposerModel interface.
Expand Down Expand Up @@ -383,6 +396,7 @@
'optimizers',
'algorithms',
'schedulers',
'tokenizers',
'models',
'dataset_replication_validators',
'collators',
Expand Down
3 changes: 3 additions & 0 deletions llmfoundry/tokenizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from llmfoundry.registry import tokenizers
from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper

tokenizers.register('tiktoken', func=TiktokenTokenizerWrapper)

__all__ = [
'TiktokenTokenizerWrapper',
]
12 changes: 9 additions & 3 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from llmfoundry.data.dataloader import build_dataloader
from llmfoundry.eval.datasets.in_context_learning_evaluation import \
get_icl_task_dataloader
from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper
from llmfoundry.utils.config_utils import to_dict_container, to_list_container
from llmfoundry.utils.registry_utils import construct_from_registry

Expand Down Expand Up @@ -506,8 +505,15 @@ def build_tokenizer(
with dist.local_rank_zero_download_and_wait(signal_file_path):
pass

if tokenizer_name.startswith('tiktoken'):
tokenizer = TiktokenTokenizerWrapper(**tokenizer_kwargs)
if tokenizer_name in registry.tokenizers:
tokenizer = construct_from_registry(
name=tokenizer_name,
registry=registry.tokenizers,
partial_function=True,
pre_validation_function=PreTrainedTokenizerBase,
post_validation_function=None,
kwargs=tokenizer_kwargs,
)
else:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
Expand Down
1 change: 1 addition & 0 deletions tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def test_expected_registries_exist():
'loggers',
'optimizers',
'schedulers',
'tokenizers',
'callbacks',
'algorithms',
'callbacks_with_config',
Expand Down
35 changes: 35 additions & 0 deletions tests/tokenizers/test_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict, Optional

from transformers import PreTrainedTokenizer

from llmfoundry.registry import tokenizers
from llmfoundry.utils import build_tokenizer


class DummyTokenizer(PreTrainedTokenizer):
"""A dummy tokenizer that inherits from ``PreTrainedTokenizer``."""

def __init__(
self,
model_name: Optional[str] = 'dummy',
**kwargs: Optional[Dict[str, Any]],
):
"""Dummy constructor that has no real purpose."""
super().__init__(
model_name=model_name,
eos_token='0',
pad_token='1',
**kwargs,
)

def get_vocab(self) -> Dict[str, int]:
return {}


def test_tokenizer_registry():
tokenizers.register('dummy', func=DummyTokenizer)
tokenizer = build_tokenizer(tokenizer_name='dummy', tokenizer_kwargs={})
assert type(tokenizer) == DummyTokenizer

0 comments on commit 51949c4

Please sign in to comment.