Skip to content

Commit

Permalink
frame_aligned_point_error
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00b1 committed May 1, 2024
1 parent 724df21 commit 61774f6
Showing 1 changed file with 87 additions and 34 deletions.
121 changes: 87 additions & 34 deletions src/beignet/nn/functional/_frame_aligned_point_error.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# ruff: noqa: E501
from typing import Tuple

import torch
from torch import Tensor
Expand All @@ -7,60 +7,113 @@


def frame_aligned_point_error(
input: (Tensor, Tensor, Tensor),
target: (Tensor, Tensor, Tensor),
mask: (Tensor, Tensor),
length_scale: float,
pair_mask: Tensor | None = None,
maximum: float | None = None,
epsilon=1e-8,
input: Tuple[Tuple[Tensor, Tensor], Tensor],
target: Tuple[Tuple[Tensor, Tensor], Tensor],
mask: Tuple[Tensor, Tensor],
z: float,
mp: Tensor | None = None,
wk: float | None = None,
) -> Tensor:
"""
r"""
Score a set of predicted atom coordinates, $\left\{\vec{x}_{j}\right\}$,
under a set of predicted local frames, $\left\{T_{i}\right\}$, against the
corresponding target atom coordinates,
$\left\{\vec{x}_{i}^{\mathrm{True}}\right\}$, and target local frames,
$\left\{T_{i}^{\mathrm{True}}\right\}$. All atoms in all backbone and side
chain frames are scored.
Additionally, a cheaper version (scoring only all $C_\alpha$ atoms in all
backbone frames) is used as an auxiliary loss in every layer of the
AlphaFold structure module.
In order to formulate the loss the atom position $\vec{x}_{j}$ is computed
relative to frame $T_{i}$ and the location of the corresponding true atom
position $\vec{x}_{j}^{\mathrm{True}}$ relative to the true frame
$T_{i}^{\mathrm{True}}$. The deviation is computed as a robust L2 norm
($\epsilon$ is a small constant added to ensure that gradients are
numerically well behaved for small differences. The exact value of this
constant does not matter, as long as it is small enough.
The $N_{\mathrm{frames}} \times N_{\mathrm{atoms}}$ deviations are
penalized with a clamped L1 loss with a length scale, $Z = 10\text{\r{A}}$,
to make the loss unitless.
Parameters
----------
input : (Tensor, Tensor, Tensor)
A $3$-tuple of rotation matrices, translations, and positions. The
rotation matrices must have the shape $(\\ldots, 3, 3)$, the
translations must have the shape $(\\ldots, 3)$, and the positions must
have the shape $(\\ldots, \text{points}, 3)$.
target : (Tensor, Tensor, Tensor)
A $3$-tuple of target rotation matrices, translations, and positions.
The rotation matrices must have the shape $(\\ldots, 3, 3)$, the
translations must have the shape $(\\ldots, 3)$, and the positions must
have the shape $(\\ldots, \text{points}, 3)$.
input : Tensor, (Tensor, Tensor)
A pair of predicted atom coordinates, $\left\{\vec{x}_{j}\right\}$, and
predicted local frames, $\left\{T_{i}\right\}$. A frame is represented
as a pair of rotation matrices and corresponding translations. The
predicted atom positions must have the shape
$(\\ldots, \text{points}, 3)$, the rotation matrices must have the
shape $(\\ldots, 3, 3)$, and the translations must have the shape
$(\\ldots, 3)$.
target : Tensor, (Tensor, Tensor)
A pair of target atom coordinates,
$\left\{\vec{x}_{i}^{\mathrm{True}}\right\}$, and target local frames,
$\left\{T_{i}^{\mathrm{True}}\right\}$. A frame is represented as a
pair of rotation matrices and corresponding translations. The predicted
atom positions must have the shape $(\\ldots, \text{points}, 3)$, the
rotation matrices must have the shape $(\\ldots, 3, 3)$, and the
translations must have the shape $(\\ldots, 3)$.
mask : (Tensor, Tensor)
[*, N_frames] binary mask for the frames
[..., points], position masks
length_scale : float
z : float
Length scale by which the loss is divided
pair_mask : Tensor | None, optional
[*, N_frames, N_pts] mask to use for separating intra- from inter-chain losses.
mp : Tensor | None, optional
[*, N_frames, N_pts] mask to use for separating intra-chain from
inter-chain losses.
maximum : float | None, optional
wk : float | None, optional
Cutoff above which distance errors are disregarded
epsilon : float, optional
Small value used to regularize denominators
Returns
-------
output : Tensor
Losses for each frame of shape $(\\ldots, 3)$.
"""
output = torch.sqrt(torch.sum((beignet.apply_transform(input[1][..., None, :, :], beignet.invert_transform(input[0])) - beignet.apply_transform(target[1][..., None, :, :], beignet.invert_transform(target[0]), )) ** 2, dim=-1) + epsilon) # fmt: off
transform, input = input

target_transform, target = target

epsilon = torch.finfo(input.dtype).eps

input = beignet.apply_transform(
input[..., None, :, :],
beignet.invert_transform(
transform,
),
)

target = beignet.apply_transform(
target[..., None, :, :],
beignet.invert_transform(
target_transform,
),
)

output = torch.sqrt(torch.sum((input - target) ** 2, dim=-1) + epsilon)

if maximum is not None:
output = torch.clamp(output, 0, maximum)
if wk is not None:
output = torch.clamp(output, 0, wk)

output = output / length_scale * mask[0][..., None] * mask[1][..., None, :] # fmt: off
output = output / z * mask[0][..., None] * mask[1][..., None, :]

if pair_mask is not None:
output = torch.sum(output * pair_mask, dim=[-1, -2]) / (torch.sum(mask[0][..., None] * mask[1][..., None, :] * pair_mask, dim=[-2, -1]) + epsilon) # fmt: off
if mp is not None:
output = torch.sum(output * mp, dim=[-1, -2]) / (
torch.sum(mask[0][..., None] * mask[1][..., None, :] * mp, dim=[-2, -1])
+ epsilon
)
else:
output = torch.sum((torch.sum(output, dim=-1) / (torch.sum(mask[0], dim=-1))[..., None] + epsilon), dim=-1) / (torch.sum(mask[1], dim=-1) + epsilon) # fmt: off
output = torch.sum(
torch.sum(output, dim=-1)
/ (torch.sum(mask[0], dim=-1)[..., None] + epsilon),
dim=-1,
) / (torch.sum(mask[1], dim=-1) + epsilon)

return output

0 comments on commit 61774f6

Please sign in to comment.