Skip to content

Commit

Permalink
Revert "Use utils to get shared fs safe signal file name (#1381)" (#1389
Browse files Browse the repository at this point in the history
)

This reverts commit d2d29ad.
  • Loading branch information
dakinggg authored Jul 23, 2024
1 parent 3d7d12e commit 221d3e2
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 23 deletions.
55 changes: 35 additions & 20 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,27 +534,42 @@ def _download_remote_hf_dataset(remote_path: str, split: str) -> str:

# Since we don't know exactly what the extension will be, since it is one of a list
# use a signal file to wait for instead of the desired file
with dist.busy_wait_for_local_rank_zero(finetune_dir):
if dist.get_local_rank() == 0:
try:
get_file(path=name, destination=destination, overwrite=True)
except FileNotFoundError as e:
if extension == SUPPORTED_EXTENSIONS[-1]:
files_searched = [
f'{name}/{split}{ext}'
for ext in SUPPORTED_EXTENSIONS
]
raise FileNotFoundError(
f'Could not find a file with any of ' + \
f'the supported extensions: {SUPPORTED_EXTENSIONS}\n' + \
f'at {files_searched}',
) from e
else:
log.debug(
f'Could not find {name}, looking for another extension',
)
continue
signal_file_path = os.path.join(
finetune_dir,
f'.node_{dist.get_node_rank()}_local_rank0_completed',
)
if dist.get_local_rank() == 0:
try:
get_file(path=name, destination=destination, overwrite=True)
except FileNotFoundError as e:
if extension == SUPPORTED_EXTENSIONS[-1]:
files_searched = [
f'{name}/{split}{ext}' for ext in SUPPORTED_EXTENSIONS
]
raise FileNotFoundError(
f'Could not find a file with any of ' + \
f'the supported extensions: {SUPPORTED_EXTENSIONS}\n' + \
f'at {files_searched}',
) from e
else:
log.debug(
f'Could not find {name}, looking for another extension',
)
continue

os.makedirs(os.path.dirname(signal_file_path), exist_ok=True)
with open(signal_file_path, 'wb') as f:
f.write(b'local_rank0_completed_download')

# Avoid the collective call until the local rank zero has finished trying to download the dataset
# so that we don't timeout for large downloads. This syncs all processes on the node
with dist.local_rank_zero_download_and_wait(signal_file_path):
# Then, wait to ensure every node has finished trying to download the dataset
dist.barrier()

# clean up signal file
if dist.get_local_rank() == 0:
os.remove(signal_file_path)
dist.barrier()
break
return finetune_dir
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,7 @@ def build_from_hf(
Returns:
Dataset: The tokenized dataset.
"""
signal_file_path = dist.get_node_signal_file_name()
signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_data_prep_completed'

# Non local rank 0 ranks will wait here for local rank 0 to finish the data processing.
# Once local rank 0 is done, the datasets are all cached on disk, and all other ranks
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def _autoset_attn_implementation_monkeypatch(
f'init_device="{init_device}" must be either "cpu" or "meta".',
)

signal_file_path = dist.get_node_signal_file_name()
signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed'
if dist.get_local_rank() == 0:
with open(signal_file_path, 'wb') as f:
f.write(b'local_rank0_completed_download')
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ def build_tokenizer(
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

signal_file_path = dist.get_node_signal_file_name()
signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed_tokenizer_setup'

if dist.is_available() and dist.is_initialized(
) and dist.get_world_size() > 1:
Expand Down

0 comments on commit 221d3e2

Please sign in to comment.