Skip to content

Commit

Permalink
better types
Browse files Browse the repository at this point in the history
  • Loading branch information
aspfohl committed Jan 10, 2024
1 parent f16e9da commit 9cab780
Showing 1 changed file with 6 additions and 10 deletions.
16 changes: 6 additions & 10 deletions llmfoundry/optim/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,19 @@

def _raise_if_units_dont_match(time: Union[str, Time], t_max: Union[str, Time],
name: str) -> None:
time = Time.from_input(time)
t_max = Time.from_input(t_max)
new_time = Time.from_input(time)
new_t_max = Time.from_input(t_max)

assert not isinstance(time, str) and not isinstance(t_max, str)

if time.unit != t_max.unit:
if new_time.unit != new_t_max.unit:
raise ValueError(
f'{name} (unit {time.unit=}) must match max_duration unit ({t_max.unit=}).'
f'{name} (unit {new_time.unit=}) must match max_duration unit ({new_t_max.unit=}).'
)


def _raise_if_units_dur(time: Union[str, Time], name: str) -> None:
time = Time.from_input(time)

assert not isinstance(time, str)
new_time = Time.from_input(time)

if time.unit == TimeUnit('dur'):
if new_time.unit == TimeUnit('dur'):
raise ValueError(f'{name} cannot be in units of "dur".')


Expand Down

0 comments on commit 9cab780

Please sign in to comment.