Skip to content

Commit

Permalink
support greyscale models
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed May 13, 2024
1 parent 8c71a87 commit 2db053c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
3 changes: 3 additions & 0 deletions docs/pretrained_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ Zoobot includes weights for the following pretrained models:

Missing a model you need? Reach out! There's a good chance we can train any model supported by `timm <https://github.com/huggingface/pytorch-image-models>`_.

.. note::

New in Zoobot v2.0.1: greyscale (single channel) versions are available `here <https://huggingface.co/collections/mwalmsley/zoobot-encoders-greyscale-66427c51133285ca01b490c6>`_.

Which model should I use?
===========================
Expand Down
9 changes: 8 additions & 1 deletion zoobot/pytorch/training/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,14 @@ def __init__(

if name is not None:
assert encoder is None, 'Cannot pass both name and encoder to use'
self.encoder = timm.create_model(name, num_classes=0, pretrained=True)
if 'greyscale' in name:
# I'm not sure why timm is happy to convert color model stem to greyscale
# but doesn't correctly load greyscale model without this hack
logging.info('Loading greyscale model (auto-detected from name)')
timm_kwargs = {'in_chans': 1}
else:
timm_kwargs = {}
self.encoder = timm.create_model(name, num_classes=0, pretrained=True, **timm_kwargs)
self.encoder_dim = self.encoder.num_features

elif zoobot_checkpoint_loc is not None:
Expand Down

0 comments on commit 2db053c

Please sign in to comment.