From 83997dc01dd53af902af37afc98186293de8ae5e Mon Sep 17 00:00:00 2001 From: Allen Goodman <allen.goodman@icloud.com> Date: Wed, 1 May 2024 16:38:56 -0400 Subject: [PATCH] clash loss --- src/beignet/nn/functional/_clash_loss.py | 82 ++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/src/beignet/nn/functional/_clash_loss.py b/src/beignet/nn/functional/_clash_loss.py index e69de29bb2..579ebfedf5 100644 --- a/src/beignet/nn/functional/_clash_loss.py +++ b/src/beignet/nn/functional/_clash_loss.py @@ -0,0 +1,82 @@ +import torch +from torch import Tensor + + +def clash_loss( + input: Tensor, + target: (Tensor, Tensor), + mask: Tensor, + tighten=0.0, + epsilon=1e-10, +) -> (Tensor, Tensor, Tensor): + r""" + A one-sided flat-bottom-potential, that penalizes steric clashes: + + $$\mathcal{L}_{\text{clash}}=\sum_{i=1}^{N_{\text{non-bonded}}}\max{ + \left(\text{distance }_{\text{Van der Waals radii}}^{i}- + \tau- + \text{distance }_{\text{predicted}}^{i},0\right)},$$ + + where $N_{\text{non-bonded pairs}}$ is the number of all non-bonded atom + pairs, $\text{distance }_{\text{predicted}}^{i}$ is the distance of two + non-bonded atoms in the predicted structure, and + $\text{distance }_{\text{Van der Waals radii}}^{i}$ is the “clashing + distance” of two non-bonded atoms according to their Van der Waals radii. + The tolerance, $\tau$, $1.5\text{\r{A}}$. + + Parameters + ---------- + input : Tensor, shape=(..., N, 14, 3) + Predicted positions of atoms in global prediction frame. + + target : Tensor, shape=(..., N, 14), Tensor, shape=(..., N, 14) + Lower and upper bound on allowed distances. + + mask : Tensor, shape=(..., N, 14) + Mask denoting whether atom at positions exists for given amino acid type. + + tighten : float, optional + Extra factor to tighten loss. Default, 0.0. + + epsilon : float, optional + Small value to avoid division by zero. Default, 1e-10. + + Returns + ------- + output : Tensor, shape=(..., N, 14) + Sum of all clash losses per atom. + + mask : Tensor, shape=(..., N, 14) + Whether atom clashes with any other atom. + + clashes : Tensor, shape=(..., N) + Number of clashes per atom. + """ + distance_mask = torch.eye(14) + distance_mask = distance_mask[None] + distance_mask = 1.0 - distance_mask + shape = [*((1,) * len(mask.shape[:-2])), *distance_mask.shape] + distance_mask = torch.reshape(distance_mask, shape) + distance_mask = distance_mask * mask[..., :, :, None] + distance_mask = distance_mask * mask[..., :, None, :] + + distance = input[..., :, :, None, :] - input[..., :, None, :, :] + distance = torch.sqrt(torch.sum(distance**2, dim=-1) + epsilon) + + a, b = target + + a = torch.nn.functional.relu((a + tighten) - distance) + b = torch.nn.functional.relu(distance - (b - tighten)) + + loss = (a + b) * distance_mask + + violations = ((distance < a) | (distance > b)) * distance_mask + + return ( + torch.sum(loss, dim=-2) + torch.sum(loss, dim=-1), + torch.maximum( + torch.max(violations, dim=-2)[0], + torch.max(violations, dim=-1)[0], + ), + torch.sum(violations, dim=-2) + torch.sum(violations, dim=-1), + )