Skip to content

Commit

Permalink
Update AutoTP default set lm_head tp
Browse files Browse the repository at this point in the history
  • Loading branch information
Yejing-Lai committed Jan 8, 2025
1 parent c9124f8 commit aec4e3f
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,9 @@ def conv2d_parallel_shard_weights(model, rank, world_size):
orig_class=orig_layer_impl,
replace_fn=replace_fn,
_replace_policy=config.injection_policy_tuple)
replaced_module = set_lm_head(replaced_module)
# AutoTP default set lm_head tp
if not config.replace_with_kernel_inject:
replaced_module = set_lm_head(replaced_module)

quantizer = GroupQuantizer(q_int8=quantize)
world_size = dist.get_world_size() if dist.is_initialized() else 1
Expand Down

0 comments on commit aec4e3f

Please sign in to comment.