diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index 8bdcf6faa053..96e8b8fafe85 100755 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -15,6 +15,8 @@ from deepspeed.accelerator import get_accelerator from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw, shard_value_with_share_qk, shard_chunk_mlp from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list +import os +import ast def move(tensor, device, copy=True): @@ -279,6 +281,7 @@ def kernel_supported(module_list): return True return False + ## tp parser based on autoTP config in environment def tp_parser(model): policy_list = [] module_list = [] @@ -289,39 +292,69 @@ def tp_parser(model): assert AutoTP.supported(model), "AutoTP not supported for model. Please use kernel injection since container policy for model exists." \ if AutoTP.kernel_supported(module_list) else "AutoTP not supported for model. Please provide policy." norm_layer_name_list = ['LayerNorm', 'layer_norm', 'ln_1', 'ln_2'] + + default_ds_common_reduceLinear_keys = ['out_proj', 'o_proj', 'down_proj'] + #the different model names of the same key are concat with comma + predefined_ds_common_reduceLinear_items = { + 'attention.dense': 'GPTNeoX', + 'self_attention.dense': 'falcon,ChatGLM,Phi', + 'w2': 'Mixtral', + 'dense_4h_to_h': 'ChatGLM' + } + ds_reduceLinear_items = predefined_ds_common_reduceLinear_items + #'DS_ALL_REDUCE_LINEAR_ITEMS' is a dictionary whose keys are layer names of LinearAllreduce and + #whose values are keywords in the module name. + # If the same layer name in multiple models is LinearAllreduce, concat the keywords of the different module names with comma + # import os + # os.environ["DS_ALL_REDUCE_LINEAR_ITEMS"] = "{'layer_name_1':'model_1 keyword','layer_name_2':'model_1 keyword',...}," + # os.environ["DS_ALL_REDUCE_LINEAR_ITEMS"] = "{'layer_name_1':'model-1 keyword,model-2 keyword,...', + # 'layer_name_2':'model-1 keyword,model-2 keyword,...',...}" + # for example: os.environ["DS_ALL_REDUCE_LINEAR_ITEMS"]="{'w2','mixtral'}" + ds_user_reduceLinear_items = os.environ.get('DS_ALL_REDUCE_LINEAR_ITEMS') + if ds_user_reduceLinear_items: + ds_user_reduceLinear_items = ast.literal_eval(ds_user_reduceLinear_items) + ds_reduceLinear_items.update(ds_user_reduceLinear_items) + + ds_reduceLinear_keys = ds_reduceLinear_items.keys() + + #'DS_REMOVED_COMMON_REDUCE_LINEAR_KEYS' is a list. The layer name in the list will be removed from those of default common LinearAllReduce. + # import os + # os.environ["DS_REMOVED_COMMON_REDUCE_LINEAR_KEYS"] = "['layer_name_1', 'layer_name_2',...]" + #for example: os.environ["DS_REMOVED_COMMON_REDUCE_LINEAR_KEYS"] = "['o_proj']" + ds_user_remove_reduceLinear_keys = os.environ.get('DS_REMOVED_COMMON_REDUCE_LINEAR_KEYS') + if ds_user_remove_reduceLinear_keys: + ds_user_remove_reduceLinear_keys = ast.literal_eval(ds_user_remove_reduceLinear_keys) + ds_common_reduceLinear_keys = [ + item for item in default_ds_common_reduceLinear_keys if item not in ds_user_remove_reduceLinear_keys + ] + else: + ds_common_reduceLinear_keys = default_ds_common_reduceLinear_keys + #ln_1 , ln_2 for Qwen for module in module_list: for key, submodule in module._modules.items(): if isinstance(submodule, nn.Linear): layer_list = layer_list + ["." + key] - elif isinstance(submodule, nn.LayerNorm) or key in norm_layer_name_list: + elif isinstance(submodule, nn.LayerNorm) or key == 'LayerNorm' or key == 'layer_norm': layer_list = layer_list + ["ln"] else: layer_list = layer_list + AutoTP.get_layers(key, submodule) + module_name = str(type(module)) for i, layer in enumerate(layer_list): if layer == 'ln': if layer_list[i - 1] != 'ln': gem_list = gem_list + [layer_list[i - 1]] - elif 'out_proj' in layer: - gem_list = gem_list + [layer] - elif 'o_proj' in layer: - gem_list = gem_list + [layer] - elif 'down_proj' in layer: - gem_list = gem_list + [layer] - elif 'attention.dense' in layer and 'GPTNeoX' in str(model): - gem_list = gem_list + [layer] - elif 'self_attention.dense' in layer and 'falcon' in str( - type(module)): # this is a hack to get the right linear layer for this model! - gem_list = gem_list + [layer] - # Mixtral-7x8b used w2*act(w1*w3) linear. need to replace w2 to linearallreduce. - elif 'w2' in layer and 'Mixtral' in str(type(module)): - gem_list = gem_list + [layer] - elif 'self_attn.dense' in layer and 'Phi' in str(type(module)): - gem_list = gem_list + [layer] - elif 'self_attention.dense' in layer and 'ChatGLM' in str(model): - gem_list = gem_list + [layer] - elif 'dense_4h_to_h' in layer and 'ChatGLM' in str(model): + + if any((key in layer) for key in ds_common_reduceLinear_keys): gem_list = gem_list + [layer] + continue + + for key in ds_reduceLinear_keys: + if key in layer: + values = ds_reduceLinear_items[key].split(',') + if any((v in module_name) for v in values): + gem_list = gem_list + [layer] + break layer_list = [] if gem_list != []: @@ -346,9 +379,38 @@ def _replace(self, child, name, conv_linear_layer): weight_shape = child.weight.shape mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group) # For TP layer skip, e.g., MoE gate, deepseek low rank layer skip - if "q_a_proj" in name or "kv_a_proj_with_mqa" in name or name == "block_sparse_moe.gate" or ( - ('mlp.shared_expert_gate' == name or 'mlp.gate' == name) and 'qwen2_moe' in str(type(self.module))): - return child + predefined_keep_Linear_Items = { + 'q_a_proj': 'DeepseekV2', + 'kv_a_proj_with_mqa': 'DeepseekV2', + 'block_sparse_moe.gate': 'DeepseekV2', + 'mlp.shared_expert_gate': 'qwen2_moe', + 'mlp.gate': 'qwen2_moe' + } + keep_Linear_Items = predefined_keep_Linear_Items + #DS_KEEP_LINEAR_ITEMS is a dictionary whose keys are layer names of Linear and + #whose values are keywords in the module name. + #If the same layer name in multiple models keeps Linear, concat the keywords of the different module names with comma + # import os + # one keyword for one model + # os.environ["DS_KEEP_LINEAR_ITEMS"] = "{'layer_name_1':'model_1 keyword','layer_name_2':'model_1 keyword',...}" + # one keyword for multiple models + # os.environ["DS_KEEP_LINEAR_ITEMS"] = "{'layer_name_1':'model_1 keyword,model_2 keyword,...', + # 'layer_name_2':'model_1 keyword,model_2 keyword,...',...}" + #for example: + # os.environ["DS_KEEP_LINEAR_ITEMS"] = "{'gate':'mixtral'}" + user_keep_Linear_Items = os.environ.get('DS_KEEP_LINEAR_ITEMS') + if user_keep_Linear_Items: + user_keep_Linear_Items = ast.literal_eval(user_keep_Linear_Items) + keep_Linear_Items.update(user_keep_Linear_Items) + + keys = keep_Linear_Items.keys() + for item in keys: + if item in name: + values = keep_Linear_Items[item] + values = values.split(',') #the different model names of the same key are concat with comma + if any((v in str(type(self.module))) for v in values): + return child + # For Yuan model if 'Yuan' in str(self.module): if 'v_proj' in name: @@ -494,7 +556,12 @@ def _replace_module(self, r_module, prev_name='', prev_class_name=''): if len(child._buffers) != 0 and self.state_dict is not None: Loading.load_buffer(child, self.state_dict, checking_key) if child.__class__ in self.linear_policies: - setattr(r_module, name, self.linear_policies[child.__class__](child, prev_name + '.' + name, + keepLinearItems = os.environ['keepLinearItems'] + keepLinearItems = ast.literal_eval(keepLinearItems) + + if any(item not in checking_key for item in keepLinearItems): + setattr( + r_module, name, self.linear_policies[child.__class__](child, prev_name + '.' + name, self.conv_linear_layer)) elif any(isinstance(child, lp) for lp in self.linear_policies): # Added for falcon model support