diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index 38ca253f80..5b5a6b1449 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -397,7 +397,6 @@ def attach_ffn_mb_args( """ ffn.experts.mlp.hidden_size = args.ffn_hidden_size ffn.experts.mlp.expert_parallel_group = expert_parallel_group - ffn.experts.mlp.weight_parallel_group = args.weight_parallel_group def get_fsdp_submesh_2d(device_mesh: DeviceMesh): diff --git a/llmfoundry/models/utils/mpt_param_count.py b/llmfoundry/models/utils/mpt_param_count.py index d7b61354c7..bd8f279ad5 100644 --- a/llmfoundry/models/utils/mpt_param_count.py +++ b/llmfoundry/models/utils/mpt_param_count.py @@ -62,13 +62,6 @@ def megablocks_n_total_params(mpt_model) -> int: # type: ignore moe_world_size = mpt_model.config.ffn_config.get('moe_world_size') - if mpt_model.config.ffn_config.get('moe_weight_parallelism', False): - # If MegaBlocks shards experts, the total sharding world size - # must be increased by the degree to which MegaBlocks shards the - # experts. - mb_args = mpt_model.model.transformer.mb_args - moe_world_size *= mb_args.weight_parallel_group.size() - n_total_params = 0 for module in mpt_model.modules(): if isinstance( @@ -109,9 +102,6 @@ def megablocks_n_active_params(mpt_model) -> int: # type: ignore moe_world_size = mpt_model.config.ffn_config.get('moe_world_size') local_experts = moe_num_experts / moe_world_size # if local_experts is < 1, then the expert is sharded - if mpt_model.config.ffn_config.get('moe_weight_parallelism', False): - mb_args = mpt_model.model.transformer.mb_args - local_experts /= mb_args.weight_parallel_group.size() moe_top_k = mpt_model.config.ffn_config.get('moe_top_k', 1) n_active_params = 0 diff --git a/llmfoundry/models/utils/param_init_fns.py b/llmfoundry/models/utils/param_init_fns.py index 9941c2d049..180e7b894c 100644 --- a/llmfoundry/models/utils/param_init_fns.py +++ b/llmfoundry/models/utils/param_init_fns.py @@ -484,19 +484,12 @@ def _megablocks_sparse_mlp_generic_param_init_fn_( div_is_residual (float): The value by which parameter initialization is divided if init_div_is_residual flag is enabled. """ - expert_process_group_size, rank, weight_parallel_group_size, weight_parallel_group_rank = 1, 0, 1, 0 + expert_process_group_size, rank = 1, 0 if module.expert_parallel_group is not None: expert_process_group_size = int( module.expert_parallel_group.size(), ) # type: ignore rank = int(module.expert_parallel_group.rank()) # type: ignore - if module.weight_parallel_group is not None: - weight_parallel_group_size = int( - module.weight_parallel_group.size(), - ) # type: ignore - weight_parallel_group_rank = int( - module.weight_parallel_group.rank(), - ) # type: ignore hidden_size = int(module.hidden_size) # type: ignore @@ -505,8 +498,7 @@ def _megablocks_sparse_mlp_generic_param_init_fn_( if isinstance(w1, DTensor): w1 = w1._local_tensor w1_size = list(w1.shape) # type: ignore - w1_size[ - 0] = w1_size[0] * expert_process_group_size * weight_parallel_group_size + w1_size[0] = w1_size[0] * expert_process_group_size n_exp = w1_size[0] // hidden_size _fused = (0, [(n + 1) * hidden_size for n in range(n_exp - 1)]) @@ -514,26 +506,21 @@ def _megablocks_sparse_mlp_generic_param_init_fn_( _w1 = w1.new_empty(w1_size) # type: ignore fused_param_init_helper(_w1, init_fn_, _fused) _w1_local = _w1.chunk(expert_process_group_size, dim=0)[rank] - _w1_local_slice = _w1_local.chunk(weight_parallel_group_size, - dim=0)[weight_parallel_group_rank] with torch.no_grad(): - w1.copy_(_w1_local_slice) # type: ignore + w1.copy_(_w1_local) # type: ignore # Initialize w2 w2 = module.w2 if isinstance(w2, DTensor): w2 = w2._local_tensor w2_size = list(w2.shape) # type: ignore - w2_size[ - 0] = w2_size[0] * expert_process_group_size * weight_parallel_group_size + w2_size[0] = w2_size[0] * expert_process_group_size _w2 = w2.new_empty(w2_size) # type: ignore # MegaBlocks operates on w2 as x @ w2, so needs flipped fan mode fused_param_init_helper(_w2, _flip_fan_mode(init_fn_), _fused) _w2_local = _w2.chunk(expert_process_group_size, dim=0)[rank] - _w2_local_slice = _w2_local.chunk(weight_parallel_group_size, - dim=0)[weight_parallel_group_rank] with torch.no_grad(): - w2.copy_(_w2_local_slice) # type: ignore + w2.copy_(_w2_local) # type: ignore if init_div_is_residual is not False: with torch.no_grad(): w2.div_(div_is_residual) # type: ignore @@ -567,19 +554,12 @@ def _megablocks_sparse_glu_generic_param_init_fn_( ) # Init ported from _megablocks_sparse_mlp_generic_param_init_fn_ for v1 - expert_process_group_size, rank, weight_parallel_group_size, weight_parallel_group_rank = 1, 0, 1, 0 + expert_process_group_size, rank = 1, 0 if module.expert_parallel_group is not None: expert_process_group_size = int( module.expert_parallel_group.size(), ) # type: ignore rank = int(module.expert_parallel_group.rank()) # type: ignore - if module.weight_parallel_group is not None: - weight_parallel_group_size = int( - module.weight_parallel_group.size(), - ) # type: ignore - weight_parallel_group_rank = int( - module.weight_parallel_group.rank(), - ) # type: ignore hidden_size = int(module.hidden_size) # type: ignore @@ -588,8 +568,7 @@ def _megablocks_sparse_glu_generic_param_init_fn_( if isinstance(v1, DTensor): v1 = v1._local_tensor v1_size = list(v1.shape) # type: ignore - v1_size[ - 0] = v1_size[0] * expert_process_group_size * weight_parallel_group_size + v1_size[0] = v1_size[0] * expert_process_group_size n_exp = v1_size[0] // hidden_size _fused = (0, [(n + 1) * hidden_size for n in range(n_exp - 1)]) @@ -597,10 +576,8 @@ def _megablocks_sparse_glu_generic_param_init_fn_( _v1 = v1.new_empty(v1_size) # type: ignore fused_param_init_helper(_v1, init_fn_, _fused) _v1_local = _v1.chunk(expert_process_group_size, dim=0)[rank] - _v1_local_slice = _v1_local.chunk(weight_parallel_group_size, - dim=0)[weight_parallel_group_rank] with torch.no_grad(): - v1.copy_(_v1_local_slice) # type: ignore + v1.copy_(_v1_local) # type: ignore def _megablocks_mlp_generic_param_init_fn_( @@ -623,41 +600,32 @@ def _megablocks_mlp_generic_param_init_fn_( div_is_residual (float): The value by which parameter initialization is divided if init_div_is_residual flag is enabled. """ - expert_process_group_size, rank, weight_parallel_group_size, w_rank = 1, 0, 1, 0 + expert_process_group_size, rank = 1, 0 if module.expert_parallel_group is not None: expert_process_group_size = int( module.expert_parallel_group.size(), ) # type: ignore rank = int(module.expert_parallel_group.rank()) # type: ignore - if module.weight_parallel_group is not None: - weight_parallel_group_size = int( - module.weight_parallel_group.size(), - ) # type: ignore - w_rank = int(module.weight_parallel_group.rank()) # type: ignore _init_fn_ = _flip_fan_mode(init_fn_) # Initialize w1 w1_size = list(module.w1.shape) # type: ignore w1_size[0] = w1_size[0] * expert_process_group_size - w1_size[1] = w1_size[1] * weight_parallel_group_size _w1 = module.w1.new_empty(w1_size) # type: ignore stacked_param_init_helper(_w1, _init_fn_, module._stack_dim) # type: ignore _w1_local = _w1.chunk(expert_process_group_size, dim=0)[rank] - _w1_local_slice = _w1_local.chunk(weight_parallel_group_size, dim=1)[w_rank] with torch.no_grad(): - module.w1.copy_(_w1_local_slice) # type: ignore + module.w1.copy_(_w1_local) # type: ignore # Initialize w2 w2_size = list(module.w2.shape) # type: ignore w2_size[0] = w2_size[0] * expert_process_group_size - w2_size[1] = w2_size[1] * weight_parallel_group_size _w2 = module.w2.new_empty(w2_size) # type: ignore stacked_param_init_helper(_w2, _init_fn_, module._stack_dim) # type: ignore _w2_local = _w2.chunk(expert_process_group_size, dim=0)[rank] - _w2_local_slice = _w2_local.chunk(weight_parallel_group_size, dim=1)[w_rank] with torch.no_grad(): - module.w2.copy_(_w2_local_slice) # type: ignore + module.w2.copy_(_w2_local) # type: ignore if init_div_is_residual is not False: with torch.no_grad(): module.w2.div_(div_is_residual) # type: ignore diff --git a/scripts/train/yamls/pretrain/testing-moe.yaml b/scripts/train/yamls/pretrain/testing-moe.yaml index e61e3e451e..ee9483ffd0 100644 --- a/scripts/train/yamls/pretrain/testing-moe.yaml +++ b/scripts/train/yamls/pretrain/testing-moe.yaml @@ -23,7 +23,6 @@ model: moe_num_experts: 4 moe_top_k: 2 moe_world_size: 1 - moe_weight_parallelism: false uniform_expert_assignment: false n_heads: 2 n_layers: 2 diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index 260988dc31..b863e1d0a8 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -519,7 +519,6 @@ def _get_model_and_tokenizer( 'moe_num_experts': 4, 'moe_top_k': 2, 'moe_world_size': 1, - 'moe_weight_parallelism': False, 'uniform_expert_assignment': False, }, 'max_seq_len': max_seq_len, @@ -1251,8 +1250,6 @@ def test_mptmoe_huggingface_conversion_callback( 2, 'moe_world_size': 2, - 'moe_weight_parallelism': - False, 'uniform_expert_assignment': True, 'mlp_impl': diff --git a/tests/models/test_mpt_gen.py b/tests/models/test_mpt_gen.py index 134ca35ec0..820da5e71f 100644 --- a/tests/models/test_mpt_gen.py +++ b/tests/models/test_mpt_gen.py @@ -190,7 +190,6 @@ def test_gen_mpt_moe( 'moe_num_experts': 4, 'moe_top_k': 2, 'moe_world_size': 1, - 'moe_weight_parallelism': False, 'uniform_expert_assignment': False, }, )