From b2f8e395926348574eb1c1595f05c88ba058042d Mon Sep 17 00:00:00 2001 From: Mike Walmsley Date: Tue, 27 Feb 2024 16:35:22 -0500 Subject: [PATCH] add convnext support --- zoobot/pytorch/training/finetune.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/zoobot/pytorch/training/finetune.py b/zoobot/pytorch/training/finetune.py index 7902db7e..4b461555 100644 --- a/zoobot/pytorch/training/finetune.py +++ b/zoobot/pytorch/training/finetune.py @@ -159,7 +159,7 @@ def configure_optimizers(self): logging.info(f'Encoder architecture to finetune: {type(self.encoder)}') - if isinstance(self.encoder, timm.models.EfficientNet): + if isinstance(self.encoder, timm.models.EfficientNet): # includes v2 # TODO for now, these count as separate layers, not ideal early_tuneable_layers = [self.encoder.conv_stem, self.encoder.bn1] encoder_blocks = list(self.encoder.blocks) @@ -177,6 +177,9 @@ def configure_optimizers(self): ] elif isinstance(self.encoder, timm.models.MaxxVit): blocks_to_tune = [self.encoder.stem] + [stage for stage in self.encoder.stages] + elif isinstance(self.encoder, timm.models.ConvNeXt): # stem + 3 blocks, for all sizes + # https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/convnext.py#L264 + blocks_to_tune = [self.encoder.stem] + [stage for stage in self.encoder.stages] else: raise ValueError(f'Encoder architecture not automatically recognised: {type(self.encoder)}')