From 5980a44271bb7bfb2633190a68a95ef361e08747 Mon Sep 17 00:00:00 2001 From: root <23239305+b-chu@users.noreply.github.com> Date: Wed, 14 Feb 2024 21:25:14 +0000 Subject: [PATCH] Add iteration to TimeUnit commit-id:e4729f79 --- composer/core/time.py | 131 +++++++++++++++++++++++++++++++++++++++++- tests/test_time.py | 39 ++++++++++++- 2 files changed, 166 insertions(+), 4 deletions(-) diff --git a/composer/core/time.py b/composer/core/time.py index 90b3bfdb97..f9baa042ad 100644 --- a/composer/core/time.py +++ b/composer/core/time.py @@ -31,12 +31,14 @@ class TimeUnit(StringEnum): """Enum class to represent units of time for the training process. Attributes: + ITERATION (str): Iterations. EPOCH (str): Epochs. BATCH (str): Batches (i.e. number of optimization steps) SAMPLE (str): Samples. TOKEN (str): Tokens. Applicable for natural language processing (NLP) models. DURATION (str): Fraction of the training process complete, on ``[0.0, 1.0)`` """ + ITERATION = 'iter' EPOCH = 'ep' BATCH = 'ba' SAMPLE = 'sp' @@ -122,6 +124,20 @@ def __init__( raise TypeError(f'value {value} is of type {type(value)}. Units {unit} require integer values.') self._value, self._unit = value, TimeUnit(unit) + @classmethod + def from_iteration(cls, iteration: int) -> Time: + """Create a :class:`Time` with units of :attr:`TimeUnit.ITERATION`. + + Equivalent to ``Time(epoch, TimeUnit.EPOCH)``. + + Args: + epoch (int): Number of epochs. + + Returns: + Time: :class:`Time` instance, in epochs. + """ + return cls(iteration, TimeUnit.ITERATION) + @classmethod def from_epoch(cls, epoch: int) -> Time: """Create a :class:`Time` with units of :attr:`TimeUnit.EPOCH`. @@ -391,37 +407,48 @@ def from_timestring(cls, timestring: str) -> Time: class Timestamp(Serializable): """Timestamp represents a snapshot of the current training progress. - The timestamp measures training progress in terms of epochs, batches, samples, tokens, and wall clock time. + The timestamp measures training progress in terms of iterations, epochs, batches, samples, tokens, and wall clock time. Timestamps are not updated in-place. See the :doc:`Time Guide ` for more details on tracking time during training. Args: + iteration (int | Time[int], optional): The iteration. epoch (int | Time[int], optional): The epoch. batch (int | Time[int], optional): the batch. sample (int | Time[int], optional): The sample. token (int | Time[int], optional): The token. + epoch_in_iteration (int | Time[int], optional): The epoch in the iteration. batch_in_epoch (int | Time[int], optional): The batch in the epoch. sample_in_epoch (int | Time[int], optional): The sample in the epoch. token_in_epoch (int | Time[int], optional): The token in the epoch. total_wct (datetime.timedelta, optional): The total wall-clock duration. + iteration_wct (datetime.timedelta, optional): The wall-clock duration of the last iteration. epoch_wct (datetime.timedelta, optional): The wall-clock duration of the last epoch. batch_wct (datetime.timedelta, optional): The wall-clock duration of the last batch. """ def __init__( self, + iteration: Union[int, Time[int]] = 0, epoch: Union[int, Time[int]] = 0, batch: Union[int, Time[int]] = 0, sample: Union[int, Time[int]] = 0, token: Union[int, Time[int]] = 0, + epoch_in_iteration: Union[int, Time[int]] = 0, batch_in_epoch: Union[int, Time[int]] = 0, sample_in_epoch: Union[int, Time[int]] = 0, token_in_epoch: Union[int, Time[int]] = 0, total_wct: Optional[datetime.timedelta] = None, + iteration_wct: Optional[datetime.timedelta] = None, epoch_wct: Optional[datetime.timedelta] = None, batch_wct: Optional[datetime.timedelta] = None, ): + iteration = Time.from_input(iteration, TimeUnit.ITERATION) + if iteration.unit != TimeUnit.ITERATION: + raise ValueError(f'The `iteration` argument has units of {iteration.unit}; not {TimeUnit.ITERATION}.') + self._iteration = iteration + epoch = Time.from_input(epoch, TimeUnit.EPOCH) if epoch.unit != TimeUnit.EPOCH: raise ValueError(f'The `epoch` argument has units of {epoch.unit}; not {TimeUnit.EPOCH}.') @@ -442,6 +469,12 @@ def __init__( raise ValueError(f'The `token` argument has units of {token.unit}; not {TimeUnit.TOKEN}.') self._token = token + epoch_in_iteration = Time.from_input(epoch_in_iteration, TimeUnit.BATCH) + if epoch_in_iteration.unit != TimeUnit.BATCH: + raise ValueError((f'The `epoch_in_iteration` argument has units of {epoch_in_iteration.unit}; ' + f'not {TimeUnit.EPOCH}.')) + self._epoch_in_iteration = epoch_in_iteration + batch_in_epoch = Time.from_input(batch_in_epoch, TimeUnit.BATCH) if batch_in_epoch.unit != TimeUnit.BATCH: raise ValueError((f'The `batch_in_epoch` argument has units of {batch_in_epoch.unit}; ' @@ -464,6 +497,10 @@ def __init__( total_wct = datetime.timedelta(seconds=0) self._total_wct = total_wct + if iteration_wct is None: + iteration_wct = datetime.timedelta(seconds=0) + self._iteration_wct = iteration_wct + if epoch_wct is None: epoch_wct = datetime.timedelta(seconds=0) self._epoch_wct = epoch_wct @@ -474,14 +511,17 @@ def __init__( def state_dict(self) -> Dict[str, Any]: return { + 'iteration': self.iteration.value, 'epoch': self.epoch.value, 'batch': self.batch.value, 'sample': self.sample.value, 'token': self.token.value, + 'epoch_in_iteration': self.epoch_in_iteration.value, 'batch_in_epoch': self.batch_in_epoch.value, 'sample_in_epoch': self.sample_in_epoch.value, 'token_in_epoch': self.token_in_epoch.value, 'total_wct': self.total_wct, + 'iteration_wct': self.iteration_wct, 'epoch_wct': self.epoch_wct, 'batch_wct': self.batch_wct, } @@ -493,23 +533,28 @@ def get_state(self) -> Dict[str, Union[Time[int], datetime.timedelta]]: Dict[str, Union[Time[int], datetime.timedelta]]: All values of the timestamp object. """ return { + 'iteration': self.iteration.value, 'epoch': self.epoch, 'batch': self.batch, 'sample': self.sample, 'token': self.token, + 'epoch_in_iteration': self.epoch_in_iteration.value, 'batch_in_epoch': self.batch_in_epoch, 'sample_in_epoch': self.sample_in_epoch, 'token_in_epoch': self.token_in_epoch, 'total_wct': self.total_wct, + 'iteration_wct': self.iteration_wct, 'epoch_wct': self.epoch_wct, 'batch_wct': self.batch_wct, } def load_state_dict(self, state: Dict[str, Any]) -> None: + self._iteration = Time(state['iteration'], TimeUnit.ITERATION) self._epoch = Time(state['epoch'], TimeUnit.EPOCH) self._batch = Time(state['batch'], TimeUnit.BATCH) self._sample = Time(state['sample'], TimeUnit.SAMPLE) self._token = Time(state['token'], TimeUnit.TOKEN) + self._epoch_in_iteration = Time(state['epoch_in_iteration'], TimeUnit.EPOCH) self._batch_in_epoch = Time(state['batch_in_epoch'], TimeUnit.BATCH) self._sample_in_epoch = Time(state['sample_in_epoch'], TimeUnit.SAMPLE) self._token_in_epoch = Time(state['token_in_epoch'], TimeUnit.TOKEN) @@ -517,11 +562,18 @@ def load_state_dict(self, state: Dict[str, Any]) -> None: # Using conditional checks as not to break old checkpoints if 'total_wct' in state: self._total_wct = state['total_wct'] + if 'iteration_wct' in state: + self._iteration_wct = state['iteration_wct'] if 'epoch_wct' in state: self._epoch_wct = state['epoch_wct'] if 'batch_wct' in state: self._batch_wct = state['batch_wct'] + @property + def iteration(self) -> Time[int]: + """The total iteration count.""" + return self._iteration + @property def epoch(self) -> Time[int]: """The total epoch count.""" @@ -541,6 +593,11 @@ def sample(self) -> Time[int]: def token(self) -> Time[int]: """The total token count.""" return self._token + + @property + def epoch_in_iteration(self) -> Time[int]: + """The epoch count in the current iteration (resets at 0 at the beginning of every iteration).""" + return self._epoch_in_iteration @property def batch_in_epoch(self) -> Time[int]: @@ -562,6 +619,11 @@ def total_wct(self) -> datetime.timedelta: """The wall-clock duration (in seconds) from the beginning of training.""" return self._total_wct + @property + def iteration_wct(self) -> datetime.timedelta: + """The wall-clock duration (in seconds) for the current iteration.""" + return self._iteration_wct + @property def epoch_wct(self) -> datetime.timedelta: """The wall-clock duration (in seconds) for the current epoch.""" @@ -582,6 +644,8 @@ def get(self, unit: Union[str, TimeUnit]) -> Time[int]: Time: The current time, in the specified unit. """ unit = TimeUnit(unit) + if unit == TimeUnit.ITERATION: + return self.iteration if unit == TimeUnit.EPOCH: return self.epoch if unit == TimeUnit.BATCH: @@ -678,6 +742,7 @@ def to_next_batch( ... token = timestamp.token + tokens, ... token_in_epoch=timestamp.token_in_epoch + tokens, ... total_wct=timestamp.total_wct + duration, + ... iteration_wct=timestamp.iteration_wct + duration, ... epoch_wct=timestamp.epoch_wct + duration, ... batch_wct=duration, ... ) @@ -705,6 +770,7 @@ def to_next_batch( token=self.token + tokens, token_in_epoch=self.token_in_epoch + tokens, total_wct=self.total_wct + duration, + iteration_wct=self.iteration_wct + duration, epoch_wct=self.epoch_wct + duration, batch_wct=duration, ) @@ -729,10 +795,12 @@ def to_next_epoch( >>> timestamp.copy( ... epoch=timestamp.epoch + 1, + ... epoch_in_iteration=timestamp.epoch_in_iteration + 1, ... batch_in_epoch=0, ... sample_in_epoch=0, ... token_in_epoch=0, ... total_wct=timestamp.total_wct + duration, + ... iteration_wct=timestamp.iteration_wct + duration, ... epoch_wct=datetime.timedelta(seconds=0), ... batch_wct=datetime.timedelta(seconds=0), ... ) @@ -743,24 +811,74 @@ def to_next_epoch( duration = datetime.timedelta(seconds=0) return self.copy( epoch=self.epoch + 1, + epoch_in_iteration=self.epoch_in_iteration + 1, + batch_in_epoch=0, + sample_in_epoch=0, + token_in_epoch=0, + total_wct=self.total_wct + duration, + iteration_wct=self.iteration_wct + duration, + epoch_wct=datetime.timedelta(seconds=0), + batch_wct=datetime.timedelta(seconds=0), + ) + + def to_next_iteration( + self, + duration: Optional[datetime.timedelta] = None, + ): + """Create a new :class:`.Timestamp`, advanced to the next iteration. + + Equivalent to: + + .. testsetup:: + + from composer.core.time import Timestamp + import datetime + + timestamp = Timestamp() + + .. doctest:: + + >>> timestamp.copy( + ... iteration=timestamp.iteration + 1, + ... epoch_in_iteration=0, + ... batch_in_epoch=0, + ... sample_in_epoch=0, + ... token_in_epoch=0, + ... total_wct=timestamp.total_wct + duration, + ... iteration_wct=datetime.timedelta(seconds=0), + ... epoch_wct=datetime.timedelta(seconds=0), + ... batch_wct=datetime.timedelta(seconds=0), + ... ) + Timestamp(...) + + """ + if duration is None: + duration = datetime.timedelta(seconds=0) + return self.copy( + iteration=self.iteration + 1, + epoch_in_iteration=0, batch_in_epoch=0, sample_in_epoch=0, token_in_epoch=0, total_wct=self.total_wct + duration, + iteration_wct=datetime.timedelta(seconds=0), epoch_wct=datetime.timedelta(seconds=0), batch_wct=datetime.timedelta(seconds=0), ) def copy( self, + iteration: Optional[Union[int, Time[int]]] = None, epoch: Optional[Union[int, Time[int]]] = None, batch: Optional[Union[int, Time[int]]] = None, sample: Optional[Union[int, Time[int]]] = None, token: Optional[Union[int, Time[int]]] = None, + epoch_in_iteration: Optional[Union[int, Time[int]]] = None, batch_in_epoch: Optional[Union[int, Time[int]]] = None, sample_in_epoch: Optional[Union[int, Time[int]]] = None, token_in_epoch: Optional[Union[int, Time[int]]] = None, total_wct: Optional[datetime.timedelta] = None, + iteration_wct: Optional[datetime.timedelta] = None, epoch_wct: Optional[datetime.timedelta] = None, batch_wct: Optional[datetime.timedelta] = None, ) -> Timestamp: @@ -769,42 +887,53 @@ def copy( Any specified values will override the existing values in the returned copy. Args: + iteration (int | Time[int], optional): The iteration. epoch (int | Time[int], optional): The epoch. batch (int | Time[int], optional): the batch. sample (int | Time[int], optional): The sample. token (int | Time[int], optional): The token. + epoch_in_iteration (int | Time[int], optional): The epoch in the iteration. batch_in_epoch (int | Time[int], optional): The batch in the epoch. sample_in_epoch (int | Time[int], optional): The sample in the epoch. token_in_epoch (int | Time[int], optional): The token in the epoch. total_wct (datetime.timedelta, optional): The elapsed duration from the beginning of training. + iteration_wct (datetime.timedelta, optional): The wall-clock duration of the last iteration. + epoch_wct (datetime.timedelta, optional): The wall-clock duration of the last epoch. + batch_wct (datetime.timedelta, optional): The wall-clock duration of the last batch. Returns: Timestamp: A new timestamp instance, created from a copy, but with any specified values overriding the existing values. """ return Timestamp( + iteration=iteration if iteration is not None else self.iteration, epoch=epoch if epoch is not None else self.epoch, batch=batch if batch is not None else self.batch, sample=sample if sample is not None else self.sample, token=token if token is not None else self.token, + epoch_in_iteration=epoch_in_iteration if epoch_in_iteration is not None else self.epoch_in_iteration, batch_in_epoch=batch_in_epoch if batch_in_epoch is not None else self.batch_in_epoch, sample_in_epoch=sample_in_epoch if sample_in_epoch is not None else self.sample_in_epoch, token_in_epoch=token_in_epoch if token_in_epoch is not None else self.token_in_epoch, total_wct=total_wct if total_wct is not None else self.total_wct, + iteration_wct=iteration_wct if iteration_wct is not None else self.iteration_wct, epoch_wct=epoch_wct if epoch_wct is not None else self.epoch_wct, batch_wct=batch_wct if batch_wct is not None else self.batch_wct, ) def __repr__(self) -> str: return (f'Timestamp(' + f'iteration={int(self.iteration)}, ' f'epoch={int(self.epoch)}, ' f'batch={int(self.batch)}, ' f'sample={int(self.sample)}, ' f'token={int(self.token)}, ' + f'epoch_in_iteration={int(self.epoch_in_iteration)}, ' f'batch_in_epoch={int(self.batch_in_epoch)}, ' f'sample_in_epoch={int(self.sample_in_epoch)}, ' f'token_in_epoch={int(self.token_in_epoch)}, ' f'total_wct={repr(self.total_wct)}, ' + f'iteration_wct={repr(self.iteration_wct)}, ' f'epoch_wct={repr(self.epoch_wct)}, ' f'batch_wct={repr(self.batch_wct)}' ')') diff --git a/tests/test_time.py b/tests/test_time.py index 611bf83f72..7898f7a6d0 100644 --- a/tests/test_time.py +++ b/tests/test_time.py @@ -9,6 +9,7 @@ @pytest.mark.parametrize('time_string,expected_value,expected_unit', [ + ['2iter', 2, TimeUnit.ITERATION], ['1ep', 1, TimeUnit.EPOCH], ['2ba', 2, TimeUnit.BATCH], ['3e10sp', 3 * 10**10, TimeUnit.SAMPLE], @@ -25,6 +26,7 @@ def test_time_parse(time_string: str, expected_value: int, expected_unit: TimeUn @pytest.mark.parametrize('expected_timestring,time', [ + ['2iter', Time(2, TimeUnit.ITERATION)], ['1ep', Time(1, TimeUnit.EPOCH)], ['2ba', Time(2, TimeUnit.BATCH)], ['3sp', Time(3, TimeUnit.SAMPLE)], @@ -136,7 +138,7 @@ def test_timestamp_update(): assert timestamp is not timestamp_2 -def test_timestamp_to_next_batch_epoch(): +def test_timestamp_to_next_batch_epoch_iteration(): timestamp = Timestamp() # Step batch 0 in epoch 0 timestamp = timestamp.to_next_batch(10, 20, datetime.timedelta(seconds=5)) @@ -168,6 +170,7 @@ def test_timestamp_to_next_batch_epoch(): timestamp = timestamp.to_next_batch(5, 0, datetime.timedelta(seconds=10)) assert timestamp.epoch == 1 assert timestamp.batch == 2 + assert timestamp.epoch_in_iteration == 1 assert timestamp.batch_in_epoch == 1 assert timestamp.sample == 15 assert timestamp.sample_in_epoch == 5 @@ -194,6 +197,36 @@ def test_timestamp_to_next_batch_epoch(): timestamp = timestamp.to_next_epoch() assert timestamp.epoch == 2 assert timestamp.batch == 3 + assert timestamp.epoch_in_iteration == 2 + assert timestamp.batch_in_epoch == 0 + assert timestamp.sample == 20 + assert timestamp.sample_in_epoch == 0 + assert timestamp.token == 21 + assert timestamp.token_in_epoch == 0 + assert timestamp.total_wct == datetime.timedelta(seconds=30) + assert timestamp.epoch_wct == datetime.timedelta(seconds=0) + assert timestamp.batch_wct == datetime.timedelta(seconds=0) + + # Step batch 0 in epoch 2 + timestamp = timestamp.to_next_batch(5, 1, datetime.timedelta(seconds=10)) + assert timestamp.epoch == 2 + assert timestamp.batch == 4 + assert timestamp.epoch_in_iteration == 2 + assert timestamp.batch_in_epoch == 1 + assert timestamp.sample == 25 + assert timestamp.sample_in_epoch == 5 + assert timestamp.token == 22 + assert timestamp.token_in_epoch == 1 + assert timestamp.total_wct == datetime.timedelta(seconds=40) + assert timestamp.epoch_wct == datetime.timedelta(seconds=10) + assert timestamp.batch_wct == datetime.timedelta(seconds=10) + + # Finish iteration 0 + timestamp = timestamp.to_next_iteration() + assert timestamp.iteration == 1 + assert timestamp.epoch == 2 + assert timestamp.batch == 3 + assert timestamp.epoch_in_iteration == 0 assert timestamp.batch_in_epoch == 0 assert timestamp.sample == 20 assert timestamp.sample_in_epoch == 0 @@ -209,12 +242,12 @@ def test_timestamp_repr(): assert timestamp == eval(repr(timestamp)) -@pytest.mark.parametrize('time_string', ['1.5ep', '2.1ba', '3.2sp', '3.4tok']) +@pytest.mark.parametrize('time_string', ['1.1iter', '1.5ep', '2.1ba', '3.2sp', '3.4tok']) def test_timestep_bad_strings(time_string: str): with pytest.raises(TypeError): Time.from_timestring(time_string) -@pytest.mark.parametrize('time_string', ['0.5dur', '2.0ep', '3.000ba', '030.0sp']) +@pytest.mark.parametrize('time_string', ['0.5dur', '1.0iter', '2.0ep', '3.000ba', '030.0sp']) def test_timestep_valid_strings(time_string: str): Time.from_timestring(time_string)