From 32e14a67d3677024e22262aecbb5b485b4019ad3 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Sun, 24 Mar 2024 11:38:11 -0700 Subject: [PATCH] Dataloaders registry (#1044) --- llmfoundry/data/__init__.py | 5 +++++ llmfoundry/data/dataloader.py | 31 ++++++++++++++++--------------- llmfoundry/registry.py | 15 +++++++++++++-- tests/data/test_dataloader.py | 4 ++-- tests/test_registry.py | 1 + 5 files changed, 37 insertions(+), 19 deletions(-) diff --git a/llmfoundry/data/__init__.py b/llmfoundry/data/__init__.py index 8da436b9b1..b64f54cfa3 100644 --- a/llmfoundry/data/__init__.py +++ b/llmfoundry/data/__init__.py @@ -9,6 +9,11 @@ build_finetuning_dataloader) from llmfoundry.data.text_data import (StreamingTextDataset, build_text_dataloader) +from llmfoundry.registry import dataloaders + +dataloaders.register('text', func=build_text_dataloader) +dataloaders.register('text_denoising', func=build_text_denoising_dataloader) +dataloaders.register('finetuning', func=build_finetuning_dataloader) __all__ = [ 'MixtureOfDenoisersCollator', diff --git a/llmfoundry/data/dataloader.py b/llmfoundry/data/dataloader.py index 63d47a65d5..a98526001a 100644 --- a/llmfoundry/data/dataloader.py +++ b/llmfoundry/data/dataloader.py @@ -7,15 +7,8 @@ from omegaconf import DictConfig from transformers import PreTrainedTokenizerBase -from llmfoundry.data.denoising import build_text_denoising_dataloader -from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader -from llmfoundry.data.text_data import build_text_dataloader - -LOADER_NAME_TO_FUNCTION = { - 'text': build_text_dataloader, - 'text_denoising': build_text_denoising_dataloader, - 'finetuning': build_finetuning_dataloader, -} +from llmfoundry import registry +from llmfoundry.utils.registry_utils import construct_from_registry def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, @@ -28,9 +21,17 @@ def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, device_batch_size (int): The size of the batches (number of examples) that the dataloader will produce. """ - if cfg.name not in LOADER_NAME_TO_FUNCTION: - allowed = ', '.join(LOADER_NAME_TO_FUNCTION.keys()) - raise ValueError(f'Expected dataloader name to be one of {allowed}' + - f' but found name "{cfg.name}" in config: {cfg}') - - return LOADER_NAME_TO_FUNCTION[cfg.name](cfg, tokenizer, device_batch_size) + kwargs = { + 'cfg': cfg, + 'tokenizer': tokenizer, + 'device_batch_size': device_batch_size + } + + return construct_from_registry( + name=cfg.name, + registry=registry.dataloaders, + partial_function=False, + pre_validation_function=None, + post_validation_function=None, + kwargs=kwargs, + ) diff --git a/llmfoundry/registry.py b/llmfoundry/registry.py index 6e664ca9c1..897f714d62 100644 --- a/llmfoundry/registry.py +++ b/llmfoundry/registry.py @@ -1,12 +1,14 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from typing import Type +from typing import Callable, Type -from composer.core import Algorithm, Callback +from composer.core import Algorithm, Callback, DataSpec from composer.loggers import LoggerDestination from composer.optim import ComposerScheduler +from omegaconf import DictConfig from torch.optim import Optimizer from torchmetrics import Metric +from transformers import PreTrainedTokenizerBase from llmfoundry.interfaces import CallbackWithConfig from llmfoundry.utils.registry_utils import create_registry @@ -81,6 +83,15 @@ entry_points=True, description=_schedulers_description) +_dataloaders_description = """The dataloaders registry is used to register functions that create a DataSpec. The function should take +a DictConfig, a PreTrainedTokenizerBase, and an int as arguments, and return a DataSpec.""" +dataloaders = create_registry( + 'llmfoundry', + 'dataloaders', + generic_type=Callable[[DictConfig, PreTrainedTokenizerBase, int], DataSpec], + entry_points=True, + description=_dataloaders_description) + _metrics_description = """The metrics registry is used to register classes that implement the torchmetrics.Metric interface.""" metrics = create_registry('llmfoundry', 'metrics', diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index 00ac8df182..dc06bfb3c1 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -12,6 +12,7 @@ from typing import ContextManager, Literal, Optional, Union from unittest.mock import MagicMock, patch +import catalogue import numpy as np import pytest import torch @@ -1070,6 +1071,5 @@ def test_build_unknown_dataloader(): 'name': 'unknown', }) tokenizer = MagicMock() - with pytest.raises(ValueError, - match='Expected dataloader name to be one of'): + with pytest.raises(catalogue.RegistryError): _ = build_dataloader(cfg, tokenizer, 2) diff --git a/tests/test_registry.py b/tests/test_registry.py index df8fda8d9f..cf03cc18c9 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -27,6 +27,7 @@ def test_expected_registries_exist(): 'callbacks', 'algorithms', 'callbacks_with_config', + 'dataloaders', 'metrics', }