Skip to content

Commit

Permalink
Merge branch 'main' into background
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea authored Apr 3, 2024
2 parents e1b905c + d452c60 commit 2a3515d
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion llmfoundry/data/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np
import torch
from composer.utils import dist
from omegaconf import DictConfig
from transformers import PreTrainedTokenizerBase

Expand Down Expand Up @@ -388,8 +389,19 @@ def profile_packing(
dataloader_cfg.persistent_workers = False

# If streaming dataset, use a temporary local folder for profiling
local_rank_zero = dist.get_global_rank() - dist.get_local_rank()
if dataloader_cfg.dataset.get('remote') is not None:
dataloader_cfg.dataset.local = tempfile.TemporaryDirectory().name
tmp_path_to_broadcast = tempfile.TemporaryDirectory().name
gathered_paths = dist.all_gather_object(tmp_path_to_broadcast)
tmp_path = gathered_paths[local_rank_zero]
dataloader_cfg.dataset.local = tmp_path

if dataloader_cfg.dataset.get('streams') is not None:
for stream_config in dataloader_cfg.dataset.streams.values():
tmp_path_to_broadcast = tempfile.TemporaryDirectory().name
gathered_paths = dist.all_gather_object(tmp_path_to_broadcast)
tmp_path = gathered_paths[local_rank_zero]
stream_config.local = tmp_path

# Determine the packing_ratio values we'll try
packing_ratios, raw_batch_sizes = [], []
Expand Down

0 comments on commit 2a3515d

Please sign in to comment.