diff --git a/zoobot/pytorch/training/finetune.py b/zoobot/pytorch/training/finetune.py index 7902db7e..4b461555 100644 --- a/zoobot/pytorch/training/finetune.py +++ b/zoobot/pytorch/training/finetune.py @@ -159,7 +159,7 @@ def configure_optimizers(self): logging.info(f'Encoder architecture to finetune: {type(self.encoder)}') - if isinstance(self.encoder, timm.models.EfficientNet): + if isinstance(self.encoder, timm.models.EfficientNet): # includes v2 # TODO for now, these count as separate layers, not ideal early_tuneable_layers = [self.encoder.conv_stem, self.encoder.bn1] encoder_blocks = list(self.encoder.blocks) @@ -177,6 +177,9 @@ def configure_optimizers(self): ] elif isinstance(self.encoder, timm.models.MaxxVit): blocks_to_tune = [self.encoder.stem] + [stage for stage in self.encoder.stages] + elif isinstance(self.encoder, timm.models.ConvNeXt): # stem + 3 blocks, for all sizes + # https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/convnext.py#L264 + blocks_to_tune = [self.encoder.stem] + [stage for stage in self.encoder.stages] else: raise ValueError(f'Encoder architecture not automatically recognised: {type(self.encoder)}')