Skip to content

Commit

Permalink
they removed all the bias from the triangle multiplicative modules
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 25, 2024
1 parent 1e3c221 commit 27a7818
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 21 deletions.
31 changes: 11 additions & 20 deletions alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,18 +433,14 @@ def __init__(
dim_hidden = default(dim_hidden, dim)
self.norm = nn.LayerNorm(dim)

self.left_proj = Linear(dim, dim_hidden)
self.right_proj = Linear(dim, dim_hidden)

self.left_gate = Linear(dim, dim_hidden)
self.right_gate = Linear(dim, dim_hidden)
self.out_gate = Linear(dim, dim_hidden)
self.left_right_proj = nn.Sequential(
LinearNoBias(dim, dim_hidden * 4),
nn.GLU(dim = -1)
)

# initialize all gating to be identity
self.left_right_gate = LinearNoBias(dim, dim_hidden * 2)

for gate in (self.left_gate, self.right_gate, self.out_gate):
nn.init.constant_(gate.weight, 0.)
nn.init.constant_(gate.bias, 1.)
self.out_gate = LinearNoBias(dim, dim_hidden)

if mix == 'outgoing':
self.mix_einsum_eq = '... i k d, ... j k d -> ... i j d'
Expand All @@ -454,7 +450,7 @@ def __init__(
self.to_out_norm = nn.LayerNorm(dim_hidden)

self.to_out = Sequential(
Linear(dim_hidden, dim),
LinearNoBias(dim_hidden, dim),
Dropout(dropout, dropout_type = dropout_type)
)

Expand All @@ -470,24 +466,19 @@ def forward(

x = self.norm(x)

left = self.left_proj(x)
right = self.right_proj(x)
left, right = self.left_right_proj(x).chunk(2, dim = -1)

if exists(mask):
left = left * mask
right = right * mask

left_gate = self.left_gate(x).sigmoid()
right_gate = self.right_gate(x).sigmoid()
out_gate = self.out_gate(x).sigmoid()

left = left * left_gate
right = right * right_gate

out = einsum(left, right, self.mix_einsum_eq)

out = self.to_out_norm(out)

out_gate = self.out_gate(x).sigmoid()
out = out * out_gate

return self.to_out(out)

# there are two types of attention in this paper, triangle and attention-pair-bias
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "alphafold3-pytorch"
version = "0.0.37"
version = "0.0.38"
description = "Alphafold 3 - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit 27a7818

Please sign in to comment.