Skip to content

Commit

Permalink
Simplify CL API
Browse files Browse the repository at this point in the history
  • Loading branch information
b-chu committed Sep 3, 2024
1 parent 02802c5 commit d5c7eca
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 36 deletions.
41 changes: 21 additions & 20 deletions llmfoundry/callbacks/curriculum_learning_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import copy
import logging
from typing import Any
from typing import Any, Union

from composer import DataSpec
from composer.core import State, Time, TimeUnit, ensure_time
Expand Down Expand Up @@ -67,34 +67,32 @@ class CurriculumLearning(CallbackWithConfig):
def __init__(
self,
train_config: dict[str, Any],
duration: Union[Time, str, int],
schedule: list[dict[str, Any]],
):
# Ensure all duration units are in epochs or tokens and values are positive
self._schedule = schedule
if len(self._schedule) == 0:
raise ValueError('The schedule must have at least one datamix.')
for index, datamix in enumerate(self._schedule):
first_datamix = {
'duration': duration,
'dataset': train_config['train_loader']['dataset'],
}
self._schedule.insert(0, first_datamix)
for datamix in self._schedule:
self._validate_datamix(datamix)

if (
index == 0 and
train_config['train_loader'] != datamix['train_loader']
):
raise ValueError((
'The first datamix in the schedule must match the '
'train_loader in the train_config.'
))

self._schedule_index = 0
self.device_train_batch_size = train_config['device_train_batch_size']
self.tokenizer = None
self._train_loader_config: dict[str, Any] = train_config['train_loader']
self._device_train_batch_size = train_config['device_train_batch_size']
self._tokenizer = None

def init(self, state: State, logger: Logger):
del logger # unused

if not hasattr(state.model, 'tokenizer'):
raise ValueError('state.model must have a tokenizer attribute.')
self.tokenizer = state.model.tokenizer
self._tokenizer = state.model.tokenizer

def before_load(self, state: State, logger: Logger):
del logger # unused
Expand Down Expand Up @@ -151,8 +149,9 @@ def iteration_start(self, state: State, logger: Logger):
# which is stale
clean_stale_shared_memory()
datamix = copy.deepcopy(self._schedule[self._schedule_index])
self._train_loader_config['dataset'].update(datamix['dataset'])
data_spec = self._build_train_loader(
train_loader_config=datamix['train_loader'],
train_loader_config=self._train_loader_config,
logger=logger,
)
state.set_dataloader(
Expand Down Expand Up @@ -211,18 +210,20 @@ def _build_train_loader(
train_loader_config: dict[str, Any],
logger: Logger,
) -> DataSpec:
del logger # unused

from llmfoundry.data.dataloader import build_dataloader

# Copied from scripts/train/train.py
log.info(
f'Building train loader in CurriculumLearning callback for dataset {self._schedule_index}',
)
assert self.tokenizer is not None
assert self._tokenizer is not None
try:
return build_dataloader(
train_loader_config,
self.tokenizer,
self.device_train_batch_size,
self._tokenizer,
self._device_train_batch_size,
)
except BaseContextualError as e:
e.location = TrainDataLoaderLocation
Expand Down Expand Up @@ -260,5 +261,5 @@ def _validate_datamix(self, datamix: dict[str, Any]):
'Schedules can only be defined in terms of epochs or tokens.',
)

if 'train_loader' not in datamix:
raise ValueError('Each datamix must have a train_loader.')
if 'dataset' not in datamix:
raise ValueError('Each datamix must have a dataset.')
48 changes: 32 additions & 16 deletions tests/callbacks/test_curriculum_learning_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
[
(None, '1ep'),
({
'dataset': 'some_dataset',
'hf_name': 'some_dataset',
}, '1ep'),
(None, '10tok'),
(None, ''),
Expand All @@ -36,23 +36,29 @@ def test_curriculum_learning_callback_init(
):
test_cfg = _get_test_cfg()
test_cfg['train_loader'] = tiny_ft_dataloader_cfg
train_loader = test_cfg['train_loader'] if datamix is None else datamix
if datamix is None:
train_loader = test_cfg['train_loader']['dataset']
else:
train_loader = datamix
kwargs = {
'schedule': [{
'duration': duration,
'train_loader': train_loader,
'dataset': train_loader,
}, {
'duration': '2ep',
'train_loader': {},
'dataset': {},
}],
}

kwargs['duration'] = kwargs['schedule'].pop(0)['duration']

if duration == '':
del kwargs['schedule'][0]['duration']
if datamix is not None and len(datamix) == 0:
del kwargs['schedule'][0]['train_loader']
del kwargs['schedule'][0]['dataset']

context = nullcontext()
if datamix is not None or duration == '':
if (datamix is not None and len(datamix) == 0) or duration == '':
context = pytest.raises(ValueError)
with context:
callback = build_callback(
Expand Down Expand Up @@ -85,13 +91,15 @@ def test_curriculum_learning_callback_before_load(
kwargs = {
'schedule': [{
'duration': duration,
'train_loader': test_cfg['train_loader'],
'dataset': test_cfg['train_loader']['dataset'],
}, {
'duration': '2ep',
'train_loader': test_cfg['train_loader'],
'dataset': test_cfg['train_loader']['dataset'],
}],
}

kwargs['duration'] = kwargs['schedule'].pop(0)['duration']

callback = build_callback(
'curriculum_learning',
kwargs=kwargs,
Expand Down Expand Up @@ -123,13 +131,15 @@ def test_curriculum_learning_callback_after_load(build_tiny_mpt: Callable,):
kwargs = {
'schedule': [{
'duration': '1ep',
'train_loader': test_cfg['train_loader'],
'dataset': test_cfg['train_loader']['dataset'],
}, {
'duration': '2ep',
'train_loader': test_cfg['train_loader'],
'dataset': test_cfg['train_loader']['dataset'],
}],
}

kwargs['duration'] = kwargs['schedule'].pop(0)['duration']

callback = build_callback(
'curriculum_learning',
kwargs=kwargs,
Expand Down Expand Up @@ -168,13 +178,15 @@ def test_curriculum_learning_callback_iteration(
kwargs = {
'schedule': [{
'duration': '1ep',
'train_loader': test_cfg['train_loader'],
'dataset': test_cfg['train_loader']['dataset'],
}, {
'duration': '2ep',
'train_loader': test_cfg['train_loader'],
'dataset': test_cfg['train_loader']['dataset'],
}],
}

kwargs['duration'] = kwargs['schedule'].pop(0)['duration']

callback = build_callback(
'curriculum_learning',
kwargs=kwargs,
Expand Down Expand Up @@ -208,13 +220,15 @@ def test_curriculum_learning_callback_state_dict(build_tiny_mpt: Callable,):
kwargs = {
'schedule': [{
'duration': '1ep',
'train_loader': test_cfg['train_loader'],
'dataset': test_cfg['train_loader']['dataset'],
}, {
'duration': '2ep',
'train_loader': test_cfg['train_loader'],
'dataset': test_cfg['train_loader']['dataset'],
}],
}

kwargs['duration'] = kwargs['schedule'].pop(0)['duration']

callback = build_callback(
'curriculum_learning',
kwargs=kwargs,
Expand Down Expand Up @@ -249,13 +263,15 @@ def test_curriculum_learning_callback_load_state_dict(
kwargs = {
'schedule': [{
'duration': '1ep',
'train_loader': test_cfg['train_loader'],
'dataset': test_cfg['train_loader']['dataset'],
}, {
'duration': '2ep',
'train_loader': test_cfg['train_loader'],
'dataset': test_cfg['train_loader']['dataset'],
}],
}

kwargs['duration'] = kwargs['schedule'].pop(0)['duration']

callback = build_callback(
'curriculum_learning',
kwargs=kwargs,
Expand Down

0 comments on commit d5c7eca

Please sign in to comment.