Skip to content

Commit

Permalink
add some more logic before getting to atom radii tomorrow
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 27, 2024
1 parent f81055a commit 26c2f9f
Showing 1 changed file with 24 additions and 7 deletions.
31 changes: 24 additions & 7 deletions alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,17 @@ def mean_pool_with_lens(
lens: Int['b n']
) -> Float['b n d']:

summed, mask = sum_pool_with_lens(feats, lens)
avg = einx.divide('b n d, b n', summed, lens.clamp(min = 1))
avg = einx.where('b n, b n d, -> b n d', mask, avg, 0.)
return avg

@typecheck
def sum_pool_with_lens(
feats: Float['b m d'],
lens: Int['b n']
) -> tuple[Float['b n d'], Bool['b n']]:

seq_len = feats.shape[1]

mask = lens > 0
Expand All @@ -364,9 +375,7 @@ def mean_pool_with_lens(
# subtract cumsum at one index from the previous one
summed = sel_cumsum[:, 1:] - sel_cumsum[:, :-1]

avg = einx.divide('b n d, b n', summed, lens.clamp(min = 1))
avg = einx.where('b n, b n d, -> b n d', mask, avg, 0.)
return avg
return summed, mask

@typecheck
def mean_pool_fixed_windows_with_mask(
Expand Down Expand Up @@ -5785,18 +5794,26 @@ def _inhouse_compute_unresolved_rasa(

per_atom_accessible_surface_score = reduce(score * radius ** 2, 'm sd -> m')

# sum up all surface scores for atoms per residue
# the final score seems to be the average of the rsa across all residues (selected by `chain_unresolved_residue_mask`)

rasa, mask = sum_pool_with_lens(
rearrange(per_atom_accessible_surface_score, '... -> 1 ...'),
rearrange(chain_molecule_atom_lens, '... -> 1 ...')
)

rasa = einx.where('b n, b n, -> b n', mask, rasa, 0.)

rasa = rearrange(rasa, '1 ... -> ...')

# rest written by @xluo

rasa = []
aatypes = []
for residue in structure.get_residues():
rsa = float(dssp_dict.get((residue.get_full_id()[2], residue.id))[3])
rasa.append(rsa)

aatype = dssp_dict.get((residue.get_full_id()[2], residue.id))[1]
aatypes.append(residue_constants.restype_order[aatype])

rasa = torch.tensor(rasa, dtype=dtype, device=device)
aatypes = torch.tensor(aatypes, device=device).int()

unresolved_aatypes = aatypes[chain_unresolved_residue_mask]
Expand Down

0 comments on commit 26c2f9f

Please sign in to comment.