From 1936f1e3655dd55d30f0931b7f47ab9b7c200036 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sun, 19 May 2024 10:53:00 -0700 Subject: [PATCH] complete the main alphafold2 flow, sans diffusion module and losses + sampling --- README.md | 50 +++++++ alphafold3_pytorch/alphafold3.py | 247 +++++++++++++++++++++++++++++-- tests/test_readme.py | 34 +++++ 3 files changed, 322 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index ae8cf2b6..b8384a33 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,56 @@ Implementation of A Getting a fair number of emails. You can chat with me about this work here +## Install + +```bash +$ pip install alphafold3-pytorch +``` + +## Usage + +```python +import torch +from alphafold3_pytorch import Alphafold3 + +alphafold3 = Alphafold3( + dim_atom_inputs = 77, + dim_additional_residue_feats = 33, + dim_template_feats = 44 +) + +# mock inputs + +seq_len = 16 +atom_seq_len = seq_len * 27 + +atom_inputs = torch.randn(2, atom_seq_len, 77) +atom_mask = torch.ones((2, atom_seq_len)).bool() +atompair_feats = torch.randn(2, atom_seq_len, atom_seq_len, 16) +additional_residue_feats = torch.randn(2, seq_len, 33) + +template_feats = torch.randn(2, 2, seq_len, seq_len, 44) +template_mask = torch.ones((2, 2)).bool() + +msa = torch.randn(2, 7, seq_len, 64) + +# train + +loss = alphafold3( + num_recycling_steps = 2, + atom_inputs = atom_inputs, + atom_mask = atom_mask, + atompair_feats = atompair_feats, + additional_residue_feats = additional_residue_feats, + msa = msa, + templates = template_feats, + template_mask = template_mask +) + +loss.backward() + +``` + ## Citations ```bibtex diff --git a/alphafold3_pytorch/alphafold3.py b/alphafold3_pytorch/alphafold3.py index 9c6f5b42..8f808a67 100644 --- a/alphafold3_pytorch/alphafold3.py +++ b/alphafold3_pytorch/alphafold3.py @@ -7,6 +7,7 @@ i - residue sequence length (source) j - residue sequence length (target) m - atom sequence length +c - coordinates (3 for spatial) d - feature dimension ds - feature dimension (single) dp - feature dimension (pairwise) @@ -861,7 +862,7 @@ def __init__( # final projection of mean pooled repr -> out self.to_out = nn.Sequential( - LinearNoBias(dim, dim), + LinearNoBias(dim, dim_pairwise), nn.ReLU() ) @@ -873,7 +874,7 @@ def forward( template_mask: Bool['b t'], pairwise_repr: Float['b n n dp'], mask: Bool['b n'] | None = None, - ) -> Float['b n n d']: + ) -> Float['b n n dp']: num_templates = templates.shape[1] @@ -884,7 +885,8 @@ def forward( v, merged_batch_ps = pack_one(v, '* i j d') - mask = repeat(mask, 'b n -> (b t) n', t = num_templates) + if exists(mask): + mask = repeat(mask, 'b n -> (b t) n', t = num_templates) for block in self.pairformer_stack: v = block( @@ -1815,7 +1817,7 @@ def forward( pairwise_repr: Float['b n n dp'], pred_atom_pos: Float['b n c'], mask: Bool['b n'] | None = None, - calc_pae_logits_and_loss = True + return_pae_logits = True ) -> ConfidenceHeadLogits[ Float['b pae n n'] | None, @@ -1854,7 +1856,7 @@ def forward( pae_logits = None - if calc_pae_logits_and_loss: + if return_pae_logits: pae_logits = self.to_pae_logits(pairwise_repr) # return all logits @@ -1863,21 +1865,248 @@ def forward( # main class +LossBreakdown = namedtuple('LossBreakdown', [ + 'distogram', + 'pae', + 'pdt', + 'plddt', + 'resolved' +]) + class Alphafold3(Module): + """ Algorithm 1 """ + + @typecheck def __init__( self, *, + dim_atom_inputs, + dim_additional_residue_feats, + dim_template_feats, + dim_template_model = 64, + atoms_per_window = 27, + dim_atom = 128, + dim_atompair = 16, + dim_input_embedder_token = 384, + dim_single = 384, + dim_pairwise = 128, + atompair_dist_bins: Float[' dist_bins'] = torch.linspace(3, 20, 37), + ignore_index = -1, + num_dist_bins = 38, + num_plddt_bins = 50, + num_pde_bins = 64, + num_pae_bins = 64, loss_confidence_weight = 1e-4, loss_distogram_weight = 1e-2, - loss_diffusion = 4. + loss_diffusion_weight = 4., + input_embedder_kwargs: dict = dict( + atom_transformer_blocks = 3, + atom_transformer_heads = 4, + atom_transformer_kwargs = dict() + ), + confidence_head_kwargs: dict = dict( + pairformer_depth = 4 + ), + template_embedder_kwargs: dict = dict( + pairformer_stack_depth = 2, + pairwise_block_kwargs = dict(), + ), + msa_module_kwargs: dict = dict( + depth = 4, + dim_msa = 64, + dim_msa_input = None, + outer_product_mean_dim_hidden = 32, + msa_pwa_dropout_row_prob = 0.15, + msa_pwa_heads = 8, + msa_pwa_dim_head = 32, + pairwise_block_kwargs = dict() + ), + pairformer_stack: dict = dict( + depth = 48, + pair_bias_attn_dim_head = 64, + pair_bias_attn_heads = 16, + dropout_row_prob = 0.25, + pairwise_block_kwargs = dict() + ) ): super().__init__() + self.atoms_per_window = atoms_per_window + + # input feature embedder + + self.input_embedder = InputFeatureEmbedder( + dim_atom_inputs = dim_atom_inputs, + dim_additional_residue_feats = dim_additional_residue_feats, + atoms_per_window = atoms_per_window, + dim_atom = dim_atom, + dim_atompair = dim_atompair, + dim_token = dim_input_embedder_token, + dim_single = dim_single, + dim_pairwise = dim_pairwise, + **input_embedder_kwargs + ) + + dim_single_inputs = dim_input_embedder_token + dim_additional_residue_feats + + # templates + + self.template_embedder = TemplateEmbedder( + dim_template_feats = dim_template_feats, + dim = dim_template_model, + dim_pairwise = dim_pairwise, + **template_embedder_kwargs + ) + + # msa + + self.msa_module = MSAModule( + dim_single = dim_single, + dim_pairwise = dim_pairwise, + **msa_module_kwargs + ) + + # main pairformer trunk, 48 layers + + self.pairformer = PairformerStack( + dim_single = dim_single, + dim_pairwise = dim_pairwise, + **pairformer_stack + ) + + # recycling related + + self.recycle_single = nn.Sequential( + nn.LayerNorm(dim_single), + LinearNoBias(dim_single, dim_single) + ) + + self.recycle_pairwise = nn.Sequential( + nn.LayerNorm(dim_pairwise), + LinearNoBias(dim_pairwise, dim_pairwise) + ) + + # logit heads + + self.distogram_head = DistogramHead( + dim_pairwise = dim_pairwise, + num_dist_bins = num_dist_bins + ) + + self.confidence_head = ConfidenceHead( + dim_single_inputs = dim_single_inputs, + atompair_dist_bins = atompair_dist_bins, + dim_single = dim_single, + dim_pairwise = dim_pairwise, + num_plddt_bins = num_plddt_bins, + num_pde_bins = num_pde_bins, + num_pae_bins = num_pae_bins, + **confidence_head_kwargs + ) + + # loss related + + self.ignore_index = ignore_index + self.loss_distogram_weight = loss_distogram_weight + self.loss_confidence_weight = loss_confidence_weight + self.loss_diffusion_weight = loss_diffusion_weight @typecheck def forward( self, *, - include_pae_loss = False # turned on in latter part of training - ): - return + atom_inputs: Float['b m dai'], + atom_mask: Bool['b m'], + atompair_feats: Float['b m m dap'], + additional_residue_feats: Float['b n rf'], + msa: Float['b s n d'], + templates: Float['b t n n dt'], + template_mask: Bool['b t'], + num_recycling_steps: int = 1, + distance_labels: Int['b n n'] | None = None, + pae_labels: Int['b n n'] | None = None, + pde_labels: Int['b n n'] | None = None, + plddt_labels: Int['b n'] | None = None, + resolved_labels: Int['b n'] | None = None, + ) -> Float['b m c'] | Float['']: + + # embed inputs + + ( + single_inputs, + single_init, + pairwise_init, + atom_feats, + atompair_feats + ) = self.input_embedder( + atom_inputs = atom_inputs, + atom_mask = atom_mask, + atompair_feats = atompair_feats, + additional_residue_feats = additional_residue_feats + ) + + w = self.atoms_per_window + + mask = reduce(atom_mask, 'b (n w) -> b n', w = w, reduction = 'any') + + # init recycled single and pairwise + + recycled_pairwise = recycled_single = None + single = pairwise = None + + # for each recycling step + + for _ in range(num_recycling_steps): + + # handle recycled single and pairwise if not first step + + recycled_single = recycled_pairwise = 0. + + if exists(single): + recycled_single = self.recycle_single(single) + + if exists(pairwise): + recycled_pairwise = self.recycle_pairwise(pairwise) + + single = single_init + recycled_single + pairwise = pairwise_init + recycled_pairwise + + # else go through main transformer trunk from alphafold2 + + # templates + + embedded_template = self.template_embedder( + templates = templates, + template_mask = template_mask, + pairwise_repr = pairwise, + mask = mask + ) + + pairwise = embedded_template + pairwise + + # msa + + embedded_msa = self.msa_module( + msa = msa, + single_repr = single, + pairwise_repr = pairwise, + mask = mask + ) + + pairwise = embedded_msa + pairwise + + # main attention trunk (pairformer) + + single, pairwise = self.pairformer( + single_repr = single, + pairwise_repr = pairwise, + mask = mask + ) + + # determine whether to return loss if any labels were to be passed in + # otherwise will sample the atomic coordinates + + labels = (distance_labels, pae_labels, pde_labels, plddt_labels, resolved_labels) + return_loss = any([*filter(exists, labels)]) + + return torch.tensor(0.) diff --git a/tests/test_readme.py b/tests/test_readme.py index 0ccec779..2fb61509 100644 --- a/tests/test_readme.py +++ b/tests/test_readme.py @@ -230,3 +230,37 @@ def test_distogram_head(): distogram_head = DistogramHead(dim_pairwise = 128) logits = distogram_head(pairwise_repr) + + +def test_alphafold3(): + seq_len = 16 + atom_seq_len = seq_len * 27 + + atom_inputs = torch.randn(2, atom_seq_len, 77) + atom_mask = torch.ones((2, atom_seq_len)).bool() + atompair_feats = torch.randn(2, atom_seq_len, atom_seq_len, 16) + additional_residue_feats = torch.randn(2, seq_len, 33) + + template_feats = torch.randn(2, 2, seq_len, seq_len, 44) + template_mask = torch.ones((2, 2)).bool() + + msa = torch.randn(2, 7, seq_len, 64) + + alphafold3 = Alphafold3( + dim_atom_inputs = 77, + dim_additional_residue_feats = 33, + dim_template_feats = 44 + ) + + loss = alphafold3( + num_recycling_steps = 2, + atom_inputs = atom_inputs, + atom_mask = atom_mask, + atompair_feats = atompair_feats, + additional_residue_feats = additional_residue_feats, + msa = msa, + templates = template_feats, + template_mask = template_mask + ) + + print(loss)