Skip to content

Commit

Permalink
complete register tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 25, 2024
1 parent 936210c commit f746f7b
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 6 deletions.
41 changes: 37 additions & 4 deletions alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1323,6 +1323,7 @@ def __init__(
dim_pairwise = 128,
attn_window_size = None,
attn_pair_bias_kwargs: dict = dict(),
num_register_tokens = 0,
serial = False
):
super().__init__()
Expand Down Expand Up @@ -1365,6 +1366,12 @@ def __init__(

self.serial = serial

self.has_registers = num_register_tokens > 0
self.num_registers = num_register_tokens

if self.has_registers:
self.registers = nn.Parameter(torch.zeros(num_register_tokens, dim))

@typecheck
def forward(
self,
Expand All @@ -1376,6 +1383,21 @@ def forward(
):
serial = self.serial

# register tokens

if self.has_registers:
num_registers = self.num_registers
registers = repeat(self.registers, 'r d -> b r d', b = noised_repr.shape[0])
noised_repr, registers_ps = pack((registers, noised_repr), 'b * d')

single_repr = F.pad(single_repr, (0, 0, num_registers, 0), value = 0.)
pairwise_repr = F.pad(pairwise_repr, (0, 0, num_registers, 0, num_registers, 0), value = 0.)

if exists(mask):
mask = F.pad(mask, (num_registers, 0), value = True)

# main transformer

for attn, transition in self.layers:

attn_out = attn(
Expand All @@ -1398,6 +1420,11 @@ def forward(

noised_repr = noised_repr + ff_out

# splice out registers

if self.has_registers:
_, noised_repr = unpack(noised_repr, registers_ps, 'b * d')

return noised_repr

class AtomToTokenPooler(Module):
Expand Down Expand Up @@ -1487,7 +1514,10 @@ def __init__(
token_transformer_heads = 16,
atom_decoder_depth = 3,
atom_decoder_heads = 4,
serial = False
serial = False,
atom_encoder_kwargs: dict = dict(),
atom_decoder_kwargs: dict = dict(),
token_transformer_kwargs: dict = dict()
):
super().__init__()

Expand Down Expand Up @@ -1543,7 +1573,8 @@ def __init__(
attn_window_size = atoms_per_window,
depth = atom_encoder_depth,
heads = atom_encoder_heads,
serial = serial
serial = serial,
**atom_encoder_kwargs
)

self.atom_feats_to_pooled_token = AtomToTokenPooler(
Expand All @@ -1565,7 +1596,8 @@ def __init__(
dim_pairwise = dim_pairwise,
depth = token_transformer_depth,
heads = token_transformer_heads,
serial = serial
serial = serial,
**token_transformer_kwargs
)

self.attended_token_norm = nn.LayerNorm(dim_token)
Expand All @@ -1581,7 +1613,8 @@ def __init__(
attn_window_size = atoms_per_window,
depth = atom_decoder_depth,
heads = atom_decoder_heads,
serial = serial
serial = serial,
**atom_decoder_kwargs
)

self.atom_feat_to_atom_pos_update = nn.Sequential(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "alphafold3-pytorch"
version = "0.0.40"
version = "0.0.41"
description = "Alphafold 3 - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
5 changes: 4 additions & 1 deletion tests/test_af3.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,10 @@ def test_diffusion_module():
dim_pairwise_rel_pos_feats = 12,
atom_encoder_depth = 1,
atom_decoder_depth = 1,
token_transformer_depth = 1
token_transformer_depth = 1,
token_transformer_kwargs = dict(
num_register_tokens = 2
)
)

atom_pos_update = diffusion_module(
Expand Down

0 comments on commit f746f7b

Please sign in to comment.