diff --git a/alphafold3_pytorch/attention.py b/alphafold3_pytorch/attention.py index fcc8be48..710f0ebb 100644 --- a/alphafold3_pytorch/attention.py +++ b/alphafold3_pytorch/attention.py @@ -32,6 +32,7 @@ def max_neg_value(t): # multi-head attention class Attention(Module): + @typecheck def __init__( self, *, @@ -42,7 +43,9 @@ def __init__( gate_output = True, query_bias = True, flash = True, - efficient_attn_config: Config = Config(True, True, True) + efficient_attn_config: Config = Config(True, True, True), + dim_pairwise_repr: int | None = None, + max_seq_len: int = 8192 ): super().__init__() """ @@ -52,6 +55,7 @@ def __init__( h - heads n - sequence d - dimension + e - dimension (pairwise rep) i - source sequence j - context sequence """ @@ -71,7 +75,8 @@ def __init__( self.to_kv = nn.Linear(dim, dim_inner * 2, bias = False) self.to_out = nn.Linear(dim_inner, dim, bias = False) - # used in alphafold2 + # gating of value + # allows attention to attend to nothing self.to_gates = None @@ -82,13 +87,32 @@ def __init__( self.to_gates = gate_linear + # for projecting features to attn bias + + self.accept_feature_to_bias_attn = exists(dim_pairwise_repr) + + if self.accept_feature_to_bias_attn: + self.max_seq_len = max_seq_len + + # line 8 of Algorithm 24 + + self.to_attn_bias = nn.Sequential( + nn.LayerNorm(dim_pairwise_repr), + nn.Linear(dim_pairwise_repr, heads, bias = False), + Rearrange('... i j h -> ... h i j') + ) + + self.attn_bias_bias = nn.Parameter(torch.zeros(max_seq_len, max_seq_len)) + @typecheck def forward( self, seq: Float['b i d'], mask: Bool['b n']| None = None, - attn_bias: Float['b h i j'] | None = None, - context: Float['b j d'] | None = None + context: Float['b j d'] | None = None, + attn_bias: Float['... i j'] | None = None, + input_to_bias_attn: Float['... i j e'] | None = None, + ) -> Float['b i d']: q = self.to_q(seq) @@ -98,18 +122,38 @@ def forward( q, k, v = tuple(self.split_heads(t) for t in (q, k, v)) + # inputs to project into attn bias - for alphafold3, pairwise rep + + assert not (exists(input_to_bias_attn) ^ self.accept_feature_to_bias_attn), 'if passing in pairwise representation, must set dim_pairwise_repr on Attention.__init__' + + if self.accept_feature_to_bias_attn: + i, j = q.shape[-2], k.shape[-2] + + assert not exists(attn_bias) + assert i <= self.max_seq_len and j <= self.max_seq_len + + attn_bias = self.to_attn_bias(input_to_bias_attn) + self.attn_bias_bias[:i, :j] + + # attention + out = self.attend( q, k, v, attn_bias = attn_bias, mask = mask ) + # merge heads + out = self.merge_heads(out) + # gate output + if exists(self.to_gates): gates = self.to_gates(seq) out = out * gates.sigmoid() + # combine heads + return self.to_out(out) # attending, both vanilla as well as in-built flash attention @@ -124,12 +168,15 @@ def __init__( ): super().__init__() """ - ein notation + ein notation: b - batch h - heads + n - sequence d - dimension - n, i, j - sequence (base sequence length, source, target) + e - dimension (pairwise rep) + i - source sequence + j - context sequence """ self.scale = scale @@ -171,8 +218,8 @@ def forward( q: Float['b h i d'], k: Float['b h j d'], v: Float['b h j d'], - attn_bias: Float['b h i j'] | None = None, - mask: Bool['b j'] | None = None + mask: Bool['b j'] | None = None, + attn_bias: Float['... i j'] | None = None, ) -> Float['b h i d']: can_use_flash = self.flash and not exists(attn_bias), 'flash attention does not support attention bias with gradients'