diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py index 313dd24e8c..5f2aee0676 100644 --- a/src/axolotl/utils/distributed.py +++ b/src/axolotl/utils/distributed.py @@ -38,18 +38,29 @@ def barrier(): def is_main_process(): """ - Check if the current process is the main process. - If not in distributed mode, always return True. + Return whether the current process is on the main rank. """ - if not is_distributed(): - return True - return dist.get_rank() == 0 + return get_rank() == 0 def get_world_size(): + """ + Return world size (1 if not distributed.) + """ + if not is_distributed(): + return 1 return int(os.getenv("WORLD_SIZE", "1")) +def get_rank(): + """ + Return rank (0 if not distributed.) + """ + if not is_distributed(): + return 0 + return dist.get_rank() + + @contextmanager def zero_only(): """ @@ -73,7 +84,9 @@ def zero_first(is_main): barrier() -def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name +def gather_scalar_from_all_ranks( + fn, world_size=1 +): # pylint: disable=invalid-name """ Run a callable 'fn' on all ranks and gather the results on the specified rank. @@ -95,7 +108,9 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n if not is_main_process(): dist.gather(value_tensor, dst=0) else: - gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)] + gathered_tensors = [ + torch.zeros_like(value_tensor) for _ in range(world_size) + ] dist.gather(value_tensor, gather_list=gathered_tensors, dst=0) # Convert tensors back to their original type (int or float) @@ -187,7 +202,9 @@ def gather_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name # Placeholder tensor for gathering results if is_main_process(): - gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)] + gathered_tensors = [ + torch.zeros_like(value_tensor) for _ in range(world_size) + ] else: gathered_tensors = None @@ -223,7 +240,9 @@ def reduce_and_broadcast(fn1, fn2): if not is_distributed(): return fn2([fn1()]) - gathered_values = gather_from_all_ranks(fn1, world_size=dist.get_world_size()) + gathered_values = gather_from_all_ranks( + fn1, world_size=dist.get_world_size() + ) # Use compute_and_broadcast to compute the reduced value on the main process # and then broadcast it to all ranks diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index 4446fece12..1461f160ba 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -2,6 +2,7 @@ Multipack Batch Sampler """ import logging +import math from concurrent.futures import ProcessPoolExecutor from multiprocessing import cpu_count @@ -9,6 +10,8 @@ import numpy as np from torch.utils.data import BatchSampler +from axolotl.utils.distributed import get_rank, get_world_size, is_main_process + LOG = logging.getLogger("axolotl.utils.samplers.multipack") @@ -44,7 +47,6 @@ def pack_group(items, group_offset, bin_capacity, max_items_per_bin): def pack(items, bin_capacity, group_size, max_items_per_bin): - items = np.array(items, dtype=np.int32) num_items = len(items) num_processes = max(1, min(num_items // group_size, cpu_count())) tasks = [ @@ -72,48 +74,50 @@ def __init__( batch_size, group_size, bin_size, - shuffle=False, + drop_last=False, ): self.sampler = sampler - self.sample_idxs = np.arange(len(sampler)) - self.lengths = ( - lengths if isinstance(lengths, np.ndarray) else np.array(lengths) - ) + self.lengths = np.array(lengths, dtype=np.int32) self.batch_max_len = batch_max_len self.batch_size = batch_size self.group_size = group_size self.bin_size = bin_size - self.shuffle = shuffle + self.drop_last = drop_last self._batches = None def _pack_batches(self): - # Shuffle indices if necessary. - idxs = np.copy(self.sample_idxs) - if self.shuffle: - np.random.shuffle(idxs) - - # Repack based on shuffled indices. - shuffled_lengths = self.lengths[idxs] + # Initially, calculate packs for all ranks. pack_idxs = pack( - shuffled_lengths, self.batch_max_len, self.group_size, self.bin_size + self.lengths, + self.batch_max_len, + self.group_size, + self.bin_size, ) - # Wrap packs into batches. - batch_idxs = [ - pack_idxs[i : i + self.batch_size] - for i in range(0, len(pack_idxs), self.batch_size) - ] + if is_main_process(): + used_tokens = self.lengths.sum() + available_tokens = len(pack_idxs) * self.batch_max_len + efficiency = used_tokens / available_tokens + LOG.debug(f"Sample packing efficiency: {efficiency * 100:.2f}%") + + # Select pack indices for this rank. + world_size = get_world_size() + if self.drop_last: + batches_per_rank = len(pack_idxs) // world_size + else: + batches_per_rank = math.ceil(len(pack_idxs) / world_size) + start_idx = batches_per_rank * get_rank() + end_idx = min(start_idx + batches_per_rank, len(pack_idxs)) + + batch_idxs = pack_idxs[start_idx:end_idx] return batch_idxs def __iter__(self): - if self.shuffle or self._batches is None: - self._batches = self._pack_batches() - + self._batches = self._pack_batches() return iter(self._batches) def __len__(self): if self._batches is None: self._batches = self._pack_batches() - return len(self._batches)