Skip to content

Commit

Permalink
Busy wait utils in dist (#3396)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Jun 14, 2024
1 parent 13f2a4f commit e494f9b
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 0 deletions.
73 changes: 73 additions & 0 deletions composer/utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
import logging
import os
import pickle
import random
import string
import sys
import time
from contextlib import contextmanager
Expand Down Expand Up @@ -627,6 +629,77 @@ def get_sampler(
)


def get_node_signal_file_name(rng: Optional[random.Random] = None) -> str:
"""Returns a file name to use for a file based wait within a node.
The file name will contain a randomly generated string to avoid conflicts.
Note: This file name will be the same on each node, so that it can be used for a file based wait.
Returns:
str: The name of the file that will be created to signal the end of a node's training.
"""
if rng is None:
rng = random.Random()

random_string = ''.join(rng.choices(string.ascii_letters + string.digits, k=6))
node_rank = get_node_rank()
file_name_list = [f'._signal_file_node{node_rank}_{random_string}']
dist.broadcast_object_list(file_name_list, src=0)
return file_name_list[0]


def write_signal_file(signal_file_name: str, dir_path: Optional[str] = None) -> str:
"""Writes a signal file to the specified directory.
This function creates a signal file in the specified directory. If the directory does
Note: Only local rank zero writes the signal file. All other ranks are expected to wait for the signal file.
Args:
signal_file_name (str): The name of the signal file.
dir_path (str, optional): The full path to the directory in which to create the signal file. If ``None``,
the current working directory will be used.
"""
if dir_path is not None:
os.makedirs(dir_path, exist_ok=True)

signal_file_path = os.path.join(dir_path or os.getcwd(), signal_file_name)
if get_local_rank() == 0:
with open(signal_file_path, 'w') as _f:
_f.write('local rank zero done')

return signal_file_path


@contextmanager
def busy_wait_for_local_rank_zero(dir_path: Optional[str] = None):
"""Busy waits for the signal file to be created by local rank zero.
This function will wait for the signal file to be created by local rank zero. It will
check every 0.1 seconds for the existence of the file.
Args:
dir_path (str, optional): The directory in which to look for the signal file. If ``None``,
the current working directory will be used.
"""
# Get unique file name
signal_file_name = get_node_signal_file_name()

# All ranks yield execution to allow local rank zero to run the code it needs to
yield

# Local rank zero writes the signal file, all other rank just get the expected path
signal_file_path = write_signal_file(signal_file_name=signal_file_name, dir_path=dir_path)

# Wait for the signal file to be created by local rank zero
with local_rank_zero_download_and_wait(signal_file_path):
# Sync all ranks across nodes as busy wait only is within node
dist.barrier()

# Remove the signal file
if get_local_rank() == 0:
os.remove(signal_file_path)


@contextmanager
def local_rank_zero_download_and_wait(expected_file_path: str):
"""Context manager to wait for a file to exist on all ranks except local rank zero.
Expand Down
47 changes: 47 additions & 0 deletions tests/utils/test_dist.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

import os
import time
from unittest.mock import patch

import pytest
Expand All @@ -27,3 +29,48 @@ def test_run_local_rank_first_context_runs_properly():
# so dist is initialized here and this code should run without error
with dist.run_local_rank_zero_first():
pass


@pytest.mark.world_size(2)
def test_get_node_signal_file_name():
file_name = dist.get_node_signal_file_name()
gathered_file_names = dist.all_gather_object(file_name)

assert len(gathered_file_names) == 2
assert gathered_file_names[0] == gathered_file_names[1]
assert gathered_file_names[0] == file_name
assert file_name.startswith('._signal_file_node0_')
assert len(file_name) == len('._signal_file_node0_') + 6


@pytest.mark.world_size(2)
def test_write_signal_file(tmp_path):
file_name = dist.get_node_signal_file_name()
file_path = os.path.join(tmp_path, file_name)
dist.write_signal_file(file_name, tmp_path)

# tmp_path will be different on each rank, and only rank zero
# should have written a file
if dist.get_local_rank() == 0:
assert os.path.exists(file_path)
else:
assert not os.path.exists(file_path)


@pytest.mark.world_size(2)
def test_busy_wait_for_local_rank_zero(tmp_path):
gathered_tmp_path = dist.all_gather_object(tmp_path)[0]

dist.barrier()
start_time = time.time()
assert os.listdir(gathered_tmp_path) == []
with dist.busy_wait_for_local_rank_zero(gathered_tmp_path):
if dist.get_local_rank() == 0:
time.sleep(0.5)

end_time = time.time()
total_time = end_time - start_time
gathered_times = dist.all_gather_object(total_time)
assert os.listdir(gathered_tmp_path) == []
assert len(gathered_times) == 2
assert abs(gathered_times[0] - gathered_times[1]) < 0.1

0 comments on commit e494f9b

Please sign in to comment.