diff --git a/llmfoundry/models/utils/config_moe_args.py b/llmfoundry/models/utils/config_moe_args.py index 4bbb246613..4de9a47bbc 100644 --- a/llmfoundry/models/utils/config_moe_args.py +++ b/llmfoundry/models/utils/config_moe_args.py @@ -13,6 +13,27 @@ from llmfoundry.models.layers.ffn import resolve_ffn_hidden_size +def create_process_group_ranks(ranks: tuple[int]): + """Creates a new distributed group. + + Used in create_set_process_group and create_mod_process_group methods below. + + This function is an alternative to `distributed.new_group(ranks)`. + + Args: + ranks (tuple[int]): Tuple of ranks of group members. + + Returns: + A handle of distributed group that can be given to collective calls. + """ + ranks_gather_list = [None for _ in range(distributed.get_world_size())] + distributed.all_gather_object(ranks_gather_list, ranks) + ranks_per_subgroup = list(set(ranks_gather_list)) + group, _ = distributed.distributed_c10d.new_subgroups_by_enumeration( + ranks_per_subgroup) + return group + + def create_set_process_group(k: int): """Creates a new distributed group using sets of k GPUs. @@ -34,7 +55,7 @@ def create_set_process_group(k: int): raise RuntimeError(f'{world_size=} must be divisible by {k=}.') start = distributed.get_rank() // k * k ranks = tuple(range(start, start + k)) - return distributed.new_group(ranks) + return create_process_group_ranks(ranks) def config_megablocks_moe_args(