diff --git a/alphafold3_pytorch/alphafold3.py b/alphafold3_pytorch/alphafold3.py index b6619acc..12b58cfa 100644 --- a/alphafold3_pytorch/alphafold3.py +++ b/alphafold3_pytorch/alphafold3.py @@ -7407,6 +7407,8 @@ def forward( denoised_molecule_pos = denoised_atom_pos.gather(1, distogram_atom_coords_indices) + # get frames atom positions + # three_atoms = einx.get_at('b [m] c, b n three -> three b n c', atom_pos, atom_indices_for_frame) # pred_three_atoms = einx.get_at('b [m] c, b n three -> three b n c', denoised_atom_pos, atom_indices_for_frame) @@ -7421,10 +7423,8 @@ def forward( three_atoms = three_atom_pos.gather(2, atom_indices_for_frame) pred_three_atoms = three_denoised_atom_pos.gather(2, atom_indices_for_frame) - # compute frames - - frames, _ = self.rigid_from_three_points(three_atoms) - pred_frames, _ = self.rigid_from_three_points(pred_three_atoms) + frame_atoms = rearrange(three_atoms, "three b n c -> b n c three") + pred_frame_atoms = rearrange(pred_three_atoms, "three b n c -> b n c three") # determine mask # must be amino acid, nucleotide, or ligand with greater than 0 atoms @@ -7436,8 +7436,8 @@ def forward( align_error = self.compute_alignment_error( denoised_molecule_pos, molecule_pos, - pred_frames, - frames, + pred_frame_atoms, # In the paragraph 2 of section 4.3.2, the Phi_i denotes the coordinates of these frame atoms rather than the rotation matrix. + frame_atoms, mask=align_error_mask, )