Skip to content

Commit

Permalink
Dataloaders registry (#1044)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Mar 24, 2024
1 parent 5c8a829 commit 32e14a6
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 19 deletions.
5 changes: 5 additions & 0 deletions llmfoundry/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
build_finetuning_dataloader)
from llmfoundry.data.text_data import (StreamingTextDataset,
build_text_dataloader)
from llmfoundry.registry import dataloaders

dataloaders.register('text', func=build_text_dataloader)
dataloaders.register('text_denoising', func=build_text_denoising_dataloader)
dataloaders.register('finetuning', func=build_finetuning_dataloader)

__all__ = [
'MixtureOfDenoisersCollator',
Expand Down
31 changes: 16 additions & 15 deletions llmfoundry/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,8 @@
from omegaconf import DictConfig
from transformers import PreTrainedTokenizerBase

from llmfoundry.data.denoising import build_text_denoising_dataloader
from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader
from llmfoundry.data.text_data import build_text_dataloader

LOADER_NAME_TO_FUNCTION = {
'text': build_text_dataloader,
'text_denoising': build_text_denoising_dataloader,
'finetuning': build_finetuning_dataloader,
}
from llmfoundry import registry
from llmfoundry.utils.registry_utils import construct_from_registry


def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
Expand All @@ -28,9 +21,17 @@ def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
device_batch_size (int): The size of the batches (number of examples)
that the dataloader will produce.
"""
if cfg.name not in LOADER_NAME_TO_FUNCTION:
allowed = ', '.join(LOADER_NAME_TO_FUNCTION.keys())
raise ValueError(f'Expected dataloader name to be one of {allowed}' +
f' but found name "{cfg.name}" in config: {cfg}')

return LOADER_NAME_TO_FUNCTION[cfg.name](cfg, tokenizer, device_batch_size)
kwargs = {
'cfg': cfg,
'tokenizer': tokenizer,
'device_batch_size': device_batch_size
}

return construct_from_registry(
name=cfg.name,
registry=registry.dataloaders,
partial_function=False,
pre_validation_function=None,
post_validation_function=None,
kwargs=kwargs,
)
15 changes: 13 additions & 2 deletions llmfoundry/registry.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
from typing import Type
from typing import Callable, Type

from composer.core import Algorithm, Callback
from composer.core import Algorithm, Callback, DataSpec
from composer.loggers import LoggerDestination
from composer.optim import ComposerScheduler
from omegaconf import DictConfig
from torch.optim import Optimizer
from torchmetrics import Metric
from transformers import PreTrainedTokenizerBase

from llmfoundry.interfaces import CallbackWithConfig
from llmfoundry.utils.registry_utils import create_registry
Expand Down Expand Up @@ -81,6 +83,15 @@
entry_points=True,
description=_schedulers_description)

_dataloaders_description = """The dataloaders registry is used to register functions that create a DataSpec. The function should take
a DictConfig, a PreTrainedTokenizerBase, and an int as arguments, and return a DataSpec."""
dataloaders = create_registry(
'llmfoundry',
'dataloaders',
generic_type=Callable[[DictConfig, PreTrainedTokenizerBase, int], DataSpec],
entry_points=True,
description=_dataloaders_description)

_metrics_description = """The metrics registry is used to register classes that implement the torchmetrics.Metric interface."""
metrics = create_registry('llmfoundry',
'metrics',
Expand Down
4 changes: 2 additions & 2 deletions tests/data/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import ContextManager, Literal, Optional, Union
from unittest.mock import MagicMock, patch

import catalogue
import numpy as np
import pytest
import torch
Expand Down Expand Up @@ -1070,6 +1071,5 @@ def test_build_unknown_dataloader():
'name': 'unknown',
})
tokenizer = MagicMock()
with pytest.raises(ValueError,
match='Expected dataloader name to be one of'):
with pytest.raises(catalogue.RegistryError):
_ = build_dataloader(cfg, tokenizer, 2)
1 change: 1 addition & 0 deletions tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def test_expected_registries_exist():
'callbacks',
'algorithms',
'callbacks_with_config',
'dataloaders',
'metrics',
}

Expand Down

0 comments on commit 32e14a6

Please sign in to comment.