Skip to content

Commit

Permalink
address pr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Oct 23, 2024
1 parent 17eb0be commit f8fd8a0
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
3 changes: 0 additions & 3 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,13 +617,10 @@ def _download_remote_hf_dataset(remote_path: str, split: str) -> str:
with open(signal_file_path, 'wb') as f:
f.write(b'local_rank0_completed_download')

print(dist.get_local_rank(), f'signal_file_path: {signal_file_path}')

# Avoid the collective call until the local rank zero has finished trying to download the dataset
# so that we don't timeout for large downloads. This syncs all processes on the node
with dist.local_rank_zero_download_and_wait(signal_file_path):
# Then, wait to ensure every node has finished trying to download the dataset
print('GOT TO BARRIER')
dist.barrier()

# clean up signal file
Expand Down
7 changes: 5 additions & 2 deletions llmfoundry/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@ def dist_mkdtemp() -> str:
str: The path to the temporary directory.
"""
tempdir = None
if dist.get_local_rank() == 0:
local_rank = dist.get_local_rank()
global_rank = dist.get_global_rank()
if local_rank == 0:
tempdir = tempfile.mkdtemp()
tempdir = dist.all_gather_object(tempdir)[0]

tempdir = dist.all_gather_object(tempdir)[global_rank - local_rank]
if tempdir is None:
raise RuntimeError('Dist operation to get tempdir failed.')
return tempdir

0 comments on commit f8fd8a0

Please sign in to comment.