diff --git a/zoobot/pytorch/training/finetune.py b/zoobot/pytorch/training/finetune.py index 9f818d89..31f4ff06 100644 --- a/zoobot/pytorch/training/finetune.py +++ b/zoobot/pytorch/training/finetune.py @@ -255,20 +255,20 @@ def configure_optimizers(self): if self.cosine_schedule: logging.info('Using cosine schedule, warmup for {} epochs, max for {} epochs'.format(self.warmup_epochs, self.max_cosine_epochs)) from lightly.utils.scheduler import CosineWarmupScheduler # new dependency for zoobot, TBD - maybe just copy + # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.LightningModule.html#lightning.pytorch.core.LightningModule.configure_optimizers + # Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config. + lr_scheduler = CosineWarmupScheduler( + optimizer=opt, + warmup_epochs=self.warmup_epochs, + max_epochs=self.max_cosine_epochs, + start_value=self.learning_rate, + end_value=self.learning_rate * self.max_learning_rate_reduction_factor, + ) + # lr_scheduler_config default is frequency=1, interval=epoch return { "optimizer": opt, - "lr_scheduler": { - "scheduler": CosineWarmupScheduler( - optimizer=opt, - warmup_epochs=self.warmup_epochs, - max_epochs=self.max_cosine_epochs, - start_value=self.learning_rate, - end_value=self.learning_rate * self.max_learning_rate_reduction_factor, - ), - 'interval': 'epoch', - "frequency": 1 - } - } + "lr_scheduler": lr_scheduler + } else: return opt