Skip to content

Commit

Permalink
use duck-typing to ensure underlying optimizer supports schedulefree …
Browse files Browse the repository at this point in the history
…hooks (#3055)

* use duck-typing to ensure underlying optimizer supports schedulefree hooks

* fixup
  • Loading branch information
tmm1 authored Sep 2, 2024
1 parent 3fcc946 commit 1d09a20
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/accelerate/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,15 @@ def train(self):
"""
Sets the optimizer to "train" mode. Useful for optimizers like `schedule_free`
"""
return self.optimizer.train()
if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
self.optimizer.train()

def eval(self):
"""
Sets the optimizer to "eval" mode. Useful for optimizers like `schedule_free`
"""
return self.optimizer.eval()
if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval):
self.optimizer.eval()

def step(self, closure=None):
if is_lomo_available():
Expand Down

0 comments on commit 1d09a20

Please sign in to comment.