diff --git a/src/accelerate/optimizer.py b/src/accelerate/optimizer.py index cfc6874826d..7c03f74a147 100644 --- a/src/accelerate/optimizer.py +++ b/src/accelerate/optimizer.py @@ -197,7 +197,7 @@ def __setstate__(self, state): self._optimizer_original_step_method = self.optimizer.step self._optimizer_patched_step_method = patch_optimizer_step(self, self.optimizer.step) - def multiply_grads(self, constant) -> None: + def multiply_grads(self, constant): """ Multiplies the gradients of the parameters by a constant. Needed during gradient accumulation.