Skip to content

Commit

Permalink
fixed merging conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
gyou2021 committed Jan 20, 2025
2 parents a2ee47a + b442b08 commit 75a1d91
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ def kernel_supported(module_list):
return True
return False

## tp parser based on autoTP config in environment
def tp_parser(model):
policy_list = []
module_list = []
Expand Down Expand Up @@ -323,7 +324,7 @@ def tp_parser(model):
for key, submodule in module._modules.items():
if isinstance(submodule, nn.Linear):
layer_list = layer_list + ["." + key]
elif isinstance(submodule, nn.LayerNorm) or key in norm_layer_name_list:
elif isinstance(submodule, nn.LayerNorm) or key == 'LayerNorm' or key == 'layer_norm':
layer_list = layer_list + ["ln"]
else:
layer_list = layer_list + AutoTP.get_layers(key, submodule)
Expand Down Expand Up @@ -533,7 +534,12 @@ def _replace_module(self, r_module, prev_name='', prev_class_name=''):
if len(child._buffers) != 0 and self.state_dict is not None:
Loading.load_buffer(child, self.state_dict, checking_key)
if child.__class__ in self.linear_policies:
setattr(r_module, name, self.linear_policies[child.__class__](child, prev_name + '.' + name,
keepLinearItems = os.environ['keepLinearItems']
keepLinearItems = ast.literal_eval(keepLinearItems)

if any(item not in checking_key for item in keepLinearItems):
setattr(
r_module, name, self.linear_policies[child.__class__](child, prev_name + '.' + name,
self.conv_linear_layer))
elif any(isinstance(child, lp) for lp in self.linear_policies):
# Added for falcon model support
Expand Down

0 comments on commit 75a1d91

Please sign in to comment.