Skip to content

Commit

Permalink
get the rest of the rsa logic in, for how they determine which pair o…
Browse files Browse the repository at this point in the history
…f atoms to include in the "free" calculation against the fibonacci sphere surface dots
  • Loading branch information
lucidrains committed Sep 28, 2024
1 parent 1ed4cd8 commit d873666
Showing 1 changed file with 23 additions and 8 deletions.
31 changes: 23 additions & 8 deletions alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5717,7 +5717,8 @@ def _inhouse_compute_unresolved_rasa(
molecule_atom_lens: Int[" n"],
atom_pos: Float["m 3"],
atom_mask: Bool[" m"],
fibonacci_sphere_n = 200 # they use 200 in mkdssp
fibonacci_sphere_n = 200, # they use 200 in mkdssp, but can be tailored for efficiency
atom_distance_min_thres = 1e-4
) -> Float[""]:
"""Compute the unresolved relative solvent accessible surface area (RASA) for proteins.
using inhouse rebuilt RSA calculation
Expand Down Expand Up @@ -5779,11 +5780,12 @@ def _inhouse_compute_unresolved_rasa(

atom_radii: Float[' m'] = self.atom_radii[structure_atom_type_for_radii]

atom_radii_sq = atom_radii.pow(2) # they use square of distance / radius to save on sqrt
water_radii = self.atom_radii[-1]

# they use the water molecule radii for some stuff
# atom radii is always summed with water radii

water_radii = self.atom_radii[-1]
atom_radii += water_radii
atom_radii_sq = atom_radii.pow(2) # always use square of radii or distance for comparison - save on sqrt

# write custom RSA function here

Expand All @@ -5809,19 +5811,32 @@ def _inhouse_compute_unresolved_rasa(
lat.sin()
), dim = -1)

# overall logic
# first get atom relative positions + distance
# for determining whether to include pairs of atom in calculation for the `free` adjective

atom_rel_pos = einx.subtract('j c, i c -> i j c', structure_atom_pos, structure_atom_pos)
atom_rel_dist_sq = atom_rel_pos.pow(2).sum(dim = -1)

surface_dots = einx.multiply('m, sd c -> m sd c', atom_radii + water_radii, unit_surface_dots)
max_distance_include = einx.add('i, j -> i j', atom_radii, atom_radii).pow(2)

include_in_free_calc = (
(atom_rel_dist_sq < max_distance_include) &
(atom_rel_dist_sq > atom_distance_min_thres)
)

# overall logic

surface_dots = einx.multiply('m, sd c -> m sd c', atom_radii, unit_surface_dots)

dist_from_surface_dots_sq = einx.subtract('i j c, i sd c -> i sd j c', atom_rel_pos, surface_dots).pow(2).sum(dim = -1)

target_atom_close_to_surface_dots = einx.less('j, i sd j -> i sd j', atom_radii_sq, dist_from_surface_dots_sq)

free = reduce(target_atom_close_to_surface_dots, 'i sd j -> i sd', 'all')
target_atom_close_or_not_included = einx.logical_or('i sd j, i j -> i sd j', target_atom_close_to_surface_dots, ~include_in_free_calc)

is_free = reduce(target_atom_close_or_not_included, 'i sd j -> i sd', 'all') # basically the most important line, calculating whether an atom is free by some distance measure

score = reduce(free.float() * weight, 'm sd -> m', 'sum')
score = reduce(is_free.float() * weight, 'm sd -> m', 'sum')

per_atom_accessible_surface_score = score * atom_radii_sq

Expand Down

0 comments on commit d873666

Please sign in to comment.