Skip to content

Commit

Permalink
Revert "Update config_moe_args.py (#1104)" (#1111)
Browse files Browse the repository at this point in the history
This reverts commit 17f8aeb.
  • Loading branch information
vchiley authored Apr 12, 2024
1 parent e9b1c6e commit b58d68c
Showing 1 changed file with 28 additions and 1 deletion.
29 changes: 28 additions & 1 deletion llmfoundry/models/utils/config_moe_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,33 @@
from llmfoundry.models.layers.ffn import resolve_ffn_hidden_size


def create_process_group_ranks(ranks: tuple[int]):
"""Creates a new distributed group.
Used in create_set_process_group and create_mod_process_group methods below.
This function is an alternative to `distributed.new_group(ranks)`.
When working with FSDP in torch1.13.1, using `distributed.new_group(ranks)`
resulted in an error but this method worked.
TODO(GRT-2416): When composer no longer has support for torch1.13.1, we should
consider using `distributed.new_group(ranks)` here and in composer's FSDP
custom process group init.
Args:
ranks (tuple[int]): Tuple of ranks of group members.
Returns:
A handle of distributed group that can be given to collective calls.
"""
ranks_gather_list = [None for _ in range(distributed.get_world_size())]
distributed.all_gather_object(ranks_gather_list, ranks)
ranks_per_subgroup = list(set(ranks_gather_list))
group, _ = distributed.distributed_c10d.new_subgroups_by_enumeration(
ranks_per_subgroup)
return group


def create_set_process_group(k: int):
"""Creates a new distributed group using sets of k GPUs.
Expand All @@ -33,7 +60,7 @@ def create_set_process_group(k: int):
raise RuntimeError(f'{world_size=} must be divisible by {k=}.')
start = distributed.get_rank() // k * k
ranks = tuple(range(start, start + k))
return distributed.new_group(ranks)
return create_process_group_ranks(ranks)


def config_megablocks_moe_args(
Expand Down

0 comments on commit b58d68c

Please sign in to comment.