diff --git a/zoobot/pytorch/training/train_with_pytorch_lightning.py b/zoobot/pytorch/training/train_with_pytorch_lightning.py index 2c9e7524..4cf47b04 100644 --- a/zoobot/pytorch/training/train_with_pytorch_lightning.py +++ b/zoobot/pytorch/training/train_with_pytorch_lightning.py @@ -304,26 +304,11 @@ def train_default_zoobot_from_scratch( lightning_model = TorchSyncBatchNorm().apply(lightning_model) - extra_callbacks = extra_callbacks if extra_callbacks else [] - - monitor_metric = 'validation/supervised_loss' # used later for checkpoint_callback.best_model_path - checkpoint_callback = ModelCheckpoint( - dirpath=os.path.join(save_dir, 'checkpoints'), - monitor=monitor_metric, - save_weights_only=True, - mode='min', - # custom filename for checkpointing due to / in metric - filename=checkpoint_file_template, - # https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.ModelCheckpoint.html#pytorch_lightning.callbacks.ModelCheckpoint.params.auto_insert_metric_name - # avoids extra folders from the checkpoint name - auto_insert_metric_name=auto_insert_metric_name, - save_top_k=save_top_k - ) - - early_stopping_callback = EarlyStopping(monitor=monitor_metric, patience=patience, check_finite=True) - callbacks = [checkpoint_callback, early_stopping_callback] + extra_callbacks + checkpoint_callback, callbacks = get_default_callbacks(save_dir, patience, checkpoint_file_template, auto_insert_metric_name, save_top_k) + if extra_callbacks: + callbacks += extra_callbacks trainer = pl.Trainer( num_sanity_val_steps=0, @@ -368,6 +353,27 @@ def train_default_zoobot_from_scratch( return lightning_model, trainer +def get_default_callbacks(save_dir, patience=8, checkpoint_file_template=None, auto_insert_metric_name=True, save_top_k=3): + + monitor_metric = 'validation/supervised_loss' + + checkpoint_callback = ModelCheckpoint( + dirpath=os.path.join(save_dir, 'checkpoints'), + monitor=monitor_metric, + save_weights_only=True, + mode='min', + # custom filename for checkpointing due to / in metric + filename=checkpoint_file_template, + # https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.ModelCheckpoint.html#pytorch_lightning.callbacks.ModelCheckpoint.params.auto_insert_metric_name + # avoids extra folders from the checkpoint name + auto_insert_metric_name=auto_insert_metric_name, + save_top_k=save_top_k + ) + + early_stopping_callback = EarlyStopping(monitor=monitor_metric, patience=patience, check_finite=True) + callbacks = [checkpoint_callback, early_stopping_callback] + return checkpoint_callback,callbacks + diff --git a/zoobot/shared/schemas.py b/zoobot/shared/schemas.py index b0123fc3..1d58dcf6 100755 --- a/zoobot/shared/schemas.py +++ b/zoobot/shared/schemas.py @@ -296,6 +296,7 @@ def answers(self): # so don't log anything during Schema.__init__! gz_evo_v1_schema = Schema(label_metadata.gz_evo_v1_pairs, label_metadata.gz_evo_v1_dependencies) +gz_evo_v1_public_schema = Schema(label_metadata.gz_evo_v1_public_pairs, label_metadata.gz_evo_v1_public_dependencies) gz_ukidss_schema = Schema(label_metadata.ukidss_ortho_pairs, label_metadata.ukidss_ortho_dependencies) gz_jwst_schema = Schema(label_metadata.jwst_ortho_pairs, label_metadata.jwst_ortho_dependencies)