Skip to content

Commit

Permalink
Merge pull request #2369 from brianhou0208/fix_reduction
Browse files Browse the repository at this point in the history
Fix feature_info.reduction
  • Loading branch information
rwightman authored Dec 19, 2024
2 parents ea23107 + ab0a70d commit a02b1a8
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
14 changes: 7 additions & 7 deletions timm/models/davit.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,19 +556,19 @@ def __init__(

dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
stages = []
for stage_idx in range(num_stages):
out_chs = embed_dims[stage_idx]
for i in range(num_stages):
out_chs = embed_dims[i]
stage = DaVitStage(
in_chs,
out_chs,
depth=depths[stage_idx],
downsample=stage_idx > 0,
depth=depths[i],
downsample=i > 0,
attn_types=attn_types,
num_heads=num_heads[stage_idx],
num_heads=num_heads[i],
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path_rates=dpr[stage_idx],
drop_path_rates=dpr[i],
norm_layer=norm_layer,
norm_layer_cl=norm_layer_cl,
ffn=ffn,
Expand All @@ -579,7 +579,7 @@ def __init__(
)
in_chs = out_chs
stages.append(stage)
self.feature_info += [dict(num_chs=out_chs, reduction=2, module=f'stages.{stage_idx}')]
self.feature_info += [dict(num_chs=out_chs, reduction=2**(i+2), module=f'stages.{i}')]

self.stages = nn.Sequential(*stages)

Expand Down
2 changes: 1 addition & 1 deletion timm/models/efficientformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def __init__(
)
prev_dim = embed_dims[i]
stages.append(stage)
self.feature_info += [dict(num_chs=embed_dims[i], reduction=2**(1+i), module=f'stages.{i}')]
self.feature_info += [dict(num_chs=embed_dims[i], reduction=2**(i+2), module=f'stages.{i}')]
self.stages = nn.Sequential(*stages)

# Classifier head
Expand Down
2 changes: 1 addition & 1 deletion timm/models/metaformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ def __init__(
**kwargs,
)]
prev_dim = dims[i]
self.feature_info += [dict(num_chs=dims[i], reduction=2, module=f'stages.{i}')]
self.feature_info += [dict(num_chs=dims[i], reduction=2**(i+2), module=f'stages.{i}')]

self.stages = nn.Sequential(*stages)

Expand Down

0 comments on commit a02b1a8

Please sign in to comment.