Skip to content

Commit

Permalink
merge from main
Browse files Browse the repository at this point in the history
  • Loading branch information
Aske-Rosted committed May 28, 2024
2 parents 6fcd7d2 + fc5d955 commit 9cc12f0
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/graphnet/models/graphs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@


from .graph_definition import GraphDefinition
from .graphs import KNNGraph
from .graphs import KNNGraph, EdgelessGraph
5 changes: 4 additions & 1 deletion src/graphnet/models/graphs/graph_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class GraphDefinition(Model):
def __init__(
self,
detector: Detector,
node_definition: NodeDefinition = NodesAsPulses(),
node_definition: NodeDefinition = None,
edge_definition: Optional[EdgeDefinition] = None,
input_feature_names: Optional[List[str]] = None,
dtype: Optional[torch.dtype] = torch.float,
Expand Down Expand Up @@ -69,6 +69,9 @@ def __init__(
# Base class constructor
super().__init__(name=__name__, class_name=self.__class__.__name__)

if node_definition is None:
node_definition = NodesAsPulses()

# Member Variables
self._detector = detector
self._edge_definition = edge_definition
Expand Down
41 changes: 41 additions & 0 deletions src/graphnet/models/graphs/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,44 @@ def __init__(
perturbation_dict=perturbation_dict,
seed=seed,
)


class EdgelessGraph(GraphDefinition):
"""A Data representation without edge assignment.
I.e the resulting representation is created without an EdgeDefinition.
"""

def __init__(
self,
detector: Detector,
node_definition: NodeDefinition = None,
input_feature_names: Optional[List[str]] = None,
dtype: Optional[torch.dtype] = torch.float,
perturbation_dict: Optional[Dict[str, float]] = None,
seed: Optional[Union[int, Generator]] = None,
) -> None:
"""Construct isolated nodes graph representation.
Args:
detector: Detector that represents your data.
node_definition: Definition of nodes in the graph.
input_feature_names: Name of input feature columns.
dtype: data type for node features.
perturbation_dict: Dictionary mapping a feature name to a standard
deviation according to which the values for this
feature should be randomly perturbed. Defaults
to None.
seed: seed or Generator used to randomly sample perturbations.
Defaults to None.
"""
# Base class constructor
super().__init__(
detector=detector,
node_definition=node_definition or NodesAsPulses(),
edge_definition=None,
dtype=dtype,
input_feature_names=input_feature_names,
perturbation_dict=perturbation_dict,
seed=seed,
)
92 changes: 92 additions & 0 deletions src/graphnet/training/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def _forward(self, prediction: Tensor, target: Tensor) -> Tensor:
"""Implement loss calculation."""
# Check(s)
assert prediction.dim() == 2
if target.dim() != prediction.dim():
target = target.squeeze(1)
assert prediction.size() == target.size()

elements = torch.mean((prediction - target) ** 2, dim=-1)
Expand Down Expand Up @@ -443,3 +445,93 @@ def _forward(self, prediction: Tensor, target: Tensor) -> Tensor:
kappa = prediction[:, 3]
p = kappa.unsqueeze(1) * prediction[:, [0, 1, 2]]
return self._evaluate(p, target)


class EnsembleLoss(LossFunction):
"""Chain multiple loss functions together."""

def __init__(
self,
loss_functions: List[LossFunction],
loss_factors: List[float] = None,
prediction_keys: Optional[List[List[int]]] = None,
*args: Any,
**kwargs: Any,
) -> None:
"""Chain multiple loss functions together.
Optionally apply a weight to each loss function contribution.
E.g. Loss = RMSE*0.5 + LogCoshLoss*1.5
Args:
loss_functions: A list of loss functions to use.
Each loss function contributes a term to the overall loss.
loss_factors: An optional list of factors that will be mulitplied
to each loss function contribution. Must be ordered according
to `loss_functions`. If not given, the weights default to 1.
prediction_keys: An optional list of lists of indices for which
prediction columns to use for each loss function. If not
given, all columns are used for all loss functions.
"""
if loss_factors is None:
# add weight of 1 - i.e no discrimination
loss_factors = np.repeat(1, len(loss_functions)).tolist()

assert len(loss_functions) == len(loss_factors)
self._factors = loss_factors
self._loss_functions = loss_functions

if prediction_keys is not None:
self._prediction_keys: Optional[List[List[int]]] = prediction_keys
else:
self._prediction_keys = None
super().__init__(*args, **kwargs)

def _forward(self, prediction: Tensor, target: Tensor) -> Tensor:
"""Calculate loss using multiple loss functions.
Args:
prediction: Output of the model.
target: Target tensor, extracted from graph object.
Returns:
Elementwise loss terms. Shape [N,]
"""
if self._prediction_keys is None:
prediction_keys = [list(range(prediction.size(1)))] * len(
self._loss_functions
)
else:
prediction_keys = self._prediction_keys
for k, (loss_function, prediction_key) in enumerate(
zip(self._loss_functions, prediction_keys)
):
if k == 0:
elements = self._factors[k] * loss_function._forward(
prediction=prediction[:, prediction_key], target=target
)
else:
elements += self._factors[k] * loss_function._forward(
prediction=prediction[:, prediction_key], target=target
)
return elements


class RMSEVonMisesFisher3DLoss(EnsembleLoss):
"""Combine the VonMisesFisher3DLoss with RMSELoss."""

def __init__(self, vmfs_factor: float = 0.05) -> None:
"""VonMisesFisher3DLoss with a RMSE penality term.
The VonMisesFisher3DLoss will be weighted with `vmfs_factor`.
Args:
vmfs_factor: A factor applied to the VonMisesFisher3DLoss term.
Defaults ot 0.05.
"""
super().__init__(
loss_functions=[RMSELoss(), VonMisesFisher3DLoss()],
loss_factors=[1, vmfs_factor],
prediction_keys=[[0, 1, 2], [0, 1, 2, 3]],
)

0 comments on commit 9cc12f0

Please sign in to comment.