-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
363 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from ._adjacent_residue_phi_cosine import ADJACENT_RESIDUE_PHI_COSINE | ||
from ._adjacent_residue_psi_cosine import ADJACENT_RESIDUE_PSI_COSINE | ||
from ._amino_acid_1 import AMINO_ACID_1 | ||
from ._amino_acid_1_to_amino_acid_3 import AMINO_ACID_1_TO_AMINO_ACID_3 | ||
from ._amino_acid_3 import AMINO_ACID_3 | ||
from ._amino_acid_3_to_amino_acid_1 import AMINO_ACID_3_TO_AMINO_ACID_1 | ||
from ._amino_acid_3_to_atom_14 import AMINO_ACID_3_TO_ATOM_14 | ||
|
||
__all__ = [ | ||
"ADJACENT_RESIDUE_PHI_COSINE", | ||
"ADJACENT_RESIDUE_PSI_COSINE", | ||
"AMINO_ACID_1", | ||
"AMINO_ACID_1_TO_AMINO_ACID_3", | ||
"AMINO_ACID_3", | ||
"AMINO_ACID_3_TO_AMINO_ACID_1", | ||
"AMINO_ACID_3_TO_ATOM_14", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
ADJACENT_RESIDUE_PHI_COSINE = [-0.5203, 0.0353] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
ADJACENT_RESIDUE_PSI_COSINE = [-0.4473, 0.0311] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
AMINO_ACID_1 = [ | ||
"A", | ||
"R", | ||
"N", | ||
"D", | ||
"C", | ||
"Q", | ||
"E", | ||
"G", | ||
"H", | ||
"I", | ||
"L", | ||
"K", | ||
"M", | ||
"F", | ||
"P", | ||
"S", | ||
"T", | ||
"W", | ||
"Y", | ||
"V", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
AMINO_ACID_1_TO_AMINO_ACID_3 = { | ||
"A": "ALA", | ||
"R": "ARG", | ||
"N": "ASN", | ||
"D": "ASP", | ||
"C": "CYS", | ||
"Q": "GLN", | ||
"E": "GLU", | ||
"G": "GLY", | ||
"H": "HIS", | ||
"I": "ILE", | ||
"L": "LEU", | ||
"K": "LYS", | ||
"M": "MET", | ||
"F": "PHE", | ||
"P": "PRO", | ||
"S": "SER", | ||
"T": "THR", | ||
"W": "TRP", | ||
"Y": "TYR", | ||
"V": "VAL", | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
AMINO_ACID_3 = [ | ||
"ALA", | ||
"ARG", | ||
"ASN", | ||
"ASP", | ||
"CYS", | ||
"GLN", | ||
"GLU", | ||
"GLY", | ||
"HIS", | ||
"ILE", | ||
"LEU", | ||
"LYS", | ||
"MET", | ||
"PHE", | ||
"PRO", | ||
"SER", | ||
"THR", | ||
"TRP", | ||
"TYR", | ||
"VAL", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
AMINO_ACID_3_TO_AMINO_ACID_1 = { | ||
"ALA": "A", | ||
"ARG": "R", | ||
"ASN": "N", | ||
"ASP": "D", | ||
"CYS": "C", | ||
"GLN": "Q", | ||
"GLU": "E", | ||
"GLY": "G", | ||
"HIS": "H", | ||
"ILE": "I", | ||
"LEU": "L", | ||
"LYS": "K", | ||
"MET": "M", | ||
"PHE": "F", | ||
"PRO": "P", | ||
"SER": "S", | ||
"THR": "T", | ||
"TRP": "W", | ||
"TYR": "Y", | ||
"VAL": "V", | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
AMINO_ACID_3_TO_ATOM_14 = { | ||
"ALA": ["N", "CA", "C", "O", "CB", "", "", "", "", "", "", "", "", "" ], | ||
"ARG": ["N", "CA", "C", "O", "CB", "CG", "CD", "NE", "CZ", "NH1", "NH2", "", "", "" ], | ||
"ASN": ["N", "CA", "C", "O", "CB", "CG", "OD1", "ND2", "", "", "", "", "", "" ], | ||
"ASP": ["N", "CA", "C", "O", "CB", "CG", "OD1", "OD2", "", "", "", "", "", "" ], | ||
"CYS": ["N", "CA", "C", "O", "CB", "SG", "", "", "", "", "", "", "", "" ], | ||
"GLN": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "NE2", "", "", "", "", "" ], | ||
"GLU": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "OE2", "", "", "", "", "" ], | ||
"GLY": ["N", "CA", "C", "O", "", "", "", "", "", "", "", "", "", "" ], | ||
"HIS": ["N", "CA", "C", "O", "CB", "CG", "ND1", "CD2", "CE1", "NE2", "", "", "", "" ], | ||
"ILE": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "CD1", "", "", "", "", "", "" ], | ||
"LEU": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "", "", "", "", "", "" ], | ||
"LYS": ["N", "CA", "C", "O", "CB", "CG", "CD", "CE", "NZ", "", "", "", "", "" ], | ||
"MET": ["N", "CA", "C", "O", "CB", "CG", "SD", "CE", "", "", "", "", "", "" ], | ||
"PHE": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "", "", "" ], | ||
"PRO": ["N", "CA", "C", "O", "CB", "CG", "CD", "", "", "", "", "", "", "" ], | ||
"SER": ["N", "CA", "C", "O", "CB", "OG", "", "", "", "", "", "", "", "" ], | ||
"THR": ["N", "CA", "C", "O", "CB", "OG1", "CG2", "", "", "", "", "", "", "" ], | ||
"TRP": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "NE1", "CE2", "CE3", "CZ2", "CZ3", "CH2"], | ||
"TYR": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "OH", "", "" ], | ||
"VAL": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "", "", "", "", "", "", "" ], | ||
"UNK": ["", "", "", "", "", "", "", "", "", "", "", "", "", "" ], | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
from typing import Sequence | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
from beignet.constants import ( | ||
ADJACENT_RESIDUE_PHI_COSINE, | ||
ADJACENT_RESIDUE_PSI_COSINE, | ||
AMINO_ACID_3, | ||
) | ||
|
||
|
||
def bond_length_violation_loss( | ||
pred_atom_positions: Tensor, # (*, N, 37/14, 3) | ||
pred_atom_mask: Tensor, # (*, N, 37/14) | ||
residue_index: Tensor, # (*, N) | ||
amino_acid: Tensor, # (*, N) | ||
tolerance_factor_soft=12.0, | ||
tolerance_factor_hard=12.0, | ||
) -> dict[str, Tensor]: | ||
r""" | ||
Parameters | ||
---------- | ||
pred_atom_positions : Tensor, shape=(*, N, 37/14, 3) | ||
Atom positions in atom37/14 representation. | ||
pred_atom_mask : Tensor, shape=(*, N, 37/14) | ||
Atom mask in atom37/14 representation. | ||
residue_index : Tensor, shape=(*, N) | ||
Residue index for given amino acid, this is assumed to be monotonically | ||
increasing. | ||
amino_acid : Tensor, shape=(*, N) | ||
Amino acid type of given residue. | ||
tolerance_factor_soft : float, optional | ||
Soft tolerance factor measured in standard deviations of pdb | ||
distributions. Default, 12.0. | ||
tolerance_factor_hard : float, optional | ||
Hard tolerance factor measured in standard deviations of pdb | ||
distributions. Default, 12.0. | ||
eps : float, optional | ||
Small value to avoid division by zero. Default, 1e-6. | ||
Flat-bottom loss to penalize structural violations between residues. | ||
This is a loss penalizing any violation of the geometry around the peptide | ||
bond between consecutive amino acids. This loss corresponds to | ||
Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 44, 45. | ||
Returns: | ||
Dict containing: | ||
* 'c_n_loss_mean': Loss for peptide bond length violations | ||
* 'ca_c_n_loss_mean': Loss for violations of bond angle around C spanned | ||
by CA, C, N | ||
* 'c_n_ca_loss_mean': Loss for violations of bond angle around N spanned | ||
by C, N, CA | ||
* 'per_residue_loss_sum': sum of all losses for each residue | ||
* 'per_residue_violation_mask': mask denoting all residues with violation | ||
present. | ||
""" | ||
error_target_0 = ADJACENT_RESIDUE_PSI_COSINE[0] | ||
|
||
error_target_1 = [0.014, 0.016][0] | ||
error_target_3 = ADJACENT_RESIDUE_PHI_COSINE[0] | ||
error_target_4 = ADJACENT_RESIDUE_PHI_COSINE[1] | ||
|
||
# The C-N bond to proline has slightly different length because of the ring. | ||
next_is_proline = ( | ||
amino_acid[..., 1:] | ||
== {k: v for v, k in enumerate([*AMINO_ACID_3, "UNK"])}["PRO"] | ||
) | ||
|
||
gt_length = _gt_length(next_is_proline) | ||
# Get the positions of the relevant backbone atoms. | ||
this_ca_pos = pred_atom_positions[..., :-1, 1, :] | ||
this_ca_mask = pred_atom_mask[..., :-1, 1] | ||
this_c_pos = pred_atom_positions[..., :-1, 2, :] | ||
this_c_mask = pred_atom_mask[..., :-1, 2] | ||
next_n_pos = pred_atom_positions[..., 1:, 0, :] | ||
next_n_mask = pred_atom_mask[..., 1:, 0] | ||
next_ca_pos = pred_atom_positions[..., 1:, 1, :] | ||
next_ca_mask = pred_atom_mask[..., 1:, 1] | ||
|
||
has_no_gap_mask = residue_index[..., 1:] - residue_index[..., :-1] == 1.0 | ||
|
||
bond_length_0 = _bond_length(next_n_pos, this_c_pos) | ||
|
||
unit_vector_0 = (next_n_pos - this_c_pos) / bond_length_0[..., None] | ||
|
||
error_0 = _error(bond_length_0, gt_length) | ||
loss_per_residue_0 = _loss_per_residue( | ||
error_0, _gt_stddev(next_is_proline), tolerance_factor_soft | ||
) | ||
loss_0 = _loss(loss_per_residue_0, this_c_mask * next_n_mask * has_no_gap_mask) | ||
|
||
bond_length_1 = _bond_length(this_c_pos, this_ca_pos) | ||
bond_length_2 = _bond_length(next_ca_pos, next_n_pos) | ||
|
||
ca_c_n_cos_angle = torch.sum( | ||
(this_ca_pos - this_c_pos) / bond_length_1[..., None] * unit_vector_0, dim=-1 | ||
) | ||
|
||
error_1 = _error(ca_c_n_cos_angle, error_target_0) | ||
loss_per_residue_1 = _loss_per_residue( | ||
error_1, error_target_1, tolerance_factor_soft | ||
) | ||
loss_1 = _loss( | ||
loss_per_residue_1, this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask | ||
) | ||
|
||
c_n_ca_cos_angle = _c_n_ca_cos_angle( | ||
unit_vector_0, (next_ca_pos - next_n_pos) / bond_length_2[..., None] | ||
) | ||
error_2 = _error(c_n_ca_cos_angle, error_target_3) | ||
|
||
loss_per_residue_2 = _loss_per_residue( | ||
error_2, error_target_4, tolerance_factor_soft | ||
) | ||
loss_2 = _loss( | ||
loss_per_residue_2, this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask | ||
) | ||
|
||
per_residue_loss_sum = _per_residue_loss_sum( | ||
loss_per_residue_2, loss_per_residue_0, loss_per_residue_1 | ||
) | ||
|
||
violation_mask = _per_residue_violation_mask( | ||
[error_0, error_2, error_1], | ||
error_target_4, | ||
this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask, | ||
tolerance_factor_hard, | ||
) | ||
|
||
return { | ||
"c_n_loss_mean": loss_0, | ||
"ca_c_n_loss_mean": loss_1, | ||
"c_n_ca_loss_mean": loss_2, | ||
"per_residue_loss_sum": per_residue_loss_sum, | ||
"per_residue_violation_mask": violation_mask, | ||
} | ||
|
||
|
||
def _c_n_ca_cos_angle(input, other): | ||
return torch.sum(-input * other, dim=-1) | ||
|
||
|
||
def _gt_stddev(input): | ||
a = [0.014, 0.016][0] | ||
b = [0.014, 0.016][1] | ||
return ~input * a + input * b | ||
|
||
|
||
def _gt_length(input): | ||
a = [1.329, 1.341][0] | ||
b = [1.329, 1.341][1] | ||
|
||
return ~input * a + input * b | ||
|
||
|
||
def _per_residue_violation_mask(inputs: Sequence[Tensor], target, mask, temperature): | ||
output = [] | ||
|
||
for input in inputs: | ||
output = [*output, ((input > target * temperature) * mask)] | ||
|
||
output = torch.max(torch.stack(output, dim=-2), dim=-2)[0] | ||
|
||
x = torch.nn.functional.pad(output, [0, 1]) | ||
y = torch.nn.functional.pad(output, [1, 0]) | ||
|
||
return torch.maximum(x, y) | ||
|
||
|
||
def _bond_length(input, other): | ||
output = torch.sum((other - input) ** 2, dim=-1) | ||
|
||
return torch.sqrt(output + torch.finfo(input.dtype).eps) | ||
|
||
|
||
def _loss(input, mask): | ||
output = torch.sum(input * mask, dim=-1) | ||
|
||
return output / (torch.sum(mask, dim=-1) + torch.finfo(input.dtype).eps) | ||
|
||
|
||
def _error(input, target): | ||
return torch.sqrt((input - target) ** 2 + torch.finfo(input.dtype).eps) | ||
|
||
|
||
def _loss_per_residue(input, target, temperature): | ||
return torch.nn.functional.relu(input - target * temperature) | ||
|
||
|
||
def _per_residue_loss_sum(a, b, c): | ||
""" | ||
Compute a per residue loss (equally distribute the loss to both | ||
neighbouring residues. | ||
""" | ||
output = a + b + c | ||
|
||
x = torch.nn.functional.pad(output, [0, 1]) | ||
y = torch.nn.functional.pad(output, [1, 0]) | ||
|
||
return 0.5 * (x + y) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import torch | ||
|
||
|
||
def global_distance_test(input, other, mask, cutoffs): | ||
n = torch.sum(mask, dim=-1) | ||
|
||
input = input.float() | ||
other = other.float() | ||
|
||
distances = torch.sqrt(torch.sum((input - other) ** 2, dim=-1)) | ||
|
||
scores = [] | ||
|
||
for c in cutoffs: | ||
score = torch.sum((distances <= c) * mask, dim=-1) / n | ||
score = torch.mean(score) | ||
scores.append(score) | ||
|
||
return sum(scores) / len(scores) |