Skip to content

Commit

Permalink
Fix classifier input dim for mnv3 after last changes
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Jun 7, 2024
1 parent a5a2ad2 commit 5ee0676
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion timm/models/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def __init__(
self.norm_head = nn.Identity()
self.act2 = act_layer(inplace=True)
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.classifier = Linear(self.head_hidden_size, num_classes) if num_classes > 0 else nn.Identity()

efficientnet_init_weights(self)

Expand Down

0 comments on commit 5ee0676

Please sign in to comment.