Skip to content

Commit

Permalink
docformatter --all-files
Browse files Browse the repository at this point in the history
  • Loading branch information
RasmusOrsoe committed Nov 2, 2024
1 parent 8183405 commit 4200a43
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 60 deletions.
9 changes: 5 additions & 4 deletions examples/04_training/06_train_icemix_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Example of training Model.
This example is based on Icemix solution proposed in
https://github.com/DrHB/icecube-2nd-place.git (2nd place solution).
https://github.com/DrHB/icecube-2nd-place.git
(2nd place solution).
"""

import os
Expand Down Expand Up @@ -78,9 +79,9 @@ def main(
"max_epochs": max_epochs,
"distribution_strategy": "ddp_find_unused_parameters_true",
},
"dataset_reference": SQLiteDataset
if path.endswith(".db")
else ParquetDataset,
"dataset_reference": (
SQLiteDataset if path.endswith(".db") else ParquetDataset
),
}

graph_definition = KNNGraph(
Expand Down
25 changes: 14 additions & 11 deletions src/graphnet/data/dataconverter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Contains `DataConverter`."""

from typing import List, Union, OrderedDict, Dict, Tuple, Any, Optional, Type
from typing import List, Union, OrderedDict, Dict, Tuple, Any, Optional
from abc import ABC

from tqdm import tqdm
Expand Down Expand Up @@ -28,7 +28,8 @@
def init_global_index(index: Synchronized, output_files: List[str]) -> None:
"""Make `global_index` available to pool workers."""
global global_index, global_output_files # type: ignore[name-defined]
global_index, global_output_files = (index, output_files) # type: ignore[name-defined]
global_index = index # type: ignore[name-defined]
global_output_files = output_files # type: ignore[name-defined]


class DataConverter(ABC, Logger):
Expand Down Expand Up @@ -116,10 +117,9 @@ def _launch_jobs(
) -> None:
"""Multi Processing Logic.
Spawns worker pool,
distributes the input files evenly across workers.
declare event_no as globally accessible variable across workers.
starts jobs.
Spawns worker pool, distributes the input files evenly across workers.
declare event_no as globally accessible variable across workers. starts
jobs.
Will call process_file in parallel.
"""
Expand All @@ -138,8 +138,8 @@ def _launch_jobs(
def _process_file(self, file_path: Union[str, I3FileSet]) -> None:
"""Process a single file.
Calls file reader to recieve extracted output, event ids
is assigned to the extracted data and is handed to save method.
Calls file reader to recieve extracted output, event ids is assigned to
the extracted data and is handed to save method.
This function is called in parallel.
"""
Expand Down Expand Up @@ -247,7 +247,8 @@ def _count_rows(
n_rows = 1
except ValueError as e:
self.error(
f"Features from {extractor_name} ({extractor_dict.keys()}) have different lengths."
f"Features from {extractor_name} ({extractor_dict.keys()}) "
"have different lengths."
)
raise e
return n_rows
Expand Down Expand Up @@ -276,7 +277,8 @@ def get_map_function(
n_workers = min(self._num_workers, nb_files)
if n_workers > 1:
self.info(
f"Starting pool of {n_workers} workers to process {nb_files} {unit}"
f"Starting pool of {n_workers} workers to process"
" {nb_files} {unit}"
)

manager = Manager()
Expand All @@ -292,7 +294,8 @@ def get_map_function(

else:
self.info(
f"Processing {nb_files} {unit} in main thread (not multiprocessing)"
f"Processing {nb_files} {unit} in main thread"
"(not multiprocessing)"
)
map_fn = map # type: ignore
pool = None
Expand Down
18 changes: 9 additions & 9 deletions src/graphnet/data/extractors/icecube/i3genericextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,9 @@ def _extract_per_pulse_attribute(
) -> Optional[Dict[str, Any]]:
"""Extract per-pulse attribute `key` from `frame`.
A per-pulse attribute (e.g., dataclasses.I3MapKeyUInt) is a
dictionary- like mapping from an OM key to some attribute, e.g.,
an integer or a vector properties.
A per-pulse attribute (e.g., dataclasses.I3MapKeyUInt) is a dictionary-
like mapping from an OM key to some attribute, e.g., an integer or a
vector properties.
"""
result = self._extract_pulse_series_map(frame, key)

Expand Down Expand Up @@ -264,12 +264,12 @@ def _flatten_result_mctree(
flatten_nested_dictionary(res) for res in result_particles
]

result_primaries_transposed: Dict[
str, List[Any]
] = transpose_list_of_dicts(result_primaries)
result_particles_transposed: Dict[
str, List[Any]
] = transpose_list_of_dicts(result_particles)
result_primaries_transposed: Dict[str, List[Any]] = (
transpose_list_of_dicts(result_primaries)
)
result_particles_transposed: Dict[str, List[Any]] = (
transpose_list_of_dicts(result_particles)
)

# Remove `majorID`, which has unsupported unit64 dtype.
# Keep only one instances of `minorID`.
Expand Down
17 changes: 11 additions & 6 deletions src/graphnet/data/extractors/icecube/utilities/i3_filters.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Filter classes for filtering I3-frames when converting I3-files."""

from abc import abstractmethod
from graphnet.utilities.logging import Logger
from typing import List
Expand Down Expand Up @@ -64,7 +65,7 @@ def _keep_frame(self, frame: "icetray.I3Frame") -> bool:


class I3FilterMask(I3Filter):
"""checks list of filters from the FilterMask in I3 frames."""
"""Checks list of filters from the FilterMask in I3 frames."""

def __init__(self, filter_names: List[str], filter_any: bool = True):
"""Initialize I3FilterMask.
Expand Down Expand Up @@ -95,7 +96,8 @@ def _keep_frame(self, frame: "icetray.I3Frame") -> bool:
for filter_name in self._filter_names:
if filter_name not in frame["FilterMask"]:
self.warning_once(
f"FilterMask {filter_name} not found in frame. skipping filter."
f"FilterMask {filter_name} not found in frame. "
"Skipping filter."
)
continue
elif frame["FilterMask"][filter].condition_passed is True:
Expand All @@ -104,25 +106,28 @@ def _keep_frame(self, frame: "icetray.I3Frame") -> bool:
bool_list.append(False)
if len(bool_list) == 0:
self.warning_once(
"None of the FilterMask filters found in frame, FilterMask filters will not be applied."
"None of the FilterMask filters found in frame."
"FilterMask filters will not be applied."
)
return any(bool_list) or len(bool_list) == 0
else: # Require all filters to pass in order to keep the frame.
for filter_name in self._filter_names:
if filter_name not in frame["FilterMask"]:
self.warning_once(
f"FilterMask {filter_name} not found in frame, skipping filter."
f"FilterMask {filter_name} not found in frame."
"Skipping filter."
)
continue
elif frame["FilterMask"][filter].condition_passed is True:
continue # current filter passed, continue to next filter
continue # current filter passed, go to next filter
else:
return (
False # current filter failed so frame is skipped.
)
return True
else:
self.warning_once(
"FilterMask not found in frame, FilterMask filters will not be applied."
"FilterMask not found in frame."
"FilterMask filters will not be applied."
)
return True
14 changes: 5 additions & 9 deletions src/graphnet/models/graphs/edges/edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ def forward(self, graph: Data) -> Data:

@abstractmethod
def _construct_edges(self, graph: Data) -> Data:
"""Construct edges and assign them to graph. I.e. ´graph.edge_index = edge_index´.
"""Construct edges and assign them to the graph.
I.e. ´graph.edge_index = edge_index´.
Args:
graph: graph without edges
Expand Down Expand Up @@ -127,16 +129,12 @@ def __init__(
self,
sigma: float,
threshold: float = 0.0,
columns: List[int] = None,
columns: List[int] = [0, 1, 2],
):
"""Construct `EuclideanEdges`."""
# Base class constructor
super().__init__(name=__name__, class_name=self.__class__.__name__)

# Check(s)
if columns is None:
columns = [0, 1, 2]

# Member variable(s)
self._sigma = sigma
self._threshold = threshold
Expand All @@ -161,9 +159,7 @@ def _construct_edges(self, graph: Data) -> Data:
)

distance_matrix = calculate_distance_matrix(xyz_coords)
affinity_matrix = torch.exp(
-0.5 * distance_matrix**2 / self._sigma**2
)
affinity_matrix = torch.exp(-0.5 * distance_matrix**2 / self._sigma**2)

# Use softmax to normalise all adjacencies to one for each node
exp_row_sums = torch.exp(affinity_matrix).sum(axis=1)
Expand Down
42 changes: 21 additions & 21 deletions src/graphnet/training/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from torch import nn
from torch.nn.functional import (
one_hot,
cross_entropy,
binary_cross_entropy,
softplus,
)
Expand Down Expand Up @@ -101,7 +100,7 @@ def _log_cosh(cls, x: Tensor) -> Tensor: # pylint: disable=invalid-name
"""Numerically stable version on log(cosh(x)).
Used to avoid `inf` for even moderately large differences.
See [https://github.com/keras-team/keras/blob/v2.6.0/keras/losses.py#L1580-L1617]
See [https://github.com/keras-team/keras/blob/v2.6.0/keras/losses.py#L1580-L1617] # noqa: E501
"""
return x + softplus(-2.0 * x) - np.log(2.0)

Expand Down Expand Up @@ -213,30 +212,31 @@ class LogCMK(torch.autograd.Function):
Copyright (c) 2019 Max Ryabinin
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
_____________________
From [https://github.com/mryab/vmf_loss/blob/master/losses.py]
Modified to use modified Bessel function instead of exponentially scaled ditto
(i.e. `.ive` -> `.iv`) as indiciated in [1812.04616] in spite of suggestion in
Sec. 8.2 of this paper. The change has been validated through comparison with
exact calculations for `m=2` and `m=3` and found to yield the correct results.
From [https://github.com/mryab/vmf_loss/blob/master/losses.py] Modified to
use modified Bessel function instead of exponentially scaled ditto
(i.e. `.ive` -> `.iv`) as indicated in [1812.04616] in spite of suggestion
in Sec. 8.2 of this paper. The change has been validated through comparison
with exact calculations for `m=2` and `m=3` and found to yield the correct
results.
"""

@staticmethod
Expand Down Expand Up @@ -358,7 +358,7 @@ def _forward(self, prediction: Tensor, target: Tensor) -> Tensor:


class VonMisesFisher2DLoss(VonMisesFisherLoss):
"""von Mises-Fisher loss function vectors in the 2D plane."""
"""Von Mises-Fisher loss function vectors in the 2D plane."""

def _forward(self, prediction: Tensor, target: Tensor) -> Tensor:
"""Calculate von Mises-Fisher loss for an angle in the 2D plane.
Expand Down Expand Up @@ -422,7 +422,7 @@ def _forward(self, prediction: Tensor, target: Tensor) -> Tensor:


class VonMisesFisher3DLoss(VonMisesFisherLoss):
"""von Mises-Fisher loss function vectors in the 3D plane."""
"""Von Mises-Fisher loss function vectors in the 3D plane."""

def _forward(self, prediction: Tensor, target: Tensor) -> Tensor:
"""Calculate von Mises-Fisher loss for a direction in the 3D.
Expand Down Expand Up @@ -453,7 +453,7 @@ class EnsembleLoss(LossFunction):
def __init__(
self,
loss_functions: List[LossFunction],
loss_factors: List[float] = None,
loss_factors: Optional[List[float]] = None,
prediction_keys: Optional[List[List[int]]] = None,
*args: Any,
**kwargs: Any,
Expand Down

0 comments on commit 4200a43

Please sign in to comment.