diff --git a/llmfoundry/optim/scheduler.py b/llmfoundry/optim/scheduler.py index 9d3cb6be02..5598db28a1 100644 --- a/llmfoundry/optim/scheduler.py +++ b/llmfoundry/optim/scheduler.py @@ -16,23 +16,19 @@ def _raise_if_units_dont_match(time: Union[str, Time], t_max: Union[str, Time], name: str) -> None: - time = Time.from_input(time) - t_max = Time.from_input(t_max) + new_time = Time.from_input(time) + new_t_max = Time.from_input(t_max) - assert not isinstance(time, str) and not isinstance(t_max, str) - - if time.unit != t_max.unit: + if new_time.unit != new_t_max.unit: raise ValueError( - f'{name} (unit {time.unit=}) must match max_duration unit ({t_max.unit=}).' + f'{name} (unit {new_time.unit=}) must match max_duration unit ({new_t_max.unit=}).' ) def _raise_if_units_dur(time: Union[str, Time], name: str) -> None: - time = Time.from_input(time) - - assert not isinstance(time, str) + new_time = Time.from_input(time) - if time.unit == TimeUnit('dur'): + if new_time.unit == TimeUnit('dur'): raise ValueError(f'{name} cannot be in units of "dur".')