diff --git a/llmfoundry/callbacks/curriculum_learning_callback.py b/llmfoundry/callbacks/curriculum_learning_callback.py index 449ab338bc..70e996e494 100644 --- a/llmfoundry/callbacks/curriculum_learning_callback.py +++ b/llmfoundry/callbacks/curriculum_learning_callback.py @@ -9,7 +9,8 @@ import copy import logging -from typing import Any +import warnings +from typing import Any, Optional, Union from composer import DataSpec from composer.core import State, Time, TimeUnit, ensure_time @@ -23,6 +24,7 @@ BaseContextualError, TrainDataLoaderLocation, ) +from llmfoundry.utils.warnings import VersionedDeprecationWarning log = logging.getLogger(__name__) @@ -32,19 +34,21 @@ class CurriculumLearning(CallbackWithConfig): """Starts an epoch with a different dataset when resuming from a checkpoint. + Example duration: + tok Example schedule: [ { 'duration': tok, - 'train_loader': , # matches top level train_loader + 'dataset': , }, { 'duration': tok, - 'train_loader': , + 'dataset': , }, { 'duration': tok, - 'train_loader': , + 'dataset': , ], ] @@ -53,48 +57,59 @@ class CurriculumLearning(CallbackWithConfig): being used. Note that this is the full train config and must contain the 'train_loader', 'device_train_batch_size', and 'tokenizer' keys. + duration (Union[Time, str, int], optional): The duration of the first datamix + (which corresponds to the train_loader). Defaults to None. schedule (list[dict[str, Any]]): The list of datamixes to use and their durations. Duration units must match max_duration and be in terms of a TimeUnit that is supported by Iteration. The duration values must be positive. There must be at least one datamix in the schedule. The - first datamix in the schedule must match the train_loader in the - train_config. On resumption, previously trained on datamixes and - durations cannot be changed. The duration of the current datamix - must be greater than the saved timestamp. The dataset must be a - StreamingDataset. + first datamix during training is not included in the schedule. On + resumption, previously trained on datamixes and durations cannot be + changed. The duration of the current datamix must be greater than + the saved timestamp. The dataset must be a StreamingDataset. """ def __init__( self, train_config: dict[str, Any], schedule: list[dict[str, Any]], + duration: Optional[Union[Time, str, int]] = None, ): + if duration is None: + warnings.warn( + VersionedDeprecationWarning( + 'Specifying the full schedule in the CurriculumLearning ' + + 'callback is deprecated. Please specify the duration of ' + + 'the first datamix separately and change the schedule ' + + 'use datasets instead of dataloaders.', + remove_version='0.15.0', + ), + ) + # Ensure all duration units are in epochs or tokens and values are positive self._schedule = schedule if len(self._schedule) == 0: raise ValueError('The schedule must have at least one datamix.') - for index, datamix in enumerate(self._schedule): + if duration is not None: + first_datamix = { + 'duration': duration, + 'dataset': train_config['train_loader']['dataset'], + } + self._schedule.insert(0, first_datamix) + for datamix in self._schedule: self._validate_datamix(datamix) - if ( - index == 0 and - train_config['train_loader'] != datamix['train_loader'] - ): - raise ValueError(( - 'The first datamix in the schedule must match the ' - 'train_loader in the train_config.' - )) - self._schedule_index = 0 - self.device_train_batch_size = train_config['device_train_batch_size'] - self.tokenizer = None + self._train_loader_config: dict[str, Any] = train_config['train_loader'] + self._device_train_batch_size = train_config['device_train_batch_size'] + self._tokenizer = None def init(self, state: State, logger: Logger): del logger # unused if not hasattr(state.model, 'tokenizer'): raise ValueError('state.model must have a tokenizer attribute.') - self.tokenizer = state.model.tokenizer + self._tokenizer = state.model.tokenizer def before_load(self, state: State, logger: Logger): del logger # unused @@ -151,8 +166,13 @@ def iteration_start(self, state: State, logger: Logger): # which is stale clean_stale_shared_memory() datamix = copy.deepcopy(self._schedule[self._schedule_index]) + train_loader_config = copy.deepcopy(self._train_loader_config) + if 'dataset' in datamix: + train_loader_config['dataset'].update(datamix['dataset']) + else: + train_loader_config = datamix['train_loader'] data_spec = self._build_train_loader( - train_loader_config=datamix['train_loader'], + train_loader_config=train_loader_config, logger=logger, ) state.set_dataloader( @@ -211,18 +231,20 @@ def _build_train_loader( train_loader_config: dict[str, Any], logger: Logger, ) -> DataSpec: + del logger # unused + from llmfoundry.data.dataloader import build_dataloader # Copied from scripts/train/train.py log.info( f'Building train loader in CurriculumLearning callback for dataset {self._schedule_index}', ) - assert self.tokenizer is not None + assert self._tokenizer is not None try: return build_dataloader( train_loader_config, - self.tokenizer, - self.device_train_batch_size, + self._tokenizer, + self._device_train_batch_size, ) except BaseContextualError as e: e.location = TrainDataLoaderLocation @@ -260,5 +282,5 @@ def _validate_datamix(self, datamix: dict[str, Any]): 'Schedules can only be defined in terms of epochs or tokens.', ) - if 'train_loader' not in datamix: - raise ValueError('Each datamix must have a train_loader.') + if 'train_loader' not in datamix and 'dataset' not in datamix: + raise ValueError('Each datamix must have a dataset.') diff --git a/tests/callbacks/test_curriculum_learning_callback.py b/tests/callbacks/test_curriculum_learning_callback.py index 075698a4c0..0e6a6c1efe 100644 --- a/tests/callbacks/test_curriculum_learning_callback.py +++ b/tests/callbacks/test_curriculum_learning_callback.py @@ -22,7 +22,7 @@ [ (None, '1ep'), ({ - 'dataset': 'some_dataset', + 'hf_name': 'some_dataset', }, '1ep'), (None, '10tok'), (None, ''), @@ -36,23 +36,29 @@ def test_curriculum_learning_callback_init( ): test_cfg = _get_test_cfg() test_cfg['train_loader'] = tiny_ft_dataloader_cfg - train_loader = test_cfg['train_loader'] if datamix is None else datamix + if datamix is None: + train_loader = test_cfg['train_loader']['dataset'] + else: + train_loader = datamix kwargs = { 'schedule': [{ 'duration': duration, - 'train_loader': train_loader, + 'dataset': train_loader, }, { 'duration': '2ep', - 'train_loader': {}, + 'dataset': {}, }], } + + kwargs['duration'] = kwargs['schedule'].pop(0)['duration'] + if duration == '': del kwargs['schedule'][0]['duration'] if datamix is not None and len(datamix) == 0: - del kwargs['schedule'][0]['train_loader'] + del kwargs['schedule'][0]['dataset'] context = nullcontext() - if datamix is not None or duration == '': + if (datamix is not None and len(datamix) == 0) or duration == '': context = pytest.raises(ValueError) with context: callback = build_callback( @@ -85,13 +91,15 @@ def test_curriculum_learning_callback_before_load( kwargs = { 'schedule': [{ 'duration': duration, - 'train_loader': test_cfg['train_loader'], + 'dataset': test_cfg['train_loader']['dataset'], }, { 'duration': '2ep', - 'train_loader': test_cfg['train_loader'], + 'dataset': test_cfg['train_loader']['dataset'], }], } + kwargs['duration'] = kwargs['schedule'].pop(0)['duration'] + callback = build_callback( 'curriculum_learning', kwargs=kwargs, @@ -123,13 +131,15 @@ def test_curriculum_learning_callback_after_load(build_tiny_mpt: Callable,): kwargs = { 'schedule': [{ 'duration': '1ep', - 'train_loader': test_cfg['train_loader'], + 'dataset': test_cfg['train_loader']['dataset'], }, { 'duration': '2ep', - 'train_loader': test_cfg['train_loader'], + 'dataset': test_cfg['train_loader']['dataset'], }], } + kwargs['duration'] = kwargs['schedule'].pop(0)['duration'] + callback = build_callback( 'curriculum_learning', kwargs=kwargs, @@ -168,13 +178,15 @@ def test_curriculum_learning_callback_iteration( kwargs = { 'schedule': [{ 'duration': '1ep', - 'train_loader': test_cfg['train_loader'], + 'dataset': test_cfg['train_loader']['dataset'], }, { 'duration': '2ep', - 'train_loader': test_cfg['train_loader'], + 'dataset': test_cfg['train_loader']['dataset'], }], } + kwargs['duration'] = kwargs['schedule'].pop(0)['duration'] + callback = build_callback( 'curriculum_learning', kwargs=kwargs, @@ -208,13 +220,15 @@ def test_curriculum_learning_callback_state_dict(build_tiny_mpt: Callable,): kwargs = { 'schedule': [{ 'duration': '1ep', - 'train_loader': test_cfg['train_loader'], + 'dataset': test_cfg['train_loader']['dataset'], }, { 'duration': '2ep', - 'train_loader': test_cfg['train_loader'], + 'dataset': test_cfg['train_loader']['dataset'], }], } + kwargs['duration'] = kwargs['schedule'].pop(0)['duration'] + callback = build_callback( 'curriculum_learning', kwargs=kwargs, @@ -249,13 +263,15 @@ def test_curriculum_learning_callback_load_state_dict( kwargs = { 'schedule': [{ 'duration': '1ep', - 'train_loader': test_cfg['train_loader'], + 'dataset': test_cfg['train_loader']['dataset'], }, { 'duration': '2ep', - 'train_loader': test_cfg['train_loader'], + 'dataset': test_cfg['train_loader']['dataset'], }], } + kwargs['duration'] = kwargs['schedule'].pop(0)['duration'] + callback = build_callback( 'curriculum_learning', kwargs=kwargs,