diff --git a/examples/04_training/06_train_icemix_model.py b/examples/04_training/06_train_icemix_model.py index 6f8b58228..d7b298d76 100644 --- a/examples/04_training/06_train_icemix_model.py +++ b/examples/04_training/06_train_icemix_model.py @@ -1,7 +1,8 @@ """Example of training Model. This example is based on Icemix solution proposed in -https://github.com/DrHB/icecube-2nd-place.git (2nd place solution). +https://github.com/DrHB/icecube-2nd-place.git +(2nd place solution). """ import os @@ -78,9 +79,9 @@ def main( "max_epochs": max_epochs, "distribution_strategy": "ddp_find_unused_parameters_true", }, - "dataset_reference": SQLiteDataset - if path.endswith(".db") - else ParquetDataset, + "dataset_reference": ( + SQLiteDataset if path.endswith(".db") else ParquetDataset + ), } graph_definition = KNNGraph( diff --git a/src/graphnet/data/dataconverter.py b/src/graphnet/data/dataconverter.py index 6bc9e9572..cd04847f4 100644 --- a/src/graphnet/data/dataconverter.py +++ b/src/graphnet/data/dataconverter.py @@ -1,6 +1,6 @@ """Contains `DataConverter`.""" -from typing import List, Union, OrderedDict, Dict, Tuple, Any, Optional, Type +from typing import List, Union, OrderedDict, Dict, Tuple, Any, Optional from abc import ABC from tqdm import tqdm @@ -28,7 +28,8 @@ def init_global_index(index: Synchronized, output_files: List[str]) -> None: """Make `global_index` available to pool workers.""" global global_index, global_output_files # type: ignore[name-defined] - global_index, global_output_files = (index, output_files) # type: ignore[name-defined] + global_index = index # type: ignore[name-defined] + global_output_files = output_files # type: ignore[name-defined] class DataConverter(ABC, Logger): @@ -116,10 +117,9 @@ def _launch_jobs( ) -> None: """Multi Processing Logic. - Spawns worker pool, - distributes the input files evenly across workers. - declare event_no as globally accessible variable across workers. - starts jobs. + Spawns worker pool, distributes the input files evenly across workers. + declare event_no as globally accessible variable across workers. starts + jobs. Will call process_file in parallel. """ @@ -138,8 +138,8 @@ def _launch_jobs( def _process_file(self, file_path: Union[str, I3FileSet]) -> None: """Process a single file. - Calls file reader to recieve extracted output, event ids - is assigned to the extracted data and is handed to save method. + Calls file reader to recieve extracted output, event ids is assigned to + the extracted data and is handed to save method. This function is called in parallel. """ @@ -247,7 +247,8 @@ def _count_rows( n_rows = 1 except ValueError as e: self.error( - f"Features from {extractor_name} ({extractor_dict.keys()}) have different lengths." + f"Features from {extractor_name} ({extractor_dict.keys()}) " + "have different lengths." ) raise e return n_rows @@ -276,7 +277,8 @@ def get_map_function( n_workers = min(self._num_workers, nb_files) if n_workers > 1: self.info( - f"Starting pool of {n_workers} workers to process {nb_files} {unit}" + f"Starting pool of {n_workers} workers to process" + " {nb_files} {unit}" ) manager = Manager() @@ -292,7 +294,8 @@ def get_map_function( else: self.info( - f"Processing {nb_files} {unit} in main thread (not multiprocessing)" + f"Processing {nb_files} {unit} in main thread" + "(not multiprocessing)" ) map_fn = map # type: ignore pool = None diff --git a/src/graphnet/data/extractors/icecube/i3genericextractor.py b/src/graphnet/data/extractors/icecube/i3genericextractor.py index c79b7329b..f41add7d0 100644 --- a/src/graphnet/data/extractors/icecube/i3genericextractor.py +++ b/src/graphnet/data/extractors/icecube/i3genericextractor.py @@ -201,9 +201,9 @@ def _extract_per_pulse_attribute( ) -> Optional[Dict[str, Any]]: """Extract per-pulse attribute `key` from `frame`. - A per-pulse attribute (e.g., dataclasses.I3MapKeyUInt) is a - dictionary- like mapping from an OM key to some attribute, e.g., - an integer or a vector properties. + A per-pulse attribute (e.g., dataclasses.I3MapKeyUInt) is a dictionary- + like mapping from an OM key to some attribute, e.g., an integer or a + vector properties. """ result = self._extract_pulse_series_map(frame, key) @@ -264,12 +264,12 @@ def _flatten_result_mctree( flatten_nested_dictionary(res) for res in result_particles ] - result_primaries_transposed: Dict[ - str, List[Any] - ] = transpose_list_of_dicts(result_primaries) - result_particles_transposed: Dict[ - str, List[Any] - ] = transpose_list_of_dicts(result_particles) + result_primaries_transposed: Dict[str, List[Any]] = ( + transpose_list_of_dicts(result_primaries) + ) + result_particles_transposed: Dict[str, List[Any]] = ( + transpose_list_of_dicts(result_particles) + ) # Remove `majorID`, which has unsupported unit64 dtype. # Keep only one instances of `minorID`. diff --git a/src/graphnet/data/extractors/icecube/utilities/i3_filters.py b/src/graphnet/data/extractors/icecube/utilities/i3_filters.py index ca83f4217..db115bd21 100644 --- a/src/graphnet/data/extractors/icecube/utilities/i3_filters.py +++ b/src/graphnet/data/extractors/icecube/utilities/i3_filters.py @@ -1,4 +1,5 @@ """Filter classes for filtering I3-frames when converting I3-files.""" + from abc import abstractmethod from graphnet.utilities.logging import Logger from typing import List @@ -64,7 +65,7 @@ def _keep_frame(self, frame: "icetray.I3Frame") -> bool: class I3FilterMask(I3Filter): - """checks list of filters from the FilterMask in I3 frames.""" + """Checks list of filters from the FilterMask in I3 frames.""" def __init__(self, filter_names: List[str], filter_any: bool = True): """Initialize I3FilterMask. @@ -95,7 +96,8 @@ def _keep_frame(self, frame: "icetray.I3Frame") -> bool: for filter_name in self._filter_names: if filter_name not in frame["FilterMask"]: self.warning_once( - f"FilterMask {filter_name} not found in frame. skipping filter." + f"FilterMask {filter_name} not found in frame. " + "Skipping filter." ) continue elif frame["FilterMask"][filter].condition_passed is True: @@ -104,18 +106,20 @@ def _keep_frame(self, frame: "icetray.I3Frame") -> bool: bool_list.append(False) if len(bool_list) == 0: self.warning_once( - "None of the FilterMask filters found in frame, FilterMask filters will not be applied." + "None of the FilterMask filters found in frame." + "FilterMask filters will not be applied." ) return any(bool_list) or len(bool_list) == 0 else: # Require all filters to pass in order to keep the frame. for filter_name in self._filter_names: if filter_name not in frame["FilterMask"]: self.warning_once( - f"FilterMask {filter_name} not found in frame, skipping filter." + f"FilterMask {filter_name} not found in frame." + "Skipping filter." ) continue elif frame["FilterMask"][filter].condition_passed is True: - continue # current filter passed, continue to next filter + continue # current filter passed, go to next filter else: return ( False # current filter failed so frame is skipped. @@ -123,6 +127,7 @@ def _keep_frame(self, frame: "icetray.I3Frame") -> bool: return True else: self.warning_once( - "FilterMask not found in frame, FilterMask filters will not be applied." + "FilterMask not found in frame." + "FilterMask filters will not be applied." ) return True diff --git a/src/graphnet/models/graphs/edges/edges.py b/src/graphnet/models/graphs/edges/edges.py index 584cf7ad9..d953ae0a0 100644 --- a/src/graphnet/models/graphs/edges/edges.py +++ b/src/graphnet/models/graphs/edges/edges.py @@ -34,7 +34,9 @@ def forward(self, graph: Data) -> Data: @abstractmethod def _construct_edges(self, graph: Data) -> Data: - """Construct edges and assign them to graph. I.e. ´graph.edge_index = edge_index´. + """Construct edges and assign them to the graph. + + I.e. ´graph.edge_index = edge_index´. Args: graph: graph without edges @@ -127,16 +129,12 @@ def __init__( self, sigma: float, threshold: float = 0.0, - columns: List[int] = None, + columns: List[int] = [0, 1, 2], ): """Construct `EuclideanEdges`.""" # Base class constructor super().__init__(name=__name__, class_name=self.__class__.__name__) - # Check(s) - if columns is None: - columns = [0, 1, 2] - # Member variable(s) self._sigma = sigma self._threshold = threshold @@ -161,9 +159,7 @@ def _construct_edges(self, graph: Data) -> Data: ) distance_matrix = calculate_distance_matrix(xyz_coords) - affinity_matrix = torch.exp( - -0.5 * distance_matrix**2 / self._sigma**2 - ) + affinity_matrix = torch.exp(-0.5 * distance_matrix**2 / self._sigma**2) # Use softmax to normalise all adjacencies to one for each node exp_row_sums = torch.exp(affinity_matrix).sum(axis=1) diff --git a/src/graphnet/training/loss_functions.py b/src/graphnet/training/loss_functions.py index d3fc43f7e..6d7c27c06 100644 --- a/src/graphnet/training/loss_functions.py +++ b/src/graphnet/training/loss_functions.py @@ -14,7 +14,6 @@ from torch import nn from torch.nn.functional import ( one_hot, - cross_entropy, binary_cross_entropy, softplus, ) @@ -101,7 +100,7 @@ def _log_cosh(cls, x: Tensor) -> Tensor: # pylint: disable=invalid-name """Numerically stable version on log(cosh(x)). Used to avoid `inf` for even moderately large differences. - See [https://github.com/keras-team/keras/blob/v2.6.0/keras/losses.py#L1580-L1617] + See [https://github.com/keras-team/keras/blob/v2.6.0/keras/losses.py#L1580-L1617] # noqa: E501 """ return x + softplus(-2.0 * x) - np.log(2.0) @@ -213,30 +212,31 @@ class LogCMK(torch.autograd.Function): Copyright (c) 2019 Max Ryabinin - Permission is hereby granted, free of charge, to any person obtaining a copy - of this software and associated documentation files (the "Software"), to deal - in the Software without restriction, including without limitation the rights - to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - copies of the Software, and to permit persons to whom the Software is - furnished to do so, subject to the following conditions: + Permission is hereby granted, free of charge, to any person obtaining a + copy of this software and associated documentation files (the "Software"), + to deal in the Software without restriction, including without limitation + the rights to use, copy, modify, merge, publish, distribute, sublicense, + and/or sell copies of the Software, and to permit persons to whom the + Software is furnished to do so, subject to the following conditions: - The above copyright notice and this permission notice shall be included in all - copies or substantial portions of the Software. + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - SOFTWARE. + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + DEALINGS IN THE SOFTWARE. _____________________ - From [https://github.com/mryab/vmf_loss/blob/master/losses.py] - Modified to use modified Bessel function instead of exponentially scaled ditto - (i.e. `.ive` -> `.iv`) as indiciated in [1812.04616] in spite of suggestion in - Sec. 8.2 of this paper. The change has been validated through comparison with - exact calculations for `m=2` and `m=3` and found to yield the correct results. + From [https://github.com/mryab/vmf_loss/blob/master/losses.py] Modified to + use modified Bessel function instead of exponentially scaled ditto + (i.e. `.ive` -> `.iv`) as indicated in [1812.04616] in spite of suggestion + in Sec. 8.2 of this paper. The change has been validated through comparison + with exact calculations for `m=2` and `m=3` and found to yield the correct + results. """ @staticmethod @@ -358,7 +358,7 @@ def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: class VonMisesFisher2DLoss(VonMisesFisherLoss): - """von Mises-Fisher loss function vectors in the 2D plane.""" + """Von Mises-Fisher loss function vectors in the 2D plane.""" def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: """Calculate von Mises-Fisher loss for an angle in the 2D plane. @@ -422,7 +422,7 @@ def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: class VonMisesFisher3DLoss(VonMisesFisherLoss): - """von Mises-Fisher loss function vectors in the 3D plane.""" + """Von Mises-Fisher loss function vectors in the 3D plane.""" def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: """Calculate von Mises-Fisher loss for a direction in the 3D. @@ -453,7 +453,7 @@ class EnsembleLoss(LossFunction): def __init__( self, loss_functions: List[LossFunction], - loss_factors: List[float] = None, + loss_factors: Optional[List[float]] = None, prediction_keys: Optional[List[List[int]]] = None, *args: Any, **kwargs: Any,