diff --git a/llmfoundry/callbacks/curriculum_learning_callback.py b/llmfoundry/callbacks/curriculum_learning_callback.py index 449ab338bc..d6123addb3 100644 --- a/llmfoundry/callbacks/curriculum_learning_callback.py +++ b/llmfoundry/callbacks/curriculum_learning_callback.py @@ -9,7 +9,7 @@ import copy import logging -from typing import Any +from typing import Any, Union from composer import DataSpec from composer.core import State, Time, TimeUnit, ensure_time @@ -67,34 +67,32 @@ class CurriculumLearning(CallbackWithConfig): def __init__( self, train_config: dict[str, Any], + duration: Union[Time, str, int], schedule: list[dict[str, Any]], ): # Ensure all duration units are in epochs or tokens and values are positive self._schedule = schedule if len(self._schedule) == 0: raise ValueError('The schedule must have at least one datamix.') - for index, datamix in enumerate(self._schedule): + first_datamix = { + 'duration': duration, + 'dataset': train_config['train_loader']['dataset'], + } + self._schedule.insert(0, first_datamix) + for datamix in self._schedule: self._validate_datamix(datamix) - if ( - index == 0 and - train_config['train_loader'] != datamix['train_loader'] - ): - raise ValueError(( - 'The first datamix in the schedule must match the ' - 'train_loader in the train_config.' - )) - self._schedule_index = 0 - self.device_train_batch_size = train_config['device_train_batch_size'] - self.tokenizer = None + self._train_loader_config: dict[str, Any] = train_config['train_loader'] + self._device_train_batch_size = train_config['device_train_batch_size'] + self._tokenizer = None def init(self, state: State, logger: Logger): del logger # unused if not hasattr(state.model, 'tokenizer'): raise ValueError('state.model must have a tokenizer attribute.') - self.tokenizer = state.model.tokenizer + self._tokenizer = state.model.tokenizer def before_load(self, state: State, logger: Logger): del logger # unused @@ -151,8 +149,9 @@ def iteration_start(self, state: State, logger: Logger): # which is stale clean_stale_shared_memory() datamix = copy.deepcopy(self._schedule[self._schedule_index]) + self._train_loader_config['dataset'].update(datamix['dataset']) data_spec = self._build_train_loader( - train_loader_config=datamix['train_loader'], + train_loader_config=self._train_loader_config, logger=logger, ) state.set_dataloader( @@ -211,18 +210,20 @@ def _build_train_loader( train_loader_config: dict[str, Any], logger: Logger, ) -> DataSpec: + del logger # unused + from llmfoundry.data.dataloader import build_dataloader # Copied from scripts/train/train.py log.info( f'Building train loader in CurriculumLearning callback for dataset {self._schedule_index}', ) - assert self.tokenizer is not None + assert self._tokenizer is not None try: return build_dataloader( train_loader_config, - self.tokenizer, - self.device_train_batch_size, + self._tokenizer, + self._device_train_batch_size, ) except BaseContextualError as e: e.location = TrainDataLoaderLocation @@ -260,5 +261,5 @@ def _validate_datamix(self, datamix: dict[str, Any]): 'Schedules can only be defined in terms of epochs or tokens.', ) - if 'train_loader' not in datamix: - raise ValueError('Each datamix must have a train_loader.') + if 'dataset' not in datamix: + raise ValueError('Each datamix must have a dataset.') diff --git a/tests/callbacks/test_curriculum_learning_callback.py b/tests/callbacks/test_curriculum_learning_callback.py index 075698a4c0..0e6a6c1efe 100644 --- a/tests/callbacks/test_curriculum_learning_callback.py +++ b/tests/callbacks/test_curriculum_learning_callback.py @@ -22,7 +22,7 @@ [ (None, '1ep'), ({ - 'dataset': 'some_dataset', + 'hf_name': 'some_dataset', }, '1ep'), (None, '10tok'), (None, ''), @@ -36,23 +36,29 @@ def test_curriculum_learning_callback_init( ): test_cfg = _get_test_cfg() test_cfg['train_loader'] = tiny_ft_dataloader_cfg - train_loader = test_cfg['train_loader'] if datamix is None else datamix + if datamix is None: + train_loader = test_cfg['train_loader']['dataset'] + else: + train_loader = datamix kwargs = { 'schedule': [{ 'duration': duration, - 'train_loader': train_loader, + 'dataset': train_loader, }, { 'duration': '2ep', - 'train_loader': {}, + 'dataset': {}, }], } + + kwargs['duration'] = kwargs['schedule'].pop(0)['duration'] + if duration == '': del kwargs['schedule'][0]['duration'] if datamix is not None and len(datamix) == 0: - del kwargs['schedule'][0]['train_loader'] + del kwargs['schedule'][0]['dataset'] context = nullcontext() - if datamix is not None or duration == '': + if (datamix is not None and len(datamix) == 0) or duration == '': context = pytest.raises(ValueError) with context: callback = build_callback( @@ -85,13 +91,15 @@ def test_curriculum_learning_callback_before_load( kwargs = { 'schedule': [{ 'duration': duration, - 'train_loader': test_cfg['train_loader'], + 'dataset': test_cfg['train_loader']['dataset'], }, { 'duration': '2ep', - 'train_loader': test_cfg['train_loader'], + 'dataset': test_cfg['train_loader']['dataset'], }], } + kwargs['duration'] = kwargs['schedule'].pop(0)['duration'] + callback = build_callback( 'curriculum_learning', kwargs=kwargs, @@ -123,13 +131,15 @@ def test_curriculum_learning_callback_after_load(build_tiny_mpt: Callable,): kwargs = { 'schedule': [{ 'duration': '1ep', - 'train_loader': test_cfg['train_loader'], + 'dataset': test_cfg['train_loader']['dataset'], }, { 'duration': '2ep', - 'train_loader': test_cfg['train_loader'], + 'dataset': test_cfg['train_loader']['dataset'], }], } + kwargs['duration'] = kwargs['schedule'].pop(0)['duration'] + callback = build_callback( 'curriculum_learning', kwargs=kwargs, @@ -168,13 +178,15 @@ def test_curriculum_learning_callback_iteration( kwargs = { 'schedule': [{ 'duration': '1ep', - 'train_loader': test_cfg['train_loader'], + 'dataset': test_cfg['train_loader']['dataset'], }, { 'duration': '2ep', - 'train_loader': test_cfg['train_loader'], + 'dataset': test_cfg['train_loader']['dataset'], }], } + kwargs['duration'] = kwargs['schedule'].pop(0)['duration'] + callback = build_callback( 'curriculum_learning', kwargs=kwargs, @@ -208,13 +220,15 @@ def test_curriculum_learning_callback_state_dict(build_tiny_mpt: Callable,): kwargs = { 'schedule': [{ 'duration': '1ep', - 'train_loader': test_cfg['train_loader'], + 'dataset': test_cfg['train_loader']['dataset'], }, { 'duration': '2ep', - 'train_loader': test_cfg['train_loader'], + 'dataset': test_cfg['train_loader']['dataset'], }], } + kwargs['duration'] = kwargs['schedule'].pop(0)['duration'] + callback = build_callback( 'curriculum_learning', kwargs=kwargs, @@ -249,13 +263,15 @@ def test_curriculum_learning_callback_load_state_dict( kwargs = { 'schedule': [{ 'duration': '1ep', - 'train_loader': test_cfg['train_loader'], + 'dataset': test_cfg['train_loader']['dataset'], }, { 'duration': '2ep', - 'train_loader': test_cfg['train_loader'], + 'dataset': test_cfg['train_loader']['dataset'], }], } + kwargs['duration'] = kwargs['schedule'].pop(0)['duration'] + callback = build_callback( 'curriculum_learning', kwargs=kwargs,