diff --git a/llmfoundry/models/utils/config_moe_args.py b/llmfoundry/models/utils/config_moe_args.py index b69cd18348..3386204e26 100644 --- a/llmfoundry/models/utils/config_moe_args.py +++ b/llmfoundry/models/utils/config_moe_args.py @@ -12,6 +12,33 @@ from llmfoundry.models.layers.ffn import resolve_ffn_hidden_size +def create_process_group_ranks(ranks: tuple[int]): + """Creates a new distributed group. + + Used in create_set_process_group and create_mod_process_group methods below. + + This function is an alternative to `distributed.new_group(ranks)`. + When working with FSDP in torch1.13.1, using `distributed.new_group(ranks)` + resulted in an error but this method worked. + + TODO(GRT-2416): When composer no longer has support for torch1.13.1, we should + consider using `distributed.new_group(ranks)` here and in composer's FSDP + custom process group init. + + Args: + ranks (tuple[int]): Tuple of ranks of group members. + + Returns: + A handle of distributed group that can be given to collective calls. + """ + ranks_gather_list = [None for _ in range(distributed.get_world_size())] + distributed.all_gather_object(ranks_gather_list, ranks) + ranks_per_subgroup = list(set(ranks_gather_list)) + group, _ = distributed.distributed_c10d.new_subgroups_by_enumeration( + ranks_per_subgroup) + return group + + def create_set_process_group(k: int): """Creates a new distributed group using sets of k GPUs. @@ -33,7 +60,7 @@ def create_set_process_group(k: int): raise RuntimeError(f'{world_size=} must be divisible by {k=}.') start = distributed.get_rank() // k * k ranks = tuple(range(start, start + k)) - return distributed.new_group(ranks) + return create_process_group_ranks(ranks) def config_megablocks_moe_args(