Skip to content

Commit

Permalink
yo
Browse files Browse the repository at this point in the history
  • Loading branch information
snarayan21 committed Aug 31, 2024
1 parent f4d01fd commit 7e2900d
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 59 deletions.
1 change: 0 additions & 1 deletion llmfoundry/models/layers/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,6 @@ def attach_ffn_mb_args(
"""
ffn.experts.mlp.hidden_size = args.ffn_hidden_size
ffn.experts.mlp.expert_parallel_group = expert_parallel_group
ffn.experts.mlp.weight_parallel_group = args.weight_parallel_group


def get_fsdp_submesh_2d(device_mesh: DeviceMesh):
Expand Down
10 changes: 0 additions & 10 deletions llmfoundry/models/utils/mpt_param_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,6 @@ def megablocks_n_total_params(mpt_model) -> int: # type: ignore

moe_world_size = mpt_model.config.ffn_config.get('moe_world_size')

if mpt_model.config.ffn_config.get('moe_weight_parallelism', False):
# If MegaBlocks shards experts, the total sharding world size
# must be increased by the degree to which MegaBlocks shards the
# experts.
mb_args = mpt_model.model.transformer.mb_args
moe_world_size *= mb_args.weight_parallel_group.size()

n_total_params = 0
for module in mpt_model.modules():
if isinstance(
Expand Down Expand Up @@ -109,9 +102,6 @@ def megablocks_n_active_params(mpt_model) -> int: # type: ignore
moe_world_size = mpt_model.config.ffn_config.get('moe_world_size')

local_experts = moe_num_experts / moe_world_size # if local_experts is < 1, then the expert is sharded
if mpt_model.config.ffn_config.get('moe_weight_parallelism', False):
mb_args = mpt_model.model.transformer.mb_args
local_experts /= mb_args.weight_parallel_group.size()

moe_top_k = mpt_model.config.ffn_config.get('moe_top_k', 1)
n_active_params = 0
Expand Down
54 changes: 11 additions & 43 deletions llmfoundry/models/utils/param_init_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,19 +484,12 @@ def _megablocks_sparse_mlp_generic_param_init_fn_(
div_is_residual (float): The value by which parameter initialization is divided
if init_div_is_residual flag is enabled.
"""
expert_process_group_size, rank, weight_parallel_group_size, weight_parallel_group_rank = 1, 0, 1, 0
expert_process_group_size, rank = 1, 0
if module.expert_parallel_group is not None:
expert_process_group_size = int(
module.expert_parallel_group.size(),
) # type: ignore
rank = int(module.expert_parallel_group.rank()) # type: ignore
if module.weight_parallel_group is not None:
weight_parallel_group_size = int(
module.weight_parallel_group.size(),
) # type: ignore
weight_parallel_group_rank = int(
module.weight_parallel_group.rank(),
) # type: ignore

hidden_size = int(module.hidden_size) # type: ignore

Expand All @@ -505,35 +498,29 @@ def _megablocks_sparse_mlp_generic_param_init_fn_(
if isinstance(w1, DTensor):
w1 = w1._local_tensor
w1_size = list(w1.shape) # type: ignore
w1_size[
0] = w1_size[0] * expert_process_group_size * weight_parallel_group_size
w1_size[0] = w1_size[0] * expert_process_group_size

n_exp = w1_size[0] // hidden_size
_fused = (0, [(n + 1) * hidden_size for n in range(n_exp - 1)])

_w1 = w1.new_empty(w1_size) # type: ignore
fused_param_init_helper(_w1, init_fn_, _fused)
_w1_local = _w1.chunk(expert_process_group_size, dim=0)[rank]
_w1_local_slice = _w1_local.chunk(weight_parallel_group_size,
dim=0)[weight_parallel_group_rank]
with torch.no_grad():
w1.copy_(_w1_local_slice) # type: ignore
w1.copy_(_w1_local) # type: ignore

# Initialize w2
w2 = module.w2
if isinstance(w2, DTensor):
w2 = w2._local_tensor
w2_size = list(w2.shape) # type: ignore
w2_size[
0] = w2_size[0] * expert_process_group_size * weight_parallel_group_size
w2_size[0] = w2_size[0] * expert_process_group_size
_w2 = w2.new_empty(w2_size) # type: ignore
# MegaBlocks operates on w2 as x @ w2, so needs flipped fan mode
fused_param_init_helper(_w2, _flip_fan_mode(init_fn_), _fused)
_w2_local = _w2.chunk(expert_process_group_size, dim=0)[rank]
_w2_local_slice = _w2_local.chunk(weight_parallel_group_size,
dim=0)[weight_parallel_group_rank]
with torch.no_grad():
w2.copy_(_w2_local_slice) # type: ignore
w2.copy_(_w2_local) # type: ignore
if init_div_is_residual is not False:
with torch.no_grad():
w2.div_(div_is_residual) # type: ignore
Expand Down Expand Up @@ -567,19 +554,12 @@ def _megablocks_sparse_glu_generic_param_init_fn_(
)

# Init ported from _megablocks_sparse_mlp_generic_param_init_fn_ for v1
expert_process_group_size, rank, weight_parallel_group_size, weight_parallel_group_rank = 1, 0, 1, 0
expert_process_group_size, rank = 1, 0
if module.expert_parallel_group is not None:
expert_process_group_size = int(
module.expert_parallel_group.size(),
) # type: ignore
rank = int(module.expert_parallel_group.rank()) # type: ignore
if module.weight_parallel_group is not None:
weight_parallel_group_size = int(
module.weight_parallel_group.size(),
) # type: ignore
weight_parallel_group_rank = int(
module.weight_parallel_group.rank(),
) # type: ignore

hidden_size = int(module.hidden_size) # type: ignore

Expand All @@ -588,19 +568,16 @@ def _megablocks_sparse_glu_generic_param_init_fn_(
if isinstance(v1, DTensor):
v1 = v1._local_tensor
v1_size = list(v1.shape) # type: ignore
v1_size[
0] = v1_size[0] * expert_process_group_size * weight_parallel_group_size
v1_size[0] = v1_size[0] * expert_process_group_size

n_exp = v1_size[0] // hidden_size
_fused = (0, [(n + 1) * hidden_size for n in range(n_exp - 1)])

_v1 = v1.new_empty(v1_size) # type: ignore
fused_param_init_helper(_v1, init_fn_, _fused)
_v1_local = _v1.chunk(expert_process_group_size, dim=0)[rank]
_v1_local_slice = _v1_local.chunk(weight_parallel_group_size,
dim=0)[weight_parallel_group_rank]
with torch.no_grad():
v1.copy_(_v1_local_slice) # type: ignore
v1.copy_(_v1_local) # type: ignore


def _megablocks_mlp_generic_param_init_fn_(
Expand All @@ -623,41 +600,32 @@ def _megablocks_mlp_generic_param_init_fn_(
div_is_residual (float): The value by which parameter initialization is divided
if init_div_is_residual flag is enabled.
"""
expert_process_group_size, rank, weight_parallel_group_size, w_rank = 1, 0, 1, 0
expert_process_group_size, rank = 1, 0
if module.expert_parallel_group is not None:
expert_process_group_size = int(
module.expert_parallel_group.size(),
) # type: ignore
rank = int(module.expert_parallel_group.rank()) # type: ignore
if module.weight_parallel_group is not None:
weight_parallel_group_size = int(
module.weight_parallel_group.size(),
) # type: ignore
w_rank = int(module.weight_parallel_group.rank()) # type: ignore

_init_fn_ = _flip_fan_mode(init_fn_)

# Initialize w1
w1_size = list(module.w1.shape) # type: ignore
w1_size[0] = w1_size[0] * expert_process_group_size
w1_size[1] = w1_size[1] * weight_parallel_group_size
_w1 = module.w1.new_empty(w1_size) # type: ignore
stacked_param_init_helper(_w1, _init_fn_, module._stack_dim) # type: ignore
_w1_local = _w1.chunk(expert_process_group_size, dim=0)[rank]
_w1_local_slice = _w1_local.chunk(weight_parallel_group_size, dim=1)[w_rank]
with torch.no_grad():
module.w1.copy_(_w1_local_slice) # type: ignore
module.w1.copy_(_w1_local) # type: ignore

# Initialize w2
w2_size = list(module.w2.shape) # type: ignore
w2_size[0] = w2_size[0] * expert_process_group_size
w2_size[1] = w2_size[1] * weight_parallel_group_size
_w2 = module.w2.new_empty(w2_size) # type: ignore
stacked_param_init_helper(_w2, _init_fn_, module._stack_dim) # type: ignore
_w2_local = _w2.chunk(expert_process_group_size, dim=0)[rank]
_w2_local_slice = _w2_local.chunk(weight_parallel_group_size, dim=1)[w_rank]
with torch.no_grad():
module.w2.copy_(_w2_local_slice) # type: ignore
module.w2.copy_(_w2_local) # type: ignore
if init_div_is_residual is not False:
with torch.no_grad():
module.w2.div_(div_is_residual) # type: ignore
Expand Down
1 change: 0 additions & 1 deletion scripts/train/yamls/pretrain/testing-moe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ model:
moe_num_experts: 4
moe_top_k: 2
moe_world_size: 1
moe_weight_parallelism: false
uniform_expert_assignment: false
n_heads: 2
n_layers: 2
Expand Down
3 changes: 0 additions & 3 deletions tests/a_scripts/inference/test_convert_composer_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,6 @@ def _get_model_and_tokenizer(
'moe_num_experts': 4,
'moe_top_k': 2,
'moe_world_size': 1,
'moe_weight_parallelism': False,
'uniform_expert_assignment': False,
},
'max_seq_len': max_seq_len,
Expand Down Expand Up @@ -1251,8 +1250,6 @@ def test_mptmoe_huggingface_conversion_callback(
2,
'moe_world_size':
2,
'moe_weight_parallelism':
False,
'uniform_expert_assignment':
True,
'mlp_impl':
Expand Down
1 change: 0 additions & 1 deletion tests/models/test_mpt_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ def test_gen_mpt_moe(
'moe_num_experts': 4,
'moe_top_k': 2,
'moe_world_size': 1,
'moe_weight_parallelism': False,
'uniform_expert_assignment': False,
},
)
Expand Down

0 comments on commit 7e2900d

Please sign in to comment.