Skip to content

Commit

Permalink
Merge branch 'main' into milo/foundry-type-cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
milocress authored Apr 10, 2024
2 parents 6a3d43a + 17f8aeb commit 7f3d913
Showing 1 changed file with 1 addition and 28 deletions.
29 changes: 1 addition & 28 deletions llmfoundry/models/utils/config_moe_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,33 +12,6 @@
from llmfoundry.models.layers.ffn import resolve_ffn_hidden_size


def create_process_group_ranks(ranks: tuple[int]):
"""Creates a new distributed group.
Used in create_set_process_group and create_mod_process_group methods below.
This function is an alternative to `distributed.new_group(ranks)`.
When working with FSDP in torch1.13.1, using `distributed.new_group(ranks)`
resulted in an error but this method worked.
TODO(GRT-2416): When composer no longer has support for torch1.13.1, we should
consider using `distributed.new_group(ranks)` here and in composer's FSDP
custom process group init.
Args:
ranks (tuple[int]): Tuple of ranks of group members.
Returns:
A handle of distributed group that can be given to collective calls.
"""
ranks_gather_list = [None for _ in range(distributed.get_world_size())]
distributed.all_gather_object(ranks_gather_list, ranks)
ranks_per_subgroup = list(set(ranks_gather_list))
group, _ = distributed.distributed_c10d.new_subgroups_by_enumeration(
ranks_per_subgroup)
return group


def create_set_process_group(k: int):
"""Creates a new distributed group using sets of k GPUs.
Expand All @@ -60,7 +33,7 @@ def create_set_process_group(k: int):
raise RuntimeError(f'{world_size=} must be divisible by {k=}.')
start = distributed.get_rank() // k * k
ranks = tuple(range(start, start + k))
return create_process_group_ranks(ranks)
return distributed.new_group(ranks)


def config_megablocks_moe_args(
Expand Down

0 comments on commit 7f3d913

Please sign in to comment.