diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index d2c894c9..247f8212 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -590,14 +590,14 @@ def step(self, closure=None): if param.grad is not None: self._store_and_try_reduce_grads_by_bucket(param) - # we need to reduce the gradients left in the communication bucket - for group_id in range(self.num_param_groups): - self._reduce_grads_stored_in_bucket(self._bucket_store[group_id], reduce_rank=None, last_bucket=True) - # we need to accumulate gradients left in the accumulate gardient bucket for group_id in range(self.num_param_groups): self._accum_grads_store_in_bucket(self._accum_grad_buckets[group_id], reduce_rank=None) + # we need to reduce the gradients left in the communication bucket + for group_id in range(self.num_param_groups): + self._reduce_grads_stored_in_bucket(self._bucket_store[group_id], reduce_rank=None, last_bucket=True) + # compute norm for gradients in the before bucket groups_norms = [] for group_id in range(self.num_param_groups):