Skip to content

Commit

Permalink
major fixes to finetuning
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Oct 17, 2023
1 parent 0132b8e commit 8a1e2ff
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 73 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,13 @@
'pandas',
'scipy',
'astropy', # for reading fits
'scikit-image >= 0.19.2',
# 'scikit-image >= 0.19.2', # TODO remove
'scikit-learn >= 1.0.2',
'matplotlib',
'pyarrow', # to read parquet, which is very handy for big datasets
# for saving metrics to weights&biases (cloud service, free within limits)
'wandb',
'setuptools==59.5.0', # wandb logger incompatibility
'galaxy-datasets==0.0.14' # for dataset loading in both TF and Torch (renamed from pytorch-galaxy-datasets)
'galaxy-datasets==0.0.15' # for dataset loading in both TF and Torch (renamed from pytorch-galaxy-datasets)
]
)
Original file line number Diff line number Diff line change
@@ -1,50 +1,51 @@
import logging
import os

import pandas as pd

from zoobot.pytorch.training import finetune
from galaxy_datasets import demo_rings
from galaxy_datasets import galaxy_mnist
from galaxy_datasets.pytorch.galaxy_datamodule import GalaxyDataModule


if __name__ == '__main__':

logging.basicConfig(level=logging.INFO)

zoobot_dir = '/Users/user/repos/zoobot' # TODO set to directory where you cloned Zoobot
zoobot_dir = '/home/walml/repos/zoobot' # TODO set to directory where you cloned Zoobot
data_dir = '/home/walml/repos/galaxy-datasets/roots/galaxy_mnist' # TODO set to any directory. rings dataset will be downloaded here
batch_size = 32
num_workers= 8
n_blocks = 1 # EffnetB0 is divided into 7 blocks. set 0 to only fit the head weights. Set 1, 2, etc to finetune deeper.
max_epochs = 6 # 6 epochs should get you ~93% accuracy. Set much higher (e.g. 1000) for harder problems, to use Zoobot's default early stopping.

# load in catalogs of images and labels to finetune on
# each catalog should be a dataframe with columns of "id_str", "file_loc", and any labels
# here I'm using galaxy-datasets to download some premade data - check it out for examples
data_dir = '/Users/user/repos/galaxy-datasets/roots/demo_rings' # TODO set to any directory. rings dataset will be downloaded here
train_catalog, _ = demo_rings(root=data_dir, download=True, train=True)
test_catalog, _ = demo_rings(root=data_dir, download=True, train=False)

train_catalog, _ = galaxy_mnist(root=data_dir, download=False, train=True)
test_catalog, _ = galaxy_mnist(root=data_dir, download=False, train=False)

# wondering about "label_cols"?
# This is a list of catalog columns which should be used as labels
# Here:
# TODO should use Galaxy MNIST as my example here
label_cols = ['ring']
# For binary classification, the label column should have binary (0 or 1) labels for your classes
import numpy as np
# 0, 1, 2
train_catalog['ring'] = np.random.randint(low=0, high=3, size=len(train_catalog))

# TODO
# To support more complicated labels, Zoobot expects a list of columns. A list with one element works fine.

# Here, it's a single column, 'label', with values 0-3 (for each of the 4 classes)
label_cols = ['label']
num_classes = 4

# load a pretrained checkpoint saved here
checkpoint_loc = os.path.join(zoobot_dir, 'data/pretrained_models/pytorch/effnetb0_greyscale_224px.ckpt')
# checkpoint_loc = '/Users/user/repos/gz-decals-classifiers/results/benchmarks/pytorch/dr5/dr5_py_gr_15366/checkpoints/epoch=58-step=18939.ckpt'

# save the finetuning results here
save_dir = os.path.join(zoobot_dir, 'results/pytorch/finetune/finetune_multiclass_classification')

datamodule = GalaxyDataModule(
label_cols=label_cols,
catalog=train_catalog, # very small, as a demo
batch_size=32
batch_size=batch_size, # increase for faster training, decrease to avoid out-of-memory errors
num_workers=num_workers # TODO set to a little less than num. CPUs
)
# datamodule.setup()
datamodule.setup()
# optionally, check the data loads and looks okay
# for images, labels in datamodule.train_dataloader():
# print(images.shape)
# print(labels.shape)
Expand All @@ -53,31 +54,38 @@

model = finetune.FinetuneableZoobotClassifier(
checkpoint_loc=checkpoint_loc,
num_classes=3,
n_layers=0 # only updating the head weights. Set e.g. 1, 2 to finetune deeper.
num_classes=num_classes,
n_blocks=n_blocks
)
# under the hood, this does:
# encoder = finetune.load_pretrained_encoder(checkpoint_loc)
# model = finetune.FinetuneableZoobotClassifier(encoder=encoder, ...)

# retrain to find rings
trainer = finetune.get_trainer(save_dir, accelerator='cpu', max_epochs=1)
trainer = finetune.get_trainer(save_dir, accelerator='auto', max_epochs=max_epochs)
trainer.fit(model, datamodule)
# can now use this model or saved checkpoint to make predictions on new data. Well done!

# see how well the model performs
# (don't do this all the time)
trainer.test(model, datamodule)

# we can load the model later any time
# pretending we want to load from scratch:
best_checkpoint = trainer.checkpoint_callback.best_model_path
finetuned_model = finetune.FinetuneableZoobotClassifier.load_from_checkpoint(best_checkpoint)

from zoobot.pytorch.predictions import predict_on_catalog

predictions_save_loc = os.path.join(save_dir, 'finetuned_predictions.csv')
predict_on_catalog.predict(
test_catalog,
finetuned_model,
n_samples=1,
label_cols=label_cols,
save_loc=os.path.join(save_dir, 'finetuned_predictions.csv')
# trainer_kwargs={'accelerator': 'gpu'}
save_loc=predictions_save_loc,
trainer_kwargs={'accelerator': 'auto'},
datamodule_kwargs={'batch_size': batch_size, 'num_workers': num_workers}
)
"""
Under the hood, this is essentially doing:
Expand All @@ -91,4 +99,9 @@
)
preds = predict_trainer.predict(finetuned_model, predict_datamodule)
print(preds)
"""
"""

predictions = pd.read_csv(predictions_save_loc)
print(predictions)

exit() # now over to you!
2 changes: 1 addition & 1 deletion zoobot/pytorch/predictions/predict_on_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from galaxy_datasets.pytorch.galaxy_datamodule import GalaxyDataModule


def predict(catalog: pd.DataFrame, model: pl.LightningModule, n_samples: int, label_cols: List, save_loc: str, datamodule_kwargs={}, trainer_kwargs={}):
def predict(catalog: pd.DataFrame, model: pl.LightningModule, n_samples: int, label_cols: List, save_loc: str, datamodule_kwargs={}, trainer_kwargs={}) -> None:
"""
Use trained model to make predictions on a catalog of galaxies.
Expand Down
110 changes: 65 additions & 45 deletions zoobot/pytorch/training/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
def freeze_batchnorm_layers(model):
for name, child in (model.named_children()):
if isinstance(child, torch.nn.BatchNorm2d):
logging.debug('freezing {} {}'.format(child, name))
logging.debug('Freezing {} {}'.format(child, name))
child.eval() # no grads, no param updates, no statistic updates
else:
freeze_batchnorm_layers(child) # recurse
Expand Down Expand Up @@ -64,15 +64,16 @@ def __init__(
encoder=None,
encoder_dim=1280, # as per current Zooot. TODO Could get automatically?
n_epochs=100, # TODO early stopping
n_layers=0, # how many layers deep to FT
n_blocks=0, # how many layers deep to FT
lr_decay=0.75,
weight_decay=0.05,
learning_rate=1e-4,
learning_rate=1e-4, # 10x lower than typical, you may like to experiment
dropout_prob=0.5,
freeze_batchnorm=True,
always_train_batchnorm=True,
prog_bar=True,
visualize_images=False, # upload examples to wandb, good for debugging
seed=42
seed=42,
n_layers=0 # for backward compat., n_blocks preferred
):
super().__init__()

Expand All @@ -94,69 +95,78 @@ def __init__(
self.encoder = encoder

self.encoder_dim = encoder_dim
self.n_layers = n_layers
self.freeze = True if n_layers == 0 else False
self.n_blocks = n_blocks

# for backwards compat.
if n_layers:
logging.warning('FinetuneableZoobot(n_layers) is now renamed to n_blocks, please update to pass n_blocks instead! For now, setting n_blocks=n_layers')
self.n_blocks = n_layers
logging.info('Layers to finetune: {}'.format(n_layers))

self.learning_rate = learning_rate
self.lr_decay = lr_decay
self.weight_decay = weight_decay
self.dropout_prob = dropout_prob
self.n_epochs = n_epochs

self.freeze_batchnorm = freeze_batchnorm
self.always_train_batchnorm = always_train_batchnorm
if self.always_train_batchnorm:
logging.info('always_train_batchnorm=True, so all batch norm layers will be finetuned')

self.train_loss_metric = tm.MeanMetric()
self.val_loss_metric = tm.MeanMetric()
self.test_loss_metric = tm.MeanMetric()


if self.freeze_batchnorm:
freeze_batchnorm_layers(self.encoder) # inplace

self.seed = seed
self.prog_bar = prog_bar
self.visualize_images = visualize_images

def configure_optimizers(self):

if self.freeze:
params = self.head.parameters()
return torch.optim.AdamW(params, betas=(0.9, 0.999), lr=self.learning_rate)
else:
lr = self.learning_rate
params = [{"params": self.head.parameters(), "lr": lr}]

# this bit is specific to Zoobot EffNet
# TODO check these are blocks not individual layers
encoder_blocks = list(self.encoder.children())

# for n, l in enumerate(encoder_blocks):
# print('\n')
# print(n)
# print(l)

# layers with no parameters don't count
# TODO double-check is_tuneable
tuneable_blocks = [b for b in encoder_blocks if is_tuneable(b)]

assert self.n_layers <= len(
tuneable_blocks
), f"Network only has {len(tuneable_blocks)} tuneable blocks, {self.n_layers} specified for finetuning"

# Append parameters of layers for finetuning along with decayed learning rate
blocks_to_tune = tuneable_blocks[:self.n_layers]
blocks_to_tune.reverse() # highest block to lowest block
for i, layer in enumerate(blocks_to_tune):
lr = self.learning_rate
params = [{"params": self.head.parameters(), "lr": lr}]

# this bit is specific to Zoobot EffNet
# may not yet work fr MaxViT (help wanted!)
encoder_blocks = self.encoder.blocks

assert self.n_blocks <= len(
encoder_blocks
), f"Network only has {len(encoder_blocks)} tuneable blocks, {self.n_blocks} specified for finetuning"

blocks_to_tune = list(encoder_blocks.named_children())
# take n blocks, ordered highest layer to lowest layer
blocks_to_tune.reverse()
# will finetune all params in first N
blocks_to_tune = blocks_to_tune[:self.n_blocks]
# optionally, can finetune batchnorm params in remaining layers
remaining_blocks = blocks_to_tune[self.n_blocks:]

# Append parameters of layers for finetuning along with decayed learning rate
for i, (_, block) in enumerate(blocks_to_tune): # _ is the block name e.g. '3'
params.append({
"params": block.parameters(),
"lr": lr * (self.lr_decay**i)
})

logging.debug(params)

# optionally, for the remaining layers (not otherwise finetuned) you can choose to still FT the batchnorm layers
for i, (_, block) in enumerate(remaining_blocks):
if self.always_train_batchnorm:
params.append({
"params": layer.parameters(),
"params": get_batch_norm_params_lighting(block),
"lr": lr * (self.lr_decay**i)
})

# Initialize AdamW optimizer
opt = torch.optim.AdamW(
params, weight_decay=self.weight_decay, betas=(0.9, 0.999)) # higher weight decay is typically good
# TODO this actually breaks training because the generator only iterates once!
# total_params = sum(p.numel() for param_set in params.copy() for p in param_set['params'])
# logging.info('Total params to fit: {}'.format(total_params))

# Initialize AdamW optimizer
opt = torch.optim.AdamW(params, weight_decay=self.weight_decay) # lr included in params dict

return opt
return opt


def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -501,6 +511,16 @@ def is_tuneable(block_of_layers):
else:
# currently, allowed to include batchnorm
return True

def get_batch_norm_params_lighting(parent_module, current_params=[]):
for child_module in parent_module.children():
if isinstance(child_module, torch.nn.BatchNorm2d):
current_params += child_module.parameters()
else:
current_params = get_batch_norm_params_lighting(child_module, current_params)
return current_params



# when ready (don't peek often, you'll overfit)
# trainer.test(model, dataloaders=datamodule)
Expand Down

0 comments on commit 8a1e2ff

Please sign in to comment.