Skip to content

Commit

Permalink
Add timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Apr 2, 2024
1 parent 8e6af11 commit e63c4da
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,13 @@ def _maybe_get_license_filename(
return None


def _register_model_with_run_id_multiprocess(mlflow_logger: MLFlowLogger,
composer_logging_level: int,
model_uri: str, name: str,
await_creation_for: int):
def _register_model_with_run_id_multiprocess(
mlflow_logger: MLFlowLogger,
composer_logging_level: int,
model_uri: str,
name: str,
await_creation_for: int,
):
"""Function for calling MLFlowLogger.register_model_with_run_id from a.
spawned child process.
Expand Down Expand Up @@ -233,7 +236,14 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
'5GB')
elif event == Event.FIT_END:
# Wait for all child processes spawned by the callback to finish.
timeout = 3600
wait_start = time.time()
while not self._all_child_processes_done():
wait_time = time.time() - wait_start
if wait_time > timeout:
raise TimeoutError(
f'Waited {wait_time} seconds for child processes to complete. Exceed timeout of {timeout} seconds.'
)
time.sleep(2)

# Clean up temporary save directory; all processes are done with it.
Expand Down

0 comments on commit e63c4da

Please sign in to comment.