Skip to content

Commit

Permalink
just tempt some student into trying it
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 2, 2024
1 parent 0791dfe commit 4dc99bd
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 1 deletion.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -514,3 +514,12 @@ docker run -v .:/data --gpus all -it af3
url = {https://api.semanticscholar.org/CorpusID:273532030}
}
```

```bibtex
@inproceedings{Duvvuri2024LASERAW,
title = {LASER: Attention with Exponential Transformation},
author = {Sai Surya Duvvuri and Inderjit S. Dhillon},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:273849947}
}
```
19 changes: 19 additions & 0 deletions alphafold3_pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ def pack_one(t, pattern):
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]

def log(t, eps = 1e-20):
return t.clamp(min = eps).log()

def softclamp(t, value):
return (t / value).tanh() * value

Expand Down Expand Up @@ -181,6 +184,7 @@ def __init__(
query_bias = True,
window_size = None,
num_memory_kv: int = 0,
laser = False,
enable_attn_softclamp = False,
attn_softclamp_value = 50.,
softmax_full_precision = False
Expand Down Expand Up @@ -222,6 +226,10 @@ def __init__(
self.memory_kv = nn.Parameter(torch.zeros(2, heads, num_memory_kv, dim_head))
nn.init.normal_(self.memory_kv, std = 0.02)

# laser attention

self.laser = laser

# gating of value
# allows attention to attend to nothing

Expand Down Expand Up @@ -262,6 +270,12 @@ def forward(

q, k, v = tuple(self.split_heads(t) for t in (q, k, v))

# maybe laser

if self.laser:
v_max = v.amax(dim = -2, keepdim = True)
v = (v - v_max).exp()

# attention

out = self.attend(
Expand All @@ -272,6 +286,11 @@ def forward(
memory_kv = self.memory_kv
)

# maybe laser

if self.laser:
out = log(out) + v_max

# merge heads

out = self.merge_heads(out)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "alphafold3-pytorch"
version = "0.6.8"
version = "0.6.9"
description = "Alphafold 3 - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" },
Expand Down

0 comments on commit 4dc99bd

Please sign in to comment.