diff --git a/composer/utils/dist.py b/composer/utils/dist.py index 573e940bb9..95a95835f4 100644 --- a/composer/utils/dist.py +++ b/composer/utils/dist.py @@ -37,6 +37,8 @@ import logging import os import pickle +import random +import string import sys import time from contextlib import contextmanager @@ -627,6 +629,77 @@ def get_sampler( ) +def get_node_signal_file_name(rng: Optional[random.Random] = None) -> str: + """Returns a file name to use for a file based wait within a node. + + The file name will contain a randomly generated string to avoid conflicts. + Note: This file name will be the same on each node, so that it can be used for a file based wait. + + Returns: + str: The name of the file that will be created to signal the end of a node's training. + """ + if rng is None: + rng = random.Random() + + random_string = ''.join(rng.choices(string.ascii_letters + string.digits, k=6)) + node_rank = get_node_rank() + file_name_list = [f'._signal_file_node{node_rank}_{random_string}'] + dist.broadcast_object_list(file_name_list, src=0) + return file_name_list[0] + + +def write_signal_file(signal_file_name: str, dir_path: Optional[str] = None) -> str: + """Writes a signal file to the specified directory. + + This function creates a signal file in the specified directory. If the directory does + Note: Only local rank zero writes the signal file. All other ranks are expected to wait for the signal file. + + Args: + signal_file_name (str): The name of the signal file. + dir_path (str, optional): The full path to the directory in which to create the signal file. If ``None``, + the current working directory will be used. + """ + if dir_path is not None: + os.makedirs(dir_path, exist_ok=True) + + signal_file_path = os.path.join(dir_path or os.getcwd(), signal_file_name) + if get_local_rank() == 0: + with open(signal_file_path, 'w') as _f: + _f.write('local rank zero done') + + return signal_file_path + + +@contextmanager +def busy_wait_for_local_rank_zero(dir_path: Optional[str] = None): + """Busy waits for the signal file to be created by local rank zero. + + This function will wait for the signal file to be created by local rank zero. It will + check every 0.1 seconds for the existence of the file. + + Args: + dir_path (str, optional): The directory in which to look for the signal file. If ``None``, + the current working directory will be used. + """ + # Get unique file name + signal_file_name = get_node_signal_file_name() + + # All ranks yield execution to allow local rank zero to run the code it needs to + yield + + # Local rank zero writes the signal file, all other rank just get the expected path + signal_file_path = write_signal_file(signal_file_name=signal_file_name, dir_path=dir_path) + + # Wait for the signal file to be created by local rank zero + with local_rank_zero_download_and_wait(signal_file_path): + # Sync all ranks across nodes as busy wait only is within node + dist.barrier() + + # Remove the signal file + if get_local_rank() == 0: + os.remove(signal_file_path) + + @contextmanager def local_rank_zero_download_and_wait(expected_file_path: str): """Context manager to wait for a file to exist on all ranks except local rank zero. diff --git a/tests/utils/test_dist.py b/tests/utils/test_dist.py index 44aedecf3d..608e56e5d2 100644 --- a/tests/utils/test_dist.py +++ b/tests/utils/test_dist.py @@ -1,6 +1,8 @@ # Copyright 2022 MosaicML Composer authors # SPDX-License-Identifier: Apache-2.0 +import os +import time from unittest.mock import patch import pytest @@ -27,3 +29,48 @@ def test_run_local_rank_first_context_runs_properly(): # so dist is initialized here and this code should run without error with dist.run_local_rank_zero_first(): pass + + +@pytest.mark.world_size(2) +def test_get_node_signal_file_name(): + file_name = dist.get_node_signal_file_name() + gathered_file_names = dist.all_gather_object(file_name) + + assert len(gathered_file_names) == 2 + assert gathered_file_names[0] == gathered_file_names[1] + assert gathered_file_names[0] == file_name + assert file_name.startswith('._signal_file_node0_') + assert len(file_name) == len('._signal_file_node0_') + 6 + + +@pytest.mark.world_size(2) +def test_write_signal_file(tmp_path): + file_name = dist.get_node_signal_file_name() + file_path = os.path.join(tmp_path, file_name) + dist.write_signal_file(file_name, tmp_path) + + # tmp_path will be different on each rank, and only rank zero + # should have written a file + if dist.get_local_rank() == 0: + assert os.path.exists(file_path) + else: + assert not os.path.exists(file_path) + + +@pytest.mark.world_size(2) +def test_busy_wait_for_local_rank_zero(tmp_path): + gathered_tmp_path = dist.all_gather_object(tmp_path)[0] + + dist.barrier() + start_time = time.time() + assert os.listdir(gathered_tmp_path) == [] + with dist.busy_wait_for_local_rank_zero(gathered_tmp_path): + if dist.get_local_rank() == 0: + time.sleep(0.5) + + end_time = time.time() + total_time = end_time - start_time + gathered_times = dist.all_gather_object(total_time) + assert os.listdir(gathered_tmp_path) == [] + assert len(gathered_times) == 2 + assert abs(gathered_times[0] - gathered_times[1]) < 0.1