diff --git a/src/accelerate/optimizer.py b/src/accelerate/optimizer.py index 3230aa6606e..1c0a777dcba 100644 --- a/src/accelerate/optimizer.py +++ b/src/accelerate/optimizer.py @@ -125,13 +125,15 @@ def train(self): """ Sets the optimizer to "train" mode. Useful for optimizers like `schedule_free` """ - return self.optimizer.train() + if hasattr(self.optimizer, "train") and callable(self.optimizer.train): + self.optimizer.train() def eval(self): """ Sets the optimizer to "eval" mode. Useful for optimizers like `schedule_free` """ - return self.optimizer.eval() + if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval): + self.optimizer.eval() def step(self, closure=None): if is_lomo_available():