From ab0a70dfff5cf5f50c98c9e37b9e2bd562687856 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Wed, 18 Dec 2024 21:12:40 +0800 Subject: [PATCH] fix feature_info.reduction --- timm/models/davit.py | 14 +++++++------- timm/models/efficientformer.py | 2 +- timm/models/metaformer.py | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index 5bda661061..b58bbbbfbe 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -556,19 +556,19 @@ def __init__( dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] stages = [] - for stage_idx in range(num_stages): - out_chs = embed_dims[stage_idx] + for i in range(num_stages): + out_chs = embed_dims[i] stage = DaVitStage( in_chs, out_chs, - depth=depths[stage_idx], - downsample=stage_idx > 0, + depth=depths[i], + downsample=i > 0, attn_types=attn_types, - num_heads=num_heads[stage_idx], + num_heads=num_heads[i], window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, - drop_path_rates=dpr[stage_idx], + drop_path_rates=dpr[i], norm_layer=norm_layer, norm_layer_cl=norm_layer_cl, ffn=ffn, @@ -579,7 +579,7 @@ def __init__( ) in_chs = out_chs stages.append(stage) - self.feature_info += [dict(num_chs=out_chs, reduction=2, module=f'stages.{stage_idx}')] + self.feature_info += [dict(num_chs=out_chs, reduction=2**(i+2), module=f'stages.{i}')] self.stages = nn.Sequential(*stages) diff --git a/timm/models/efficientformer.py b/timm/models/efficientformer.py index a759dc41ed..669062d365 100644 --- a/timm/models/efficientformer.py +++ b/timm/models/efficientformer.py @@ -407,7 +407,7 @@ def __init__( ) prev_dim = embed_dims[i] stages.append(stage) - self.feature_info += [dict(num_chs=embed_dims[i], reduction=2**(1+i), module=f'stages.{i}')] + self.feature_info += [dict(num_chs=embed_dims[i], reduction=2**(i+2), module=f'stages.{i}')] self.stages = nn.Sequential(*stages) # Classifier head diff --git a/timm/models/metaformer.py b/timm/models/metaformer.py index 9a568ff075..be68b2ba5e 100644 --- a/timm/models/metaformer.py +++ b/timm/models/metaformer.py @@ -541,7 +541,7 @@ def __init__( **kwargs, )] prev_dim = dims[i] - self.feature_info += [dict(num_chs=dims[i], reduction=2, module=f'stages.{i}')] + self.feature_info += [dict(num_chs=dims[i], reduction=2**(i+2), module=f'stages.{i}')] self.stages = nn.Sequential(*stages)