Skip to content

Commit

Permalink
check example runs with v2
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Mar 20, 2023
1 parent 0b1a095 commit dba7a4d
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions zoobot/pytorch/examples/finetuning/finetune_counts_full_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,20 @@

schema = gz_candels_ortho_schema

# TODO you will want to replace these paths with your own paths
# I'm being a little lazy and leaving my if/else for local/cluster training here,
# this is often convenient for debugging
if os.path.isdir('/share/nas2'): # run on cluster
repo_dir = '/share/nas2/walml/repos'
data_download_dir = '/share/nas2/walml/repos/_data'
data_download_dir = '/share/nas2/walml/repos/_data/demo_gz_candels'
accelerator = 'gpu'
devices = 1
batch_size = 64
prog_bar = False
max_galaxies = None
else: # test locally
repo_dir = '/home/walml/repos'
data_download_dir = '/share/nas2/walml/repos/galaxy-datasets/roots'
repo_dir = '/Users/user/repos'
data_download_dir = '/Users/user/repos/galaxy-datasets/roots/demo_gz_candels'
accelerator = 'cpu'
devices = None
batch_size = 32 # 32 with resize=224, 16 at 380
Expand All @@ -65,8 +68,8 @@
)

checkpoint_loc = os.path.join(
# repo_dir, 'gz-decals-classifiers/results/pytorch/desi/_desi_pytorch_v4_posthp_train_all_test_dr8_m1/checkpoints/epoch=48-step=215159.ckpt') # bad hparams
repo_dir, 'gz-decals-classifiers/results/pytorch/desi/_desi_pytorch_v5_posthp_train_all_test_dr8_decals_hparams_m5/checkpoints/epoch=36-step=20313.ckpt') # decals hparams
# TODO replace with path to downloaded checkpoints. See Zoobot README for download links.
repo_dir, 'gz-decals-classifiers/results/benchmarks/pytorch/evo/uploaded/effnetb0_greyscale_224px.ckpt') # decals hparams

model = finetune.FinetuneableZoobotTree(checkpoint_loc=checkpoint_loc, schema=schema)

Expand All @@ -77,7 +80,7 @@
from pytorch_lightning.loggers import WandbLogger
logger = WandbLogger(project='finetune', name='full_tree_example')

trainer = finetune.get_trainer(save_dir=save_dir, logger=logger)
trainer = finetune.get_trainer(save_dir=save_dir, logger=logger, accelerator=accelerator)
trainer.fit(model, datamodule)

# now save predictions on test set to evaluate performance
Expand Down

0 comments on commit dba7a4d

Please sign in to comment.