From f746f7bebfed08c6e8ee5ce0474d3ea926d2477e Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sat, 25 May 2024 11:53:40 -0700 Subject: [PATCH] complete register tokens --- alphafold3_pytorch/alphafold3.py | 41 ++++++++++++++++++++++++++++---- pyproject.toml | 2 +- tests/test_af3.py | 5 +++- 3 files changed, 42 insertions(+), 6 deletions(-) diff --git a/alphafold3_pytorch/alphafold3.py b/alphafold3_pytorch/alphafold3.py index 84628c84..f089e963 100644 --- a/alphafold3_pytorch/alphafold3.py +++ b/alphafold3_pytorch/alphafold3.py @@ -1323,6 +1323,7 @@ def __init__( dim_pairwise = 128, attn_window_size = None, attn_pair_bias_kwargs: dict = dict(), + num_register_tokens = 0, serial = False ): super().__init__() @@ -1365,6 +1366,12 @@ def __init__( self.serial = serial + self.has_registers = num_register_tokens > 0 + self.num_registers = num_register_tokens + + if self.has_registers: + self.registers = nn.Parameter(torch.zeros(num_register_tokens, dim)) + @typecheck def forward( self, @@ -1376,6 +1383,21 @@ def forward( ): serial = self.serial + # register tokens + + if self.has_registers: + num_registers = self.num_registers + registers = repeat(self.registers, 'r d -> b r d', b = noised_repr.shape[0]) + noised_repr, registers_ps = pack((registers, noised_repr), 'b * d') + + single_repr = F.pad(single_repr, (0, 0, num_registers, 0), value = 0.) + pairwise_repr = F.pad(pairwise_repr, (0, 0, num_registers, 0, num_registers, 0), value = 0.) + + if exists(mask): + mask = F.pad(mask, (num_registers, 0), value = True) + + # main transformer + for attn, transition in self.layers: attn_out = attn( @@ -1398,6 +1420,11 @@ def forward( noised_repr = noised_repr + ff_out + # splice out registers + + if self.has_registers: + _, noised_repr = unpack(noised_repr, registers_ps, 'b * d') + return noised_repr class AtomToTokenPooler(Module): @@ -1487,7 +1514,10 @@ def __init__( token_transformer_heads = 16, atom_decoder_depth = 3, atom_decoder_heads = 4, - serial = False + serial = False, + atom_encoder_kwargs: dict = dict(), + atom_decoder_kwargs: dict = dict(), + token_transformer_kwargs: dict = dict() ): super().__init__() @@ -1543,7 +1573,8 @@ def __init__( attn_window_size = atoms_per_window, depth = atom_encoder_depth, heads = atom_encoder_heads, - serial = serial + serial = serial, + **atom_encoder_kwargs ) self.atom_feats_to_pooled_token = AtomToTokenPooler( @@ -1565,7 +1596,8 @@ def __init__( dim_pairwise = dim_pairwise, depth = token_transformer_depth, heads = token_transformer_heads, - serial = serial + serial = serial, + **token_transformer_kwargs ) self.attended_token_norm = nn.LayerNorm(dim_token) @@ -1581,7 +1613,8 @@ def __init__( attn_window_size = atoms_per_window, depth = atom_decoder_depth, heads = atom_decoder_heads, - serial = serial + serial = serial, + **atom_decoder_kwargs ) self.atom_feat_to_atom_pos_update = nn.Sequential( diff --git a/pyproject.toml b/pyproject.toml index 1dce5d44..ef3a2bfa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "alphafold3-pytorch" -version = "0.0.40" +version = "0.0.41" description = "Alphafold 3 - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_af3.py b/tests/test_af3.py index 253eaa9d..d47f6795 100644 --- a/tests/test_af3.py +++ b/tests/test_af3.py @@ -233,7 +233,10 @@ def test_diffusion_module(): dim_pairwise_rel_pos_feats = 12, atom_encoder_depth = 1, atom_decoder_depth = 1, - token_transformer_depth = 1 + token_transformer_depth = 1, + token_transformer_kwargs = dict( + num_register_tokens = 2 + ) ) atom_pos_update = diffusion_module(