Skip to content

Commit

Permalink
just tempt some student into trying laser attention
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 2, 2024
1 parent 0791dfe commit 34f6c59
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 1 deletion.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -514,3 +514,12 @@ docker run -v .:/data --gpus all -it af3
url = {https://api.semanticscholar.org/CorpusID:273532030}
}
```

```bibtex
@inproceedings{Duvvuri2024LASERAW,
title = {LASER: Attention with Exponential Transformation},
author = {Sai Surya Duvvuri and Inderjit S. Dhillon},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:273849947}
}
```
32 changes: 32 additions & 0 deletions alphafold3_pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ def pack_one(t, pattern):
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]

def log(t, eps = 1e-20):
return t.clamp(min = eps).log()

def softclamp(t, value):
return (t / value).tanh() * value

Expand Down Expand Up @@ -181,6 +184,7 @@ def __init__(
query_bias = True,
window_size = None,
num_memory_kv: int = 0,
laser = False,
enable_attn_softclamp = False,
attn_softclamp_value = 50.,
softmax_full_precision = False
Expand All @@ -202,6 +206,7 @@ def __init__(
dim_inner = dim_head * heads

self.attend = Attend(
laser = laser,
dropout = dropout,
window_size = window_size,
enable_attn_softclamp = enable_attn_softclamp,
Expand Down Expand Up @@ -299,6 +304,7 @@ class Attend(Module):
def __init__(
self,
dropout = 0.,
laser = False,
window_size = None,
scale: float | None = None,
enable_attn_softclamp = False,
Expand Down Expand Up @@ -327,6 +333,10 @@ def __init__(

self.attn_dropout = nn.Dropout(dropout)

# laser attention

self.laser = laser

# softclamp attention logits
# being adopted by a number of recent llms (gemma, grok)

Expand Down Expand Up @@ -447,10 +457,21 @@ def local_attn(

attn = sim.softmax(dim = -1)

# maybe laser

if self.laser:
v_max = v.amax(dim = -2, keepdim = True)
v = (v - v_max).exp()

# aggregate

out = einsum(attn, v, "... i j, ... j d -> ... i d")

# maybe laser

if self.laser:
out = log(out) + v_max

# un-window the output

out = rearrange(out, "b h n w d -> b h (n w) d")
Expand Down Expand Up @@ -546,8 +567,19 @@ def forward(

attn = self.attn_dropout(attn)

# maybe laser

if self.laser:
v_max = v.amax(dim = -2, keepdim = True)
v = (v - v_max).exp()

# aggregate values

out = einsum(attn, v, "b h i j, b h j d -> b h i d")

# maybe laser

if self.laser:
out = log(out) + v_max

return out
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "alphafold3-pytorch"
version = "0.6.8"
version = "0.6.10"
description = "Alphafold 3 - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" },
Expand Down

0 comments on commit 34f6c59

Please sign in to comment.