From 8f81a3cc1023f25d52f28df9a2be0b2c272296a3 Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Wed, 10 Jan 2024 19:35:24 +0000 Subject: [PATCH] Remove foundry time wrangling --- llmfoundry/callbacks/async_eval_callback.py | 26 +++++++-------------- llmfoundry/callbacks/hf_checkpointer.py | 9 ++----- llmfoundry/optim/scheduler.py | 13 +++++------ 3 files changed, 16 insertions(+), 32 deletions(-) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index 8352a9e283..4227448d87 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -158,31 +158,21 @@ def get_eval_parameters( def validate_interval(interval: Union[str, int, Time], save_interval: Union[str, int, Time]) -> Time: - if isinstance(save_interval, str): - new_save_interval: Time = Time.from_timestring(save_interval) - elif isinstance(save_interval, int): - new_save_interval: Time = Time(save_interval, TimeUnit.EPOCH) - else: - new_save_interval: Time = save_interval - - if isinstance(interval, str): - result: Time = Time.from_timestring(interval) - elif isinstance(interval, int): - result: Time = Time(interval, TimeUnit.EPOCH) - else: - result: Time = interval - - if new_save_interval.unit != result.unit: + + new_save_interval = Time.from_input(save_interval, TimeUnit.EPOCH) + async_interval = Time.from_input(interval, TimeUnit.EPOCH) + + if new_save_interval.unit != async_interval.unit: raise ValueError( 'Save interval and async eval interval must be in the same unit') - if result < new_save_interval: + if async_interval < new_save_interval: raise ValueError( 'Async eval interval must be equal or greater (less frequent) than save interval' ) - if result.value % new_save_interval.value != 0: + if async_interval.value % new_save_interval.value != 0: raise ValueError( 'Async eval interval must be a multiple of save interval') - return result + return async_interval class AsyncEval(Callback): diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 491d510188..06fb541aec 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -96,14 +96,9 @@ def __init__( self.huggingface_folder_name_fstr = os.path.join( 'huggingface', huggingface_folder_name) - if isinstance(save_interval, str): - save_interval = Time.from_timestring(save_interval) - if isinstance(save_interval, int): - save_interval = Time(save_interval, TimeUnit.EPOCH) - - self.save_interval: Time = save_interval + self.save_interval: Time = Time.from_input(save_interval, TimeUnit.EPOCH) self.check_interval = create_interval_scheduler( - save_interval, include_end_of_training=True) + self.save_interval, include_end_of_training=True) self.remote_ud = maybe_create_remote_uploader_downloader_from_uri( save_folder, loggers=[]) if self.remote_ud is not None: diff --git a/llmfoundry/optim/scheduler.py b/llmfoundry/optim/scheduler.py index 4a6d21c873..9d3cb6be02 100644 --- a/llmfoundry/optim/scheduler.py +++ b/llmfoundry/optim/scheduler.py @@ -16,20 +16,19 @@ def _raise_if_units_dont_match(time: Union[str, Time], t_max: Union[str, Time], name: str) -> None: - if isinstance(time, str): - time = Time.from_timestring(time) - if isinstance(t_max, str): - t_max = Time.from_timestring(t_max) + time = Time.from_input(time) + t_max = Time.from_input(t_max) assert not isinstance(time, str) and not isinstance(t_max, str) if time.unit != t_max.unit: - raise ValueError(f'{time.unit=} does not match {t_max.unit=}.') + raise ValueError( + f'{name} (unit {time.unit=}) must match max_duration unit ({t_max.unit=}).' + ) def _raise_if_units_dur(time: Union[str, Time], name: str) -> None: - if isinstance(time, str): - time = Time.from_timestring(time) + time = Time.from_input(time) assert not isinstance(time, str)