Skip to content

Commit

Permalink
take a small optimization step for the torch rsa calculation for mode…
Browse files Browse the repository at this point in the history
…l selection
  • Loading branch information
lucidrains committed Sep 28, 2024
1 parent 097cbff commit bdfd89a
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 23 deletions.
66 changes: 44 additions & 22 deletions alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5707,6 +5707,36 @@ def _compute_unresolved_rasa(

return unresolved_rasa.mean()

@typecheck
def calc_atom_access_surface_score_from_structure(
self,
structure: Structure,
**kwargs
) -> Float['m']:

# use the structure as source of truth, matching what xluo did

structure_atom_pos = []
structure_atom_type_for_radii = []
side_atom_index = len(self.atom_type_index)

for atom in structure.get_atoms():

one_atom_pos = list(atom.get_vector())
one_atom_type = self.atom_type_index.get(atom.name, side_atom_index)

structure_atom_pos.append(one_atom_pos)
structure_atom_type_for_radii.append(one_atom_type)

structure_atom_pos: Float[' m 3'] = tensor(structure_atom_pos)
structure_atom_type_for_radii: Int[' m'] = tensor(structure_atom_type_for_radii)

return self.calc_atom_access_surface_score(
atom_pos = structure_atom_pos,
atom_type = structure_atom_type_for_radii,
**kwargs
)

@typecheck
def calc_atom_access_surface_score(
self,
Expand Down Expand Up @@ -5749,7 +5779,7 @@ def calc_atom_access_surface_score(
lat.sin()
), dim = -1)

# first get atom relative positions + distance
# get atom relative positions + distance
# for determining whether to include pairs of atom in calculation for the `free` adjective

atom_rel_pos = einx.subtract('j c, i c -> i j c', atom_pos, atom_pos)
Expand All @@ -5762,13 +5792,23 @@ def calc_atom_access_surface_score(
(atom_rel_dist_sq > atom_distance_min_thres)
)

# max included in calculation per row

max_included = include_in_free_calc.long().sum(dim = -1).amax()

include_indices = include_in_free_calc.long().topk(max_included, dim = -1).indices

include_in_free_calc = einx.get_at('i [m], i j -> i j', include_in_free_calc, include_indices)
atom_rel_pos = einx.get_at('i [m] c, i j -> i j c', atom_rel_pos, include_indices)
target_atom_radii_sq = einx.get_at('[m], i j -> i j', atom_radii_sq, include_indices)

# overall logic

surface_dots = einx.multiply('m, sd c -> m sd c', atom_radii, unit_surface_dots)

dist_from_surface_dots_sq = einx.subtract('i j c, i sd c -> i sd j c', atom_rel_pos, surface_dots).pow(2).sum(dim = -1)

target_atom_close_to_surface_dots = einx.less('j, i sd j -> i sd j', atom_radii_sq, dist_from_surface_dots_sq)
target_atom_close_to_surface_dots = einx.less('i j, i sd j -> i sd j', target_atom_radii_sq, dist_from_surface_dots_sq)

target_atom_close_or_not_included = einx.logical_or('i sd j, i j -> i sd j', target_atom_close_to_surface_dots, ~include_in_free_calc)

Expand Down Expand Up @@ -5831,28 +5871,10 @@ def _inhouse_compute_unresolved_rasa(
chain_atom_mask,
)

# use the structure as source of truth, matching what xluo did

structure_atom_pos = []
structure_atom_type_for_radii = []
side_atom_index = len(self.atom_type_index)

for atom in structure.get_atoms():

one_atom_pos = list(atom.get_vector())
one_atom_type = self.atom_type_index.get(atom.name, side_atom_index)

structure_atom_pos.append(one_atom_pos)
structure_atom_type_for_radii.append(one_atom_type)

structure_atom_pos: Float[' m 3'] = tensor(structure_atom_pos)
structure_atom_type_for_radii: Int[' m'] = tensor(structure_atom_type_for_radii)

# per atom rsa calculation

per_atom_access_surface_score = self.calc_atom_access_surface_score(
structure_atom_pos,
structure_atom_type_for_radii,
per_atom_access_surface_score = self.calc_atom_access_surface_score_from_structure(
structure,
**rsa_calc_kwargs
)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "alphafold3-pytorch"
version = "0.5.48"
version = "0.5.49"
description = "Alphafold 3 - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" },
Expand Down

0 comments on commit bdfd89a

Please sign in to comment.