From fb4aa2703b00a97039cb19dc063e3c854c1d5af1 Mon Sep 17 00:00:00 2001 From: "askerosted@gmail.com" Date: Tue, 14 May 2024 11:24:34 +0900 Subject: [PATCH 1/9] add RMSEVonMisesFisher3DLoss --- src/graphnet/training/loss_functions.py | 54 +++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/src/graphnet/training/loss_functions.py b/src/graphnet/training/loss_functions.py index 624a5fa53..30bde11a0 100644 --- a/src/graphnet/training/loss_functions.py +++ b/src/graphnet/training/loss_functions.py @@ -21,6 +21,7 @@ from graphnet.models.model import Model from graphnet.utilities.decorators import final +import importlib class LossFunction(Model): @@ -443,3 +444,56 @@ def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: kappa = prediction[:, 3] p = kappa.unsqueeze(1) * prediction[:, [0, 1, 2]] return self._evaluate(p, target) + + +# class LossCombiner(LossFunction): +# """Combine multiple loss functions into a single loss function.""" + +# def __init__(self, loss_functions: List[str], **kwargs: Any) -> None: +# """Construct `LossCombiner`.""" + +# super().__init__(**kwargs) +# self._loss_functions = [] +# for loss_function in loss_functions: +# loss = importlib.import_module(f"graphnet.training.loss_functions.{loss_function}") +# self._loss_functions.append(loss) + + +# def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: +# """Calculate combined loss.""" +# for count, loss in enumerate(self._loss_functions): +# if count == 0: +# elements = loss.forward(prediction, target) +# else: +# elements += loss.forward(prediction, target) +# return elements + + +class RMSEVonMisesFisher3DLoss(VonMisesFisherLoss): + """von Mises-Fisher loss function vectors in the 3D plane.""" + + def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: + """Calculate von Mises-Fisher loss for a direction in the 3D. + + Args: + prediction: Output of the model. Must have shape [N, 4] where + columns 0, 1, 2 are predictions of `direction` and last column + is an estimate of `kappa`. + target: Target tensor, extracted from graph object. + + Returns: + Elementwise von Mises-Fisher loss terms. Shape [N,] + """ + target = target.reshape(-1, 3) + # Check(s) + assert prediction.dim() == 2 and prediction.size()[1] == 4 + assert target.dim() == 2 + assert prediction.size()[0] == target.size()[0] + + kappa = prediction[:, 3] + p = kappa.unsqueeze(1) * prediction[:, [0, 1, 2]] + elements = 0.05 * self._evaluate(p, target) + elements += torch.sqrt( + torch.mean((prediction[:, :-1] - target) ** 2, dim=-1) + ) + return elements From 169f8fd43bd72da9483a0df543373f6ac00beb9b Mon Sep 17 00:00:00 2001 From: "askerosted@gmail.com" Date: Tue, 14 May 2024 11:56:59 +0900 Subject: [PATCH 2/9] add isolated nodes --- src/graphnet/models/graphs/__init__.py | 2 +- src/graphnet/models/graphs/graphs.py | 38 ++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/src/graphnet/models/graphs/__init__.py b/src/graphnet/models/graphs/__init__.py index ea5066307..6974cdddc 100644 --- a/src/graphnet/models/graphs/__init__.py +++ b/src/graphnet/models/graphs/__init__.py @@ -7,4 +7,4 @@ from .graph_definition import GraphDefinition -from .graphs import KNNGraph +from .graphs import KNNGraph, IsolatedNodes diff --git a/src/graphnet/models/graphs/graphs.py b/src/graphnet/models/graphs/graphs.py index d486bba0a..bdb442f09 100644 --- a/src/graphnet/models/graphs/graphs.py +++ b/src/graphnet/models/graphs/graphs.py @@ -54,3 +54,41 @@ def __init__( perturbation_dict=perturbation_dict, seed=seed, ) + + +class IsolatedNodes(GraphDefinition): + """A Graph representation where each node is isolated.""" + + def __init__( + self, + detector: Detector, + node_definition: NodeDefinition = None, + input_feature_names: Optional[List[str]] = None, + dtype: Optional[torch.dtype] = torch.float, + perturbation_dict: Optional[Dict[str, float]] = None, + seed: Optional[Union[int, Generator]] = None, + ) -> None: + """Construct isolated nodes graph representation. + + Args: + detector: Detector that represents your data. + node_definition: Definition of nodes in the graph. + input_feature_names: Name of input feature columns. + dtype: data type for node features. + perturbation_dict: Dictionary mapping a feature name to a standard + deviation according to which the values for this + feature should be randomly perturbed. Defaults + to None. + seed: seed or Generator used to randomly sample perturbations. + Defaults to None. + """ + # Base class constructor + super().__init__( + detector=detector, + node_definition=node_definition or NodesAsPulses(), + edge_definition=None, + dtype=dtype, + input_feature_names=input_feature_names, + perturbation_dict=perturbation_dict, + seed=seed, + ) From 661111b651d224e5244fbcb31efd5400445496dc Mon Sep 17 00:00:00 2001 From: "askerosted@gmail.com" Date: Tue, 14 May 2024 13:34:45 +0900 Subject: [PATCH 3/9] cleaning --- src/graphnet/training/loss_functions.py | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/src/graphnet/training/loss_functions.py b/src/graphnet/training/loss_functions.py index 30bde11a0..aa646d744 100644 --- a/src/graphnet/training/loss_functions.py +++ b/src/graphnet/training/loss_functions.py @@ -446,29 +446,6 @@ def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: return self._evaluate(p, target) -# class LossCombiner(LossFunction): -# """Combine multiple loss functions into a single loss function.""" - -# def __init__(self, loss_functions: List[str], **kwargs: Any) -> None: -# """Construct `LossCombiner`.""" - -# super().__init__(**kwargs) -# self._loss_functions = [] -# for loss_function in loss_functions: -# loss = importlib.import_module(f"graphnet.training.loss_functions.{loss_function}") -# self._loss_functions.append(loss) - - -# def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: -# """Calculate combined loss.""" -# for count, loss in enumerate(self._loss_functions): -# if count == 0: -# elements = loss.forward(prediction, target) -# else: -# elements += loss.forward(prediction, target) -# return elements - - class RMSEVonMisesFisher3DLoss(VonMisesFisherLoss): """von Mises-Fisher loss function vectors in the 3D plane.""" From 3e3390472b13b3cfacfc84f465dd9cb2bebbea94 Mon Sep 17 00:00:00 2001 From: "askerosted@gmail.com" Date: Tue, 14 May 2024 13:43:50 +0900 Subject: [PATCH 4/9] cleanup --- src/graphnet/training/loss_functions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/graphnet/training/loss_functions.py b/src/graphnet/training/loss_functions.py index aa646d744..b41aaa6f5 100644 --- a/src/graphnet/training/loss_functions.py +++ b/src/graphnet/training/loss_functions.py @@ -21,7 +21,6 @@ from graphnet.models.model import Model from graphnet.utilities.decorators import final -import importlib class LossFunction(Model): From c9eadcc5ec9fddae85bff9962f15195719aa75e2 Mon Sep 17 00:00:00 2001 From: "askerosted@gmail.com" Date: Mon, 20 May 2024 14:09:42 +0900 Subject: [PATCH 5/9] docstring update --- src/graphnet/models/components/embedding.py | 2 +- src/graphnet/models/graphs/__init__.py | 2 +- src/graphnet/models/graphs/graphs.py | 7 +++++-- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/graphnet/models/components/embedding.py b/src/graphnet/models/components/embedding.py index e97ca90e7..1b49cd901 100644 --- a/src/graphnet/models/components/embedding.py +++ b/src/graphnet/models/components/embedding.py @@ -84,7 +84,6 @@ def __init__( super().__init__() self.sin_emb = SinusoidalPosEmb(dim=seq_length, scaled=scaled) - self.aux_emb = nn.Embedding(2, seq_length // 2) self.sin_emb2 = SinusoidalPosEmb(dim=seq_length // 2, scaled=scaled) if n_features < 4: @@ -93,6 +92,7 @@ def __init__( f"{n_features} features." ) elif n_features >= 6: + self.aux_emb = nn.Embedding(2, seq_length // 2) hidden_dim = 6 * seq_length else: hidden_dim = int((n_features + 0.5) * seq_length) diff --git a/src/graphnet/models/graphs/__init__.py b/src/graphnet/models/graphs/__init__.py index 6974cdddc..a07d1308d 100644 --- a/src/graphnet/models/graphs/__init__.py +++ b/src/graphnet/models/graphs/__init__.py @@ -7,4 +7,4 @@ from .graph_definition import GraphDefinition -from .graphs import KNNGraph, IsolatedNodes +from .graphs import KNNGraph, EdgelessGraph diff --git a/src/graphnet/models/graphs/graphs.py b/src/graphnet/models/graphs/graphs.py index bdb442f09..0289b943d 100644 --- a/src/graphnet/models/graphs/graphs.py +++ b/src/graphnet/models/graphs/graphs.py @@ -56,8 +56,11 @@ def __init__( ) -class IsolatedNodes(GraphDefinition): - """A Graph representation where each node is isolated.""" +class EdgelessGraph(GraphDefinition): + """A Data representation without edge assignment. + + I.e the resulting representation is created without an EdgeDefinition. + """ def __init__( self, From b9c3195b48e5586967248660d21d841b7cc624b0 Mon Sep 17 00:00:00 2001 From: "askerosted@gmail.com" Date: Mon, 20 May 2024 14:55:25 +0900 Subject: [PATCH 6/9] large refactor --- src/graphnet/training/loss_functions.py | 93 ++++++++++++++++++++----- 1 file changed, 75 insertions(+), 18 deletions(-) diff --git a/src/graphnet/training/loss_functions.py b/src/graphnet/training/loss_functions.py index b41aaa6f5..6468e5296 100644 --- a/src/graphnet/training/loss_functions.py +++ b/src/graphnet/training/loss_functions.py @@ -445,31 +445,88 @@ def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: return self._evaluate(p, target) -class RMSEVonMisesFisher3DLoss(VonMisesFisherLoss): - """von Mises-Fisher loss function vectors in the 3D plane.""" +class EnsembleLoss(LossFunction): + """Chain multiple loss functions together.""" + + def __init__( + self, + loss_functions: List[LossFunction], + loss_factors: List[float] = None, + prediction_keys: Optional[List[List[int]]] = None, + ) -> None: + """Chain multiple loss functions together. + + Optionally apply a weight to each loss function contribution. + + E.g. Loss = RMSE*0.5 + LogCoshLoss*1.5 + + Args: + loss_functions: A list of loss functions to use. + Each loss function contributes a term to the overall loss. + loss_factors: An optional list of factors that will be mulitplied + to each loss function contribution. Must be ordered according + to `loss_functions`. If not given, the weights default to 1. + prediction_keys: An optional list of lists of indices for which + prediction columns to use for each loss function. If not + given, all columns are used for all loss functions. + """ + if loss_factors is None: + # add weight of 1 - i.e no discrimination + loss_factors = np.repeat(1, len(loss_functions)).tolist() + + assert len(loss_functions) == len(loss_factors) + self._factors = loss_factors + self._loss_functions = loss_functions + + if prediction_keys is not None: + self._prediction_keys: Optional[List[List[int]]] = prediction_keys + else: + self._prediction_keys = None def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: - """Calculate von Mises-Fisher loss for a direction in the 3D. + """Calculate loss using multiple loss functions. Args: - prediction: Output of the model. Must have shape [N, 4] where - columns 0, 1, 2 are predictions of `direction` and last column - is an estimate of `kappa`. + prediction: Output of the model. target: Target tensor, extracted from graph object. Returns: - Elementwise von Mises-Fisher loss terms. Shape [N,] + Elementwise loss terms. Shape [N,] """ - target = target.reshape(-1, 3) - # Check(s) - assert prediction.dim() == 2 and prediction.size()[1] == 4 - assert target.dim() == 2 - assert prediction.size()[0] == target.size()[0] + if self._prediction_keys is None: + prediction_keys = [list(range(prediction.size(1)))] * len( + self._loss_functions + ) + else: + prediction_keys = self._prediction_keys + for k, (loss_function, prediction_key) in enumerate( + zip(self._loss_functions, prediction_keys) + ): + if k == 0: + elements = self._factors[k] * loss_function._forward( + prediction=prediction[prediction_key], target=target + ) + else: + elements += self._factors[k] * loss_function._forward( + prediction=prediction[prediction_key], target=target + ) + return elements - kappa = prediction[:, 3] - p = kappa.unsqueeze(1) * prediction[:, [0, 1, 2]] - elements = 0.05 * self._evaluate(p, target) - elements += torch.sqrt( - torch.mean((prediction[:, :-1] - target) ** 2, dim=-1) + +class RMSEVonMisesFisher3DLoss(EnsembleLoss): + """Combine the VonMisesFisher3DLoss with RMSELoss.""" + + def __init__(self, vmfs_factor: float = 0.05) -> None: + """VonMisesFisher3DLoss with a RMSE penality term. + + The VonMisesFisher3DLoss will be weighted with `vmfs_factor`. + + Args: + vmfs_factor: A factor applied to the VonMisesFisher3DLoss term. + Defaults ot 0.05. + """ + super().__init__( + loss_functions=[RMSELoss(), VonMisesFisher3DLoss()], + loss_factors=[1, vmfs_factor], + prediction_keys=[[0, 1, 2], [0, 1, 2, 3]], ) - return elements From 338814c2fc5bad8081d5f8ace77acb7c0c9246fa Mon Sep 17 00:00:00 2001 From: "askerosted@gmail.com" Date: Mon, 20 May 2024 15:31:01 +0900 Subject: [PATCH 7/9] fixing --- src/graphnet/training/loss_functions.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/graphnet/training/loss_functions.py b/src/graphnet/training/loss_functions.py index 6468e5296..d3fc43f7e 100644 --- a/src/graphnet/training/loss_functions.py +++ b/src/graphnet/training/loss_functions.py @@ -71,6 +71,8 @@ def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: """Implement loss calculation.""" # Check(s) assert prediction.dim() == 2 + if target.dim() != prediction.dim(): + target = target.squeeze(1) assert prediction.size() == target.size() elements = torch.mean((prediction - target) ** 2, dim=-1) @@ -453,6 +455,8 @@ def __init__( loss_functions: List[LossFunction], loss_factors: List[float] = None, prediction_keys: Optional[List[List[int]]] = None, + *args: Any, + **kwargs: Any, ) -> None: """Chain multiple loss functions together. @@ -482,6 +486,7 @@ def __init__( self._prediction_keys: Optional[List[List[int]]] = prediction_keys else: self._prediction_keys = None + super().__init__(*args, **kwargs) def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: """Calculate loss using multiple loss functions. @@ -504,11 +509,11 @@ def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: ): if k == 0: elements = self._factors[k] * loss_function._forward( - prediction=prediction[prediction_key], target=target + prediction=prediction[:, prediction_key], target=target ) else: elements += self._factors[k] * loss_function._forward( - prediction=prediction[prediction_key], target=target + prediction=prediction[:, prediction_key], target=target ) return elements From ba214615cc04a78085ed959dc29ed7e5cb21c5a6 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Mon, 20 May 2024 15:47:27 +0200 Subject: [PATCH 8/9] Change defaulting behavior of GraphDefinition --- src/graphnet/models/graphs/graph_definition.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/graphnet/models/graphs/graph_definition.py b/src/graphnet/models/graphs/graph_definition.py index 6366fc390..e384425f9 100644 --- a/src/graphnet/models/graphs/graph_definition.py +++ b/src/graphnet/models/graphs/graph_definition.py @@ -24,7 +24,7 @@ class GraphDefinition(Model): def __init__( self, detector: Detector, - node_definition: NodeDefinition = NodesAsPulses(), + node_definition: NodeDefinition = None, edge_definition: Optional[EdgeDefinition] = None, input_feature_names: Optional[List[str]] = None, dtype: Optional[torch.dtype] = torch.float, @@ -69,6 +69,9 @@ def __init__( # Base class constructor super().__init__(name=__name__, class_name=self.__class__.__name__) + if node_definition is None: + node_definition = NodesAsPulses() + # Member Variables self._detector = detector self._edge_definition = edge_definition From 96c7c695ed7cb93dd69b258c49cb4d70712c3e1c Mon Sep 17 00:00:00 2001 From: "askerosted@gmail.com" Date: Tue, 21 May 2024 10:49:10 +0900 Subject: [PATCH 9/9] revert embedding changes --- src/graphnet/models/components/embedding.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/graphnet/models/components/embedding.py b/src/graphnet/models/components/embedding.py index 1b49cd901..9539fc444 100644 --- a/src/graphnet/models/components/embedding.py +++ b/src/graphnet/models/components/embedding.py @@ -84,6 +84,7 @@ def __init__( super().__init__() self.sin_emb = SinusoidalPosEmb(dim=seq_length, scaled=scaled) + self.aux_emb = nn.Embedding(2, seq_length // 2) self.sin_emb2 = SinusoidalPosEmb(dim=seq_length // 2, scaled=scaled) if n_features < 4: @@ -92,7 +93,7 @@ def __init__( f"{n_features} features." ) elif n_features >= 6: - self.aux_emb = nn.Embedding(2, seq_length // 2) + hidden_dim = 6 * seq_length else: hidden_dim = int((n_features + 0.5) * seq_length)