From 26c2f9fb9a91d9cdc0cf791104bd1c0122ea1bb6 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Fri, 27 Sep 2024 10:45:35 -0700 Subject: [PATCH] add some more logic before getting to atom radii tomorrow --- alphafold3_pytorch/alphafold3.py | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/alphafold3_pytorch/alphafold3.py b/alphafold3_pytorch/alphafold3.py index 9fe43e7a..18f3c5db 100644 --- a/alphafold3_pytorch/alphafold3.py +++ b/alphafold3_pytorch/alphafold3.py @@ -345,6 +345,17 @@ def mean_pool_with_lens( lens: Int['b n'] ) -> Float['b n d']: + summed, mask = sum_pool_with_lens(feats, lens) + avg = einx.divide('b n d, b n', summed, lens.clamp(min = 1)) + avg = einx.where('b n, b n d, -> b n d', mask, avg, 0.) + return avg + +@typecheck +def sum_pool_with_lens( + feats: Float['b m d'], + lens: Int['b n'] +) -> tuple[Float['b n d'], Bool['b n']]: + seq_len = feats.shape[1] mask = lens > 0 @@ -364,9 +375,7 @@ def mean_pool_with_lens( # subtract cumsum at one index from the previous one summed = sel_cumsum[:, 1:] - sel_cumsum[:, :-1] - avg = einx.divide('b n d, b n', summed, lens.clamp(min = 1)) - avg = einx.where('b n, b n d, -> b n d', mask, avg, 0.) - return avg + return summed, mask @typecheck def mean_pool_fixed_windows_with_mask( @@ -5785,18 +5794,26 @@ def _inhouse_compute_unresolved_rasa( per_atom_accessible_surface_score = reduce(score * radius ** 2, 'm sd -> m') + # sum up all surface scores for atoms per residue + # the final score seems to be the average of the rsa across all residues (selected by `chain_unresolved_residue_mask`) + + rasa, mask = sum_pool_with_lens( + rearrange(per_atom_accessible_surface_score, '... -> 1 ...'), + rearrange(chain_molecule_atom_lens, '... -> 1 ...') + ) + + rasa = einx.where('b n, b n, -> b n', mask, rasa, 0.) + + rasa = rearrange(rasa, '1 ... -> ...') + # rest written by @xluo - rasa = [] aatypes = [] for residue in structure.get_residues(): - rsa = float(dssp_dict.get((residue.get_full_id()[2], residue.id))[3]) - rasa.append(rsa) aatype = dssp_dict.get((residue.get_full_id()[2], residue.id))[1] aatypes.append(residue_constants.restype_order[aatype]) - rasa = torch.tensor(rasa, dtype=dtype, device=device) aatypes = torch.tensor(aatypes, device=device).int() unresolved_aatypes = aatypes[chain_unresolved_residue_mask]