Skip to content

Commit

Permalink
Merge pull request #90 from mwalmsley/dev
Browse files Browse the repository at this point in the history
Support Lightning v2.0.0
  • Loading branch information
mwalmsley authored Mar 20, 2023
2 parents 2d32d46 + dba7a4d commit b9f9812
Show file tree
Hide file tree
Showing 9 changed files with 137 additions and 98 deletions.
6 changes: 3 additions & 3 deletions benchmarks/pytorch/run_benchmarks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ SEED=$RANDOM

# GZ Evo i.e. all galaxies
# effnet, greyscale and color
# sbatch --job-name=evo_py_gr_eff_224_$SEED --export=ARCHITECTURE=efficientnet_b0,BATCH_SIZE=256,RESIZE_AFTER_CROP=224,DATASET=gz_evo,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB
sbatch --job-name=evo_py_gr_eff_224_$SEED --export=ARCHITECTURE=efficientnet_b0,BATCH_SIZE=256,RESIZE_AFTER_CROP=224,DATASET=gz_evo,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB
# sbatch --job-name=evo_py_gr_eff_300_$SEED --export=ARCHITECTURE=efficientnet_b0,BATCH_SIZE=256,RESIZE_AFTER_CROP=300,DATASET=gz_evo,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB
sbatch --job-name=evo_py_co_eff_224_$SEED --export=ARCHITECTURE=efficientnet_b0,BATCH_SIZE=256,RESIZE_AFTER_CROP=224,DATASET=gz_evo,COLOR_STRING=--color,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB
sbatch --job-name=evo_py_co_eff_300_$SEED --export=ARCHITECTURE=efficientnet_b0,BATCH_SIZE=256,RESIZE_AFTER_CROP=300,DATASET=gz_evo,COLOR_STRING=--color,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB
# sbatch --job-name=evo_py_co_eff_224_$SEED --export=ARCHITECTURE=efficientnet_b0,BATCH_SIZE=256,RESIZE_AFTER_CROP=224,DATASET=gz_evo,COLOR_STRING=--color,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB
# sbatch --job-name=evo_py_co_eff_300_$SEED --export=ARCHITECTURE=efficientnet_b0,BATCH_SIZE=256,RESIZE_AFTER_CROP=300,DATASET=gz_evo,COLOR_STRING=--color,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB
# and resnet18
# sbatch --job-name=evo_py_gr_res18_224_$SEED --export=ARCHITECTURE=resnet18,BATCH_SIZE=256,RESIZE_AFTER_CROP=224,DATASET=gz_evo,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB
# sbatch --job-name=evo_py_gr_res18_300_$SEED --export=ARCHITECTURE=resnet18,BATCH_SIZE=256,RESIZE_AFTER_CROP=300,DATASET=gz_evo,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB
Expand Down
4 changes: 4 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ method.
.. install sphinx https://www.sphinx-doc.org/en/master/usage/installation.html is confusing, you can just use pip install -U sphinx
.. run from in docs folder: make html
.. can also check docs with
.. make linkcheck
.. (thanks, BS!)
.. docs/autodoc contains the tree that sphinx uses to add automatic documentation
.. it needs folders and files matching the python source
.. you will need to add a new {folder}.rst, a new folder, and a new {file}.rst
10 changes: 5 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="zoobot",
version="1.0.0",
version="1.0.1",
author="Mike Walmsley",
author_email="[email protected]",
description="Galaxy morphology classifiers",
Expand All @@ -28,7 +28,7 @@
'torch == 1.12.1+cpu',
'torchvision == 0.13.1+cpu',
'torchaudio == 0.12.1',
'pytorch-lightning==1.9.4', # tensorboard/protobuf issue fixed now
'pytorch-lightning >= 2.0.0',
'simplejpeg',
'albumentations',
'pyro-ppl == 1.8.0',
Expand All @@ -41,7 +41,7 @@
'torch == 1.12.1',
'torchvision == 0.13.1',
'torchaudio == 0.12.1',
'pytorch-lightning==1.9.4', # tensorboard/protobuf issue fixed now
'pytorch-lightning >= 2.0.0',
'simplejpeg',
'albumentations',
'pyro-ppl == 1.8.0',
Expand All @@ -55,15 +55,15 @@
'torch == 1.12.1+cu113',
'torchvision == 0.13.1+cu113',
'torchaudio == 0.12.1',
'pytorch-lightning>=1.9.4',
'pytorch-lightning >= 2.0.0',
'simplejpeg',
'albumentations',
'pyro-ppl == 1.8.0',
'torchmetrics == 0.11.0',
'timm'
],
'pytorch_colab': [
'pytorch-lightning>=1.9.4',
'pytorch-lightning >= 2.0.0',
'simplejpeg',
'albumentations',
'pyro-ppl>=1.8.0',
Expand Down
8 changes: 5 additions & 3 deletions zoobot/pytorch/estimators/define_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,19 +85,19 @@ def configure_optimizers(self):
def training_step(self, batch, batch_idx):
return self.make_step(batch, batch_idx, step_name='train')

def training_step_end(self, outputs):
def on_training_batch_end(self, outputs, *args):
self.log_outputs(outputs, step_name='train')

def validation_step(self, batch, batch_idx):
return self.make_step(batch, batch_idx, step_name='validation')

def validation_step_end(self, outputs):
def on_validation_batch_end(self, outputs, *args):
self.log_outputs(outputs, step_name='validation')

def test_step(self, batch, batch_idx):
return self.make_step(batch, batch_idx, step_name='test')

def test_step_end(self, outputs):
def on_test_batch_end(self, outputs, *args):
self.log_outputs(outputs, step_name='test')


Expand Down Expand Up @@ -248,6 +248,8 @@ def log_outputs(self, outputs, step_name):
def log_loss_per_question(self, multiq_loss, prefix):
# log questions individually
# TODO need schema attribute or similar to have access to question names, this will do for now
# unlike Finetuneable..., does not use TorchMetrics, simply logs directly
# TODO could use TorchMetrics and for q in schema, self.q_metric loop
for question_n in range(multiq_loss.shape[1]):
self.log(f'{prefix}/epoch_questions/question_{question_n}_loss:0', torch.mean(multiq_loss[:, question_n]), on_epoch=True, on_step=False, sync_dist=True)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# load in catalogs of images and labels to finetune on
# each catalog should be a dataframe with columns of "id_str", "file_loc", and any labels
# here I'm using galaxy-datasets to download some premade data - check it out for examples
data_dir = '/Users/user/repos/galaxy-datasets/roots/demo_rings'
data_dir = '/Users/user/repos/galaxy-datasets/roots/demo_rings' # TODO set to any directory. rings dataset will be downloaded here
train_catalog, _ = demo_rings(root=data_dir, download=True, train=True)
test_catalog, _ = demo_rings(root=data_dir, download=True, train=False)

Expand All @@ -35,7 +35,7 @@

datamodule = GalaxyDataModule(
label_cols=label_cols,
catalog=train_catalog,
catalog=train_catalog, # very small, as a demo
batch_size=32
)
# datamodule.setup()
Expand Down
15 changes: 9 additions & 6 deletions zoobot/pytorch/examples/finetuning/finetune_counts_full_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,20 @@

schema = gz_candels_ortho_schema

# TODO you will want to replace these paths with your own paths
# I'm being a little lazy and leaving my if/else for local/cluster training here,
# this is often convenient for debugging
if os.path.isdir('/share/nas2'): # run on cluster
repo_dir = '/share/nas2/walml/repos'
data_download_dir = '/share/nas2/walml/repos/_data'
data_download_dir = '/share/nas2/walml/repos/_data/demo_gz_candels'
accelerator = 'gpu'
devices = 1
batch_size = 64
prog_bar = False
max_galaxies = None
else: # test locally
repo_dir = '/home/walml/repos'
data_download_dir = '/share/nas2/walml/repos/galaxy-datasets/roots'
repo_dir = '/Users/user/repos'
data_download_dir = '/Users/user/repos/galaxy-datasets/roots/demo_gz_candels'
accelerator = 'cpu'
devices = None
batch_size = 32 # 32 with resize=224, 16 at 380
Expand All @@ -65,8 +68,8 @@
)

checkpoint_loc = os.path.join(
# repo_dir, 'gz-decals-classifiers/results/pytorch/desi/_desi_pytorch_v4_posthp_train_all_test_dr8_m1/checkpoints/epoch=48-step=215159.ckpt') # bad hparams
repo_dir, 'gz-decals-classifiers/results/pytorch/desi/_desi_pytorch_v5_posthp_train_all_test_dr8_decals_hparams_m5/checkpoints/epoch=36-step=20313.ckpt') # decals hparams
# TODO replace with path to downloaded checkpoints. See Zoobot README for download links.
repo_dir, 'gz-decals-classifiers/results/benchmarks/pytorch/evo/uploaded/effnetb0_greyscale_224px.ckpt') # decals hparams

model = finetune.FinetuneableZoobotTree(checkpoint_loc=checkpoint_loc, schema=schema)

Expand All @@ -77,7 +80,7 @@
from pytorch_lightning.loggers import WandbLogger
logger = WandbLogger(project='finetune', name='full_tree_example')

trainer = finetune.get_trainer(save_dir=save_dir, logger=logger)
trainer = finetune.get_trainer(save_dir=save_dir, logger=logger, accelerator=accelerator)
trainer.fit(model, datamodule)

# now save predictions on test set to evaluate performance
Expand Down
Loading

0 comments on commit b9f9812

Please sign in to comment.