Skip to content

Commit

Permalink
Add global_pool to mambaout __init__ and pass to heads
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Sep 14, 2024
1 parent b0cfd9d commit 02d29de
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions timm/models/mambaout.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def __init__(
self,
in_chans=3,
num_classes=1000,
global_pool='avg',
depths=(3, 3, 9, 3),
dims=(96, 192, 384, 576),
norm_layer=LayerNorm,
Expand Down Expand Up @@ -369,7 +370,7 @@ def __init__(
self.head = MlpHead(
prev_dim,
num_classes,
pool_type='avg',
pool_type=global_pool,
drop_rate=drop_rate,
norm_layer=norm_layer,
)
Expand All @@ -379,7 +380,7 @@ def __init__(
prev_dim,
num_classes,
hidden_size=int(prev_dim * 4),
pool_type='avg',
pool_type=global_pool,
norm_layer=norm_layer,
drop_rate=drop_rate,
)
Expand Down

0 comments on commit 02d29de

Please sign in to comment.