diff --git a/llmfoundry/callbacks/curriculum_learning_callback.py b/llmfoundry/callbacks/curriculum_learning_callback.py index 98a672f8db..449ab338bc 100644 --- a/llmfoundry/callbacks/curriculum_learning_callback.py +++ b/llmfoundry/callbacks/curriculum_learning_callback.py @@ -128,18 +128,17 @@ def after_load(self, state: State, logger: Logger): self._validate_dataloader(state.train_dataloader) # If checkpoint was saved before iteration was incremented, we need to increment it now + duration = self._schedule[self._schedule_index]['duration'] if (( - self._schedule[self._schedule_index]['duration'].unit - == TimeUnit.TOKEN and state.timestamp.token_in_iteration >= - self._schedule[self._schedule_index]['duration'].value + duration.unit == TimeUnit.TOKEN and + state.timestamp.token_in_iteration >= duration.value ) or ( - self._schedule[self._schedule_index]['duration'].unit - == TimeUnit.EPOCH and state.timestamp.epoch_in_iteration >= - self._schedule[self._schedule_index]['duration'].value + duration.unit == TimeUnit.EPOCH and + state.timestamp.epoch_in_iteration >= duration.value )): log.warning(( - 'The CurriculumLearning callback has detected that the previous run did not correctly ' - 'increment the iteration.' + 'The CurriculumLearning callback has detected that the ' + 'previous run did not correctly increment the iteration.' )) self._schedule_index += 1 state.timestamp = state.timestamp.to_next_iteration() @@ -199,24 +198,13 @@ def load_state_dict(self, state: dict[str, Any]): f'Expected {saved_loader} but got {current_loader}', )) - # Ensure that the current datamix duration is greater than timestamp + # Ensure that the current datamix duration is in the correct units duration = self._schedule[self._schedule_index]['duration'] if duration.unit != TimeUnit.TOKEN and duration.unit != TimeUnit.EPOCH: raise ValueError(( f'Duration must be in terms of tokens or epochs, but got ', f'{duration.unit}.', )) - if (( - duration.unit == TimeUnit.TOKEN and - duration > state['timestamp'].token_in_iteration - ) or ( - duration.unit == TimeUnit.EPOCH and - duration > state['timestamp'].epoch_in_iteration - )): - raise ValueError(( - 'The duration of the current datamix must be less or equal to ' - 'than the saved timestamp.' - )) def _build_train_loader( self,