diff --git a/zoobot/pytorch/estimators/define_model.py b/zoobot/pytorch/estimators/define_model.py index 30c3481a..2524a116 100755 --- a/zoobot/pytorch/estimators/define_model.py +++ b/zoobot/pytorch/estimators/define_model.py @@ -127,7 +127,7 @@ def on_validation_epoch_end(self) -> None: def on_test_epoch_end(self) -> None: logging.info('start test epoch end') - self.log_all_metrics(subset='test') + # self.log_all_metrics(subset='test') logging.info('end test epoch end') def calculate_loss_and_update_loss_metrics(self, predictions, labels, step_name): @@ -140,7 +140,7 @@ def log_all_metrics(self, subset=None): if subset is not None: for name, metric in self.loss_metrics.items(): if subset in name: - print('logging', name) + logging.info(name) self.log(name, metric, on_epoch=True, on_step=False, prog_bar=True, logger=True) else: # just log everything self.log_dict(self.loss_metrics, on_epoch=True, on_step=False, prog_bar=True, logger=True) diff --git a/zoobot/pytorch/training/train_with_pytorch_lightning.py b/zoobot/pytorch/training/train_with_pytorch_lightning.py index 96b0e503..6f464ab8 100644 --- a/zoobot/pytorch/training/train_with_pytorch_lightning.py +++ b/zoobot/pytorch/training/train_with_pytorch_lightning.py @@ -335,45 +335,45 @@ def train_default_zoobot_from_scratch( # can test as per the below, but note that datamodule must have a test dataset attribute as per pytorch lightning docs. # also be careful not to test regularly, as this breaks train/val/test conceptual separation and may cause hparam overfitting - if datamodule.test_dataloader is not None: - logging.info(f'Testing on {checkpoint_callback.best_model_path} with single GPU. Be careful not to overfit your choices to the test data...') - # test_trainer.validate( - # model=lightning_model, - # datamodule=datamodule, - # ckpt_path=checkpoint_callback.best_model_path # can optionally point to a specific checkpoint here e.g. "/share/nas2/walml/repos/gz-decals-classifiers/results/early_stopping_1xgpu_greyscale/checkpoints/epoch=26-step=16847.ckpt" - # ) - test_trainer = pl.Trainer( - accelerator=accelerator, - devices=1, - precision=precision, - logger=wandb_logger, - default_root_dir=save_dir - ) - if test_trainer.is_global_zero: - test_datamodule = webdatamodule.WebDataModule( - train_urls=None, - val_urls=None, - test_urls=test_urls, - label_cols=schema.label_cols, - batch_size=batch_size, - num_workers=1, # 20 / 5 == 4, /2=2 - prefetch_factor=prefetch_factor, - cache_dir=None, - color=color, - crop_scale_bounds=crop_scale_bounds, - crop_ratio_bounds=crop_ratio_bounds, - resize_after_crop=resize_after_crop - ) - test_datamodule.setup(stage='test') - test_trainer.test( - model=lightning_model, - datamodule=test_datamodule, - ckpt_path=checkpoint_callback.best_model_path # can optionally point to a specific checkpoint here e.g. "/share/nas2/walml/repos/gz-decals-classifiers/results/early_stopping_1xgpu_greyscale/checkpoints/epoch=26-step=16847.ckpt" - ) - else: - logging.info('Not global zero, skipping test metrics') - else: - logging.info('No test dataloader found, skipping test metrics') + # if datamodule.test_dataloader is not None: + # logging.info(f'Testing on {checkpoint_callback.best_model_path} with single GPU. Be careful not to overfit your choices to the test data...') + # # test_trainer.validate( + # # model=lightning_model, + # # datamodule=datamodule, + # # ckpt_path=checkpoint_callback.best_model_path # can optionally point to a specific checkpoint here e.g. "/share/nas2/walml/repos/gz-decals-classifiers/results/early_stopping_1xgpu_greyscale/checkpoints/epoch=26-step=16847.ckpt" + # # ) + # test_trainer = pl.Trainer( + # accelerator=accelerator, + # devices=1, + # precision=precision, + # logger=wandb_logger, + # default_root_dir=save_dir + # ) + # if test_trainer.is_global_zero: + # test_datamodule = webdatamodule.WebDataModule( + # train_urls=None, + # val_urls=None, + # test_urls=test_urls, + # label_cols=schema.label_cols, + # batch_size=batch_size, + # num_workers=1, # 20 / 5 == 4, /2=2 + # prefetch_factor=prefetch_factor, + # cache_dir=None, + # color=color, + # crop_scale_bounds=crop_scale_bounds, + # crop_ratio_bounds=crop_ratio_bounds, + # resize_after_crop=resize_after_crop + # ) + datamodule.setup(stage='test') + trainer.test( + model=lightning_model, + datamodule=datamodule, + ckpt_path=checkpoint_callback.best_model_path # can optionally point to a specific checkpoint here e.g. "/share/nas2/walml/repos/gz-decals-classifiers/results/early_stopping_1xgpu_greyscale/checkpoints/epoch=26-step=16847.ckpt" + ) + # else: + # logging.info('Not global zero, skipping test metrics') + # else: + # logging.info('No test dataloader found, skipping test metrics') # explicitly update the model weights to the best checkpoint before returning