diff --git a/_modules/graphnet/models/components/embedding.html b/_modules/graphnet/models/components/embedding.html index f73cb60b1..06128b82d 100644 --- a/_modules/graphnet/models/components/embedding.html +++ b/_modules/graphnet/models/components/embedding.html @@ -452,6 +452,7 @@

Source code for f"{n_features} features." ) elif n_features >= 6: + hidden_dim = 6 * seq_length else: hidden_dim = int((n_features + 0.5) * seq_length) diff --git a/_modules/graphnet/models/graphs/graph_definition.html b/_modules/graphnet/models/graphs/graph_definition.html index 5b13ae1db..595450b28 100644 --- a/_modules/graphnet/models/graphs/graph_definition.html +++ b/_modules/graphnet/models/graphs/graph_definition.html @@ -377,7 +377,7 @@

Source code def __init__( self, detector: Detector, - node_definition: NodeDefinition = NodesAsPulses(), + node_definition: NodeDefinition = None, edge_definition: Optional[EdgeDefinition] = None, input_feature_names: Optional[List[str]] = None, dtype: Optional[torch.dtype] = torch.float, @@ -422,6 +422,9 @@

Source code # Base class constructor super().__init__(name=__name__, class_name=self.__class__.__name__) + if node_definition is None: + node_definition = NodesAsPulses() + # Member Variables self._detector = detector self._edge_definition = edge_definition diff --git a/_modules/graphnet/models/graphs/graphs.html b/_modules/graphnet/models/graphs/graphs.html index 55767ca84..b7d2da420 100644 --- a/_modules/graphnet/models/graphs/graphs.html +++ b/_modules/graphnet/models/graphs/graphs.html @@ -408,6 +408,50 @@

Source code for graphn seed=seed, ) + + +
+[docs] +class EdgelessGraph(GraphDefinition): + """A Data representation without edge assignment. + + I.e the resulting representation is created without an EdgeDefinition. + """ + + def __init__( + self, + detector: Detector, + node_definition: NodeDefinition = None, + input_feature_names: Optional[List[str]] = None, + dtype: Optional[torch.dtype] = torch.float, + perturbation_dict: Optional[Dict[str, float]] = None, + seed: Optional[Union[int, Generator]] = None, + ) -> None: + """Construct isolated nodes graph representation. + + Args: + detector: Detector that represents your data. + node_definition: Definition of nodes in the graph. + input_feature_names: Name of input feature columns. + dtype: data type for node features. + perturbation_dict: Dictionary mapping a feature name to a standard + deviation according to which the values for this + feature should be randomly perturbed. Defaults + to None. + seed: seed or Generator used to randomly sample perturbations. + Defaults to None. + """ + # Base class constructor + super().__init__( + detector=detector, + node_definition=node_definition or NodesAsPulses(), + edge_definition=None, + dtype=dtype, + input_feature_names=input_feature_names, + perturbation_dict=perturbation_dict, + seed=seed, + )
+ diff --git a/_modules/graphnet/training/loss_functions.html b/_modules/graphnet/training/loss_functions.html index 2b907c806..331fc0db2 100644 --- a/_modules/graphnet/training/loss_functions.html +++ b/_modules/graphnet/training/loss_functions.html @@ -430,6 +430,8 @@

Source code for gra """Implement loss calculation.""" # Check(s) assert prediction.dim() == 2 + if target.dim() != prediction.dim(): + target = target.squeeze(1) assert prediction.size() == target.size() elements = torch.mean((prediction - target) ** 2, dim=-1) @@ -845,6 +847,102 @@

Source code for gra p = kappa.unsqueeze(1) * prediction[:, [0, 1, 2]] return self._evaluate(p, target) + + +
+[docs] +class EnsembleLoss(LossFunction): + """Chain multiple loss functions together.""" + + def __init__( + self, + loss_functions: List[LossFunction], + loss_factors: List[float] = None, + prediction_keys: Optional[List[List[int]]] = None, + *args: Any, + **kwargs: Any, + ) -> None: + """Chain multiple loss functions together. + + Optionally apply a weight to each loss function contribution. + + E.g. Loss = RMSE*0.5 + LogCoshLoss*1.5 + + Args: + loss_functions: A list of loss functions to use. + Each loss function contributes a term to the overall loss. + loss_factors: An optional list of factors that will be mulitplied + to each loss function contribution. Must be ordered according + to `loss_functions`. If not given, the weights default to 1. + prediction_keys: An optional list of lists of indices for which + prediction columns to use for each loss function. If not + given, all columns are used for all loss functions. + """ + if loss_factors is None: + # add weight of 1 - i.e no discrimination + loss_factors = np.repeat(1, len(loss_functions)).tolist() + + assert len(loss_functions) == len(loss_factors) + self._factors = loss_factors + self._loss_functions = loss_functions + + if prediction_keys is not None: + self._prediction_keys: Optional[List[List[int]]] = prediction_keys + else: + self._prediction_keys = None + super().__init__(*args, **kwargs) + + def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: + """Calculate loss using multiple loss functions. + + Args: + prediction: Output of the model. + target: Target tensor, extracted from graph object. + + Returns: + Elementwise loss terms. Shape [N,] + """ + if self._prediction_keys is None: + prediction_keys = [list(range(prediction.size(1)))] * len( + self._loss_functions + ) + else: + prediction_keys = self._prediction_keys + for k, (loss_function, prediction_key) in enumerate( + zip(self._loss_functions, prediction_keys) + ): + if k == 0: + elements = self._factors[k] * loss_function._forward( + prediction=prediction[:, prediction_key], target=target + ) + else: + elements += self._factors[k] * loss_function._forward( + prediction=prediction[:, prediction_key], target=target + ) + return elements
+ + + +
+[docs] +class RMSEVonMisesFisher3DLoss(EnsembleLoss): + """Combine the VonMisesFisher3DLoss with RMSELoss.""" + + def __init__(self, vmfs_factor: float = 0.05) -> None: + """VonMisesFisher3DLoss with a RMSE penality term. + + The VonMisesFisher3DLoss will be weighted with `vmfs_factor`. + + Args: + vmfs_factor: A factor applied to the VonMisesFisher3DLoss term. + Defaults ot 0.05. + """ + super().__init__( + loss_functions=[RMSELoss(), VonMisesFisher3DLoss()], + loss_factors=[1, vmfs_factor], + prediction_keys=[[0, 1, 2], [0, 1, 2, 3]], + )
+ diff --git a/api/graphnet.models.graphs.graph_definition.html b/api/graphnet.models.graphs.graph_definition.html index bdcff8cfd..a7c2a4bc3 100644 --- a/api/graphnet.models.graphs.graph_definition.html +++ b/api/graphnet.models.graphs.graph_definition.html @@ -600,14 +600,7 @@
Parameters:
@@ -553,6 +562,8 @@
  • graphs
  • @@ -601,6 +612,35 @@
    +
    +
    +class graphnet.models.graphs.graphs.EdgelessGraph(*args, **kwargs)[source]
    +

    Bases: GraphDefinition

    +

    A Data representation without edge assignment.

    +

    I.e the resulting representation is created without an EdgeDefinition.

    +

    Construct isolated nodes graph representation.

    +
    +
    Parameters:
    +
      +
    • detector (Detector) – Detector that represents your data.

    • +
    • node_definition (Optional[NodeDefinition], default: None) – Definition of nodes in the graph.

    • +
    • input_feature_names (Optional[List[str]], default: None) – Name of input feature columns.

    • +
    • dtype (Optional[dtype], default: torch.float32) – data type for node features.

    • +
    • perturbation_dict (Optional[Dict[str, float]], default: None) – Dictionary mapping a feature name to a standard +deviation according to which the values for this +feature should be randomly perturbed. Defaults +to None.

    • +
    • seed (Union[int, Generator, None], default: None) – seed or Generator used to randomly sample perturbations. +Defaults to None.

    • +
    • args (Any)

    • +
    • kwargs (Any)

    • +
    +
    +
    Return type:
    +

    object

    +
    +
    +
    diff --git a/api/graphnet.models.graphs.html b/api/graphnet.models.graphs.html index c515a008c..1aa55999d 100644 --- a/api/graphnet.models.graphs.html +++ b/api/graphnet.models.graphs.html @@ -572,6 +572,7 @@
  • graphs
  • utils
  • utils
  • @@ -559,6 +563,20 @@ VonMisesFisher3DLoss + +
  • + + + EnsembleLoss + + +
  • +
  • + + + RMSEVonMisesFisher3DLoss + +
  • @@ -649,6 +667,10 @@
  • EuclideanDistanceLoss
  • VonMisesFisher3DLoss +
  • +
  • EnsembleLoss +
  • +
  • RMSEVonMisesFisher3DLoss
  • @@ -992,6 +1014,60 @@ +
    +
    +class graphnet.training.loss_functions.EnsembleLoss(*args, **kwargs)[source]
    +

    Bases: LossFunction

    +

    Chain multiple loss functions together.

    +

    Chain multiple loss functions together.

    +
    +

    Optionally apply a weight to each loss function contribution.

    +

    E.g. Loss = RMSE*0.5 + LogCoshLoss*1.5

    +
    +
    +
    Parameters:
    +
      +
    • loss_functions (List[LossFunction]) – A list of loss functions to use. +Each loss function contributes a term to the overall loss.

    • +
    • loss_factors (Optional[List[float]], default: None) – An optional list of factors that will be mulitplied

    • +
    • according (to each loss function contribution. Must be ordered)

    • +
    • given (to loss_functions. If not)

    • +
    • 1. (the weights default to)

    • +
    • prediction_keys (Optional[List[List[int]]], default: None) – An optional list of lists of indices for which +prediction columns to use for each loss function. If not +given, all columns are used for all loss functions.

    • +
    • args (Any)

    • +
    • kwargs (Any)

    • +
    +
    +
    Return type:
    +

    object

    +
    +
    +
    +
    +
    +class graphnet.training.loss_functions.RMSEVonMisesFisher3DLoss(*args, **kwargs)[source]
    +

    Bases: EnsembleLoss

    +

    Combine the VonMisesFisher3DLoss with RMSELoss.

    +

    VonMisesFisher3DLoss with a RMSE penality term.

    +
    +

    The VonMisesFisher3DLoss will be weighted with vmfs_factor.

    +
    +
    +
    Parameters:
    +
      +
    • vmfs_factor (float, default: 0.05) – A factor applied to the VonMisesFisher3DLoss term.

    • +
    • 0.05. (Defaults ot)

    • +
    • args (Any)

    • +
    • kwargs (Any)

    • +
    +
    +
    Return type:
    +

    object

    +
    +
    +
    diff --git a/genindex.html b/genindex.html index 2994875e1..1a626ee27 100644 --- a/genindex.html +++ b/genindex.html @@ -681,6 +681,8 @@

    E

  • EdgeConvTito (class in graphnet.models.components.layers)
  • EdgeDefinition (class in graphnet.models.graphs.edges.edges) +
  • +
  • EdgelessGraph (class in graphnet.models.graphs.graphs)
  • EnergyReconstruction (class in graphnet.models.task.reconstruction)
  • @@ -692,12 +694,14 @@

    E

  • EnsembleDataset (class in graphnet.data.dataset.dataset)
  • -
  • eps_like() (in module graphnet.utilities.maths) +
  • EnsembleLoss (class in graphnet.training.loss_functions)
  • -
  • ERDAHostedDataset (class in graphnet.data.curated_datamodule) +
  • eps_like() (in module graphnet.utilities.maths)
  • - +