From 17f8aeb0d3dd85e3f4987d249f6d36efa70e3a63 Mon Sep 17 00:00:00 2001 From: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> Date: Tue, 9 Apr 2024 19:30:48 -0700 Subject: [PATCH] Update config_moe_args.py (#1104) https://databricks.atlassian.net/browse/GRT-2812 see [here](https://github.com/mosaicml/llm-foundry-private/pull/245#issuecomment-2046306845) Ran 500 steps, new version did marginally better --- llmfoundry/models/utils/config_moe_args.py | 29 +--------------------- 1 file changed, 1 insertion(+), 28 deletions(-) diff --git a/llmfoundry/models/utils/config_moe_args.py b/llmfoundry/models/utils/config_moe_args.py index 3386204e26..b69cd18348 100644 --- a/llmfoundry/models/utils/config_moe_args.py +++ b/llmfoundry/models/utils/config_moe_args.py @@ -12,33 +12,6 @@ from llmfoundry.models.layers.ffn import resolve_ffn_hidden_size -def create_process_group_ranks(ranks: tuple[int]): - """Creates a new distributed group. - - Used in create_set_process_group and create_mod_process_group methods below. - - This function is an alternative to `distributed.new_group(ranks)`. - When working with FSDP in torch1.13.1, using `distributed.new_group(ranks)` - resulted in an error but this method worked. - - TODO(GRT-2416): When composer no longer has support for torch1.13.1, we should - consider using `distributed.new_group(ranks)` here and in composer's FSDP - custom process group init. - - Args: - ranks (tuple[int]): Tuple of ranks of group members. - - Returns: - A handle of distributed group that can be given to collective calls. - """ - ranks_gather_list = [None for _ in range(distributed.get_world_size())] - distributed.all_gather_object(ranks_gather_list, ranks) - ranks_per_subgroup = list(set(ranks_gather_list)) - group, _ = distributed.distributed_c10d.new_subgroups_by_enumeration( - ranks_per_subgroup) - return group - - def create_set_process_group(k: int): """Creates a new distributed group using sets of k GPUs. @@ -60,7 +33,7 @@ def create_set_process_group(k: int): raise RuntimeError(f'{world_size=} must be divisible by {k=}.') start = distributed.get_rank() // k * k ranks = tuple(range(start, start + k)) - return create_process_group_ranks(ranks) + return distributed.new_group(ranks) def config_megablocks_moe_args(