diff --git a/llmfoundry/data/packing.py b/llmfoundry/data/packing.py index 9f4af3099e..6abb758fb4 100644 --- a/llmfoundry/data/packing.py +++ b/llmfoundry/data/packing.py @@ -7,6 +7,7 @@ import numpy as np import torch +from composer.utils import dist from omegaconf import DictConfig from transformers import PreTrainedTokenizerBase @@ -388,8 +389,19 @@ def profile_packing( dataloader_cfg.persistent_workers = False # If streaming dataset, use a temporary local folder for profiling + local_rank_zero = dist.get_global_rank() - dist.get_local_rank() if dataloader_cfg.dataset.get('remote') is not None: - dataloader_cfg.dataset.local = tempfile.TemporaryDirectory().name + tmp_path_to_broadcast = tempfile.TemporaryDirectory().name + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + tmp_path = gathered_paths[local_rank_zero] + dataloader_cfg.dataset.local = tmp_path + + if dataloader_cfg.dataset.get('streams') is not None: + for stream_config in dataloader_cfg.dataset.streams.values(): + tmp_path_to_broadcast = tempfile.TemporaryDirectory().name + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + tmp_path = gathered_paths[local_rank_zero] + stream_config.local = tmp_path # Determine the packing_ratio values we'll try packing_ratios, raw_batch_sizes = [], []