Skip to content

Commit

Permalink
Merge branch 'dev' into inigo_dev
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed May 30, 2024
2 parents e0c6c8c + cea55ac commit 04b9c9f
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 18 deletions.
42 changes: 24 additions & 18 deletions zoobot/pytorch/training/train_with_pytorch_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,26 +304,11 @@ def train_default_zoobot_from_scratch(
lightning_model = TorchSyncBatchNorm().apply(lightning_model)


extra_callbacks = extra_callbacks if extra_callbacks else []

monitor_metric = 'validation/supervised_loss'

# used later for checkpoint_callback.best_model_path
checkpoint_callback = ModelCheckpoint(
dirpath=os.path.join(save_dir, 'checkpoints'),
monitor=monitor_metric,
save_weights_only=True,
mode='min',
# custom filename for checkpointing due to / in metric
filename=checkpoint_file_template,
# https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.ModelCheckpoint.html#pytorch_lightning.callbacks.ModelCheckpoint.params.auto_insert_metric_name
# avoids extra folders from the checkpoint name
auto_insert_metric_name=auto_insert_metric_name,
save_top_k=save_top_k
)

early_stopping_callback = EarlyStopping(monitor=monitor_metric, patience=patience, check_finite=True)
callbacks = [checkpoint_callback, early_stopping_callback] + extra_callbacks
checkpoint_callback, callbacks = get_default_callbacks(save_dir, patience, checkpoint_file_template, auto_insert_metric_name, save_top_k)
if extra_callbacks:
callbacks += extra_callbacks

trainer = pl.Trainer(
num_sanity_val_steps=0,
Expand Down Expand Up @@ -368,6 +353,27 @@ def train_default_zoobot_from_scratch(

return lightning_model, trainer

def get_default_callbacks(save_dir, patience=8, checkpoint_file_template=None, auto_insert_metric_name=True, save_top_k=3):

monitor_metric = 'validation/supervised_loss'

checkpoint_callback = ModelCheckpoint(
dirpath=os.path.join(save_dir, 'checkpoints'),
monitor=monitor_metric,
save_weights_only=True,
mode='min',
# custom filename for checkpointing due to / in metric
filename=checkpoint_file_template,
# https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.ModelCheckpoint.html#pytorch_lightning.callbacks.ModelCheckpoint.params.auto_insert_metric_name
# avoids extra folders from the checkpoint name
auto_insert_metric_name=auto_insert_metric_name,
save_top_k=save_top_k
)

early_stopping_callback = EarlyStopping(monitor=monitor_metric, patience=patience, check_finite=True)
callbacks = [checkpoint_callback, early_stopping_callback]
return checkpoint_callback,callbacks




Expand Down
1 change: 1 addition & 0 deletions zoobot/shared/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def answers(self):
# so don't log anything during Schema.__init__!

gz_evo_v1_schema = Schema(label_metadata.gz_evo_v1_pairs, label_metadata.gz_evo_v1_dependencies)
gz_evo_v1_public_schema = Schema(label_metadata.gz_evo_v1_public_pairs, label_metadata.gz_evo_v1_public_dependencies)

gz_ukidss_schema = Schema(label_metadata.ukidss_ortho_pairs, label_metadata.ukidss_ortho_dependencies)
gz_jwst_schema = Schema(label_metadata.jwst_ortho_pairs, label_metadata.jwst_ortho_dependencies)
Expand Down

0 comments on commit 04b9c9f

Please sign in to comment.