Skip to content

Commit

Permalink
Remove curriculum learning error when duration less than saved timest…
Browse files Browse the repository at this point in the history
…amp (#1406)

Co-authored-by: Saaketh Narayan <[email protected]>
  • Loading branch information
b-chu and snarayan21 authored Jul 29, 2024
1 parent 5c7e99b commit 6d5d016
Showing 1 changed file with 8 additions and 20 deletions.
28 changes: 8 additions & 20 deletions llmfoundry/callbacks/curriculum_learning_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,18 +128,17 @@ def after_load(self, state: State, logger: Logger):
self._validate_dataloader(state.train_dataloader)

# If checkpoint was saved before iteration was incremented, we need to increment it now
duration = self._schedule[self._schedule_index]['duration']
if ((
self._schedule[self._schedule_index]['duration'].unit
== TimeUnit.TOKEN and state.timestamp.token_in_iteration >=
self._schedule[self._schedule_index]['duration'].value
duration.unit == TimeUnit.TOKEN and
state.timestamp.token_in_iteration >= duration.value
) or (
self._schedule[self._schedule_index]['duration'].unit
== TimeUnit.EPOCH and state.timestamp.epoch_in_iteration >=
self._schedule[self._schedule_index]['duration'].value
duration.unit == TimeUnit.EPOCH and
state.timestamp.epoch_in_iteration >= duration.value
)):
log.warning((
'The CurriculumLearning callback has detected that the previous run did not correctly '
'increment the iteration.'
'The CurriculumLearning callback has detected that the '
'previous run did not correctly increment the iteration.'
))
self._schedule_index += 1
state.timestamp = state.timestamp.to_next_iteration()
Expand Down Expand Up @@ -199,24 +198,13 @@ def load_state_dict(self, state: dict[str, Any]):
f'Expected {saved_loader} but got {current_loader}',
))

# Ensure that the current datamix duration is greater than timestamp
# Ensure that the current datamix duration is in the correct units
duration = self._schedule[self._schedule_index]['duration']
if duration.unit != TimeUnit.TOKEN and duration.unit != TimeUnit.EPOCH:
raise ValueError((
f'Duration must be in terms of tokens or epochs, but got ',
f'{duration.unit}.',
))
if ((
duration.unit == TimeUnit.TOKEN and
duration > state['timestamp'].token_in_iteration
) or (
duration.unit == TimeUnit.EPOCH and
duration > state['timestamp'].epoch_in_iteration
)):
raise ValueError((
'The duration of the current datamix must be less or equal to '
'than the saved timestamp.'
))

def _build_train_loader(
self,
Expand Down

0 comments on commit 6d5d016

Please sign in to comment.