diff --git a/alphafold3_pytorch/alphafold3.py b/alphafold3_pytorch/alphafold3.py index 818d1d7f..e1f5d872 100644 --- a/alphafold3_pytorch/alphafold3.py +++ b/alphafold3_pytorch/alphafold3.py @@ -5717,7 +5717,8 @@ def _inhouse_compute_unresolved_rasa( molecule_atom_lens: Int[" n"], atom_pos: Float["m 3"], atom_mask: Bool[" m"], - fibonacci_sphere_n = 200 # they use 200 in mkdssp + fibonacci_sphere_n = 200, # they use 200 in mkdssp, but can be tailored for efficiency + atom_distance_min_thres = 1e-4 ) -> Float[""]: """Compute the unresolved relative solvent accessible surface area (RASA) for proteins. using inhouse rebuilt RSA calculation @@ -5779,11 +5780,12 @@ def _inhouse_compute_unresolved_rasa( atom_radii: Float[' m'] = self.atom_radii[structure_atom_type_for_radii] - atom_radii_sq = atom_radii.pow(2) # they use square of distance / radius to save on sqrt + water_radii = self.atom_radii[-1] - # they use the water molecule radii for some stuff + # atom radii is always summed with water radii - water_radii = self.atom_radii[-1] + atom_radii += water_radii + atom_radii_sq = atom_radii.pow(2) # always use square of radii or distance for comparison - save on sqrt # write custom RSA function here @@ -5809,19 +5811,32 @@ def _inhouse_compute_unresolved_rasa( lat.sin() ), dim = -1) - # overall logic + # first get atom relative positions + distance + # for determining whether to include pairs of atom in calculation for the `free` adjective atom_rel_pos = einx.subtract('j c, i c -> i j c', structure_atom_pos, structure_atom_pos) + atom_rel_dist_sq = atom_rel_pos.pow(2).sum(dim = -1) - surface_dots = einx.multiply('m, sd c -> m sd c', atom_radii + water_radii, unit_surface_dots) + max_distance_include = einx.add('i, j -> i j', atom_radii, atom_radii).pow(2) + + include_in_free_calc = ( + (atom_rel_dist_sq < max_distance_include) & + (atom_rel_dist_sq > atom_distance_min_thres) + ) + + # overall logic + + surface_dots = einx.multiply('m, sd c -> m sd c', atom_radii, unit_surface_dots) dist_from_surface_dots_sq = einx.subtract('i j c, i sd c -> i sd j c', atom_rel_pos, surface_dots).pow(2).sum(dim = -1) target_atom_close_to_surface_dots = einx.less('j, i sd j -> i sd j', atom_radii_sq, dist_from_surface_dots_sq) - free = reduce(target_atom_close_to_surface_dots, 'i sd j -> i sd', 'all') + target_atom_close_or_not_included = einx.logical_or('i sd j, i j -> i sd j', target_atom_close_to_surface_dots, ~include_in_free_calc) + + is_free = reduce(target_atom_close_or_not_included, 'i sd j -> i sd', 'all') # basically the most important line, calculating whether an atom is free by some distance measure - score = reduce(free.float() * weight, 'm sd -> m', 'sum') + score = reduce(is_free.float() * weight, 'm sd -> m', 'sum') per_atom_accessible_surface_score = score * atom_radii_sq