Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify CL API #1510

Merged
merged 1 commit into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 50 additions & 28 deletions llmfoundry/callbacks/curriculum_learning_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

import copy
import logging
from typing import Any
import warnings
from typing import Any, Optional, Union

from composer import DataSpec
from composer.core import State, Time, TimeUnit, ensure_time
Expand All @@ -23,6 +24,7 @@
BaseContextualError,
TrainDataLoaderLocation,
)
from llmfoundry.utils.warnings import VersionedDeprecationWarning

log = logging.getLogger(__name__)

Expand All @@ -32,19 +34,21 @@
class CurriculumLearning(CallbackWithConfig):
"""Starts an epoch with a different dataset when resuming from a checkpoint.

Example duration:
<number>tok
Example schedule:
[
{
'duration': <number>tok,
'train_loader': <dataloader parameters>, # matches top level train_loader
'dataset': <dataset parameters>,
},
{
'duration': <number>tok,
'train_loader': <dataloader parameters>,
'dataset': <dataset parameters>,
},
{
'duration': <number>tok,
'train_loader': <dataloader parameters>,
'dataset': <dataset parameters>,
],
]

Expand All @@ -53,48 +57,59 @@ class CurriculumLearning(CallbackWithConfig):
being used. Note that this is the full train config and must
contain the 'train_loader', 'device_train_batch_size', and
'tokenizer' keys.
duration (Union[Time, str, int], optional): The duration of the first datamix
(which corresponds to the train_loader). Defaults to None.
schedule (list[dict[str, Any]]): The list of datamixes to use and their
durations. Duration units must match max_duration and be in terms of
a TimeUnit that is supported by Iteration. The duration values must
be positive. There must be at least one datamix in the schedule. The
first datamix in the schedule must match the train_loader in the
train_config. On resumption, previously trained on datamixes and
durations cannot be changed. The duration of the current datamix
must be greater than the saved timestamp. The dataset must be a
StreamingDataset.
first datamix during training is not included in the schedule. On
resumption, previously trained on datamixes and durations cannot be
changed. The duration of the current datamix must be greater than
the saved timestamp. The dataset must be a StreamingDataset.
"""

def __init__(
self,
train_config: dict[str, Any],
schedule: list[dict[str, Any]],
duration: Optional[Union[Time, str, int]] = None,
):
if duration is None:
warnings.warn(
VersionedDeprecationWarning(
'Specifying the full schedule in the CurriculumLearning ' +
'callback is deprecated. Please specify the duration of ' +
'the first datamix separately and change the schedule ' +
'use datasets instead of dataloaders.',
remove_version='0.15.0',
),
)

# Ensure all duration units are in epochs or tokens and values are positive
self._schedule = schedule
if len(self._schedule) == 0:
raise ValueError('The schedule must have at least one datamix.')
for index, datamix in enumerate(self._schedule):
if duration is not None:
first_datamix = {
'duration': duration,
'dataset': train_config['train_loader']['dataset'],
}
self._schedule.insert(0, first_datamix)
for datamix in self._schedule:
self._validate_datamix(datamix)

if (
index == 0 and
train_config['train_loader'] != datamix['train_loader']
):
raise ValueError((
'The first datamix in the schedule must match the '
'train_loader in the train_config.'
))

self._schedule_index = 0
self.device_train_batch_size = train_config['device_train_batch_size']
self.tokenizer = None
self._train_loader_config: dict[str, Any] = train_config['train_loader']
b-chu marked this conversation as resolved.
Show resolved Hide resolved
self._device_train_batch_size = train_config['device_train_batch_size']
self._tokenizer = None

def init(self, state: State, logger: Logger):
del logger # unused

if not hasattr(state.model, 'tokenizer'):
raise ValueError('state.model must have a tokenizer attribute.')
self.tokenizer = state.model.tokenizer
self._tokenizer = state.model.tokenizer

def before_load(self, state: State, logger: Logger):
del logger # unused
Expand Down Expand Up @@ -151,8 +166,13 @@ def iteration_start(self, state: State, logger: Logger):
# which is stale
clean_stale_shared_memory()
datamix = copy.deepcopy(self._schedule[self._schedule_index])
train_loader_config = copy.deepcopy(self._train_loader_config)
if 'dataset' in datamix:
train_loader_config['dataset'].update(datamix['dataset'])
else:
train_loader_config = datamix['train_loader']
data_spec = self._build_train_loader(
train_loader_config=datamix['train_loader'],
train_loader_config=train_loader_config,
logger=logger,
)
state.set_dataloader(
Expand Down Expand Up @@ -211,18 +231,20 @@ def _build_train_loader(
train_loader_config: dict[str, Any],
logger: Logger,
) -> DataSpec:
del logger # unused

from llmfoundry.data.dataloader import build_dataloader

# Copied from scripts/train/train.py
log.info(
f'Building train loader in CurriculumLearning callback for dataset {self._schedule_index}',
)
assert self.tokenizer is not None
assert self._tokenizer is not None
try:
return build_dataloader(
train_loader_config,
self.tokenizer,
self.device_train_batch_size,
self._tokenizer,
b-chu marked this conversation as resolved.
Show resolved Hide resolved
self._device_train_batch_size,
)
except BaseContextualError as e:
e.location = TrainDataLoaderLocation
Expand Down Expand Up @@ -260,5 +282,5 @@ def _validate_datamix(self, datamix: dict[str, Any]):
'Schedules can only be defined in terms of epochs or tokens.',
)

if 'train_loader' not in datamix:
raise ValueError('Each datamix must have a train_loader.')
if 'train_loader' not in datamix and 'dataset' not in datamix:
raise ValueError('Each datamix must have a dataset.')
48 changes: 32 additions & 16 deletions tests/callbacks/test_curriculum_learning_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
[
(None, '1ep'),
({
'dataset': 'some_dataset',
'hf_name': 'some_dataset',
}, '1ep'),
(None, '10tok'),
(None, ''),
Expand All @@ -36,23 +36,29 @@ def test_curriculum_learning_callback_init(
):
test_cfg = _get_test_cfg()
test_cfg['train_loader'] = tiny_ft_dataloader_cfg
train_loader = test_cfg['train_loader'] if datamix is None else datamix
if datamix is None:
train_loader = test_cfg['train_loader']['dataset']
else:
train_loader = datamix
kwargs = {
'schedule': [{
'duration': duration,
'train_loader': train_loader,
'dataset': train_loader,
}, {
'duration': '2ep',
'train_loader': {},
'dataset': {},
}],
}

kwargs['duration'] = kwargs['schedule'].pop(0)['duration']

if duration == '':
del kwargs['schedule'][0]['duration']
if datamix is not None and len(datamix) == 0:
del kwargs['schedule'][0]['train_loader']
del kwargs['schedule'][0]['dataset']

context = nullcontext()
if datamix is not None or duration == '':
if (datamix is not None and len(datamix) == 0) or duration == '':
context = pytest.raises(ValueError)
with context:
callback = build_callback(
Expand Down Expand Up @@ -85,13 +91,15 @@ def test_curriculum_learning_callback_before_load(
kwargs = {
'schedule': [{
'duration': duration,
'train_loader': test_cfg['train_loader'],
'dataset': test_cfg['train_loader']['dataset'],
}, {
'duration': '2ep',
'train_loader': test_cfg['train_loader'],
'dataset': test_cfg['train_loader']['dataset'],
}],
}

kwargs['duration'] = kwargs['schedule'].pop(0)['duration']

callback = build_callback(
'curriculum_learning',
kwargs=kwargs,
Expand Down Expand Up @@ -123,13 +131,15 @@ def test_curriculum_learning_callback_after_load(build_tiny_mpt: Callable,):
kwargs = {
'schedule': [{
'duration': '1ep',
'train_loader': test_cfg['train_loader'],
'dataset': test_cfg['train_loader']['dataset'],
}, {
'duration': '2ep',
'train_loader': test_cfg['train_loader'],
'dataset': test_cfg['train_loader']['dataset'],
}],
}

kwargs['duration'] = kwargs['schedule'].pop(0)['duration']

callback = build_callback(
'curriculum_learning',
kwargs=kwargs,
Expand Down Expand Up @@ -168,13 +178,15 @@ def test_curriculum_learning_callback_iteration(
kwargs = {
'schedule': [{
'duration': '1ep',
'train_loader': test_cfg['train_loader'],
'dataset': test_cfg['train_loader']['dataset'],
}, {
'duration': '2ep',
'train_loader': test_cfg['train_loader'],
'dataset': test_cfg['train_loader']['dataset'],
}],
}

kwargs['duration'] = kwargs['schedule'].pop(0)['duration']

callback = build_callback(
'curriculum_learning',
kwargs=kwargs,
Expand Down Expand Up @@ -208,13 +220,15 @@ def test_curriculum_learning_callback_state_dict(build_tiny_mpt: Callable,):
kwargs = {
'schedule': [{
'duration': '1ep',
'train_loader': test_cfg['train_loader'],
'dataset': test_cfg['train_loader']['dataset'],
}, {
'duration': '2ep',
'train_loader': test_cfg['train_loader'],
'dataset': test_cfg['train_loader']['dataset'],
}],
}

kwargs['duration'] = kwargs['schedule'].pop(0)['duration']

callback = build_callback(
'curriculum_learning',
kwargs=kwargs,
Expand Down Expand Up @@ -249,13 +263,15 @@ def test_curriculum_learning_callback_load_state_dict(
kwargs = {
'schedule': [{
'duration': '1ep',
'train_loader': test_cfg['train_loader'],
'dataset': test_cfg['train_loader']['dataset'],
}, {
'duration': '2ep',
'train_loader': test_cfg['train_loader'],
'dataset': test_cfg['train_loader']['dataset'],
}],
}

kwargs['duration'] = kwargs['schedule'].pop(0)['duration']

callback = build_callback(
'curriculum_learning',
kwargs=kwargs,
Expand Down
Loading