From e63c4dadec1b9aebad8a9dac45e191c1b2c51e95 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Tue, 2 Apr 2024 15:58:53 +0000 Subject: [PATCH] Add timeout --- llmfoundry/callbacks/hf_checkpointer.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 90189ddeeb..2ff39a51fa 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -75,10 +75,13 @@ def _maybe_get_license_filename( return None -def _register_model_with_run_id_multiprocess(mlflow_logger: MLFlowLogger, - composer_logging_level: int, - model_uri: str, name: str, - await_creation_for: int): +def _register_model_with_run_id_multiprocess( + mlflow_logger: MLFlowLogger, + composer_logging_level: int, + model_uri: str, + name: str, + await_creation_for: int, +): """Function for calling MLFlowLogger.register_model_with_run_id from a. spawned child process. @@ -233,7 +236,14 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: '5GB') elif event == Event.FIT_END: # Wait for all child processes spawned by the callback to finish. + timeout = 3600 + wait_start = time.time() while not self._all_child_processes_done(): + wait_time = time.time() - wait_start + if wait_time > timeout: + raise TimeoutError( + f'Waited {wait_time} seconds for child processes to complete. Exceed timeout of {timeout} seconds.' + ) time.sleep(2) # Clean up temporary save directory; all processes are done with it.