From 4ee56bd600ae09319a4738845b525df89005f4e7 Mon Sep 17 00:00:00 2001 From: root <23239305+b-chu@users.noreply.github.com> Date: Wed, 5 Jun 2024 21:28:42 +0000 Subject: [PATCH] Add curriculum learning callback --- .../callbacks/curriculum_learning_callback.py | 248 +++++++++++++----- llmfoundry/models/mpt/modeling_mpt.py | 7 +- llmfoundry/utils/__init__.py | 4 + .../test_curriculum_learning_callback.py | 2 +- 4 files changed, 191 insertions(+), 70 deletions(-) diff --git a/llmfoundry/callbacks/curriculum_learning_callback.py b/llmfoundry/callbacks/curriculum_learning_callback.py index 961bf1cae1..24f754d974 100644 --- a/llmfoundry/callbacks/curriculum_learning_callback.py +++ b/llmfoundry/callbacks/curriculum_learning_callback.py @@ -8,22 +8,26 @@ """ import logging -from typing import Any, Dict +import copy +from typing import Any -from composer.core import State -from composer.loggers import Logger +from composer import DataSpec +from composer.core import State, Time, TimeUnit, ensure_time +from composer.loggers import Logger, MosaicMLLogger +from composer.trainer.trainer import _get_initial_device_train_microbatch_size from streaming import StreamingDataset +from streaming.base.util import clean_stale_shared_memory from torch.utils.data import DataLoader from llmfoundry.interfaces import CallbackWithConfig -from llmfoundry.utils.warnings import experimental_class +from llmfoundry.utils.config_utils import calculate_batch_size_info +from llmfoundry.utils.exceptions import BaseContextualError, TrainDataLoaderLocation log = logging.getLogger(__name__) __all__ = ['CurriculumLearning'] -@experimental_class('CurriculumLearning callback') class CurriculumLearning(CallbackWithConfig): """Starts an epoch with a different dataset when resuming from a checkpoint. @@ -34,20 +38,183 @@ class CurriculumLearning(CallbackWithConfig): dataset_index (int): The index of the dataset currently being used. """ - def __init__(self, train_config: Dict, dataset_index: int): - self.dataset_index = dataset_index - self.saved_dataset_index = 0 - self.all_dataset_configs = [] - self.current_dataset_state = {} - # The current dataset config is resolved and passed in train.py - self.current_dataset_config = train_config['train_loader'] + def __init__( + self, train_config: dict[str, Any], duration: str | int | Time, + schedule: list[dict[str, Any]] + ): + from llmfoundry.utils.builders import build_tokenizer + non_positive_error = ValueError('The duration must be positive.') + unit_error = ValueError( + 'Schedules can only be defined in terms of epochs or tokens.' + ) + + # Ensure all duration values are positive + # Ensure all duration units are in epochs or tokens + self._duration = ensure_time(duration, TimeUnit.EPOCH) + if self._duration.value <= 0: + raise non_positive_error + if self._duration.unit != TimeUnit.EPOCH and self._duration.unit != TimeUnit.TOKEN: + raise unit_error + + self._schedule = schedule + for datamix in self._schedule: + assert 'duration' in datamix, 'Each datamix must have a duration.' + datamix['duration'] = ensure_time( + datamix['duration'], TimeUnit.EPOCH + ) + if datamix['duration'].value <= 0: + raise non_positive_error + if datamix['duration'].unit != TimeUnit.EPOCH and datamix[ + 'duration'].unit != TimeUnit.TOKEN: + raise unit_error + assert 'train_loader' in datamix, 'Each datamix must have a train_loader.' + + self._schedule_index = -1 + + # Copied from llmfoundry/utils/config_utils.py + self.device_train_batch_size, _, _ = calculate_batch_size_info( + train_config['global_train_batch_size'], + train_config['device_train_microbatch_size'], + data_replication_degree=1, + ) + + # Copied from scripts/train/train.py + tokenizer_name = train_config['tokenizer']['name'] + tokenizer_kwargs = train_config['tokenizer'].get('kwargs', {}) + self.tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) def before_load(self, state: State, logger: Logger): del logger - # Save the current dataset state so we can restore it correctly - # if we are resuming with a new dataset. - train_loader = state.train_dataloader + # Ensure all duration units are the same as max_duration + units_match = True + assert state.max_duration is not None, 'max_duration should have beeen set.' + if self._duration.unit != state.max_duration.unit: + units_match = False + for datamix in self._schedule: + if datamix['duration'].unit != state.max_duration.unit: + units_match = False + if not units_match: + raise ValueError(( + 'All durations in the schedule must have the same units as ' + 'the max_duration.' + )) + + # Ensure schedule duration is greater than max_duration + schedule_duration = self._duration + for datamix in self._schedule: + assert isinstance(datamix['duration'], Time) + schedule_duration += datamix['duration'] + if schedule_duration < state.max_duration: + raise ValueError(( + 'The sum of all durations in the schedule must be greater than ' + 'or equal to the max_duration.' + )) + + self._validate_dataloader(state.train_dataloader) + + def after_load(self, state: State, logger: Logger): + del logger + + self._validate_dataloader(state.train_dataloader) + + # Check if adding a new datamix to a run that didn't use this callback + if self._schedule_index == -1 and state.timestamp >= self._duration: + self._schedule_index = 0 + state.timestamp = state.timestamp.to_next_iteration() + # If checkpoint was saved before iteration was incremented, we need to increment it now + elif (( + self._schedule[self._schedule_index]['duration'].unit + == TimeUnit.TOKEN and state.timestamp.token_in_iteration + >= self._schedule[self._schedule_index]['duration'].value + ) or ( + self._schedule[self._schedule_index]['duration'].unit + == TimeUnit.EPOCH and state.timestamp.epoch_in_iteration + >= self._schedule[self._schedule_index]['duration'].value + )): + log.warning(( + 'The CurriculumLearning callback has detected that the previous run did not correctly ' + 'increment the iteration.' + )) + self._schedule_index += 1 + state.timestamp = state.timestamp.to_next_iteration() + + def iteration_start(self, state: State, logger: Logger): + # Reset and initialize state train dataloader + log.warning( + 'trainer._train_data_spec should be updated whenever the dataloader is updated' + ) + + # Swap the dataset if starting a new iteration that's not the original datamix + if self._schedule_index >= 0: + clean_stale_shared_memory() + data_spec = self._build_train_loader( + train_loader_config=copy.deepcopy( + self._schedule[self._schedule_index] + )['train_loader'], + logger=logger, + ) + state.set_dataloader( + dataloader=data_spec.dataloader, + dataloader_label='train', + ) + # state.train_dataloader = state.dataloader + state.device_train_microbatch_size = _get_initial_device_train_microbatch_size( + state.device_train_microbatch_size, + state.auto_microbatching, + state.train_dataloader, + ) + self._validate_dataloader(state.train_dataloader) + + # Set the length of the new iteration + if self._schedule_index == -1: + state._iteration_length = self._duration + else: + state._iteration_length = self._schedule[self._schedule_index + ]['duration'] + + def iteration_end(self, state: State, logger: Logger): + del state, logger # unused + + self._schedule_index += 1 + + def state_dict(self): + return { + 'duration': self._duration, + 'schedule': self._schedule, + 'schedule_index': self._schedule_index, + } + + def load_state_dict(self, state: dict[str, Any]): + # Ensure that the schedule has not changed on already trained datamixes + assert self._duration == state['duration'] + for idx in range(state['schedule_index'] + 1): + assert self._schedule[idx] == state['schedule'][idx] + + self._schedule_index = state['schedule_index'] + + def _build_train_loader( + self, train_loader_config: dict[str, Any], logger: Logger + ) -> DataSpec: + from llmfoundry.data.dataloader import build_dataloader + # Copied from scripts/train/train.py + log.info( + f'Building train loader in CurriculumLearning callback for dataset {self._schedule_index}' + ) + try: + return build_dataloader( + train_loader_config, + self.tokenizer, + self.device_train_batch_size, + ) + except BaseContextualError as e: + for destination in logger.destinations: + if isinstance(destination, MosaicMLLogger): + e.location = TrainDataLoaderLocation + destination.log_exception(e) + raise e + + def _validate_dataloader(self, train_loader: Any): # Check if we are using a DataLoader and StreamingDataset if not isinstance(train_loader, DataLoader): raise ValueError( @@ -61,54 +228,3 @@ def before_load(self, state: State, logger: Logger): f'because it requires loading and saving dataset state. ', f'Instead, got a dataset of type {type(dataset)}', ) - assert isinstance(dataset, StreamingDataset) - # Save the current dataset state so we can restore it if needed. - self.current_dataset_state = dataset.state_dict( # type: ignore - num_samples=0, from_beginning=False) - - def after_load(self, state: State, logger: Logger): - del logger - - # As saved_dataset_index is loaded from state_dict, this only runs when - # a user explicitly increments the dataset_index and not on any other - # resumption, including autoresume. - train_loader = state._train_dataloader - assert isinstance( - train_loader, - DataLoader, - ), 'CurriculumLearning callback requires a DataLoader.' - dataset = train_loader.dataset - assert isinstance( - dataset, - StreamingDataset, - ), 'CurriculumLearning callback requires a StreamingDataset.' - if self.saved_dataset_index < self.dataset_index: - # Ignore the dataset state that was read in from the checkpoint, and - # replace with the new dataset state. This preserves resumption info. - if self.current_dataset_state['epoch'] < 0: - # Make sure the epoch in the loaded state dict is not negative. - # Since `__iter__` has not yet been called on the dataset, the - # epoch index in the dataset will still be -1. We need to ensure - # that we set the epoch correctly to 0 in this case. - self.current_dataset_state['epoch'] = 0 - dataset.load_state_dict( # type: ignore - self.current_dataset_state) - # Start a new epoch since we are using a new dataset. - # This will also reset the sample_in_epoch written to checkpoint, - # making sure that subsequent resumptions proceed correctly. - state.timestamp = state.timestamp.to_next_epoch() - # Append the new dataset config to the list of all dataset configs. - self.all_dataset_configs.append(self.current_dataset_config) - elif self.dataset_index == 0 and len(self.all_dataset_configs) == 0: - # Make sure to track our current dataset config if we are just starting training. - self.all_dataset_configs.append(self.current_dataset_config) - - def state_dict(self): - return { - 'dataset_index': self.dataset_index, - 'all_dataset_configs': self.all_dataset_configs, - } - - def load_state_dict(self, state: Dict[str, Any]): - self.saved_dataset_index = state.get('dataset_index', 0) - self.all_dataset_configs = state.get('all_dataset_configs', []) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 9d18799e93..195c617972 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -580,8 +580,9 @@ def forward( 'sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.', ) - elif (self.attn_uses_sequence_id is - False) and (sequence_id is not None): + elif ( + self.attn_uses_sequence_id is False and sequence_id is not None + ): warnings.warn( 'MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + @@ -1092,7 +1093,7 @@ def __init__( additional_train_metrics = additional_train_metrics or [] - model = self.model_class(self.config_class(**kwargs),) + model = self.model_class(self.config_class(**kwargs)) use_train_metrics = use_train_metrics train_metric_names = DEFAULT_CAUSAL_LM_TRAIN_METRICS + additional_train_metrics diff --git a/llmfoundry/utils/__init__.py b/llmfoundry/utils/__init__.py index dd43efcdd7..f5d94bd7c2 100644 --- a/llmfoundry/utils/__init__.py +++ b/llmfoundry/utils/__init__.py @@ -2,9 +2,11 @@ # SPDX-License-Identifier: Apache-2.0 from llmfoundry.utils.builders import ( + add_metrics_to_eval_loaders, build_algorithm, build_callback, build_composer_model, + build_eval_loaders, build_evaluators, build_icl_data_and_gauntlet, build_icl_evaluators, @@ -60,8 +62,10 @@ ) __all__ = [ + 'add_metrics_to_eval_loaders', 'build_algorithm', 'build_callback', + 'build_eval_loaders', 'build_evaluators', 'build_icl_data_and_gauntlet', 'build_icl_evaluators', diff --git a/tests/callbacks/test_curriculum_learning_callback.py b/tests/callbacks/test_curriculum_learning_callback.py index bbdbf3d691..807f640465 100644 --- a/tests/callbacks/test_curriculum_learning_callback.py +++ b/tests/callbacks/test_curriculum_learning_callback.py @@ -5,7 +5,7 @@ def test_curriculum_learning_callback_builds(): - kwargs = {'dataset_index': 0} + kwargs = {'duration': '1ep'} callback = build_callback( 'curriculum_learning', kwargs=kwargs,