From fb4aa2703b00a97039cb19dc063e3c854c1d5af1 Mon Sep 17 00:00:00 2001 From: "askerosted@gmail.com" Date: Tue, 14 May 2024 11:24:34 +0900 Subject: [PATCH] add RMSEVonMisesFisher3DLoss --- src/graphnet/training/loss_functions.py | 54 +++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/src/graphnet/training/loss_functions.py b/src/graphnet/training/loss_functions.py index 624a5fa53..30bde11a0 100644 --- a/src/graphnet/training/loss_functions.py +++ b/src/graphnet/training/loss_functions.py @@ -21,6 +21,7 @@ from graphnet.models.model import Model from graphnet.utilities.decorators import final +import importlib class LossFunction(Model): @@ -443,3 +444,56 @@ def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: kappa = prediction[:, 3] p = kappa.unsqueeze(1) * prediction[:, [0, 1, 2]] return self._evaluate(p, target) + + +# class LossCombiner(LossFunction): +# """Combine multiple loss functions into a single loss function.""" + +# def __init__(self, loss_functions: List[str], **kwargs: Any) -> None: +# """Construct `LossCombiner`.""" + +# super().__init__(**kwargs) +# self._loss_functions = [] +# for loss_function in loss_functions: +# loss = importlib.import_module(f"graphnet.training.loss_functions.{loss_function}") +# self._loss_functions.append(loss) + + +# def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: +# """Calculate combined loss.""" +# for count, loss in enumerate(self._loss_functions): +# if count == 0: +# elements = loss.forward(prediction, target) +# else: +# elements += loss.forward(prediction, target) +# return elements + + +class RMSEVonMisesFisher3DLoss(VonMisesFisherLoss): + """von Mises-Fisher loss function vectors in the 3D plane.""" + + def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: + """Calculate von Mises-Fisher loss for a direction in the 3D. + + Args: + prediction: Output of the model. Must have shape [N, 4] where + columns 0, 1, 2 are predictions of `direction` and last column + is an estimate of `kappa`. + target: Target tensor, extracted from graph object. + + Returns: + Elementwise von Mises-Fisher loss terms. Shape [N,] + """ + target = target.reshape(-1, 3) + # Check(s) + assert prediction.dim() == 2 and prediction.size()[1] == 4 + assert target.dim() == 2 + assert prediction.size()[0] == target.size()[0] + + kappa = prediction[:, 3] + p = kappa.unsqueeze(1) * prediction[:, [0, 1, 2]] + elements = 0.05 * self._evaluate(p, target) + elements += torch.sqrt( + torch.mean((prediction[:, :-1] - target) ** 2, dim=-1) + ) + return elements