diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py index b3ea07c055..aa7e3feda3 100644 --- a/src/axolotl/utils/distributed.py +++ b/src/axolotl/utils/distributed.py @@ -3,6 +3,7 @@ """ from contextlib import contextmanager +import torch import torch.distributed as dist from accelerate import Accelerator @@ -53,3 +54,91 @@ def zero_first(is_main): yield if is_main: # then rank 0 waits after it has run the context barrier() + + +def compute_and_broadcast(fn): # pylint: disable=invalid-name + """ + Compute a value using the function 'fn' only on the specified rank (default is 0). + The value is then broadcasted to all other ranks. + + Args: + - fn (callable): A function that computes the value. This should not have any side effects. + - rank (int, optional): The rank that computes the value. Default is 0. + + Returns: + - The computed value (int or float). + """ + if is_main_process(): + value_scalar = fn() + value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float() + else: + value_tensor = torch.tensor(0.0, device=dist.get_rank()) # Placeholder tensor + + # Broadcast the tensor to all processes. + barrier() + dist.broadcast(value_tensor, src=0) + + # Convert the tensor back to its original type (int or float) + if value_tensor == value_tensor.int(): + return int(value_tensor.item()) + return float(value_tensor.item()) + + +def gather_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name + """ + Run a callable 'fn' on all ranks and gather the results on the specified rank. + + Args: + - fn (callable): A function that computes the value. This should not have any side effects. + - rank (int, optional): The rank that gathers the values. Default is 0. + - world_size (int, optional): Total number of processes in the current distributed setup. + + Returns: + - A list of computed values from all ranks if on the gathering rank, otherwise None. + """ + value_scalar = fn() + value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float() + + # Placeholder tensor for gathering results + if is_main_process(): + gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)] + else: + gathered_tensors = None + + dist.gather(value_tensor, gather_list=gathered_tensors, dst=0) + + if is_main_process(): + # Convert tensors back to their original type (int or float) + gathered_values = [] + for tensor in gathered_tensors: + if tensor == tensor.int(): + gathered_values.append(int(tensor.item())) + else: + gathered_values.append(float(tensor.item())) + return gathered_values + return None + + +def reduce_and_broadcast(fn1, fn2): + """ + Run a callable 'fn1' on all ranks, gather the results, reduce them using 'fn2', + and then broadcast the reduced result to all ranks. + + Args: + - fn1 (callable): A function that computes the value on each rank. + - fn2 (callable): A reduction function that takes a list of values and returns a single value. + - world_size (int, optional): Total number of processes in the current distributed setup. + + Returns: + - The reduced and broadcasted value. + """ + + # Gather values from all ranks using fn1 + if not is_distributed(): + return fn2([fn1()]) + + gathered_values = gather_from_all_ranks(fn1, world_size=dist.get_world_size()) + + # Use compute_and_broadcast to compute the reduced value on the main process + # and then broadcast it to all ranks + return compute_and_broadcast(lambda: fn2(gathered_values)) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index c9c17fe33c..06ff906fab 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -8,11 +8,12 @@ from dataclasses import dataclass, field from functools import partial from pathlib import Path -from typing import Optional, Union +from typing import List, Optional, Union import bitsandbytes as bnb import numpy as np import torch.cuda +import torch.distributed as dist import transformers from datasets import Dataset, set_caching_enabled from torch import nn @@ -31,6 +32,7 @@ ) from axolotl.utils.collators import DataCollatorForSeq2Seq from axolotl.utils.dataloader import MultipackDistributedDataloader +from axolotl.utils.distributed import is_distributed, reduce_and_broadcast from axolotl.utils.schedulers import ( InterpolatingLogScheduler, get_cosine_schedule_with_quadratic_warmup, @@ -331,7 +333,16 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer): f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}" ) else: - sampler = RandomSampler(train_dataset) + if cfg.world_size > 1 and is_distributed(): + sampler = DistributedSampler( + train_dataset, + num_replicas=cfg.world_size, + rank=dist.get_rank(), + seed=cfg.seed or 42, + ) + else: + sampler = RandomSampler(train_dataset) + data_loader = MultipackDistributedDataloader( train_dataset, batch_size=cfg.micro_batch_size, @@ -349,18 +360,21 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer): data_loader_len = data_loader.len_w_stats() actual_eff = data_loader.efficiency() LOG.info(f"data_loader_len: {data_loader_len}") - total_num_steps = int( - math.floor( - data_loader_len - * cfg.micro_batch_size - * cfg.num_epochs - // cfg.batch_size - ) + total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs)) + + def calc_sample_packing_eff_est(estimates: List[float]): + LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}") + return max(estimates) + + sample_packing_eff_est = reduce_and_broadcast( + lambda: math.ceil(actual_eff * 100.0) / 100.0, + calc_sample_packing_eff_est, ) + sample_packing_eff_est = math.ceil(sample_packing_eff_est * 100.0) / 100.0 LOG.info( - f"📝 UPDATE CONFIG WITH: `sample_packing_eff_est: {math.ceil(actual_eff * 100.0) / 100.0}`" + f"📝 UPDATE CONFIG WITH: `sample_packing_eff_est: {sample_packing_eff_est}`" ) - cfg.sample_packing_eff_est = math.ceil(actual_eff * 100.0) / 100.0 + cfg.sample_packing_eff_est = sample_packing_eff_est else: total_num_steps = int( math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)