Skip to content

Commit

Permalink
Fixed unfused attn2d scale
Browse files Browse the repository at this point in the history
  • Loading branch information
laclouis5 committed Jan 1, 2025
1 parent 851e074 commit b0c47c5
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 5 deletions.
27 changes: 26 additions & 1 deletion tests/test_layers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import pytest
import torch
import torch.nn as nn

from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn
from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn, Attention2d

import importlib
import os
Expand Down Expand Up @@ -119,3 +120,27 @@ def test_get_act_fn_none():
assert get_act_fn(None) is None
assert get_act_fn('') is None


@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("expand_first", [True, False])
@pytest.mark.parametrize("head_first", [True, False])
@pytest.mark.parametrize("attn_mask", [True, False])
def test_attn2d(bias, expand_first, head_first, attn_mask):
x = torch.randn(1, 128, 32, 48)
attn = Attention2d(
128, 128, num_heads=4, bias=bias, expand_first=expand_first, head_first=head_first
)

if attn_mask:
mask = torch.randint(0, 1, size=(32 * 48, 32 * 48), dtype=torch.float32)
else:
mask = None

o1 = attn(x, mask)
attn.fused_attn = False
o2 = attn(x, mask)

assert torch.allclose(o1, o2, atol=1e-5), f"{torch.abs(o1 - o2).max()}"



8 changes: 4 additions & 4 deletions timm/layers/attention2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,6 @@ def __init__(
self.num_heads = num_heads
self.dim_head = dim_attn // num_heads
self.head_first = head_first
self.scale = num_heads ** -0.5
self.fused_attn = use_fused_attn()

self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias)
Expand All @@ -337,14 +336,15 @@ def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
dropout_p=self.attn_drop.p if self.training else 0.,
).transpose(-1, -2).reshape(B, -1, H, W)
else:
q = q * self.scale
attn = q.transpose(-2, -1) @ k
q = q.transpose(-1, -2)
v = v.transpose(-1, -2)
attn = q @ k * q.size(-1) ** -0.5
if attn_mask is not None:
# NOTE: assumes mask is float and in correct shape
attn = attn + attn_mask
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
x = (attn @ v).transpose(-1, -2).reshape(B, -1, H, W)

x = self.proj(x)
x = self.proj_drop(x)
Expand Down

0 comments on commit b0c47c5

Please sign in to comment.