diff --git a/zoobot/pytorch/training/finetune.py b/zoobot/pytorch/training/finetune.py index 0d277a75..744e4b43 100644 --- a/zoobot/pytorch/training/finetune.py +++ b/zoobot/pytorch/training/finetune.py @@ -256,26 +256,25 @@ def configure_optimizers(self): logging.info('Optimizer ready, configuring scheduler') if self.cosine_schedule: - logging.info('Using cosine schedule, warmup for {} epochs, max for {} epochs'.format(self.warmup_epochs, self.max_cosine_epochs)) - from lightly.utils.scheduler import CosineWarmupScheduler # new dependency for zoobot, TBD - maybe just copy - # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.LightningModule.html#lightning.pytorch.core.LightningModule.configure_optimizers - # Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config. - lr_scheduler = CosineWarmupScheduler( - optimizer=opt, - warmup_epochs=self.warmup_epochs, - max_epochs=self.max_cosine_epochs, - start_value=self.learning_rate, - end_value=self.learning_rate * self.max_learning_rate_reduction_factor, - ) - - # logging.info('Using cosine schedule, warmup not supported, max for {} epochs'.format(self.max_cosine_epochs)) - # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + # logging.info('Using lightly cosine schedule, warmup for {} epochs, max for {} epochs'.format(self.warmup_epochs, self.max_cosine_epochs)) + # from lightly.utils.scheduler import CosineWarmupScheduler # new dependency for zoobot, TBD - maybe just copy + # # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.LightningModule.html#lightning.pytorch.core.LightningModule.configure_optimizers + # # Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config. + # lr_scheduler = CosineWarmupScheduler( # optimizer=opt, - # T_max=self.max_cosine_epochs, - # eta_min=self.learning_rate * self.max_learning_rate_reduction_factor + # warmup_epochs=self.warmup_epochs, + # max_epochs=self.max_cosine_epochs, + # start_value=self.learning_rate, + # end_value=self.learning_rate * self.max_learning_rate_reduction_factor, # ) - # lr_scheduler_config default is frequency=1, interval=epoch + logging.info('Using CosineAnnealingLR schedule, warmup not supported, max for {} epochs'.format(self.max_cosine_epochs)) + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer=opt, + T_max=self.max_cosine_epochs, + eta_min=self.learning_rate * self.max_learning_rate_reduction_factor + ) + return { "optimizer": opt, "lr_scheduler": {