-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
132 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -167,4 +167,5 @@ hparams.yaml | |
|
||
data/pretrained_models | ||
|
||
*.tar | ||
*.tar | ||
*.ckpt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import pytest | ||
|
||
import timm | ||
import torch | ||
|
||
|
||
def test_get_encoder(): | ||
model = timm.create_model("hf_hub:mwalmsley/zoobot-encoder-efficientnet_b0", pretrained=True) | ||
assert model(torch.rand(1, 3, 224, 224)).shape == (1, 1280) | ||
|
||
|
||
def test_get_finetuned(): | ||
# checkpoint_loc = 'https://huggingface.co/mwalmsley/zoobot-finetuned-is_tidal/resolve/main/3.ckpt' pickle problem via lightning | ||
# checkpoint_loc = '/home/walml/Downloads/3.ckpt' # works when downloaded manually | ||
|
||
from huggingface_hub import hf_hub_download | ||
|
||
REPO_ID = "mwalmsley/zoobot-finetuned-is_tidal" | ||
FILENAME = "4.ckpt" | ||
|
||
downloaded_loc = hf_hub_download( | ||
repo_id=REPO_ID, | ||
filename=FILENAME, | ||
) | ||
from zoobot.pytorch.training import finetune | ||
model = finetune.FinetuneableZoobotClassifier.load_from_checkpoint(downloaded_loc, map_location='cpu') # hub_name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', | ||
assert model(torch.rand(1, 3, 224, 224)).shape == (1, 2) | ||
|
||
|
||
|
||
# def test_get_finetuned_from_local(): | ||
# # checkpoint_loc = '/home/walml/repos/zoobot/tests/convnext_nano_finetuned_linear_is-lsb.ckpt' | ||
# checkpoint_loc = '/home/walml/repos/zoobot-foundation/results/finetune/is-lsb/debug/checkpoints/4.ckpt' | ||
|
||
# from zoobot.pytorch.training import finetune | ||
# # if originally trained with a direct in-memory checkpoint, must specify the hub name manually. otherwise it's saved as an hparam. | ||
# model = finetune.FinetuneableZoobotClassifier.load_from_checkpoint(checkpoint_loc, map_location='cpu') # hub_name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', ) | ||
# assert model(torch.rand(1, 3, 224, 224)).shape == (1, 2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,53 @@ | ||
import logging | ||
import pytorch_lightning as pl | ||
|
||
from timm import create_model | ||
|
||
|
||
class ZoobotEncoder(pl.LightningModule): | ||
# very simple wrapper to turn pytorch model into lightning module | ||
# useful when we want to use lightning to make predictions with our encoder | ||
# (i.e. to get representations) | ||
|
||
def __init__(self, encoder, pyramid=False) -> None: | ||
super().__init__() | ||
def __init__(self, encoder): | ||
logging.info('ZoobotEncoder: using provided in-memory encoder') | ||
self.encoder = encoder # plain pytorch module e.g. Sequential | ||
if pyramid: | ||
raise NotImplementedError('Will eventually support resetting timm classifier to get FPN features') | ||
|
||
|
||
def forward(self, x): | ||
if isinstance(x, list) and len(x) == 1: | ||
return self(x[0]) | ||
return self.encoder(x) | ||
|
||
@classmethod | ||
def load_from_name(cls, name: str): | ||
""" | ||
e.g. ZoobotEncoder.load_from_name('hf_hub:mwalmsley/zoobot-encoder-convnext_nano') | ||
Args: | ||
name (str): huggingface hub name to load | ||
Returns: | ||
nn.Module: timm model | ||
""" | ||
timm_model = create_model(name) | ||
return cls(timm_model) | ||
|
||
|
||
|
||
|
||
|
||
class ZoobotEncoder(pl.LightningModule): | ||
# very simple wrapper to turn pytorch model into lightning module | ||
# useful when we want to use lightning to make predictions with our encoder | ||
# (i.e. to get representations) | ||
|
||
# pretrained_cfg, pretrained_cfg_overlay=timm_kwargs | ||
def __init__(self, architecture_name=None, channels=None, timm_kwargs={}) -> None: | ||
super().__init__() | ||
|
||
logging.info('ZoobotEncoder: using timm encoder') | ||
self.encoder = | ||
|
||
# if pyramid: | ||
# raise NotImplementedError('Will eventually support resetting timm classifier to get FPN features') | ||
|
||
|
||
# def save_timm_encoder(): | ||
|