Skip to content

Commit

Permalink
Add support for packing in a distributed environment.
Browse files Browse the repository at this point in the history
  • Loading branch information
dsesclei committed Apr 11, 2024
1 parent 739dd5f commit 67f1504
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 33 deletions.
37 changes: 28 additions & 9 deletions src/axolotl/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,29 @@ def barrier():

def is_main_process():
"""
Check if the current process is the main process.
If not in distributed mode, always return True.
Return whether the current process is on the main rank.
"""
if not is_distributed():
return True
return dist.get_rank() == 0
return get_rank() == 0


def get_world_size():
"""
Return world size (1 if not distributed.)
"""
if not is_distributed():
return 1
return int(os.getenv("WORLD_SIZE", "1"))


def get_rank():
"""
Return rank (0 if not distributed.)
"""
if not is_distributed():
return 0
return dist.get_rank()


@contextmanager
def zero_only():
"""
Expand All @@ -73,7 +84,9 @@ def zero_first(is_main):
barrier()


def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
def gather_scalar_from_all_ranks(
fn, world_size=1
): # pylint: disable=invalid-name
"""
Run a callable 'fn' on all ranks and gather the results on the specified rank.
Expand All @@ -95,7 +108,9 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
if not is_main_process():
dist.gather(value_tensor, dst=0)
else:
gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)]
gathered_tensors = [
torch.zeros_like(value_tensor) for _ in range(world_size)
]
dist.gather(value_tensor, gather_list=gathered_tensors, dst=0)

# Convert tensors back to their original type (int or float)
Expand Down Expand Up @@ -187,7 +202,9 @@ def gather_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name

# Placeholder tensor for gathering results
if is_main_process():
gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)]
gathered_tensors = [
torch.zeros_like(value_tensor) for _ in range(world_size)
]
else:
gathered_tensors = None

Expand Down Expand Up @@ -223,7 +240,9 @@ def reduce_and_broadcast(fn1, fn2):
if not is_distributed():
return fn2([fn1()])

gathered_values = gather_from_all_ranks(fn1, world_size=dist.get_world_size())
gathered_values = gather_from_all_ranks(
fn1, world_size=dist.get_world_size()
)

# Use compute_and_broadcast to compute the reduced value on the main process
# and then broadcast it to all ranks
Expand Down
52 changes: 28 additions & 24 deletions src/axolotl/utils/samplers/multipack.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
Multipack Batch Sampler
"""
import logging
import math
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import cpu_count

import numba
import numpy as np
from torch.utils.data import BatchSampler

from axolotl.utils.distributed import get_rank, get_world_size, is_main_process

LOG = logging.getLogger("axolotl.utils.samplers.multipack")


Expand Down Expand Up @@ -44,7 +47,6 @@ def pack_group(items, group_offset, bin_capacity, max_items_per_bin):


def pack(items, bin_capacity, group_size, max_items_per_bin):
items = np.array(items, dtype=np.int32)
num_items = len(items)
num_processes = max(1, min(num_items // group_size, cpu_count()))
tasks = [
Expand Down Expand Up @@ -72,48 +74,50 @@ def __init__(
batch_size,
group_size,
bin_size,
shuffle=False,
drop_last=False,
):
self.sampler = sampler
self.sample_idxs = np.arange(len(sampler))
self.lengths = (
lengths if isinstance(lengths, np.ndarray) else np.array(lengths)
)
self.lengths = np.array(lengths, dtype=np.int32)
self.batch_max_len = batch_max_len
self.batch_size = batch_size
self.group_size = group_size
self.bin_size = bin_size
self.shuffle = shuffle
self.drop_last = drop_last
self._batches = None

def _pack_batches(self):
# Shuffle indices if necessary.
idxs = np.copy(self.sample_idxs)
if self.shuffle:
np.random.shuffle(idxs)

# Repack based on shuffled indices.
shuffled_lengths = self.lengths[idxs]
# Initially, calculate packs for all ranks.
pack_idxs = pack(
shuffled_lengths, self.batch_max_len, self.group_size, self.bin_size
self.lengths,
self.batch_max_len,
self.group_size,
self.bin_size,
)

# Wrap packs into batches.
batch_idxs = [
pack_idxs[i : i + self.batch_size]
for i in range(0, len(pack_idxs), self.batch_size)
]
if is_main_process():
used_tokens = self.lengths.sum()
available_tokens = len(pack_idxs) * self.batch_max_len
efficiency = used_tokens / available_tokens
LOG.debug(f"Sample packing efficiency: {efficiency * 100:.2f}%")

# Select pack indices for this rank.
world_size = get_world_size()
if self.drop_last:
batches_per_rank = len(pack_idxs) // world_size
else:
batches_per_rank = math.ceil(len(pack_idxs) / world_size)

start_idx = batches_per_rank * get_rank()
end_idx = min(start_idx + batches_per_rank, len(pack_idxs))

batch_idxs = pack_idxs[start_idx:end_idx]
return batch_idxs

def __iter__(self):
if self.shuffle or self._batches is None:
self._batches = self._pack_batches()

self._batches = self._pack_batches()
return iter(self._batches)

def __len__(self):
if self._batches is None:
self._batches = self._pack_batches()

return len(self._batches)

0 comments on commit 67f1504

Please sign in to comment.