Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix MQA V2 #2388

Merged
merged 2 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
import torch.nn as nn

from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn, Attention2d
from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn, Attention2d, MultiQueryAttentionV2

import importlib
import os
Expand Down Expand Up @@ -121,6 +121,23 @@ def test_get_act_fn_none():
assert get_act_fn('') is None


@pytest.mark.parametrize("dim", [128])
@pytest.mark.parametrize("dim_out", [128, 256])
@pytest.mark.parametrize("use_m", [True, False])
def test_mqa_v2(dim, dim_out, use_m):
mqa = MultiQueryAttentionV2(dim, dim_out)

x = torch.randn(1, dim, 32, 48)
if use_m:
m = torch.randn(1, dim, 16, 24)
else:
m = None

y = mqa(x, m=m)

assert (y.shape) == (1, dim_out, 32, 48)


@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("expand_first", [True, False])
@pytest.mark.parametrize("head_first", [True, False])
Expand All @@ -141,6 +158,3 @@ def test_attn2d(bias, expand_first, head_first, attn_mask):
o2 = attn(x, mask)

assert torch.allclose(o1, o2, atol=1e-5), f"{torch.abs(o1 - o2).max()}"



10 changes: 5 additions & 5 deletions timm/layers/attention2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,24 +59,24 @@ def _reshape_input(self, t):

def forward(self, x, m: Optional[torch.Tensor] = None):
"""Run layer computation."""
s = x.shape
m = m or x
b, _, h, w = x.shape
m = m if m is not None else x

reshaped_x = self._reshape_input(x)
reshaped_m = self._reshape_input(m)

q = torch.einsum('bnd,hkd->bnhk', reshaped_x, self.query_proj)
k = torch.einsum('bmd,dk->bmk', reshaped_m, self.key_proj)

attn = torch.einsum('bnhk,bmk->bnhm', q, k)
attn = torch.einsum('bnhk,bmk->bnhm', q, k) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)

v = torch.einsum('bmd,dv->bmv', reshaped_m, self.value_proj)
o = torch.einsum('bnhm,bmv->bnhv', attn, v)
result = torch.einsum('bnhv,dhv->bnd', o, self.out_proj)
result = torch.einsum('bnhv,dhv->bdn', o, self.out_proj)
result = self.proj_drop(result)
return result.reshape(s)
return result.reshape(b, -1, h, w)


class MultiQueryAttention2d(nn.Module):
Expand Down
Loading