Skip to content

Commit

Permalink
Merge branch 'main' into ffn-registry
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Apr 12, 2024
2 parents f8d4c8f + 6257e5b commit 03548e8
Showing 1 changed file with 22 additions and 1 deletion.
23 changes: 22 additions & 1 deletion llmfoundry/models/utils/config_moe_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,27 @@
from llmfoundry.models.layers.ffn import resolve_ffn_hidden_size


def create_process_group_ranks(ranks: tuple[int]):
"""Creates a new distributed group.
Used in create_set_process_group and create_mod_process_group methods below.
This function is an alternative to `distributed.new_group(ranks)`.
Args:
ranks (tuple[int]): Tuple of ranks of group members.
Returns:
A handle of distributed group that can be given to collective calls.
"""
ranks_gather_list = [None for _ in range(distributed.get_world_size())]
distributed.all_gather_object(ranks_gather_list, ranks)
ranks_per_subgroup = list(set(ranks_gather_list))
group, _ = distributed.distributed_c10d.new_subgroups_by_enumeration(
ranks_per_subgroup)
return group


def create_set_process_group(k: int):
"""Creates a new distributed group using sets of k GPUs.
Expand All @@ -34,7 +55,7 @@ def create_set_process_group(k: int):
raise RuntimeError(f'{world_size=} must be divisible by {k=}.')
start = distributed.get_rank() // k * k
ranks = tuple(range(start, start + k))
return distributed.new_group(ranks)
return create_process_group_ranks(ranks)


def config_megablocks_moe_args(
Expand Down

0 comments on commit 03548e8

Please sign in to comment.