diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 86305312..bc8c615a 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -546,7 +546,6 @@ def step(self, closure=None): total_norms[group_name] = self._compute_norm_with_stage( group_id=group_id, last_bucket=True, - last_stage=True, previous_norm=groups_norms[group_id], ) @@ -620,9 +619,7 @@ def _step(self, closure=None, norms=None): # the following operations are performed only on the rank to which parameters are assigned. if len(single_grad_partition_groups) != 0 and self._clip_grad_norm > 0: - self._unscale_and_clip_grads( - single_grad_partition_groups, list(global_norm_groups.values()), loss_scale - ) + self._unscale_and_clip_grads(single_grad_partition_groups, list(global_norm_groups.values()), loss_scale) # update the parameters timer("step").start()