From 02d29debda1fa6e51c9b85b44d7a13614d9396c2 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 13 Sep 2024 19:51:33 -0700 Subject: [PATCH] Add global_pool to mambaout __init__ and pass to heads --- timm/models/mambaout.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/timm/models/mambaout.py b/timm/models/mambaout.py index c2f2f07b46..bda69b1124 100644 --- a/timm/models/mambaout.py +++ b/timm/models/mambaout.py @@ -300,6 +300,7 @@ def __init__( self, in_chans=3, num_classes=1000, + global_pool='avg', depths=(3, 3, 9, 3), dims=(96, 192, 384, 576), norm_layer=LayerNorm, @@ -369,7 +370,7 @@ def __init__( self.head = MlpHead( prev_dim, num_classes, - pool_type='avg', + pool_type=global_pool, drop_rate=drop_rate, norm_layer=norm_layer, ) @@ -379,7 +380,7 @@ def __init__( prev_dim, num_classes, hidden_size=int(prev_dim * 4), - pool_type='avg', + pool_type=global_pool, norm_layer=norm_layer, drop_rate=drop_rate, )