Skip to content

Commit

Permalink
Remove foundry time wrangling
Browse files Browse the repository at this point in the history
  • Loading branch information
aspfohl committed Jan 10, 2024
1 parent ddba5c8 commit 8f81a3c
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 32 deletions.
26 changes: 8 additions & 18 deletions llmfoundry/callbacks/async_eval_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,31 +158,21 @@ def get_eval_parameters(

def validate_interval(interval: Union[str, int, Time],
save_interval: Union[str, int, Time]) -> Time:
if isinstance(save_interval, str):
new_save_interval: Time = Time.from_timestring(save_interval)
elif isinstance(save_interval, int):
new_save_interval: Time = Time(save_interval, TimeUnit.EPOCH)
else:
new_save_interval: Time = save_interval

if isinstance(interval, str):
result: Time = Time.from_timestring(interval)
elif isinstance(interval, int):
result: Time = Time(interval, TimeUnit.EPOCH)
else:
result: Time = interval

if new_save_interval.unit != result.unit:

new_save_interval = Time.from_input(save_interval, TimeUnit.EPOCH)
async_interval = Time.from_input(interval, TimeUnit.EPOCH)

if new_save_interval.unit != async_interval.unit:
raise ValueError(
'Save interval and async eval interval must be in the same unit')
if result < new_save_interval:
if async_interval < new_save_interval:
raise ValueError(
'Async eval interval must be equal or greater (less frequent) than save interval'
)
if result.value % new_save_interval.value != 0:
if async_interval.value % new_save_interval.value != 0:
raise ValueError(
'Async eval interval must be a multiple of save interval')
return result
return async_interval


class AsyncEval(Callback):
Expand Down
9 changes: 2 additions & 7 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,9 @@ def __init__(
self.huggingface_folder_name_fstr = os.path.join(
'huggingface', huggingface_folder_name)

if isinstance(save_interval, str):
save_interval = Time.from_timestring(save_interval)
if isinstance(save_interval, int):
save_interval = Time(save_interval, TimeUnit.EPOCH)

self.save_interval: Time = save_interval
self.save_interval: Time = Time.from_input(save_interval, TimeUnit.EPOCH)
self.check_interval = create_interval_scheduler(
save_interval, include_end_of_training=True)
self.save_interval, include_end_of_training=True)
self.remote_ud = maybe_create_remote_uploader_downloader_from_uri(
save_folder, loggers=[])
if self.remote_ud is not None:
Expand Down
13 changes: 6 additions & 7 deletions llmfoundry/optim/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,19 @@

def _raise_if_units_dont_match(time: Union[str, Time], t_max: Union[str, Time],
name: str) -> None:
if isinstance(time, str):
time = Time.from_timestring(time)
if isinstance(t_max, str):
t_max = Time.from_timestring(t_max)
time = Time.from_input(time)
t_max = Time.from_input(t_max)

assert not isinstance(time, str) and not isinstance(t_max, str)

if time.unit != t_max.unit:
raise ValueError(f'{time.unit=} does not match {t_max.unit=}.')
raise ValueError(
f'{name} (unit {time.unit=}) must match max_duration unit ({t_max.unit=}).'
)


def _raise_if_units_dur(time: Union[str, Time], name: str) -> None:
if isinstance(time, str):
time = Time.from_timestring(time)
time = Time.from_input(time)

assert not isinstance(time, str)

Expand Down

0 comments on commit 8f81a3c

Please sign in to comment.