Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabled configurable auto Tensor Parallelism (TP) for the inference of diverse models #6553

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 91 additions & 24 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from deepspeed.accelerator import get_accelerator
from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw, shard_value_with_share_qk, shard_chunk_mlp
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list
import os
import ast


def move(tensor, device, copy=True):
Expand Down Expand Up @@ -279,6 +281,7 @@ def kernel_supported(module_list):
return True
return False

## tp parser based on autoTP config in environment
def tp_parser(model):
policy_list = []
module_list = []
Expand All @@ -289,39 +292,69 @@ def tp_parser(model):
assert AutoTP.supported(model), "AutoTP not supported for model. Please use kernel injection since container policy for model exists." \
if AutoTP.kernel_supported(module_list) else "AutoTP not supported for model. Please provide policy."
norm_layer_name_list = ['LayerNorm', 'layer_norm', 'ln_1', 'ln_2']

default_ds_common_reduceLinear_keys = ['out_proj', 'o_proj', 'down_proj']
#the different model names of the same key are concat with comma
predefined_ds_common_reduceLinear_items = {
'attention.dense': 'GPTNeoX',
'self_attention.dense': 'falcon,ChatGLM,Phi',
'w2': 'Mixtral',
'dense_4h_to_h': 'ChatGLM'
}
ds_reduceLinear_items = predefined_ds_common_reduceLinear_items
#'DS_ALL_REDUCE_LINEAR_ITEMS' is a dictionary whose keys are layer names of LinearAllreduce and
#whose values are keywords in the module name.
# If the same layer name in multiple models is LinearAllreduce, concat the keywords of the different module names with comma
# import os
# os.environ["DS_ALL_REDUCE_LINEAR_ITEMS"] = "{'layer_name_1':'model_1 keyword','layer_name_2':'model_1 keyword',...},"
# os.environ["DS_ALL_REDUCE_LINEAR_ITEMS"] = "{'layer_name_1':'model-1 keyword,model-2 keyword,...',
# 'layer_name_2':'model-1 keyword,model-2 keyword,...',...}"
# for example: os.environ["DS_ALL_REDUCE_LINEAR_ITEMS"]="{'w2','mixtral'}"
ds_user_reduceLinear_items = os.environ.get('DS_ALL_REDUCE_LINEAR_ITEMS')
delock marked this conversation as resolved.
Show resolved Hide resolved
if ds_user_reduceLinear_items:
ds_user_reduceLinear_items = ast.literal_eval(ds_user_reduceLinear_items)
ds_reduceLinear_items.update(ds_user_reduceLinear_items)

ds_reduceLinear_keys = ds_reduceLinear_items.keys()

#'DS_REMOVED_COMMON_REDUCE_LINEAR_KEYS' is a list. The layer name in the list will be removed from those of default common LinearAllReduce.
# import os
# os.environ["DS_REMOVED_COMMON_REDUCE_LINEAR_KEYS"] = "['layer_name_1', 'layer_name_2',...]"
#for example: os.environ["DS_REMOVED_COMMON_REDUCE_LINEAR_KEYS"] = "['o_proj']"
ds_user_remove_reduceLinear_keys = os.environ.get('DS_REMOVED_COMMON_REDUCE_LINEAR_KEYS')
delock marked this conversation as resolved.
Show resolved Hide resolved
if ds_user_remove_reduceLinear_keys:
ds_user_remove_reduceLinear_keys = ast.literal_eval(ds_user_remove_reduceLinear_keys)
ds_common_reduceLinear_keys = [
item for item in default_ds_common_reduceLinear_keys if item not in ds_user_remove_reduceLinear_keys
]
else:
ds_common_reduceLinear_keys = default_ds_common_reduceLinear_keys

#ln_1 , ln_2 for Qwen
for module in module_list:
for key, submodule in module._modules.items():
if isinstance(submodule, nn.Linear):
layer_list = layer_list + ["." + key]
elif isinstance(submodule, nn.LayerNorm) or key in norm_layer_name_list:
elif isinstance(submodule, nn.LayerNorm) or key == 'LayerNorm' or key == 'layer_norm':
gyou2021 marked this conversation as resolved.
Show resolved Hide resolved
layer_list = layer_list + ["ln"]
else:
layer_list = layer_list + AutoTP.get_layers(key, submodule)
module_name = str(type(module))
for i, layer in enumerate(layer_list):
if layer == 'ln':
if layer_list[i - 1] != 'ln':
gem_list = gem_list + [layer_list[i - 1]]
elif 'out_proj' in layer:
gem_list = gem_list + [layer]
elif 'o_proj' in layer:
gem_list = gem_list + [layer]
elif 'down_proj' in layer:
gem_list = gem_list + [layer]
elif 'attention.dense' in layer and 'GPTNeoX' in str(model):
gem_list = gem_list + [layer]
elif 'self_attention.dense' in layer and 'falcon' in str(
type(module)): # this is a hack to get the right linear layer for this model!
gem_list = gem_list + [layer]
# Mixtral-7x8b used w2*act(w1*w3) linear. need to replace w2 to linearallreduce.
elif 'w2' in layer and 'Mixtral' in str(type(module)):
gem_list = gem_list + [layer]
elif 'self_attn.dense' in layer and 'Phi' in str(type(module)):
gem_list = gem_list + [layer]
elif 'self_attention.dense' in layer and 'ChatGLM' in str(model):
gem_list = gem_list + [layer]
elif 'dense_4h_to_h' in layer and 'ChatGLM' in str(model):

if any((key in layer) for key in ds_common_reduceLinear_keys):
gem_list = gem_list + [layer]
continue

for key in ds_reduceLinear_keys:
if key in layer:
values = ds_reduceLinear_items[key].split(',')
if any((v in module_name) for v in values):
gem_list = gem_list + [layer]
break

layer_list = []
if gem_list != []:
Expand All @@ -346,9 +379,38 @@ def _replace(self, child, name, conv_linear_layer):
weight_shape = child.weight.shape
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)
# For TP layer skip, e.g., MoE gate, deepseek low rank layer skip
if "q_a_proj" in name or "kv_a_proj_with_mqa" in name or name == "block_sparse_moe.gate" or (
('mlp.shared_expert_gate' == name or 'mlp.gate' == name) and 'qwen2_moe' in str(type(self.module))):
return child
predefined_keep_Linear_Items = {
'q_a_proj': 'DeepseekV2',
'kv_a_proj_with_mqa': 'DeepseekV2',
'block_sparse_moe.gate': 'DeepseekV2',
'mlp.shared_expert_gate': 'qwen2_moe',
'mlp.gate': 'qwen2_moe'
}
keep_Linear_Items = predefined_keep_Linear_Items
#DS_KEEP_LINEAR_ITEMS is a dictionary whose keys are layer names of Linear and
#whose values are keywords in the module name.
#If the same layer name in multiple models keeps Linear, concat the keywords of the different module names with comma
# import os
# one keyword for one model
# os.environ["DS_KEEP_LINEAR_ITEMS"] = "{'layer_name_1':'model_1 keyword','layer_name_2':'model_1 keyword',...}"
# one keyword for multiple models
# os.environ["DS_KEEP_LINEAR_ITEMS"] = "{'layer_name_1':'model_1 keyword,model_2 keyword,...',
# 'layer_name_2':'model_1 keyword,model_2 keyword,...',...}"
#for example:
# os.environ["DS_KEEP_LINEAR_ITEMS"] = "{'gate':'mixtral'}"
user_keep_Linear_Items = os.environ.get('DS_KEEP_LINEAR_ITEMS')
delock marked this conversation as resolved.
Show resolved Hide resolved
if user_keep_Linear_Items:
user_keep_Linear_Items = ast.literal_eval(user_keep_Linear_Items)
keep_Linear_Items.update(user_keep_Linear_Items)

keys = keep_Linear_Items.keys()
for item in keys:
if item in name:
values = keep_Linear_Items[item]
values = values.split(',') #the different model names of the same key are concat with comma
if any((v in str(type(self.module))) for v in values):
return child

# For Yuan model
if 'Yuan' in str(self.module):
if 'v_proj' in name:
Expand Down Expand Up @@ -494,7 +556,12 @@ def _replace_module(self, r_module, prev_name='', prev_class_name=''):
if len(child._buffers) != 0 and self.state_dict is not None:
Loading.load_buffer(child, self.state_dict, checking_key)
if child.__class__ in self.linear_policies:
setattr(r_module, name, self.linear_policies[child.__class__](child, prev_name + '.' + name,
keepLinearItems = os.environ['keepLinearItems']
keepLinearItems = ast.literal_eval(keepLinearItems)

if any(item not in checking_key for item in keepLinearItems):
setattr(
r_module, name, self.linear_policies[child.__class__](child, prev_name + '.' + name,
self.conv_linear_layer))
elif any(isinstance(child, lp) for lp in self.linear_policies):
# Added for falcon model support
Expand Down