diff --git a/zoobot/pytorch/datasets/webdatamodule.py b/zoobot/pytorch/datasets/webdatamodule.py index 7943294c..abbb32f1 100644 --- a/zoobot/pytorch/datasets/webdatamodule.py +++ b/zoobot/pytorch/datasets/webdatamodule.py @@ -163,7 +163,7 @@ def val_dataloader(self): return self.make_loader(self.val_urls, mode="val") def test_dataloader(self): - return self.make_loader(self.val_urls, mode="test") + return self.make_loader(self.test_urls, mode="test") def predict_dataloader(self): return self.make_loader(self.predict_urls, mode="predict") diff --git a/zoobot/pytorch/training/train_with_pytorch_lightning.py b/zoobot/pytorch/training/train_with_pytorch_lightning.py index 8caea736..99c96f1e 100644 --- a/zoobot/pytorch/training/train_with_pytorch_lightning.py +++ b/zoobot/pytorch/training/train_with_pytorch_lightning.py @@ -345,16 +345,20 @@ def train_default_zoobot_from_scratch( # also be careful not to test regularly, as this breaks train/val/test conceptual separation and may cause hparam overfitting if datamodule.test_dataloader is not None: logging.info(f'Testing on {checkpoint_callback.best_model_path} with single GPU. Be careful not to overfit your choices to the test data...') - test_trainer.validate( - model=lightning_model, - datamodule=datamodule, - ckpt_path=checkpoint_callback.best_model_path # can optionally point to a specific checkpoint here e.g. "/share/nas2/walml/repos/gz-decals-classifiers/results/early_stopping_1xgpu_greyscale/checkpoints/epoch=26-step=16847.ckpt" - ) + # test_trainer.validate( + # model=lightning_model, + # datamodule=datamodule, + # ckpt_path=checkpoint_callback.best_model_path # can optionally point to a specific checkpoint here e.g. "/share/nas2/walml/repos/gz-decals-classifiers/results/early_stopping_1xgpu_greyscale/checkpoints/epoch=26-step=16847.ckpt" + # ) + datamodule.setup(stage='test') + # temp + print(datamodule.test_urls) test_trainer.test( model=lightning_model, datamodule=datamodule, ckpt_path=checkpoint_callback.best_model_path # can optionally point to a specific checkpoint here e.g. "/share/nas2/walml/repos/gz-decals-classifiers/results/early_stopping_1xgpu_greyscale/checkpoints/epoch=26-step=16847.ckpt" ) + # TODO may need to remake on 1 gpu only # explicitly update the model weights to the best checkpoint before returning # (assumes only one checkpoint callback, very likely in practice)