Skip to content

Commit

Permalink
make models portable
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Mar 19, 2024
1 parent cad9786 commit 0fe5cec
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 29 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -167,4 +167,5 @@ hparams.yaml

data/pretrained_models

*.tar
*.tar
*.ckpt
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@
'pyarrow', # to read parquet, which is very handy for big datasets
# for saving metrics to weights&biases (cloud service, free within limits)
'wandb',
'huggingface_hub', # login may be required
'setuptools', # no longer pinned
'galaxy-datasets>=0.0.15' # for dataset loading in both TF and Torch (see github/mwalmsley/galaxy-datasets)
]
Expand Down
38 changes: 38 additions & 0 deletions tests/test_from_hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pytest

import timm
import torch


def test_get_encoder():
model = timm.create_model("hf_hub:mwalmsley/zoobot-encoder-efficientnet_b0", pretrained=True)
assert model(torch.rand(1, 3, 224, 224)).shape == (1, 1280)


def test_get_finetuned():
# checkpoint_loc = 'https://huggingface.co/mwalmsley/zoobot-finetuned-is_tidal/resolve/main/3.ckpt' pickle problem via lightning
# checkpoint_loc = '/home/walml/Downloads/3.ckpt' # works when downloaded manually

from huggingface_hub import hf_hub_download

REPO_ID = "mwalmsley/zoobot-finetuned-is_tidal"
FILENAME = "4.ckpt"

downloaded_loc = hf_hub_download(
repo_id=REPO_ID,
filename=FILENAME,
)
from zoobot.pytorch.training import finetune
model = finetune.FinetuneableZoobotClassifier.load_from_checkpoint(downloaded_loc, map_location='cpu') # hub_name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano',
assert model(torch.rand(1, 3, 224, 224)).shape == (1, 2)



# def test_get_finetuned_from_local():
# # checkpoint_loc = '/home/walml/repos/zoobot/tests/convnext_nano_finetuned_linear_is-lsb.ckpt'
# checkpoint_loc = '/home/walml/repos/zoobot-foundation/results/finetune/is-lsb/debug/checkpoints/4.ckpt'

# from zoobot.pytorch.training import finetune
# # if originally trained with a direct in-memory checkpoint, must specify the hub name manually. otherwise it's saved as an hparam.
# model = finetune.FinetuneableZoobotClassifier.load_from_checkpoint(checkpoint_loc, map_location='cpu') # hub_name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', )
# assert model(torch.rand(1, 3, 224, 224)).shape == (1, 2)
5 changes: 4 additions & 1 deletion zoobot/pytorch/estimators/define_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,4 +480,7 @@ def schema_to_campaigns(schema):
if __name__ == '__main__':
encoder = get_pytorch_encoder(channels=1)
dim = get_encoder_dim(encoder, channels=1)
print(dim)
print(dim)


ZoobotTree.load_from_checkpoint
64 changes: 44 additions & 20 deletions zoobot/pytorch/training/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,17 @@ class FinetuneableZoobotAbstract(pl.LightningModule):
Both :class:`FinetuneableZoobotClassifier` and :class:`FinetuneableZoobotTree`
can (and should) be passed any of these arguments to customise finetuning.
You could subclass this class to solve new finetuning tasks (like regression) - see :ref:`advanced_finetuning`.
Any FinetuneableZoobot model can be loaded in one of three ways:
- HuggingFace name e.g. FinetuneableZoobotX(name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', ...). Recommended.
- Any PyTorch model in memory e.g. FinetuneableZoobotX(encoder=some_model, ...)
- ZoobotTree checkpoint e.g. FinetuneableZoobotX(zoobot_checkpoint_loc='path/to/zoobot_tree.ckpt', ...)
You could subclass this class to solve new finetuning tasks - see :ref:`advanced_finetuning`.
Args:
checkpoint_loc (str, optional): Path to encoder checkpoint to load (likely a saved ZoobotTree). Defaults to None.
encoder (pl.LightningModule, optional): Alternatively, pass an encoder directly. Load with :func:`zoobot.pytorch.training.finetune.load_pretrained_encoder`.
name (str, optional): Name of a model on HuggingFace Hub e.g.'hf_hub:mwalmsley/zoobot-encoder-convnext_nano'. Defaults to None.
encoder (torch.nn.Module, optional): A PyTorch model already loaded in memory
zoobot_checkpoint_loc (str, optional): Path to ZoobotTree lightning checkpoint to load. Loads with Load with :func:`zoobot.pytorch.training.finetune.load_pretrained_encoder`. Defaults to None.
encoder_dim (int, optional): Output dimension of encoder. Defaults to 1280 (EfficientNetB0's encoder dim).
lr_decay (float, optional): For each layer i below the head, reduce the learning rate by lr_decay ^ i. Defaults to 0.75.
weight_decay (float, optional): AdamW weight decay arg (i.e. L2 penalty). Defaults to 0.05.
Expand All @@ -61,26 +67,39 @@ class FinetuneableZoobotAbstract(pl.LightningModule):

def __init__(
self,
# can provide either zoobot_checkpoint_loc, and will load this model as encoder...
zoobot_checkpoint_loc=None,

# load a pretrained timm encoder saved on huggingface hub
# (aimed at most users, easiest way to load published models)
name=None,

# ...or directly pass any model to use as encoder (if you do this, you will need to keep it around for later)
encoder=None,
# (aimed at tinkering with new architectures e.g. SSL)
encoder=None, # use any torch model already loaded in memory (must have .forward() method)

# load a pretrained zoobottree model and grab the encoder (a timm model)
# requires the exact same zoobot version used for training, not very portable
# (aimed at supervised experiments)
zoobot_checkpoint_loc=None,

# finetuning settings
n_blocks=0, # how many layers deep to FT
lr_decay=0.75,
weight_decay=0.05,
learning_rate=1e-4, # 10x lower than typical, you may like to experiment
dropout_prob=0.5,
always_train_batchnorm=False, # temporarily deprecated
prog_bar=True,
visualize_images=False, # upload examples to wandb, good for debugging
seed=42,
n_layers=0, # for backward compat., n_blocks preferred
# these args are for the optional learning rate scheduler, best not to use unless you've tuned everything else already
cosine_schedule=False,
warmup_epochs=10,
max_cosine_epochs=100,
max_learning_rate_reduction_factor=0.01,
from_scratch=False
# escape hatch for 'from scratch' baselines
from_scratch=False,
# debugging utils
prog_bar=True,
visualize_images=False, # upload examples to wandb, good for debugging
seed=42
):
super().__init__()

Expand All @@ -95,17 +114,22 @@ def __init__(
self.save_hyperparameters(ignore=['encoder']) # never serialise the encoder, way too heavy
# if you need the encoder to recreate, pass when loading checkpoint e.g.
# FinetuneableZoobotTree.load_from_checkpoint(loc, encoder=encoder)

if zoobot_checkpoint_loc is not None:
assert encoder is None, 'Cannot pass both checkpoint to load and encoder to use'
self.encoder = load_pretrained_zoobot(zoobot_checkpoint_loc)

if name is not None:
assert encoder is None, 'Cannot pass both name and encoder to use'
self.encoder = timm.create_model(name, pretrained=True)
self.encoder_dim = self.encoder.num_features

elif zoobot_checkpoint_loc is not None:
assert encoder is None, 'Cannot pass both checkpoint to load and encoder to use'
self.encoder = load_pretrained_zoobot(zoobot_checkpoint_loc) # extracts the timm encoder
self.encoder_dim = self.encoder.num_features
else:
assert zoobot_checkpoint_loc is None, 'Cannot pass both checkpoint to load and encoder to use'
assert encoder is not None, 'Must pass either checkpoint to load or encoder to use'
self.encoder = encoder

# TODO read as encoder property
self.encoder_dim = define_model.get_encoder_dim(self.encoder)
assert zoobot_checkpoint_loc is None, 'Cannot pass both checkpoint to load and encoder to use'
assert encoder is not None, 'Must pass either checkpoint to load or encoder to use'
self.encoder = encoder
# work out encoder dim 'manually'
self.encoder_dim = define_model.get_encoder_dim(self.encoder)

# for backwards compat.
if n_layers:
Expand Down
50 changes: 43 additions & 7 deletions zoobot/pytorch/training/representations.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,53 @@
import logging
import pytorch_lightning as pl

from timm import create_model


class ZoobotEncoder(pl.LightningModule):
# very simple wrapper to turn pytorch model into lightning module
# useful when we want to use lightning to make predictions with our encoder
# (i.e. to get representations)

def __init__(self, encoder, pyramid=False) -> None:
super().__init__()
def __init__(self, encoder):
logging.info('ZoobotEncoder: using provided in-memory encoder')
self.encoder = encoder # plain pytorch module e.g. Sequential
if pyramid:
raise NotImplementedError('Will eventually support resetting timm classifier to get FPN features')


def forward(self, x):
if isinstance(x, list) and len(x) == 1:
return self(x[0])
return self.encoder(x)

@classmethod
def load_from_name(cls, name: str):
"""
e.g. ZoobotEncoder.load_from_name('hf_hub:mwalmsley/zoobot-encoder-convnext_nano')
Args:
name (str): huggingface hub name to load
Returns:
nn.Module: timm model
"""
timm_model = create_model(name)
return cls(timm_model)





class ZoobotEncoder(pl.LightningModule):
# very simple wrapper to turn pytorch model into lightning module
# useful when we want to use lightning to make predictions with our encoder
# (i.e. to get representations)

# pretrained_cfg, pretrained_cfg_overlay=timm_kwargs
def __init__(self, architecture_name=None, channels=None, timm_kwargs={}) -> None:
super().__init__()

logging.info('ZoobotEncoder: using timm encoder')
self.encoder =

# if pyramid:
# raise NotImplementedError('Will eventually support resetting timm classifier to get FPN features')


# def save_timm_encoder():

0 comments on commit 0fe5cec

Please sign in to comment.