From bdfd89aa1781b441563036d28f67288d57c5efe3 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sat, 28 Sep 2024 16:56:11 -0700 Subject: [PATCH] take a small optimization step for the torch rsa calculation for model selection --- alphafold3_pytorch/alphafold3.py | 66 +++++++++++++++++++++----------- pyproject.toml | 2 +- 2 files changed, 45 insertions(+), 23 deletions(-) diff --git a/alphafold3_pytorch/alphafold3.py b/alphafold3_pytorch/alphafold3.py index c0879900..07eb990e 100644 --- a/alphafold3_pytorch/alphafold3.py +++ b/alphafold3_pytorch/alphafold3.py @@ -5707,6 +5707,36 @@ def _compute_unresolved_rasa( return unresolved_rasa.mean() + @typecheck + def calc_atom_access_surface_score_from_structure( + self, + structure: Structure, + **kwargs + ) -> Float['m']: + + # use the structure as source of truth, matching what xluo did + + structure_atom_pos = [] + structure_atom_type_for_radii = [] + side_atom_index = len(self.atom_type_index) + + for atom in structure.get_atoms(): + + one_atom_pos = list(atom.get_vector()) + one_atom_type = self.atom_type_index.get(atom.name, side_atom_index) + + structure_atom_pos.append(one_atom_pos) + structure_atom_type_for_radii.append(one_atom_type) + + structure_atom_pos: Float[' m 3'] = tensor(structure_atom_pos) + structure_atom_type_for_radii: Int[' m'] = tensor(structure_atom_type_for_radii) + + return self.calc_atom_access_surface_score( + atom_pos = structure_atom_pos, + atom_type = structure_atom_type_for_radii, + **kwargs + ) + @typecheck def calc_atom_access_surface_score( self, @@ -5749,7 +5779,7 @@ def calc_atom_access_surface_score( lat.sin() ), dim = -1) - # first get atom relative positions + distance + # get atom relative positions + distance # for determining whether to include pairs of atom in calculation for the `free` adjective atom_rel_pos = einx.subtract('j c, i c -> i j c', atom_pos, atom_pos) @@ -5762,13 +5792,23 @@ def calc_atom_access_surface_score( (atom_rel_dist_sq > atom_distance_min_thres) ) + # max included in calculation per row + + max_included = include_in_free_calc.long().sum(dim = -1).amax() + + include_indices = include_in_free_calc.long().topk(max_included, dim = -1).indices + + include_in_free_calc = einx.get_at('i [m], i j -> i j', include_in_free_calc, include_indices) + atom_rel_pos = einx.get_at('i [m] c, i j -> i j c', atom_rel_pos, include_indices) + target_atom_radii_sq = einx.get_at('[m], i j -> i j', atom_radii_sq, include_indices) + # overall logic surface_dots = einx.multiply('m, sd c -> m sd c', atom_radii, unit_surface_dots) dist_from_surface_dots_sq = einx.subtract('i j c, i sd c -> i sd j c', atom_rel_pos, surface_dots).pow(2).sum(dim = -1) - target_atom_close_to_surface_dots = einx.less('j, i sd j -> i sd j', atom_radii_sq, dist_from_surface_dots_sq) + target_atom_close_to_surface_dots = einx.less('i j, i sd j -> i sd j', target_atom_radii_sq, dist_from_surface_dots_sq) target_atom_close_or_not_included = einx.logical_or('i sd j, i j -> i sd j', target_atom_close_to_surface_dots, ~include_in_free_calc) @@ -5831,28 +5871,10 @@ def _inhouse_compute_unresolved_rasa( chain_atom_mask, ) - # use the structure as source of truth, matching what xluo did - - structure_atom_pos = [] - structure_atom_type_for_radii = [] - side_atom_index = len(self.atom_type_index) - - for atom in structure.get_atoms(): - - one_atom_pos = list(atom.get_vector()) - one_atom_type = self.atom_type_index.get(atom.name, side_atom_index) - - structure_atom_pos.append(one_atom_pos) - structure_atom_type_for_radii.append(one_atom_type) - - structure_atom_pos: Float[' m 3'] = tensor(structure_atom_pos) - structure_atom_type_for_radii: Int[' m'] = tensor(structure_atom_type_for_radii) - # per atom rsa calculation - per_atom_access_surface_score = self.calc_atom_access_surface_score( - structure_atom_pos, - structure_atom_type_for_radii, + per_atom_access_surface_score = self.calc_atom_access_surface_score_from_structure( + structure, **rsa_calc_kwargs ) diff --git a/pyproject.toml b/pyproject.toml index 6b7da72f..036feb9d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "alphafold3-pytorch" -version = "0.5.48" +version = "0.5.49" description = "Alphafold 3 - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" },