diff --git a/alphafold3_pytorch/attention.py b/alphafold3_pytorch/attention.py index 73e49714..3b4eb725 100644 --- a/alphafold3_pytorch/attention.py +++ b/alphafold3_pytorch/attention.py @@ -1,5 +1,6 @@ from __future__ import annotations from beartype.typing import NamedTuple, Tuple +from functools import partial import torch from torch import nn, Tensor @@ -18,6 +19,10 @@ typecheck ) +# alias + +LinearNoBias = partial(nn.Linear, bias = False) + # helpers def exists(val): @@ -178,7 +183,6 @@ def __init__( num_memory_kv: int = 0, enable_attn_softclamp = False, attn_softclamp_value = 50., - init_gate_bias = -2., softmax_full_precision = False ): super().__init__() @@ -209,8 +213,8 @@ def __init__( self.merge_heads = Rearrange('b h n d -> b n (h d)') self.to_q = nn.Linear(dim, dim_inner, bias = query_bias) - self.to_kv = nn.Linear(dim, dim_inner * 2, bias = False) - self.to_out = nn.Linear(dim_inner, dim, bias = False) + self.to_kv = LinearNoBias(dim, dim_inner * 2) + self.to_out = LinearNoBias(dim_inner, dim) self.memory_kv = None @@ -224,11 +228,7 @@ def __init__( self.to_gates = None if gate_output: - gate_linear = nn.Linear(dim, dim_inner) - nn.init.zeros_(gate_linear.weight) - nn.init.constant_(gate_linear.bias, init_gate_bias) - - self.to_gates = gate_linear + self.to_gates = nn.Sequential(LinearNoBias(dim, dim_inner), nn.Sigmoid()) @typecheck def forward( @@ -266,7 +266,7 @@ def forward( if exists(self.to_gates): gates = self.to_gates(seq) - out = out * gates.sigmoid() + out = out * gates # combine heads