From 5ee06760dc4a7ac2115e0127acf188a5bb7227db Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 5 Jun 2024 08:15:17 -0700 Subject: [PATCH] Fix classifier input dim for mnv3 after last changes --- timm/models/mobilenetv3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index 9d17770661..358f15d1c1 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -134,7 +134,7 @@ def __init__( self.norm_head = nn.Identity() self.act2 = act_layer(inplace=True) self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled - self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.classifier = Linear(self.head_hidden_size, num_classes) if num_classes > 0 else nn.Identity() efficientnet_init_weights(self)