From 64066cb155c9fc5fd7a18636d6afe060b265a027 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Thu, 3 Oct 2024 16:24:09 -0500 Subject: [PATCH 1/2] Update attention.py --- alphafold3_pytorch/attention.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/alphafold3_pytorch/attention.py b/alphafold3_pytorch/attention.py index 73e49714..a3159b9f 100644 --- a/alphafold3_pytorch/attention.py +++ b/alphafold3_pytorch/attention.py @@ -178,7 +178,6 @@ def __init__( num_memory_kv: int = 0, enable_attn_softclamp = False, attn_softclamp_value = 50., - init_gate_bias = -2., softmax_full_precision = False ): super().__init__() @@ -224,11 +223,7 @@ def __init__( self.to_gates = None if gate_output: - gate_linear = nn.Linear(dim, dim_inner) - nn.init.zeros_(gate_linear.weight) - nn.init.constant_(gate_linear.bias, init_gate_bias) - - self.to_gates = gate_linear + self.to_gates = nn.Sequential(nn.Linear(dim, dim_inner, bias = False), nn.Sigmoid()) @typecheck def forward( @@ -266,7 +261,7 @@ def forward( if exists(self.to_gates): gates = self.to_gates(seq) - out = out * gates.sigmoid() + out = out * gates # combine heads From 9b96216a054576f936986ed0f86ff4993ff92d37 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Thu, 3 Oct 2024 16:46:59 -0500 Subject: [PATCH 2/2] Update attention.py --- alphafold3_pytorch/attention.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/alphafold3_pytorch/attention.py b/alphafold3_pytorch/attention.py index a3159b9f..3b4eb725 100644 --- a/alphafold3_pytorch/attention.py +++ b/alphafold3_pytorch/attention.py @@ -1,5 +1,6 @@ from __future__ import annotations from beartype.typing import NamedTuple, Tuple +from functools import partial import torch from torch import nn, Tensor @@ -18,6 +19,10 @@ typecheck ) +# alias + +LinearNoBias = partial(nn.Linear, bias = False) + # helpers def exists(val): @@ -208,8 +213,8 @@ def __init__( self.merge_heads = Rearrange('b h n d -> b n (h d)') self.to_q = nn.Linear(dim, dim_inner, bias = query_bias) - self.to_kv = nn.Linear(dim, dim_inner * 2, bias = False) - self.to_out = nn.Linear(dim_inner, dim, bias = False) + self.to_kv = LinearNoBias(dim, dim_inner * 2) + self.to_out = LinearNoBias(dim_inner, dim) self.memory_kv = None @@ -223,7 +228,7 @@ def __init__( self.to_gates = None if gate_output: - self.to_gates = nn.Sequential(nn.Linear(dim, dim_inner, bias = False), nn.Sigmoid()) + self.to_gates = nn.Sequential(LinearNoBias(dim, dim_inner), nn.Sigmoid()) @typecheck def forward(