Skip to content

Commit

Permalink
add convnext support
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Feb 27, 2024
1 parent 9eb0d89 commit b2f8e39
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion zoobot/pytorch/training/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def configure_optimizers(self):

logging.info(f'Encoder architecture to finetune: {type(self.encoder)}')

if isinstance(self.encoder, timm.models.EfficientNet):
if isinstance(self.encoder, timm.models.EfficientNet): # includes v2
# TODO for now, these count as separate layers, not ideal
early_tuneable_layers = [self.encoder.conv_stem, self.encoder.bn1]
encoder_blocks = list(self.encoder.blocks)
Expand All @@ -177,6 +177,9 @@ def configure_optimizers(self):
]
elif isinstance(self.encoder, timm.models.MaxxVit):
blocks_to_tune = [self.encoder.stem] + [stage for stage in self.encoder.stages]
elif isinstance(self.encoder, timm.models.ConvNeXt): # stem + 3 blocks, for all sizes
# https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/convnext.py#L264
blocks_to_tune = [self.encoder.stem] + [stage for stage in self.encoder.stages]
else:
raise ValueError(f'Encoder architecture not automatically recognised: {type(self.encoder)}')

Expand Down

0 comments on commit b2f8e39

Please sign in to comment.