diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fd6bae19e..59d6f2cd0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,28 +1,28 @@ exclude: '^(versioneer.py|src/graphnet/_version.py|docs/)' repos: - repo: https://github.com/psf/black - rev: 22.3.0 + rev: 24.10.0 hooks: - id: black language_version: python3 args: [--config=black.toml] - repo: https://github.com/pycqa/flake8 - rev: 4.0.1 + rev: 7.1.1 hooks: - id: flake8 language_version: python3 - repo: https://github.com/pycqa/docformatter - rev: v1.5.0 + rev: v1.7.5 hooks: - id: docformatter language_version: python3 - repo: https://github.com/pycqa/pydocstyle - rev: 6.1.1 + rev: 6.3.0 hooks: - id: pydocstyle language_version: python3 - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.982 + rev: v1.13.0 hooks: - id: mypy args: [--follow-imports=silent, --disallow-untyped-defs, --disallow-incomplete-defs, --disallow-untyped-calls] diff --git a/docker/gnn-benchmarking/apply.py b/docker/gnn-benchmarking/apply.py index 87744b94a..909025bac 100644 --- a/docker/gnn-benchmarking/apply.py +++ b/docker/gnn-benchmarking/apply.py @@ -1,6 +1,5 @@ """Script for applying GraphNeTModule in IceTray chain.""" - import argparse from glob import glob from os import makedirs @@ -37,9 +36,7 @@ def main( # Get GCD file gcd_pattern = "GeoCalibDetector" gcd_candidates = [p for p in input_files if gcd_pattern in p] - assert ( - len(gcd_candidates) == 1 - ), f"Did not get exactly one GCD-file candidate in `{dirname(input_files[0])}: {gcd_candidates}" + assert len(gcd_candidates) == 1, "Did not get exactly one GCD-file " gcd_file = gcd_candidates[0] # Get all input I3-files @@ -78,8 +75,10 @@ def main( """The main function must get an input folder and output folder! Args: - input_folder (str): The input folder where i3 files of a given dataset are located. - output_folder (str): The output folder where processed i3 files will be saved. + input_folder (str): The input folder where i3 files of a + given dataset are located. + output_folder (str): The output folder where processed i3 + files will be saved. """ parser = argparse.ArgumentParser() diff --git a/examples/01_icetray/03_i3_deployer_example.py b/examples/01_icetray/03_i3_deployer_example.py index 28d73c00d..bd2de9f43 100644 --- a/examples/01_icetray/03_i3_deployer_example.py +++ b/examples/01_icetray/03_i3_deployer_example.py @@ -50,7 +50,7 @@ def main() -> None: model_config = f"{base_path}/{model_name}/{model_name}_config.yml" state_dict = f"{base_path}/{model_name}/{model_name}_state_dict.pth" output_folder = f"{EXAMPLE_OUTPUT_DIR}/i3_deployment/upgrade_03_04" - gcd_file = f"{TEST_DATA_DIR}/i3/upgrade_genie_step4_140028_000998/GeoCalibDetectorStatus_ICUpgrade.v58.mixed.V0.i3.bz2" + gcd_file = f"{TEST_DATA_DIR}/i3/upgrade_genie_step4_140028_000998/GeoCalibDetectorStatus_ICUpgrade.v58.mixed.V0.i3.bz2" # noqa: E501 input_files = [] for folder in input_folders: input_files.extend(glob(join(folder, "*.i3.gz"))) diff --git a/examples/01_icetray/04_i3_module_in_native_icetray_example.py b/examples/01_icetray/04_i3_module_in_native_icetray_example.py index 09e9b358e..c4f797b02 100644 --- a/examples/01_icetray/04_i3_module_in_native_icetray_example.py +++ b/examples/01_icetray/04_i3_module_in_native_icetray_example.py @@ -86,7 +86,7 @@ def main() -> None: model_config = f"{base_path}/{model_name}/{model_name}_config.yml" state_dict = f"{base_path}/{model_name}/{model_name}_state_dict.pth" output_folder = f"{EXAMPLE_OUTPUT_DIR}/i3_deployment/upgrade" - gcd_file = f"{TEST_DATA_DIR}/i3/upgrade_genie_step4_140028_000998/GeoCalibDetectorStatus_ICUpgrade.v58.mixed.V0.i3.bz2" + gcd_file = f"{TEST_DATA_DIR}/i3/upgrade_genie_step4_140028_000998/GeoCalibDetectorStatus_ICUpgrade.v58.mixed.V0.i3.bz2" # noqa:E501 features = FEATURES.UPGRADE input_files = [] for folder in input_folders: diff --git a/examples/04_training/01_train_dynedge.py b/examples/04_training/01_train_dynedge.py index c61df789b..6ee6e0223 100644 --- a/examples/04_training/01_train_dynedge.py +++ b/examples/04_training/01_train_dynedge.py @@ -70,9 +70,9 @@ def main( "gpus": gpus, "max_epochs": max_epochs, }, - "dataset_reference": SQLiteDataset - if path.endswith(".db") - else ParquetDataset, + "dataset_reference": ( + SQLiteDataset if path.endswith(".db") else ParquetDataset + ), } archive = os.path.join(EXAMPLE_OUTPUT_DIR, "train_model_without_configs") diff --git a/examples/04_training/02_train_tito_model.py b/examples/04_training/02_train_tito_model.py index 50c5f7fd6..991ed7744 100644 --- a/examples/04_training/02_train_tito_model.py +++ b/examples/04_training/02_train_tito_model.py @@ -72,9 +72,9 @@ def main( "gpus": gpus, "max_epochs": max_epochs, }, - "dataset_reference": SQLiteDataset - if path.endswith(".db") - else ParquetDataset, + "dataset_reference": ( + SQLiteDataset if path.endswith(".db") else ParquetDataset + ), } graph_definition = KNNGraph(detector=Prometheus()) diff --git a/examples/04_training/05_train_RNN_TITO.py b/examples/04_training/05_train_RNN_TITO.py index 8e521bd1d..6f75d0364 100644 --- a/examples/04_training/05_train_RNN_TITO.py +++ b/examples/04_training/05_train_RNN_TITO.py @@ -76,9 +76,9 @@ def main( "gpus": gpus, "max_epochs": max_epochs, }, - "dataset_reference": SQLiteDataset - if path.endswith(".db") - else ParquetDataset, + "dataset_reference": ( + SQLiteDataset if path.endswith(".db") else ParquetDataset + ), } graph_definition = KNNGraph( diff --git a/examples/04_training/06_train_icemix_model.py b/examples/04_training/06_train_icemix_model.py index 6f8b58228..d7b298d76 100644 --- a/examples/04_training/06_train_icemix_model.py +++ b/examples/04_training/06_train_icemix_model.py @@ -1,7 +1,8 @@ """Example of training Model. This example is based on Icemix solution proposed in -https://github.com/DrHB/icecube-2nd-place.git (2nd place solution). +https://github.com/DrHB/icecube-2nd-place.git +(2nd place solution). """ import os @@ -78,9 +79,9 @@ def main( "max_epochs": max_epochs, "distribution_strategy": "ddp_find_unused_parameters_true", }, - "dataset_reference": SQLiteDataset - if path.endswith(".db") - else ParquetDataset, + "dataset_reference": ( + SQLiteDataset if path.endswith(".db") else ParquetDataset + ), } graph_definition = KNNGraph( diff --git a/examples/04_training/07_train_normalizing_flow.py b/examples/04_training/07_train_normalizing_flow.py index baa3eec85..4048e9923 100644 --- a/examples/04_training/07_train_normalizing_flow.py +++ b/examples/04_training/07_train_normalizing_flow.py @@ -4,7 +4,6 @@ from typing import Any, Dict, List, Optional from pytorch_lightning.loggers import WandbLogger -import torch from torch.optim.adam import Adam from graphnet.constants import EXAMPLE_DATA_DIR, EXAMPLE_OUTPUT_DIR diff --git a/examples/05_liquido/01_convert_h5.py b/examples/05_liquido/01_convert_h5.py index ed1569c51..eaa5610fd 100644 --- a/examples/05_liquido/01_convert_h5.py +++ b/examples/05_liquido/01_convert_h5.py @@ -1,4 +1,5 @@ """Example of converting H5 files from LiquidO to SQLite and Parquet.""" + import os from graphnet.constants import EXAMPLE_OUTPUT_DIR, TEST_DATA_DIR diff --git a/examples/06_prometheus/01_convert_prometheus.py b/examples/06_prometheus/01_convert_prometheus.py index 1a6c818df..63b071372 100644 --- a/examples/06_prometheus/01_convert_prometheus.py +++ b/examples/06_prometheus/01_convert_prometheus.py @@ -1,4 +1,5 @@ """Example of converting files from Prometheus to SQLite and Parquet.""" + import os from graphnet.constants import EXAMPLE_OUTPUT_DIR, TEST_DATA_DIR diff --git a/setup.cfg b/setup.cfg index 8873542f4..476acd91c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,6 +17,11 @@ omit = [flake8] exclude = versioneer.py +# Ignore unused imports in __init__ files +per-file-ignores= + __init__.py:F401 + src/graphnet/utilities/imports.py:F401 +ignore=E203,W503 [docformatter] wrap-summaries = 79 diff --git a/setup.py b/setup.py index fa2b71ad2..a0293e7b8 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -# type: ignore[no-untyped-call] +# mypy: disable-error-code="no-untyped-call" """Setup script for the GraphNeT package.""" from setuptools import setup, find_packages @@ -39,7 +39,7 @@ "MarkupSafe<=2.1", "mypy", "myst-parser", - "pre-commit", + "pre-commit<4.0", "pydocstyle", "pylint", "pytest", diff --git a/src/graphnet/__init__.py b/src/graphnet/__init__.py index 34f46afc3..28e4f21f6 100644 --- a/src/graphnet/__init__.py +++ b/src/graphnet/__init__.py @@ -32,4 +32,6 @@ from . import _version -__version__ = _version.get_versions()["version"] # type: ignore[no-untyped-call] +__version__ = _version.get_versions()[ # type: ignore[no-untyped-call] + "version" +] diff --git a/src/graphnet/data/__init__.py b/src/graphnet/data/__init__.py index 512d414d9..bd95b4a2a 100644 --- a/src/graphnet/data/__init__.py +++ b/src/graphnet/data/__init__.py @@ -3,6 +3,7 @@ `graphnet.data` enables converting domain-specific data to industry-standard, intermediate file formats and reading this data. """ + from .extractors.icecube.utilities.i3_filters import I3Filter, I3FilterMask from .dataconverter import DataConverter from .pre_configured import I3ToParquetConverter diff --git a/src/graphnet/data/curated_datamodule.py b/src/graphnet/data/curated_datamodule.py index 63b691c9d..a206783bc 100644 --- a/src/graphnet/data/curated_datamodule.py +++ b/src/graphnet/data/curated_datamodule.py @@ -31,8 +31,8 @@ def __init__( features: Optional[List[str]] = None, backend: str = "parquet", train_dataloader_kwargs: Optional[Dict[str, Any]] = None, - validation_dataloader_kwargs: Dict[str, Any] = None, - test_dataloader_kwargs: Dict[str, Any] = None, + validation_dataloader_kwargs: Optional[Dict[str, Any]] = None, + test_dataloader_kwargs: Optional[Dict[str, Any]] = None, ) -> None: """Construct CuratedDataset. diff --git a/src/graphnet/data/dataconverter.py b/src/graphnet/data/dataconverter.py index 6bc9e9572..cd04847f4 100644 --- a/src/graphnet/data/dataconverter.py +++ b/src/graphnet/data/dataconverter.py @@ -1,6 +1,6 @@ """Contains `DataConverter`.""" -from typing import List, Union, OrderedDict, Dict, Tuple, Any, Optional, Type +from typing import List, Union, OrderedDict, Dict, Tuple, Any, Optional from abc import ABC from tqdm import tqdm @@ -28,7 +28,8 @@ def init_global_index(index: Synchronized, output_files: List[str]) -> None: """Make `global_index` available to pool workers.""" global global_index, global_output_files # type: ignore[name-defined] - global_index, global_output_files = (index, output_files) # type: ignore[name-defined] + global_index = index # type: ignore[name-defined] + global_output_files = output_files # type: ignore[name-defined] class DataConverter(ABC, Logger): @@ -116,10 +117,9 @@ def _launch_jobs( ) -> None: """Multi Processing Logic. - Spawns worker pool, - distributes the input files evenly across workers. - declare event_no as globally accessible variable across workers. - starts jobs. + Spawns worker pool, distributes the input files evenly across workers. + declare event_no as globally accessible variable across workers. starts + jobs. Will call process_file in parallel. """ @@ -138,8 +138,8 @@ def _launch_jobs( def _process_file(self, file_path: Union[str, I3FileSet]) -> None: """Process a single file. - Calls file reader to recieve extracted output, event ids - is assigned to the extracted data and is handed to save method. + Calls file reader to recieve extracted output, event ids is assigned to + the extracted data and is handed to save method. This function is called in parallel. """ @@ -247,7 +247,8 @@ def _count_rows( n_rows = 1 except ValueError as e: self.error( - f"Features from {extractor_name} ({extractor_dict.keys()}) have different lengths." + f"Features from {extractor_name} ({extractor_dict.keys()}) " + "have different lengths." ) raise e return n_rows @@ -276,7 +277,8 @@ def get_map_function( n_workers = min(self._num_workers, nb_files) if n_workers > 1: self.info( - f"Starting pool of {n_workers} workers to process {nb_files} {unit}" + f"Starting pool of {n_workers} workers to process" + " {nb_files} {unit}" ) manager = Manager() @@ -292,7 +294,8 @@ def get_map_function( else: self.info( - f"Processing {nb_files} {unit} in main thread (not multiprocessing)" + f"Processing {nb_files} {unit} in main thread" + "(not multiprocessing)" ) map_fn = map # type: ignore pool = None diff --git a/src/graphnet/data/datamodule.py b/src/graphnet/data/datamodule.py index 802a64a7d..978cdf52a 100644 --- a/src/graphnet/data/datamodule.py +++ b/src/graphnet/data/datamodule.py @@ -1,4 +1,5 @@ """Base `Dataloader` class(es) used in `graphnet`.""" + from typing import Dict, Any, Optional, List, Tuple, Union, Type import pytorch_lightning as pl from copy import deepcopy @@ -26,9 +27,9 @@ def __init__( dataset_args: Dict[str, Any], selection: Optional[Union[List[int], List[List[int]]]] = None, test_selection: Optional[Union[List[int], List[List[int]]]] = None, - train_dataloader_kwargs: Dict[str, Any] = None, - validation_dataloader_kwargs: Dict[str, Any] = None, - test_dataloader_kwargs: Dict[str, Any] = None, + train_dataloader_kwargs: Optional[Dict[str, Any]] = None, + validation_dataloader_kwargs: Optional[Dict[str, Any]] = None, + test_dataloader_kwargs: Optional[Dict[str, Any]] = None, train_val_split: Optional[List[float]] = [0.9, 0.10], split_seed: int = 42, ) -> None: diff --git a/src/graphnet/data/dataset/__init__.py b/src/graphnet/data/dataset/__init__.py index ed1c55ef5..c5d2bbc86 100644 --- a/src/graphnet/data/dataset/__init__.py +++ b/src/graphnet/data/dataset/__init__.py @@ -1,4 +1,5 @@ """Dataset classes for training in GraphNeT.""" + # Configuration from graphnet.utilities.imports import has_torch_package diff --git a/src/graphnet/data/dataset/parquet/__init__.py b/src/graphnet/data/dataset/parquet/__init__.py index edfc62b4e..f1f61b6a6 100644 --- a/src/graphnet/data/dataset/parquet/__init__.py +++ b/src/graphnet/data/dataset/parquet/__init__.py @@ -1,4 +1,5 @@ """Datasets using parquet backend.""" + # Configuration from graphnet.utilities.imports import has_torch_package diff --git a/src/graphnet/data/dataset/samplers.py b/src/graphnet/data/dataset/samplers.py index c43455447..726508805 100644 --- a/src/graphnet/data/dataset/samplers.py +++ b/src/graphnet/data/dataset/samplers.py @@ -34,7 +34,7 @@ ) from collections import defaultdict -from multiprocessing import Pool, cpu_count, get_context +from multiprocessing import get_context import numpy as np import torch diff --git a/src/graphnet/data/dataset/sqlite/__init__.py b/src/graphnet/data/dataset/sqlite/__init__.py index c44d66184..3ea73a068 100644 --- a/src/graphnet/data/dataset/sqlite/__init__.py +++ b/src/graphnet/data/dataset/sqlite/__init__.py @@ -1,4 +1,5 @@ """Datasets using SQLite backend.""" + from graphnet.utilities.imports import has_torch_package if has_torch_package(): diff --git a/src/graphnet/data/extractors/__init__.py b/src/graphnet/data/extractors/__init__.py index ad76d2742..21c5d7e31 100644 --- a/src/graphnet/data/extractors/__init__.py +++ b/src/graphnet/data/extractors/__init__.py @@ -1,3 +1,4 @@ """Module containing data-specific extractor modules.""" + from .extractor import Extractor from .combine_extractors import CombinedExtractor diff --git a/src/graphnet/data/extractors/combine_extractors.py b/src/graphnet/data/extractors/combine_extractors.py index a978e9b72..af89defe2 100644 --- a/src/graphnet/data/extractors/combine_extractors.py +++ b/src/graphnet/data/extractors/combine_extractors.py @@ -1,4 +1,5 @@ """Module for combining multiple extractors into a single extractor.""" + from typing import TYPE_CHECKING from graphnet.utilities.imports import has_icecube_package @@ -20,7 +21,11 @@ def __init__(self, extractors: List[I3Extractor], extractor_name: str): """Construct CombinedExtractor. Args: - extractors: List of extractors to combine. The extractors must all return data on the same level; e.g. all event-level data or pulse-level data. Mixing tables that contain event-level and pulse-level information will fail. + extractors: List of extractors to combine. + The extractors must all return data on the same level; + e.g. all event-level data or pulse-level data. + Mixing tables that contain event-level and + pulse-level information will fail. extractor_name: Name of the new extractor. """ super().__init__(extractor_name=extractor_name) diff --git a/src/graphnet/data/extractors/extractor.py b/src/graphnet/data/extractors/extractor.py index abd00668f..159524fbf 100644 --- a/src/graphnet/data/extractors/extractor.py +++ b/src/graphnet/data/extractors/extractor.py @@ -1,4 +1,5 @@ """Base I3Extractor class(es).""" + from typing import Any, Union from abc import ABC, abstractmethod import pandas as pd @@ -26,9 +27,10 @@ def __init__(self, extractor_name: str): """Construct Extractor. Args: - extractor_name: Name of the `Extractor` instance. Used to keep track of the - provenance of different data, and to name tables to which this - data is saved. E.g. "mc_truth". + extractor_name: Name of the `Extractor` instance. + Used to keep track of the provenance of different + data, and to name tables to which this data is + saved. E.g. "mc_truth". """ # Member variable(s) self._extractor_name: str = extractor_name diff --git a/src/graphnet/data/extractors/icecube/i3extractor.py b/src/graphnet/data/extractors/icecube/i3extractor.py index 3f2fc92d2..07bf38030 100644 --- a/src/graphnet/data/extractors/icecube/i3extractor.py +++ b/src/graphnet/data/extractors/icecube/i3extractor.py @@ -24,9 +24,9 @@ def __init__(self, extractor_name: str): """Construct I3Extractor. Args: - extractor_name: Name of the `I3Extractor` instance. Used to keep track of the - provenance of different data, and to name tables to which this - data is saved. + extractor_name: Name of the `I3Extractor` instance. Used to keep + track of the provenance of different data, and to name tables + to which this data is saved. """ # Member variable(s) self._i3_file: str = "" diff --git a/src/graphnet/data/extractors/icecube/i3genericextractor.py b/src/graphnet/data/extractors/icecube/i3genericextractor.py index c79b7329b..f41add7d0 100644 --- a/src/graphnet/data/extractors/icecube/i3genericextractor.py +++ b/src/graphnet/data/extractors/icecube/i3genericextractor.py @@ -201,9 +201,9 @@ def _extract_per_pulse_attribute( ) -> Optional[Dict[str, Any]]: """Extract per-pulse attribute `key` from `frame`. - A per-pulse attribute (e.g., dataclasses.I3MapKeyUInt) is a - dictionary- like mapping from an OM key to some attribute, e.g., - an integer or a vector properties. + A per-pulse attribute (e.g., dataclasses.I3MapKeyUInt) is a dictionary- + like mapping from an OM key to some attribute, e.g., an integer or a + vector properties. """ result = self._extract_pulse_series_map(frame, key) @@ -264,12 +264,12 @@ def _flatten_result_mctree( flatten_nested_dictionary(res) for res in result_particles ] - result_primaries_transposed: Dict[ - str, List[Any] - ] = transpose_list_of_dicts(result_primaries) - result_particles_transposed: Dict[ - str, List[Any] - ] = transpose_list_of_dicts(result_particles) + result_primaries_transposed: Dict[str, List[Any]] = ( + transpose_list_of_dicts(result_primaries) + ) + result_particles_transposed: Dict[str, List[Any]] = ( + transpose_list_of_dicts(result_particles) + ) # Remove `majorID`, which has unsupported unit64 dtype. # Keep only one instances of `minorID`. diff --git a/src/graphnet/data/extractors/icecube/i3truthextractor.py b/src/graphnet/data/extractors/icecube/i3truthextractor.py index b715e57ab..4db330fc0 100644 --- a/src/graphnet/data/extractors/icecube/i3truthextractor.py +++ b/src/graphnet/data/extractors/icecube/i3truthextractor.py @@ -121,9 +121,10 @@ def __call__( "L7_oscNext_bool": padding_value, } - # Only InIceSplit P frames contain ML appropriate I3RecoPulseSeriesMap etc. - # At low levels i3files contain several other P frame splits (e.g NullSplit), - # we remove those here. + # Only InIceSplit P frames contain ML appropriate + # for example I3RecoPulseSeriesMap, etc. + # At low levels i3 files contain several other P frame splits + # (e.g NullSplit). We remove those here. if frame["I3EventHeader"].sub_event_stream not in [ "InIceSplit", "Final", @@ -181,7 +182,10 @@ def __call__( energy_cascade, inelasticity, ) = self._get_primary_track_energy_and_inelasticity(frame) - except RuntimeError: # track energy fails on northeren tracks with ""Hadrons" has no mass implemented. Cannot get total energy." + except ( + RuntimeError + ): # track energy fails on northeren tracks with ""Hadrons" + # has no mass implemented. Cannot get total energy." energy_track, energy_cascade, inelasticity = ( padding_value, padding_value, @@ -216,9 +220,10 @@ def __call__( muon_final = self._muon_stopped(output, self._borders) output.update( { - "position_x": muon_final[ - "x" - ], # position_xyz has no meaning for muons. These will now be updated to muon final position, given track length/azimuth/zenith + "position_x": muon_final["x"], + # position_xyz has no meaning for muons. + # These will now be updated to muon final position, + # given track length/azimuth/zenith "position_y": muon_final["y"], "position_z": muon_final["z"], "stopped_muon": muon_final["stopped"], @@ -362,10 +367,11 @@ def _get_primary_particle_interaction_type_and_elasticity( MCInIcePrimary = frame[self._mctree][0] if ( MCInIcePrimary.energy != MCInIcePrimary.energy - ): # This is a nan check. Only happens for some muons where second item in MCTree is primary. Weird! - MCInIcePrimary = frame[self._mctree][ - 1 - ] # For some strange reason the second entry is identical in all variables and has no nans (always muon) + ): # This is a nan check. Only happens for some muons + # where second item in MCTree is primary. Weird! + MCInIcePrimary = frame[self._mctree][1] + # For some strange reason the second entry is identical in + # all variables and has no nans (always muon) else: MCInIcePrimary = None try: diff --git a/src/graphnet/data/extractors/icecube/utilities/i3_filters.py b/src/graphnet/data/extractors/icecube/utilities/i3_filters.py index ca83f4217..db115bd21 100644 --- a/src/graphnet/data/extractors/icecube/utilities/i3_filters.py +++ b/src/graphnet/data/extractors/icecube/utilities/i3_filters.py @@ -1,4 +1,5 @@ """Filter classes for filtering I3-frames when converting I3-files.""" + from abc import abstractmethod from graphnet.utilities.logging import Logger from typing import List @@ -64,7 +65,7 @@ def _keep_frame(self, frame: "icetray.I3Frame") -> bool: class I3FilterMask(I3Filter): - """checks list of filters from the FilterMask in I3 frames.""" + """Checks list of filters from the FilterMask in I3 frames.""" def __init__(self, filter_names: List[str], filter_any: bool = True): """Initialize I3FilterMask. @@ -95,7 +96,8 @@ def _keep_frame(self, frame: "icetray.I3Frame") -> bool: for filter_name in self._filter_names: if filter_name not in frame["FilterMask"]: self.warning_once( - f"FilterMask {filter_name} not found in frame. skipping filter." + f"FilterMask {filter_name} not found in frame. " + "Skipping filter." ) continue elif frame["FilterMask"][filter].condition_passed is True: @@ -104,18 +106,20 @@ def _keep_frame(self, frame: "icetray.I3Frame") -> bool: bool_list.append(False) if len(bool_list) == 0: self.warning_once( - "None of the FilterMask filters found in frame, FilterMask filters will not be applied." + "None of the FilterMask filters found in frame." + "FilterMask filters will not be applied." ) return any(bool_list) or len(bool_list) == 0 else: # Require all filters to pass in order to keep the frame. for filter_name in self._filter_names: if filter_name not in frame["FilterMask"]: self.warning_once( - f"FilterMask {filter_name} not found in frame, skipping filter." + f"FilterMask {filter_name} not found in frame." + "Skipping filter." ) continue elif frame["FilterMask"][filter].condition_passed is True: - continue # current filter passed, continue to next filter + continue # current filter passed, go to next filter else: return ( False # current filter failed so frame is skipped. @@ -123,6 +127,7 @@ def _keep_frame(self, frame: "icetray.I3Frame") -> bool: return True else: self.warning_once( - "FilterMask not found in frame, FilterMask filters will not be applied." + "FilterMask not found in frame." + "FilterMask filters will not be applied." ) return True diff --git a/src/graphnet/data/extractors/internal/parquet_extractor.py b/src/graphnet/data/extractors/internal/parquet_extractor.py index ec8bb5db2..0f9bd8e6b 100644 --- a/src/graphnet/data/extractors/internal/parquet_extractor.py +++ b/src/graphnet/data/extractors/internal/parquet_extractor.py @@ -1,4 +1,5 @@ """Parquet Extractor for conversion from internal parquet format.""" + import polars as pol import pandas as pd diff --git a/src/graphnet/data/extractors/liquido/__init__.py b/src/graphnet/data/extractors/liquido/__init__.py index 023c55e12..9ead11289 100644 --- a/src/graphnet/data/extractors/liquido/__init__.py +++ b/src/graphnet/data/extractors/liquido/__init__.py @@ -1,2 +1,3 @@ """Module containing different extractors for LiquidO files.""" + from .h5_extractor import H5Extractor, H5HitExtractor, H5TruthExtractor diff --git a/src/graphnet/data/extractors/liquido/h5_extractor.py b/src/graphnet/data/extractors/liquido/h5_extractor.py index 3075a9cde..231e3349f 100644 --- a/src/graphnet/data/extractors/liquido/h5_extractor.py +++ b/src/graphnet/data/extractors/liquido/h5_extractor.py @@ -1,4 +1,5 @@ """H5 Extractor for LiquidO data files.""" + from typing import List import numpy as np import pandas as pd diff --git a/src/graphnet/data/extractors/prometheus/__init__.py b/src/graphnet/data/extractors/prometheus/__init__.py index 09f6a9ea3..da31b2a2f 100644 --- a/src/graphnet/data/extractors/prometheus/__init__.py +++ b/src/graphnet/data/extractors/prometheus/__init__.py @@ -1,4 +1,5 @@ """Extractors for extracting data from parquet files Prometheus.""" + from .prometheus_extractor import ( PrometheusExtractor, PrometheusTruthExtractor, diff --git a/src/graphnet/data/extractors/prometheus/prometheus_extractor.py b/src/graphnet/data/extractors/prometheus/prometheus_extractor.py index 9e4d02973..a65d315b4 100644 --- a/src/graphnet/data/extractors/prometheus/prometheus_extractor.py +++ b/src/graphnet/data/extractors/prometheus/prometheus_extractor.py @@ -1,4 +1,5 @@ """Parquet Extractor for conversion of simulation files from PROMETHEUS.""" + from typing import List import pandas as pd import numpy as np diff --git a/src/graphnet/data/parquet/__init__.py b/src/graphnet/data/parquet/__init__.py index 2c41ca75d..b0725ea3c 100644 --- a/src/graphnet/data/parquet/__init__.py +++ b/src/graphnet/data/parquet/__init__.py @@ -1,2 +1,3 @@ """Module for deprecated parquet methods.""" + from .deprecated_methods import ParquetDataConverter diff --git a/src/graphnet/data/parquet/deprecated_methods.py b/src/graphnet/data/parquet/deprecated_methods.py index ae2593813..62b95aeee 100644 --- a/src/graphnet/data/parquet/deprecated_methods.py +++ b/src/graphnet/data/parquet/deprecated_methods.py @@ -2,6 +2,7 @@ This code will be removed in GraphNeT 2.0. """ + from typing import List, Union from graphnet.data.extractors.icecube import I3Extractor @@ -26,8 +27,9 @@ def __init__( """Convert I3 files to Parquet. Args: - gcd_rescue: gcd_rescue: Path to a GCD file that will be used if no GCD file is - found in subfolder. `I3Reader` will recursively search + gcd_rescue: gcd_rescue: Path to a GCD file that will be used if no + GCD file is found in subfolder. + `I3Reader` will recursively search the input directory for I3-GCD file pairs. By IceCube convention, a folder containing i3 files will have an accompanying GCD file. However, in some cases, this diff --git a/src/graphnet/data/pre_configured/__init__.py b/src/graphnet/data/pre_configured/__init__.py index c19679325..8214fdcc5 100644 --- a/src/graphnet/data/pre_configured/__init__.py +++ b/src/graphnet/data/pre_configured/__init__.py @@ -1,4 +1,5 @@ """Module for pre-configured converter modules.""" + from .dataconverters import ( I3ToParquetConverter, I3ToSQLiteConverter, diff --git a/src/graphnet/data/pre_configured/dataconverters.py b/src/graphnet/data/pre_configured/dataconverters.py index 1194da3c2..a7e8f324f 100644 --- a/src/graphnet/data/pre_configured/dataconverters.py +++ b/src/graphnet/data/pre_configured/dataconverters.py @@ -25,10 +25,12 @@ def __init__( """Convert I3 files to Parquet. Args: - gcd_rescue: gcd_rescue: Path to a GCD file that will be used if no GCD file is - found in subfolder. `I3Reader` will recursively search - the input directory for I3-GCD file pairs. By IceCube - convention, a folder containing i3 files will have an + gcd_rescue: gcd_rescue: Path to a GCD file that will be used if + no GCD file is found in subfolder. `I3Reader` will + recursively search the input directory for I3-GCD file + pairs. + By IceCube convention, + a folder containing i3 files will have an accompanying GCD file. However, in some cases, this convention is broken. In cases where a folder contains i3 files but no GCD file, the `gcd_rescue` is used @@ -70,10 +72,11 @@ def __init__( """Convert I3 files to SQLite. Args: - gcd_rescue: gcd_rescue: Path to a GCD file that will be used if no GCD file is - found in subfolder. `I3Reader` will recursively search - the input directory for I3-GCD file pairs. By IceCube - convention, a folder containing i3 files will have an + gcd_rescue: gcd_rescue: Path to a GCD file that will be used if + no GCD file is found in subfolder. `I3Reader` will + recursively search the input directory for I3-GCD file + pairs. By IceCube convention, + a folder containing i3 files will have an accompanying GCD file. However, in some cases, this convention is broken. In cases where a folder contains i3 files but no GCD file, the `gcd_rescue` is used diff --git a/src/graphnet/data/readers/__init__.py b/src/graphnet/data/readers/__init__.py index 658da0b0c..911a00255 100644 --- a/src/graphnet/data/readers/__init__.py +++ b/src/graphnet/data/readers/__init__.py @@ -1,4 +1,5 @@ """Modules for reading experiment-specific data and applying Extractors.""" + from .graphnet_file_reader import GraphNeTFileReader from .i3reader import I3Reader from .internal_parquet_reader import ParquetReader diff --git a/src/graphnet/data/readers/graphnet_file_reader.py b/src/graphnet/data/readers/graphnet_file_reader.py index f30a19fcf..6d145f1c0 100644 --- a/src/graphnet/data/readers/graphnet_file_reader.py +++ b/src/graphnet/data/readers/graphnet_file_reader.py @@ -125,7 +125,9 @@ def _validate_extractors( ) -> None: for extractor in extractors: try: - assert isinstance(extractor, tuple(self.accepted_extractors)) # type: ignore + assert isinstance( + extractor, tuple(self.accepted_extractors) # type: ignore + ) except AssertionError as e: self.error( f"{extractor.__class__.__name__}" @@ -164,5 +166,7 @@ def _validate_file(self, file: str) -> None: assert file.lower().endswith(tuple(self.accepted_file_extensions)) except AssertionError: self.error( - f'{self.__class__.__name__} accepts {self.accepted_file_extensions} but {file.split("/")[-1]} has extension {os.path.splitext(file)[1]}.' + f"{self.__class__.__name__} accepts " + f'{self.accepted_file_extensions} but {file.split("/")[-1]} ' + f"has extension {os.path.splitext(file)[1]}." ) diff --git a/src/graphnet/data/readers/i3reader.py b/src/graphnet/data/readers/i3reader.py index a2d1d0c83..39a090eb3 100644 --- a/src/graphnet/data/readers/i3reader.py +++ b/src/graphnet/data/readers/i3reader.py @@ -1,6 +1,6 @@ """Module containing different I3Reader.""" -from typing import List, Union, OrderedDict +from typing import List, Union, OrderedDict, Optional from graphnet.utilities.imports import has_icecube_package from graphnet.data.extractors.icecube.utilities.i3_filters import ( @@ -27,7 +27,7 @@ class I3Reader(GraphNeTFileReader): def __init__( self, gcd_rescue: str, - i3_filters: Union[I3Filter, List[I3Filter]] = None, + i3_filters: Optional[Union[I3Filter, List[I3Filter]]] = None, icetray_verbose: int = 0, ): """Initialize `I3Reader`. @@ -65,7 +65,9 @@ def __init__( # Base class constructor super().__init__(name=__name__, class_name=self.__class__.__name__) - def __call__(self, file_path: I3FileSet) -> List[OrderedDict]: # type: ignore + def __call__( + self, file_path: I3FileSet + ) -> List[OrderedDict]: # noqa: E501 # type: ignore """Extract data from single I3 file. Args: diff --git a/src/graphnet/data/sqlite/__init__.py b/src/graphnet/data/sqlite/__init__.py index 436a86f2d..8fbfcfcb3 100644 --- a/src/graphnet/data/sqlite/__init__.py +++ b/src/graphnet/data/sqlite/__init__.py @@ -1,2 +1,3 @@ """Module for deprecated methods using sqlite.""" + from .deprecated_methods import SQLiteDataConverter diff --git a/src/graphnet/data/sqlite/deprecated_methods.py b/src/graphnet/data/sqlite/deprecated_methods.py index 30b563c59..4d7e0e69f 100644 --- a/src/graphnet/data/sqlite/deprecated_methods.py +++ b/src/graphnet/data/sqlite/deprecated_methods.py @@ -3,13 +3,10 @@ This code will be removed in GraphNeT 2.0. """ -from typing import List, Union, Type +from typing import List, Union from graphnet.data.extractors.icecube import I3Extractor -from graphnet.data.extractors.icecube.utilities.i3_filters import ( - I3Filter, - NullSplitI3Filter, -) +from graphnet.data.extractors.icecube.utilities.i3_filters import I3Filter from graphnet.data import I3ToSQLiteConverter @@ -28,10 +25,11 @@ def __init__( """Convert I3 files to Parquet. Args: - gcd_rescue: gcd_rescue: Path to a GCD file that will be used if no GCD file is - found in subfolder. `I3Reader` will recursively search - the input directory for I3-GCD file pairs. By IceCube - convention, a folder containing i3 files will have an + gcd_rescue: gcd_rescue: Path to a GCD file that will be used if no + GCD file is found in subfolder. `I3Reader` will + recursively search the input directory for I3-GCD file + pairs. By IceCube convention, + a folder containing i3 files will have an accompanying GCD file. However, in some cases, this convention is broken. In cases where a folder contains i3 files but no GCD file, the `gcd_rescue` is used diff --git a/src/graphnet/data/utilities/__init__.py b/src/graphnet/data/utilities/__init__.py index ad4f0c7db..fb5d6fad6 100644 --- a/src/graphnet/data/utilities/__init__.py +++ b/src/graphnet/data/utilities/__init__.py @@ -1,4 +1,5 @@ """Utilities for use across `graphnet.data`.""" + from .sqlite_utilities import create_table_and_save_to_sql from .sqlite_utilities import get_primary_keys from .sqlite_utilities import query_database diff --git a/src/graphnet/data/utilities/parquet_to_sqlite.py b/src/graphnet/data/utilities/parquet_to_sqlite.py index f429d5e6b..07bc86ecb 100644 --- a/src/graphnet/data/utilities/parquet_to_sqlite.py +++ b/src/graphnet/data/utilities/parquet_to_sqlite.py @@ -1,3 +1,3 @@ """Utilities for converting files from Parquet to SQLite.""" -from graphnet.data.pre_configured import ParquetToSQLiteConverter +from graphnet.data.pre_configured import ParquetToSQLiteConverter # noqa: F401 diff --git a/src/graphnet/data/utilities/random.py b/src/graphnet/data/utilities/random.py index 1084f782c..7aaa8c4be 100644 --- a/src/graphnet/data/utilities/random.py +++ b/src/graphnet/data/utilities/random.py @@ -9,7 +9,8 @@ def pairwise_shuffle( ) -> Tuple[List[str], List[str]]: """Shuffle the I3 file list and the correponding gcd file list. - This is handy because it ensures a more even extraction load for each worker. + This is handy because it ensures a more even extraction load for each + worker. Args: files_list: List of I3 file paths. diff --git a/src/graphnet/data/utilities/sqlite_utilities.py b/src/graphnet/data/utilities/sqlite_utilities.py index 95755ef44..3458f4485 100644 --- a/src/graphnet/data/utilities/sqlite_utilities.py +++ b/src/graphnet/data/utilities/sqlite_utilities.py @@ -1,7 +1,7 @@ """SQLite-specific utility functions for use in `graphnet.data`.""" import os.path -from typing import List, Dict, Tuple +from typing import List, Dict, Tuple, Union import pandas as pd import sqlalchemy @@ -30,7 +30,9 @@ def query_database(database: str, query: str) -> pd.DataFrame: return pd.read_sql(query, conn) -def get_primary_keys(database: str) -> Tuple[Dict[str, str], str]: +def get_primary_keys( + database: str, +) -> Tuple[Dict[str, Union[str, None]], Union[str, None]]: """Get name of primary key column for each table in database. Args: @@ -50,7 +52,7 @@ def get_primary_keys(database: str) -> Tuple[Dict[str, str], str]: integer_primary_key = {} for table in table_names: - query = f"SELECT l.name FROM pragma_table_info('{table}') as l WHERE l.pk = 1;" + query = f"SELECT l.name FROM pragma_table_info('{table}') as l WHERE l.pk = 1;" # noqa: E501 first_primary_key = [ key[0] for key in conn.execute(query).fetchall() ] @@ -78,7 +80,7 @@ def database_table_exists(database_path: str, table_name: str) -> bool: """Check whether `table_name` exists in database at `database_path`.""" if not database_exists(database_path): return False - query = f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table_name}';" + query = f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table_name}';" # noqa: E501 with sqlite3.connect(database_path) as conn: result = pd.read_sql(query, conn) return len(result) == 1 diff --git a/src/graphnet/data/writers/__init__.py b/src/graphnet/data/writers/__init__.py index ad3e2748e..cf70660ae 100644 --- a/src/graphnet/data/writers/__init__.py +++ b/src/graphnet/data/writers/__init__.py @@ -1,4 +1,5 @@ """Modules for saving interim dataformat to various data backends.""" + from .graphnet_writer import GraphNeTWriter from .parquet_writer import ParquetWriter from .sqlite_writer import SQLiteWriter diff --git a/src/graphnet/datasets/__init__.py b/src/graphnet/datasets/__init__.py index 74893bb83..c01f50bc3 100644 --- a/src/graphnet/datasets/__init__.py +++ b/src/graphnet/datasets/__init__.py @@ -1,3 +1,4 @@ """Contains pre-converted datasets ready for training.""" + from .test_dataset import TestDataset from .prometheus_datasets import TRIDENTSmall, BaikalGVDSmall, PONESmall diff --git a/src/graphnet/datasets/prometheus_datasets.py b/src/graphnet/datasets/prometheus_datasets.py index 14cb2bc78..8dc028278 100644 --- a/src/graphnet/datasets/prometheus_datasets.py +++ b/src/graphnet/datasets/prometheus_datasets.py @@ -1,4 +1,5 @@ """Public datasets from Prometheus Simulation.""" + from typing import Dict, Any, List, Tuple, Union import os from sklearn.model_selection import train_test_split diff --git a/src/graphnet/datasets/test_dataset.py b/src/graphnet/datasets/test_dataset.py index 0a30a962d..1c3d2b679 100644 --- a/src/graphnet/datasets/test_dataset.py +++ b/src/graphnet/datasets/test_dataset.py @@ -1,4 +1,5 @@ """A CuratedDataset for unit tests.""" + from typing import Dict, Any, List, Tuple, Union import os diff --git a/src/graphnet/deployment/__init__.py b/src/graphnet/deployment/__init__.py index c26dfc1b7..b0d0fc429 100644 --- a/src/graphnet/deployment/__init__.py +++ b/src/graphnet/deployment/__init__.py @@ -3,5 +3,6 @@ `graphnet.deployment` allows for using trained models for inference in domain- specific reconstruction chains. """ + from .deployer import Deployer from .deployment_module import DeploymentModule diff --git a/src/graphnet/deployment/deployer.py b/src/graphnet/deployment/deployer.py index 4ac67f4d8..37473d195 100644 --- a/src/graphnet/deployment/deployer.py +++ b/src/graphnet/deployment/deployer.py @@ -1,4 +1,5 @@ """Contains the graphnet deployment module.""" + import random from abc import abstractmethod, ABC import multiprocessing @@ -9,7 +10,7 @@ from .deployment_module import DeploymentModule from graphnet.utilities.logging import Logger -if has_torch_package or TYPE_CHECKING: +if has_torch_package() or TYPE_CHECKING: import torch diff --git a/src/graphnet/deployment/deployment_module.py b/src/graphnet/deployment/deployment_module.py index b91490ad0..72026a21b 100644 --- a/src/graphnet/deployment/deployment_module.py +++ b/src/graphnet/deployment/deployment_module.py @@ -1,4 +1,5 @@ """Class(es) for deploying GraphNeT models in icetray as I3Modules.""" + from abc import abstractmethod from typing import Any, List, Union, Dict diff --git a/src/graphnet/deployment/i3modules/__init__.py b/src/graphnet/deployment/i3modules/__init__.py index e2fd05a43..0fe5115da 100644 --- a/src/graphnet/deployment/i3modules/__init__.py +++ b/src/graphnet/deployment/i3modules/__init__.py @@ -5,5 +5,5 @@ detector configurations. """ -from .deprecated_methods import * +from .deprecated_methods import * # noqa: F403 from graphnet.deployment.icecube import I3InferenceModule, I3PulseCleanerModule diff --git a/src/graphnet/deployment/i3modules/deprecated_methods.py b/src/graphnet/deployment/i3modules/deprecated_methods.py index 6acdc8d33..987dcdef1 100644 --- a/src/graphnet/deployment/i3modules/deprecated_methods.py +++ b/src/graphnet/deployment/i3modules/deprecated_methods.py @@ -1,4 +1,5 @@ """Contains deprecated methods.""" + from typing import Union, Sequence # from graphnet.deployment.icecube import I3Deployer, I3InferenceModule diff --git a/src/graphnet/deployment/icecube/__init__.py b/src/graphnet/deployment/icecube/__init__.py index 15c7485ef..184d29ef9 100644 --- a/src/graphnet/deployment/icecube/__init__.py +++ b/src/graphnet/deployment/icecube/__init__.py @@ -1,4 +1,5 @@ """Deployment modules specific to IceCube.""" + from .inference_module import I3InferenceModule from .cleaning_module import I3PulseCleanerModule from .i3deployer import I3Deployer diff --git a/src/graphnet/deployment/icecube/cleaning_module.py b/src/graphnet/deployment/icecube/cleaning_module.py index 27e5f2260..e2fee80cd 100644 --- a/src/graphnet/deployment/icecube/cleaning_module.py +++ b/src/graphnet/deployment/icecube/cleaning_module.py @@ -2,6 +2,7 @@ Contains functionality for writing model predictions to i3 files. """ + from typing import List, Union, TYPE_CHECKING, Dict, Any, Tuple import numpy as np @@ -117,13 +118,13 @@ def __call__(self, frame: I3Frame) -> bool: # checking the prediction for each pulse # (Adds the actual pulsemap to dictionary) if self._total_pulsemap_name not in frame.keys(): - data_dict[ - self._total_pulsemap_name - ] = dataclasses.I3RecoPulseSeriesMapMask( - frame, - self._pulsemap, - lambda om_key, index, pulse: predictions_map[om_key][index] - >= self._threshold, + data_dict[self._total_pulsemap_name] = ( + dataclasses.I3RecoPulseSeriesMapMask( + frame, + self._pulsemap, + lambda om_key, index, pulse: predictions_map[om_key][index] + >= self._threshold, + ) ) # Submit predictions and general pulsemap @@ -138,19 +139,19 @@ def __call__(self, frame: I3Frame) -> bool: ) if f"{self._total_pulsemap_name}_mDOMs_Only" not in frame.keys(): - data[ - f"{self._total_pulsemap_name}_mDOMs_Only" - ] = dataclasses.I3RecoPulseSeriesMap(mDOMMap) + data[f"{self._total_pulsemap_name}_mDOMs_Only"] = ( + dataclasses.I3RecoPulseSeriesMap(mDOMMap) + ) if f"{self._total_pulsemap_name}_dEggs_Only" not in frame.keys(): - data[ - f"{self._total_pulsemap_name}_dEggs_Only" - ] = dataclasses.I3RecoPulseSeriesMap(DEggMap) + data[f"{self._total_pulsemap_name}_dEggs_Only"] = ( + dataclasses.I3RecoPulseSeriesMap(DEggMap) + ) if f"{self._total_pulsemap_name}_pDOMs_Only" not in frame.keys(): - data[ - f"{self._total_pulsemap_name}_pDOMs_Only" - ] = dataclasses.I3RecoPulseSeriesMap(IceCubeMap) + data[f"{self._total_pulsemap_name}_pDOMs_Only"] = ( + dataclasses.I3RecoPulseSeriesMap(IceCubeMap) + ) # Submits the additional pulsemaps to the frame frame = self._add_to_frame(frame=frame, data=data) @@ -211,7 +212,7 @@ def _construct_prediction_map( for om_key, pulses in pulsemap.items(): num_pulses = len(pulses) predictions_map[om_key] = predictions[ - idx : idx + num_pulses + idx : idx + num_pulses # noqa: E203 ].tolist() idx += num_pulses diff --git a/src/graphnet/deployment/icecube/inference_module.py b/src/graphnet/deployment/icecube/inference_module.py index 9631cc3e0..e963850a5 100644 --- a/src/graphnet/deployment/icecube/inference_module.py +++ b/src/graphnet/deployment/icecube/inference_module.py @@ -2,6 +2,7 @@ Contains functionality for writing model predictions to i3 files. """ + from typing import List, Union, Optional, TYPE_CHECKING, Dict, Any import numpy as np @@ -119,13 +120,13 @@ def _create_dictionary( for i in range(dim): try: assert len(predictions[:, i]) == 1 - data[ - self.model_name + "_" + self.prediction_columns[i] - ] = I3Double(float(predictions[:, i][0])) + data[self.model_name + "_" + self.prediction_columns[i]] = ( + I3Double(float(predictions[:, i][0])) + ) except IndexError: - data[ - self.model_name + "_" + self.prediction_columns[i] - ] = I3Double(predictions[0]) + data[self.model_name + "_" + self.prediction_columns[i]] = ( + I3Double(predictions[0]) + ) return data def _apply_model(self, data: Data) -> np.ndarray: diff --git a/src/graphnet/exceptions/__init__.py b/src/graphnet/exceptions/__init__.py index 9d7796808..71fd44001 100644 --- a/src/graphnet/exceptions/__init__.py +++ b/src/graphnet/exceptions/__init__.py @@ -1,2 +1,3 @@ """Custom Exceptions for GraphNeT.""" + from .exceptions import ColumnMissingException diff --git a/src/graphnet/models/__init__.py b/src/graphnet/models/__init__.py index 12d4cbcc5..e3e4f23b9 100644 --- a/src/graphnet/models/__init__.py +++ b/src/graphnet/models/__init__.py @@ -6,6 +6,7 @@ existing, purpose-built components and chain them together to form a complete GNN """ + from graphnet.utilities.imports import has_jammy_flows_package from .model import Model from .standard_model import StandardModel diff --git a/src/graphnet/models/coarsening.py b/src/graphnet/models/coarsening.py index d40f0c009..d481a3d62 100644 --- a/src/graphnet/models/coarsening.py +++ b/src/graphnet/models/coarsening.py @@ -8,7 +8,6 @@ from torch_geometric.data import Data, Batch from sklearn.cluster import DBSCAN -# from torch_geometric.utils import unbatch_edge_index from graphnet.models.components.pool import ( group_by, avg_pool, @@ -28,7 +27,7 @@ # NOTE: From [https://github.com/pyg-team/pytorch_geometric/pull/4903] # TODO: Remove once bumping to torch_geometric>=2.1.0 -# See [https://github.com/pyg-team/pytorch_geometric/blob/master/CHANGELOG.md] +# See [https://github.com/pyg-team/pytorch_geometric/blob/master/CHANGELOG.md] # noqa: E501 def unbatch_edge_index(edge_index: Tensor, batch: Tensor) -> List[Tensor]: @@ -170,15 +169,18 @@ def _reconstruct_batch(self, original: Data, pooled: Data) -> Data: return pooled def _add_slice_dict(self, original: Data, pooled: Data) -> Data: - # Copy original slice_dict and count nodes in each graph in pooled batch + # Copy original slice_dict and count nodes in each + # graph in pooled batch slice_dict = deepcopy(original._slice_dict) _, counts = torch.unique_consecutive(pooled.batch, return_counts=True) - # Reconstruct the entry in slice_dict for pulsemaps - only these are affected by pooling + # Reconstruct the entry in slice_dict for pulsemaps - + # only these are affected by pooling pulsemap_slice = [0] for i in range(len(counts)): pulsemap_slice.append(pulsemap_slice[i] + counts[i].item()) - # Identifies pulsemap entries in slice_dict and set them to pulsemap_slice + # Identifies pulsemap entries in slice_dict and + # set them to pulsemap_slice for field in slice_dict.keys(): if (original._num_graphs) == slice_dict[field][-1]: pass # not pulsemap, so skip diff --git a/src/graphnet/models/components/embedding.py b/src/graphnet/models/components/embedding.py index 1b49cd901..08d699931 100644 --- a/src/graphnet/models/components/embedding.py +++ b/src/graphnet/models/components/embedding.py @@ -1,4 +1,5 @@ """Classes for performing embedding of input data.""" + import torch import torch.nn as nn from torch.functional import Tensor @@ -33,9 +34,7 @@ def __init__( super().__init__() if dim % 2 != 0: raise ValueError(f"dim has to be even. Got: {dim}") - self.scale = ( - nn.Parameter(torch.ones(1) * dim**-0.5) if scaled else 1.0 - ) + self.scale = nn.Parameter(torch.ones(1) * dim**-0.5) if scaled else 1.0 self.dim = dim self.n_freq = torch.Tensor([n_freq]) diff --git a/src/graphnet/models/components/layers.py b/src/graphnet/models/components/layers.py index bd2ed8daf..4f04067f9 100644 --- a/src/graphnet/models/components/layers.py +++ b/src/graphnet/models/components/layers.py @@ -9,7 +9,6 @@ from torch_geometric.typing import Adj, PairTensor from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.inits import reset -from torch_geometric.data import Data import torch.nn as nn from torch.nn.functional import linear from torch.nn.modules import TransformerEncoder, TransformerEncoderLayer diff --git a/src/graphnet/models/components/pool.py b/src/graphnet/models/components/pool.py index 3af15990c..592c2fce2 100644 --- a/src/graphnet/models/components/pool.py +++ b/src/graphnet/models/components/pool.py @@ -9,11 +9,11 @@ from torch_geometric.nn.pool.pool import pool_edge, pool_batch, pool_pos from torch_scatter import scatter, scatter_std -from torch_geometric.nn.pool import ( - avg_pool, +from torch_geometric.nn.pool import ( # noqa:F401 max_pool, - avg_pool_x, max_pool_x, + avg_pool_x, + avg_pool, ) @@ -90,8 +90,8 @@ def group_by(data: Union[Data, Batch], keys: List[str]) -> LongTensor: This grouping is done with in each event in case of batching. This allows for, e.g., assigning the same index to all pulses on the same PMT or DOM in the same event. This can be used for coarsening graphs, e.g., from pulse- - level to DOM-level by aggregating feature across each group returned by this - method. + level to DOM-level by aggregating feature across each group returned by + this method. Example: Given: @@ -140,7 +140,7 @@ def sum_pool_x( batch: LongTensor, size: Optional[int] = None, ) -> Tensor: - r"""Sum-pool node features according to the clustering defined in `cluster`. + r"""Sum-pool node features according to the cluster defined in `cluster`. Args: cluster: Cluster vector :math:`\mathbf{c} \in \{ 0, @@ -172,7 +172,7 @@ def std_pool_x( batch: LongTensor, size: Optional[int] = None, ) -> Tensor: - r"""Std-pool node features according to the clustering defined in `cluster`. + r"""Std-pool node features according to the cluster defined in `cluster`. Args: cluster: Cluster vector :math:`\mathbf{c} \in \{ 0, @@ -201,7 +201,7 @@ def std_pool_x( def sum_pool( cluster: LongTensor, data: Data, transform: Optional[Callable] = None ) -> Data: - r"""Pool and coarsen graph according to the clustering defined in `cluster`. + r"""Pool and coarsen graph according to the cluster defined in `cluster`. All nodes within the same cluster will be represented as one node. Final node features are defined by the *sum* of features of all nodes @@ -235,7 +235,7 @@ def sum_pool( def std_pool( cluster: LongTensor, data: Data, transform: Optional[Callable] = None ) -> Data: - r"""Pool and coarsen graph according to the clustering defined in `cluster`. + r"""Pool and coarsen graph according to the cluster defined in `cluster`. All nodes within the same cluster will be represented as one node. Final node features are defined by the *std* of features of all nodes diff --git a/src/graphnet/models/detector/detector.py b/src/graphnet/models/detector/detector.py index 9b9fc61b0..df28f3191 100644 --- a/src/graphnet/models/detector/detector.py +++ b/src/graphnet/models/detector/detector.py @@ -66,7 +66,9 @@ def _standardize( ) -> Data: for idx, feature in enumerate(input_feature_names): try: - input_features[:, idx] = self.feature_map()[feature]( # type: ignore + input_features[:, idx] = self.feature_map()[ + feature + ]( # noqa: E501 # type: ignore input_features[:, idx] ) except KeyError as e: diff --git a/src/graphnet/models/gnn/RNN_tito.py b/src/graphnet/models/gnn/RNN_tito.py index 8b9157c58..26539eb72 100644 --- a/src/graphnet/models/gnn/RNN_tito.py +++ b/src/graphnet/models/gnn/RNN_tito.py @@ -1,4 +1,5 @@ """RNN_DynEdge model implementation.""" + from typing import List, Optional, Tuple import torch @@ -43,24 +44,42 @@ def __init__( Args: nb_inputs (int): Number of input features. - time_series_columns (List[int]): The indices of the input data that should be treated as time series data. The first index should be the charge column. + time_series_columns (List[int]): The indices of the input data that + should be treated as time series data. + The first index should be the charge column. nb_neighbours (int, optional): Number of neighbours to consider. Defaults to 8. rnn_layers (int, optional): Number of RNN layers. Defaults to 1. - rnn_hidden_size (int, optional): Size of the hidden state of the RNN. Also determines the size of the output of the RNN. + rnn_hidden_size (int, optional): Size of the hidden state of the + RNN. Also determines the size of the output of the RNN. Defaults to 64. - rnn_dropout (float, optional): Dropout to use in the RNN. Defaults to 0.5. + rnn_dropout (float, optional): Dropout to use in the RNN. + Defaults to 0.5. features_subset (List[int], optional): The subset of latent - features on each node that are used as metric dimensions when performing the k-nearest neighbours clustering. Defaults to [0,1,2,3] - dyntrans_layer_sizes (List[Tuple[int, ...]], optional): List of tuples representing the sizes of the hidden layers of the DynTrans model. - post_processing_layer_sizes (List[int], optional): List of integers representing the sizes of the hidden layers of the post-processing model. - readout_layer_sizes (List[int], optional): List of integers representing the sizes of the hidden layers of the readout model. - global_pooling_schemes (Union[str, List[str]], optional): Pooling schemes to use. Defaults to None. - embedding_dim (int, optional): Embedding dimension of the RNN. Defaults to None ie. no embedding. - n_head (int, optional): Number of heads to use in the DynTrans model. Defaults to 16. - use_global_features (bool, optional): Whether to use global features after pooling. Defaults to True. - use_post_processing_layers (bool, optional): Whether to use post-processing layers after the DynTrans layers. Defaults to True. + features on each node that are used as metric dimensions when + performing the k-nearest neighbours clustering. + Defaults to [0,1,2,3] + dyntrans_layer_sizes (List[Tuple[int, ...]], optional): List of + tuples representing the sizes of the hidden layers of + the DynTrans model. + post_processing_layer_sizes (List[int], optional): List of + integers representing the sizes of the hidden layers of the + post-processing model. + readout_layer_sizes (List[int], optional): List of integers + representing the sizes of the hidden layers of the + readout model. + global_pooling_schemes (Union[str, List[str]], optional): Pooling + schemes to use. Defaults to None. + embedding_dim (int, optional): Embedding dimension of the RNN. + Defaults to None ie. no embedding. + n_head (int, optional): Number of heads to use in the DynTrans + model. Defaults to 16. + use_global_features (bool, optional): Whether to use global + features after pooling. Defaults to True. + use_post_processing_layers (bool, optional): Whether to use + post-processing layers after the DynTrans layers. + Defaults to True. """ self._nb_neighbours = nb_neighbours self._nb_inputs = nb_inputs diff --git a/src/graphnet/models/gnn/dynedge.py b/src/graphnet/models/gnn/dynedge.py index a80ce82ce..aabfabb88 100644 --- a/src/graphnet/models/gnn/dynedge.py +++ b/src/graphnet/models/gnn/dynedge.py @@ -1,4 +1,5 @@ """Implementation of the DynEdge GNN model architecture.""" + from typing import List, Optional, Tuple, Union import torch diff --git a/src/graphnet/models/gnn/dynedge_kaggle_tito.py b/src/graphnet/models/gnn/dynedge_kaggle_tito.py index 88d4c8811..dded5616c 100644 --- a/src/graphnet/models/gnn/dynedge_kaggle_tito.py +++ b/src/graphnet/models/gnn/dynedge_kaggle_tito.py @@ -34,12 +34,12 @@ class DynEdgeTITO(GNN): def __init__( self, nb_inputs: int, - features_subset: List[int] = None, + features_subset: Optional[List[int]] = None, dyntrans_layer_sizes: Optional[List[Tuple[int, ...]]] = None, global_pooling_schemes: List[str] = ["max"], use_global_features: bool = True, use_post_processing_layers: bool = True, - post_processing_layer_sizes: List[int] = None, + post_processing_layer_sizes: Optional[List[int]] = None, readout_layer_sizes: Optional[List[int]] = None, n_head: int = 8, nb_neighbours: int = 8, @@ -57,9 +57,12 @@ def __init__( global_pooling_schemes: The list global pooling schemes to use. Options are: "min", "max", "mean", and "sum". use_global_features: Whether to use global features after pooling. - use_post_processing_layers: Whether to use post-processing layers after the `DynTrans` layers. - post_processing_layer_sizes: The layer sizes used in the post-processing layers. Defaults to [336, 256]. - readout_layer_sizes: The layer sizes used in the readout layers. Defaults to [256, 128]. + use_post_processing_layers: Whether to use post-processing layers + after the `DynTrans` layers. + post_processing_layer_sizes: The layer sizes used in the + post-processing layers. Defaults to [336, 256]. + readout_layer_sizes: The layer sizes used in the readout layers. + Defaults to [256, 128]. n_head: The number of heads to use in the `DynTrans` layer. nb_neighbours: The number of neighbours to use in the `DynTrans` layer. diff --git a/src/graphnet/models/gnn/icemix.py b/src/graphnet/models/gnn/icemix.py index a073e3ca8..6e7ebe647 100644 --- a/src/graphnet/models/gnn/icemix.py +++ b/src/graphnet/models/gnn/icemix.py @@ -1,15 +1,14 @@ -"""Implementation of IceMix architecture used in. - - IceCube - Neutrinos in Deep Ice -Reconstruct the direction of neutrinos from the Universe to the South Pole +"""Implementation of IceMix. +This method was a solution submitted to the IceCube - Neutrinos in Deep Ice Kaggle competition. Solution by DrHB: https://github.com/DrHB/icecube-2nd-place """ + import torch import torch.nn as nn -from typing import Set, Dict, Any +from typing import Set, Dict, Any, Optional from graphnet.models.components.layers import ( Block_rel, @@ -42,7 +41,7 @@ def __init__( n_rel: int = 1, scaled_emb: bool = False, include_dynedge: bool = False, - dynedge_args: Dict[str, Any] = None, + dynedge_args: Optional[Dict[str, Any]] = None, n_features: int = 6, ): """Construct `DeepIce`. diff --git a/src/graphnet/models/gnn/particlenet.py b/src/graphnet/models/gnn/particlenet.py index cf2d00998..02d060372 100644 --- a/src/graphnet/models/gnn/particlenet.py +++ b/src/graphnet/models/gnn/particlenet.py @@ -1,5 +1,6 @@ """Implementation of the ParticleNet GNN model architecture.""" -from typing import List, Optional, Callable, Tuple, Union + +from typing import List, Optional, Tuple, Union import torch from torch import Tensor, LongTensor diff --git a/src/graphnet/models/graphs/__init__.py b/src/graphnet/models/graphs/__init__.py index a07d1308d..e5db7d735 100644 --- a/src/graphnet/models/graphs/__init__.py +++ b/src/graphnet/models/graphs/__init__.py @@ -5,6 +5,5 @@ and their features. """ - from .graph_definition import GraphDefinition from .graphs import KNNGraph, EdgelessGraph diff --git a/src/graphnet/models/graphs/edges/__init__.py b/src/graphnet/models/graphs/edges/__init__.py index 40c8bbeab..ad11df041 100644 --- a/src/graphnet/models/graphs/edges/__init__.py +++ b/src/graphnet/models/graphs/edges/__init__.py @@ -4,5 +4,6 @@ graph-manipulation.´EdgeDefinition´ defines how edges are drawn between nodes and their features. """ + from .edges import EdgeDefinition, KNNEdges, RadialEdges, EuclideanEdges from .minkowski import MinkowskiKNNEdges diff --git a/src/graphnet/models/graphs/edges/edges.py b/src/graphnet/models/graphs/edges/edges.py index 584cf7ad9..d953ae0a0 100644 --- a/src/graphnet/models/graphs/edges/edges.py +++ b/src/graphnet/models/graphs/edges/edges.py @@ -34,7 +34,9 @@ def forward(self, graph: Data) -> Data: @abstractmethod def _construct_edges(self, graph: Data) -> Data: - """Construct edges and assign them to graph. I.e. ´graph.edge_index = edge_index´. + """Construct edges and assign them to the graph. + + I.e. ´graph.edge_index = edge_index´. Args: graph: graph without edges @@ -127,16 +129,12 @@ def __init__( self, sigma: float, threshold: float = 0.0, - columns: List[int] = None, + columns: List[int] = [0, 1, 2], ): """Construct `EuclideanEdges`.""" # Base class constructor super().__init__(name=__name__, class_name=self.__class__.__name__) - # Check(s) - if columns is None: - columns = [0, 1, 2] - # Member variable(s) self._sigma = sigma self._threshold = threshold @@ -161,9 +159,7 @@ def _construct_edges(self, graph: Data) -> Data: ) distance_matrix = calculate_distance_matrix(xyz_coords) - affinity_matrix = torch.exp( - -0.5 * distance_matrix**2 / self._sigma**2 - ) + affinity_matrix = torch.exp(-0.5 * distance_matrix**2 / self._sigma**2) # Use softmax to normalise all adjacencies to one for each node exp_row_sums = torch.exp(affinity_matrix).sum(axis=1) diff --git a/src/graphnet/models/graphs/edges/minkowski.py b/src/graphnet/models/graphs/edges/minkowski.py index 2526de1cb..9e454c4fd 100644 --- a/src/graphnet/models/graphs/edges/minkowski.py +++ b/src/graphnet/models/graphs/edges/minkowski.py @@ -1,4 +1,5 @@ """Module containing EdgeDefinitions based on the Minkowski Metric.""" + from typing import Optional, List import torch diff --git a/src/graphnet/models/graphs/graph_definition.py b/src/graphnet/models/graphs/graph_definition.py index 0338225b8..67c5065a2 100644 --- a/src/graphnet/models/graphs/graph_definition.py +++ b/src/graphnet/models/graphs/graph_definition.py @@ -5,7 +5,6 @@ passed to dataloaders during training and deployment. """ - from typing import Any, List, Optional, Dict, Callable, Union import torch from torch_geometric.data import Data @@ -24,7 +23,7 @@ class GraphDefinition(Model): def __init__( self, detector: Detector, - node_definition: NodeDefinition = None, + node_definition: Optional[NodeDefinition] = None, edge_definition: Optional[EdgeDefinition] = None, input_feature_names: Optional[List[str]] = None, dtype: Optional[torch.dtype] = torch.float, @@ -33,7 +32,7 @@ def __init__( add_inactive_sensors: bool = False, sensor_mask: Optional[List[int]] = None, string_mask: Optional[List[int]] = None, - sort_by: str = None, + sort_by: Optional[str] = None, repeat_labels: bool = False, ): """Construct ´GraphDefinition´. The ´detector´ holds. @@ -97,7 +96,9 @@ def __init__( if input_feature_names is None: # Assume all features in Detector is used. - input_feature_names = list(self._detector.feature_map().keys()) # type: ignore + input_feature_names = list( + self._detector.feature_map().keys() + ) # noqa: E501 # type: ignore self._input_feature_names = input_feature_names # Set input data column names for node definition @@ -110,10 +111,13 @@ def __init__( if sort_by is not None: assert isinstance(sort_by, str) try: - sort_by = self.output_feature_names.index(sort_by) # type: ignore + sort_by = self.output_feature_names.index( # type: ignore + sort_by + ) # type: ignore except ValueError as e: self.error( - f"{sort_by} not in node features {self.output_feature_names}." + f"{sort_by} not in node " + f"features {self.output_feature_names}." ) raise e self._sort_by = sort_by @@ -159,10 +163,11 @@ def forward( # type: ignore """Construct graph as ´Data´ object. Args: - input_features: Input features for graph construction. Shape ´[num_rows, d]´ + input_features: Input features for graph construction. + Shape ´[num_rows, d]´ input_feature_names: name of each column. Shape ´[,d]´. truth_dicts: Dictionary containing truth labels. - custom_label_functions: Custom label functions. See https://github.com/graphnet-team/graphnet/blob/main/GETTING_STARTED.md#adding-custom-truth-labels. + custom_label_functions: Custom label functions. loss_weight_column: Name of column that holds loss weight. Defaults to None. loss_weight: Loss weight associated with event. Defaults to None. @@ -253,7 +258,7 @@ def _resolve_masks(self) -> None: if self._string_mask is not None: assert ( 1 == 2 - ), """Got arguments for both `sensor_mask`and `string_mask`. Please specify only one. """ + ), "Please specify only one of `sensor_mask`and `string_mask`." if (self._sensor_mask is None) & (self._string_mask is not None): self._sensor_mask = self._convert_string_to_sensor_mask() diff --git a/src/graphnet/models/graphs/graphs.py b/src/graphnet/models/graphs/graphs.py index c0f0cf2db..e642ed06c 100644 --- a/src/graphnet/models/graphs/graphs.py +++ b/src/graphnet/models/graphs/graphs.py @@ -16,7 +16,7 @@ class KNNGraph(GraphDefinition): def __init__( self, detector: Detector, - node_definition: NodeDefinition = None, + node_definition: Optional[NodeDefinition] = None, input_feature_names: Optional[List[str]] = None, dtype: Optional[torch.dtype] = torch.float, perturbation_dict: Optional[Dict[str, float]] = None, @@ -37,10 +37,11 @@ def __init__( feature should be randomly perturbed. Defaults to None. seed: seed or Generator used to randomly sample perturbations. - Defaults to None. - nb_nearest_neighbours: Number of edges for each node. Defaults to 8. - columns: node feature columns used for distance calculation - . Defaults to [0, 1, 2]. + Defaults to None. + nb_nearest_neighbours: Number of edges for each node. + Defaults to 8. + columns: node feature columns used for distance calculation. + Defaults to [0, 1, 2]. """ # Base class constructor super().__init__( @@ -67,7 +68,7 @@ class EdgelessGraph(GraphDefinition): def __init__( self, detector: Detector, - node_definition: NodeDefinition = None, + node_definition: Optional[NodeDefinition] = None, input_feature_names: Optional[List[str]] = None, dtype: Optional[torch.dtype] = torch.float, perturbation_dict: Optional[Dict[str, float]] = None, diff --git a/src/graphnet/models/graphs/nodes/nodes.py b/src/graphnet/models/graphs/nodes/nodes.py index 4e094e6be..558ec96f4 100644 --- a/src/graphnet/models/graphs/nodes/nodes.py +++ b/src/graphnet/models/graphs/nodes/nodes.py @@ -241,7 +241,8 @@ def __init__( id_columns: List of columns that uniquely identify a DOM. time_column: Name of time column. charge_column: Name of charge column. - max_activations: Maximum number of activations to include in the time series. + max_activations: Maximum number of activations to include in + the time series. """ self._keys = keys super().__init__(input_feature_names=self._keys) @@ -251,9 +252,8 @@ def __init__( self._charge_index: Optional[int] = self._keys.index(charge_column) except ValueError: self.warning( - "Charge column with name {} not found. Running without.".format( - charge_column - ) + "Charge column with name {charge_column} not found. " + "Running without." ) self._charge_index = None @@ -271,7 +271,8 @@ def _construct_nodes(self, x: torch.Tensor) -> Data: x = x.numpy() if x.shape[0] == 0: return Data(x=torch.tensor(np.column_stack([x, []]))) - # if there is no charge column add a dummy column of zeros with the same shape as the time column + # if there is no charge column add a dummy column + # of zeros with the same shape as the time column if self._charge_index is None: charge_index: int = len(self._keys) x = np.insert(x, charge_index, np.zeros(x.shape[0]), axis=1) diff --git a/src/graphnet/models/graphs/utils.py b/src/graphnet/models/graphs/utils.py index ea8445f90..77669eaeb 100644 --- a/src/graphnet/models/graphs/utils.py +++ b/src/graphnet/models/graphs/utils.py @@ -1,6 +1,6 @@ """Utility functions for construction of graphs.""" -from typing import List, Tuple +from typing import List, Tuple, Optional import os import numpy as np import pandas as pd @@ -173,7 +173,7 @@ def cluster_summarize_with_percentiles( def ice_transparency( - z_offset: float = None, z_scaling: float = None + z_offset: Optional[float] = None, z_scaling: Optional[float] = None ) -> Tuple[interp1d, interp1d]: """Return interpolation functions for optical properties of IceCube. @@ -200,9 +200,9 @@ def ice_transparency( z_scaling = z_scaling or 500.0 df["z_norm"] = (df["depth"] + z_offset) / z_scaling - df[ - ["scattering_len_norm", "absorption_len_norm"] - ] = RobustScaler().fit_transform(df[["scattering_len", "absorption_len"]]) + df[["scattering_len_norm", "absorption_len_norm"]] = ( + RobustScaler().fit_transform(df[["scattering_len", "absorption_len"]]) + ) f_scattering = interp1d(df["z_norm"], df["scattering_len_norm"]) f_absorption = interp1d(df["z_norm"], df["absorption_len_norm"]) diff --git a/src/graphnet/models/model.py b/src/graphnet/models/model.py index 7c51c7952..650da29a9 100644 --- a/src/graphnet/models/model.py +++ b/src/graphnet/models/model.py @@ -50,7 +50,8 @@ def save_state_dict(self, path: str) -> None: """Save model `state_dict` to `path`.""" if not path.endswith(".pth"): self.info( - "It is recommended to use the .pth suffix for state_dict files." + "It is recommended to use the .pth suffix " + "for state_dict files." ) state_dict = self.state_dict() for key, value in state_dict.items(): @@ -74,7 +75,8 @@ def load_state_dict( # type: ignore[override] ) if state_dict_altered: self.warning( - "DeprecationWarning: State dicts with `_gnn` entries will be deprecated in GraphNeT 2.0" + "DeprecationWarning: State dicts with `_gnn`" + " entries will be deprecated in GraphNeT 2.0" ) return super().load_state_dict(state_dict, **kargs) diff --git a/src/graphnet/models/normalizing_flow.py b/src/graphnet/models/normalizing_flow.py index d62cf7c42..403208790 100644 --- a/src/graphnet/models/normalizing_flow.py +++ b/src/graphnet/models/normalizing_flow.py @@ -1,6 +1,6 @@ """Standard model class(es).""" -from typing import Any, Dict, List, Optional, Union, Type +from typing import Dict, List, Optional, Union, Type import torch from torch import Tensor from torch_geometric.data import Data @@ -26,7 +26,7 @@ def __init__( self, graph_definition: GraphDefinition, target_labels: str, - backbone: GNN = None, + backbone: Optional[GNN] = None, condition_on: Union[str, List[str], None] = None, flow_layers: str = "gggt", optimizer_class: Type[torch.optim.Optimizer] = Adam, diff --git a/src/graphnet/models/rnn/node_rnn.py b/src/graphnet/models/rnn/node_rnn.py index 45f7d643d..ebdf52041 100644 --- a/src/graphnet/models/rnn/node_rnn.py +++ b/src/graphnet/models/rnn/node_rnn.py @@ -2,6 +2,7 @@ (cannot be used as a standalone model) """ + import torch from graphnet.models.gnn.gnn import GNN @@ -42,13 +43,21 @@ def __init__( Args: nb_inputs: Number of features in the input data. - hidden_size: Number of features for the RNN output and hidden layers. + hidden_size: Number of features for the RNN output and hidden + layers. num_layers: Number of layers in the RNN. - time_series_columns: The indices of the input data that should be treated as time series data. The first index should be the charge column. - nb_neighbours: Number of neighbours to use when reconstructing the graph representation. Defaults to 8. - features_subset: The subset of latent features on each node that are used as metric dimensions when performing the k-nearest neighbours clustering. Defaults to [0,1,2,3] + time_series_columns: The indices of the input data that should be + treated as time series data. The first index should be + the charge column. + nb_neighbours: Number of neighbours to use when reconstructing the + graph representation. Defaults to 8. + features_subset: The subset of latent features on each node that + are used as metric dimensions when performing the k-nearest + neighbours clustering. Defaults to [0,1,2,3] dropout: Dropout fraction to use in the RNN. Defaults to 0.5. - embedding_dim: Embedding dimension of the RNN. Defaults to no embedding. + embedding_dim: Embedding dimension of the RNN. + Defaults to no embedding. + embedding_dim: Dimension of the embedding. Defaults to 0. """ self._hidden_size = hidden_size self._num_layers = num_layers diff --git a/src/graphnet/models/standard_averaged_model.py b/src/graphnet/models/standard_averaged_model.py index 362e5c2b2..5d4b7cc43 100644 --- a/src/graphnet/models/standard_averaged_model.py +++ b/src/graphnet/models/standard_averaged_model.py @@ -1,4 +1,5 @@ """Averaged Standard model class(es).""" + from typing import Any, Callable, Dict, List, Optional, Union, Type from collections import OrderedDict diff --git a/src/graphnet/models/standard_model.py b/src/graphnet/models/standard_model.py index bef153097..1d77cfa33 100644 --- a/src/graphnet/models/standard_model.py +++ b/src/graphnet/models/standard_model.py @@ -1,6 +1,6 @@ """Standard model class(es).""" -from typing import Any, Dict, List, Optional, Union, Type +from typing import Dict, List, Optional, Union, Type import torch from torch import Tensor from torch_geometric.data import Data @@ -26,7 +26,7 @@ def __init__( self, graph_definition: GraphDefinition, tasks: Union[StandardLearnedTask, List[StandardLearnedTask]], - backbone: Model = None, + backbone: Optional[Model] = None, gnn: Optional[GNN] = None, optimizer_class: Type[torch.optim.Optimizer] = Adam, optimizer_kwargs: Optional[Dict] = None, diff --git a/src/graphnet/models/task/reconstruction.py b/src/graphnet/models/task/reconstruction.py index e9b3cdaa5..5408aa5b9 100644 --- a/src/graphnet/models/task/reconstruction.py +++ b/src/graphnet/models/task/reconstruction.py @@ -50,9 +50,8 @@ class DirectionReconstructionWithKappa(StandardLearnedTask): """Reconstructs direction with kappa from the 3D-vMF distribution.""" # Requires three features: untransformed points in (x,y,z)-space. - default_target_labels = [ - "direction" - ] # contains dir_x, dir_y, dir_z see https://github.com/graphnet-team/graphnet/blob/95309556cfd46a4046bc4bd7609888aab649e295/src/graphnet/training/labels.py#L29 + default_target_labels = ["direction"] # contains dir_x, dir_y, dir_z + # see Direction label in /src/graphnet/training/labels.py default_prediction_labels = [ "dir_x_pred", "dir_y_pred", @@ -86,7 +85,8 @@ def _forward(self, x: Tensor) -> Tensor: class ZenithReconstructionWithKappa(ZenithReconstruction): """Reconstructs zenith angle and associated kappa (1/var).""" - # Requires one feature in addition to `ZenithReconstruction`: kappa (unceratinty; 1/variance). + # Requires one feature in addition to `ZenithReconstruction`: + # kappa (unceratinty; 1/variance). default_target_labels = ["zenith"] default_prediction_labels = ["zenith_pred", "zenith_kappa"] nb_inputs = 2 @@ -148,7 +148,8 @@ def _forward(self, x: Tensor) -> Tensor: class EnergyReconstructionWithUncertainty(EnergyReconstruction): """Reconstructs energy and associated uncertainty (log(var)).""" - # Requires one feature in addition to `EnergyReconstruction`: log-variance (uncertainty). + # Requires one feature in addition to `EnergyReconstruction`: + # log-variance (uncertainty). default_target_labels = ["energy"] default_prediction_labels = ["energy_pred", "energy_sigma"] nb_inputs = 2 diff --git a/src/graphnet/models/task/task.py b/src/graphnet/models/task/task.py index 0b9101107..604fc601b 100644 --- a/src/graphnet/models/task/task.py +++ b/src/graphnet/models/task/task.py @@ -13,7 +13,9 @@ if TYPE_CHECKING: # Avoid cyclic dependency - from graphnet.training.loss_functions import LossFunction # type: ignore[attr-defined] + from graphnet.training.loss_functions import ( + LossFunction, + ) # noqa: E501 # type: ignore[attr-defined] from graphnet.models import Model from graphnet.utilities.decorators import final @@ -108,12 +110,12 @@ def __init__( self._inference = False self._loss_weight = loss_weight - self._transform_prediction_training: Callable[ - [Tensor], Tensor - ] = lambda x: x - self._transform_prediction_inference: Callable[ - [Tensor], Tensor - ] = lambda x: x + self._transform_prediction_training: Callable[[Tensor], Tensor] = ( + lambda x: x + ) + self._transform_prediction_inference: Callable[[Tensor], Tensor] = ( + lambda x: x + ) self._transform_target: Callable[[Tensor], Tensor] = lambda x: x self._validate_and_set_transforms( transform_prediction_and_target, @@ -158,10 +160,10 @@ def _validate_and_set_transforms( assert not ( (transform_prediction_and_target is not None) and (transform_target is not None) - ), "Please specify at most one of `transform_prediction_and_target` and `transform_target`" + ), "Please specify at most one of `transform_prediction_and_target` and `transform_target`" # noqa: E501 if (transform_target is not None) != (transform_inference is not None): self.warning( - "Setting one of `transform_target` and `transform_inference`, but not " + "Setting one of `transform_target` and `transform_inference`, but not " # noqa: E501 "the other." ) @@ -434,7 +436,9 @@ def nb_inputs(self) -> Union[int, None]: # type: ignore """Return number of conditional inputs assumed by task.""" return self._hidden_size - def _forward(self, x: Optional[Tensor], y: Tensor) -> Tensor: # type: ignore + def _forward( + self, x: Optional[Tensor], y: Tensor + ) -> Tensor: # noqa: E501 # type: ignore y = y / self._norm if x is not None: if x.shape[0] != y.shape[0]: diff --git a/src/graphnet/training/callbacks.py b/src/graphnet/training/callbacks.py index c319dccc1..c8eb1f0b3 100644 --- a/src/graphnet/training/callbacks.py +++ b/src/graphnet/training/callbacks.py @@ -6,7 +6,6 @@ import warnings import numpy as np -import torch from tqdm.std import Bar from pytorch_lightning import LightningModule, Trainer @@ -55,7 +54,8 @@ def __init__( raise ValueError("Milestones must be increasing") if len(milestones) != len(factors): raise ValueError( - "Only multiplicative factor must be specified for each milestone." + "Only multiplicative factor must be specified" + " for each milestone." ) self.milestones = milestones diff --git a/src/graphnet/training/loss_functions.py b/src/graphnet/training/loss_functions.py index d3fc43f7e..6d7c27c06 100644 --- a/src/graphnet/training/loss_functions.py +++ b/src/graphnet/training/loss_functions.py @@ -14,7 +14,6 @@ from torch import nn from torch.nn.functional import ( one_hot, - cross_entropy, binary_cross_entropy, softplus, ) @@ -101,7 +100,7 @@ def _log_cosh(cls, x: Tensor) -> Tensor: # pylint: disable=invalid-name """Numerically stable version on log(cosh(x)). Used to avoid `inf` for even moderately large differences. - See [https://github.com/keras-team/keras/blob/v2.6.0/keras/losses.py#L1580-L1617] + See [https://github.com/keras-team/keras/blob/v2.6.0/keras/losses.py#L1580-L1617] # noqa: E501 """ return x + softplus(-2.0 * x) - np.log(2.0) @@ -213,30 +212,31 @@ class LogCMK(torch.autograd.Function): Copyright (c) 2019 Max Ryabinin - Permission is hereby granted, free of charge, to any person obtaining a copy - of this software and associated documentation files (the "Software"), to deal - in the Software without restriction, including without limitation the rights - to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - copies of the Software, and to permit persons to whom the Software is - furnished to do so, subject to the following conditions: + Permission is hereby granted, free of charge, to any person obtaining a + copy of this software and associated documentation files (the "Software"), + to deal in the Software without restriction, including without limitation + the rights to use, copy, modify, merge, publish, distribute, sublicense, + and/or sell copies of the Software, and to permit persons to whom the + Software is furnished to do so, subject to the following conditions: - The above copyright notice and this permission notice shall be included in all - copies or substantial portions of the Software. + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - SOFTWARE. + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + DEALINGS IN THE SOFTWARE. _____________________ - From [https://github.com/mryab/vmf_loss/blob/master/losses.py] - Modified to use modified Bessel function instead of exponentially scaled ditto - (i.e. `.ive` -> `.iv`) as indiciated in [1812.04616] in spite of suggestion in - Sec. 8.2 of this paper. The change has been validated through comparison with - exact calculations for `m=2` and `m=3` and found to yield the correct results. + From [https://github.com/mryab/vmf_loss/blob/master/losses.py] Modified to + use modified Bessel function instead of exponentially scaled ditto + (i.e. `.ive` -> `.iv`) as indicated in [1812.04616] in spite of suggestion + in Sec. 8.2 of this paper. The change has been validated through comparison + with exact calculations for `m=2` and `m=3` and found to yield the correct + results. """ @staticmethod @@ -358,7 +358,7 @@ def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: class VonMisesFisher2DLoss(VonMisesFisherLoss): - """von Mises-Fisher loss function vectors in the 2D plane.""" + """Von Mises-Fisher loss function vectors in the 2D plane.""" def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: """Calculate von Mises-Fisher loss for an angle in the 2D plane. @@ -422,7 +422,7 @@ def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: class VonMisesFisher3DLoss(VonMisesFisherLoss): - """von Mises-Fisher loss function vectors in the 3D plane.""" + """Von Mises-Fisher loss function vectors in the 3D plane.""" def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: """Calculate von Mises-Fisher loss for a direction in the 3D. @@ -453,7 +453,7 @@ class EnsembleLoss(LossFunction): def __init__( self, loss_functions: List[LossFunction], - loss_factors: List[float] = None, + loss_factors: Optional[List[float]] = None, prediction_keys: Optional[List[List[int]]] = None, *args: Any, **kwargs: Any, diff --git a/src/graphnet/training/utils.py b/src/graphnet/training/utils.py index 4e6f7c64d..b6e705846 100644 --- a/src/graphnet/training/utils.py +++ b/src/graphnet/training/utils.py @@ -79,10 +79,10 @@ def make_dataloader( selection: Optional[List[int]] = None, num_workers: int = 10, persistent_workers: bool = True, - node_truth: List[str] = None, + node_truth: Optional[List[str]] = None, truth_table: str = "truth", node_truth_table: Optional[str] = None, - string_selection: List[int] = None, + string_selection: Optional[List[int]] = None, loss_weight_table: Optional[str] = None, loss_weight_column: Optional[str] = None, index_column: str = "event_no", diff --git a/src/graphnet/training/weight_fitting.py b/src/graphnet/training/weight_fitting.py index e66c2d4c5..9e4b66040 100644 --- a/src/graphnet/training/weight_fitting.py +++ b/src/graphnet/training/weight_fitting.py @@ -39,9 +39,9 @@ def _get_truth( ) -> pd.DataFrame: """Return truth `variable`, optionally only for `selection` events.""" if selection is None: - query = f"select {self._index_column}, {variable} from {self._truth_table}" + query = f"select {self._index_column}, {variable} from {self._truth_table}" # noqa: E501 else: - query = f"select {self._index_column}, {variable} from {self._truth_table} where {self._index_column} in {str(tuple(selection))}" + query = f"select {self._index_column}, {variable} from {self._truth_table} where {self._index_column} in {str(tuple(selection))}" # noqa: E501 with sqlite3.connect(self._database_path) as con: data = pd.read_sql(query, con) return data @@ -160,10 +160,12 @@ def _fit_weights(self, truth: pd.DataFrame) -> pd.DataFrame: # Histogram `truth_values` bin_counts, _ = np.histogram(truth[self._variable], bins=self._bins) - # Get reweighting for each bin to achieve uniformity. (NB: No normalisation applied.) + # Get reweighting for each bin to achieve uniformity. + # (NB: No normalisation applied.) bin_weights = 1.0 / np.where(bin_counts == 0, np.nan, bin_counts) - # For each sample in `truth_values`, get the weight in the corresponding bin + # For each sample in `truth_values`, get the weight in + # the corresponding bin ix = np.digitize(truth[self._variable], bins=self._bins) - 1 sample_weights = bin_weights[ix] sample_weights = sample_weights / sample_weights.mean() @@ -207,10 +209,12 @@ def _fit_weights( # type: ignore[override] # Histogram `truth_values` bin_counts, _ = np.histogram(truth[self._variable], bins=self._bins) - # Get reweighting for each bin to achieve uniformity. (NB: No normalisation applied.) + # Get reweighting for each bin to achieve uniformity. + # (NB: No normalisation applied.) bin_weights = 1.0 / np.where(bin_counts == 0, np.nan, bin_counts) - # For each sample in `truth_values`, get the weight in the corresponding bin + # For each sample in `truth_values`, + # get the weight in the corresponding bin ix = np.digitize(truth[self._variable], bins=self._bins) - 1 sample_weights = bin_weights[ix] sample_weights = sample_weights / sample_weights.mean() diff --git a/src/graphnet/utilities/config/base_config.py b/src/graphnet/utilities/config/base_config.py index 6867317e1..0bc826ad7 100644 --- a/src/graphnet/utilities/config/base_config.py +++ b/src/graphnet/utilities/config/base_config.py @@ -1,6 +1,5 @@ """Base config class(es).""" -from abc import abstractmethod from collections import OrderedDict import inspect import sys diff --git a/src/graphnet/utilities/config/configurable.py b/src/graphnet/utilities/config/configurable.py index 3a0976123..e11c0e695 100644 --- a/src/graphnet/utilities/config/configurable.py +++ b/src/graphnet/utilities/config/configurable.py @@ -1,6 +1,6 @@ """Bases for all configurable classes in `graphnet`.""" -from abc import ABC, abstractclassmethod +from abc import ABC, abstractmethod from typing import Any, Union from graphnet.utilities.config.base_config import BaseConfig @@ -34,6 +34,6 @@ def save_config(self, path: str) -> None: """Save Config to `path` as YAML file.""" self.config.dump(path) - @abstractclassmethod + @abstractmethod # type: ignore def from_config(cls, source: Union[BaseConfig, str]) -> Any: """Construct instance from `source` configuration.""" diff --git a/src/graphnet/utilities/config/dataset_config.py b/src/graphnet/utilities/config/dataset_config.py index 42f68dff1..79e03c184 100644 --- a/src/graphnet/utilities/config/dataset_config.py +++ b/src/graphnet/utilities/config/dataset_config.py @@ -1,4 +1,5 @@ """Config classes for the `graphnet.data.dataset` module.""" + import warnings from abc import ABCMeta from functools import wraps diff --git a/src/graphnet/utilities/config/model_config.py b/src/graphnet/utilities/config/model_config.py index 18b06def9..5f32f559a 100644 --- a/src/graphnet/utilities/config/model_config.py +++ b/src/graphnet/utilities/config/model_config.py @@ -1,4 +1,5 @@ """Config classes for the `graphnet.models` module.""" + from abc import ABCMeta from functools import wraps import inspect @@ -68,9 +69,9 @@ def __init__(self, **data: Any) -> None: value = data["arguments"][arg] if isinstance(value, (tuple, list)): for ix, elem in enumerate(value): - data["arguments"][arg][ - ix - ] = self._parse_if_model_config_entry(elem) + data["arguments"][arg][ix] = ( + self._parse_if_model_config_entry(elem) + ) else: data["arguments"][arg] = self._parse_if_model_config_entry( value diff --git a/src/graphnet/utilities/deprecation_tools.py b/src/graphnet/utilities/deprecation_tools.py index 3ba051aba..778a67236 100644 --- a/src/graphnet/utilities/deprecation_tools.py +++ b/src/graphnet/utilities/deprecation_tools.py @@ -1,4 +1,5 @@ """Utility functions for handling deprecation transitions.""" + from typing import Dict, Tuple from copy import deepcopy from torch import Tensor diff --git a/src/graphnet/utilities/imports.py b/src/graphnet/utilities/imports.py index 1c143280a..5ce5e45d0 100644 --- a/src/graphnet/utilities/imports.py +++ b/src/graphnet/utilities/imports.py @@ -56,7 +56,8 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: return test_function(*args, **kwargs) else: Logger(log_folder=None).info( - f"Function `{test_function.__name__}` not used since `icecube` isn't available." + f"Function `{test_function.__name__}` " + "not used since `icecube` isn't available." ) return diff --git a/tests/data/test_dataconverters_and_datasets.py b/tests/data/test_dataconverters_and_datasets.py index f370dbe17..229fc91df 100644 --- a/tests/data/test_dataconverters_and_datasets.py +++ b/tests/data/test_dataconverters_and_datasets.py @@ -21,15 +21,10 @@ from graphnet.data.dataset import ParquetDataset, SQLiteDataset from graphnet.data.sqlite import SQLiteDataConverter from graphnet.data.utilities.parquet_to_sqlite import ParquetToSQLiteConverter -from graphnet.utilities.imports import has_icecube_package from graphnet.models.graphs import KNNGraph from graphnet.models.graphs.nodes import NodesAsPulses from graphnet.models.detector import IceCubeDeepCore -if has_icecube_package(): - from icecube import dataio # pyright: reportMissingImports=false - - # Global variable(s) TEST_DATA_DIR = os.path.join( graphnet.constants.TEST_DATA_DIR, "i3", "oscNext_genie_level7_v02" @@ -199,7 +194,9 @@ def test_parquet_to_sqlite_converter() -> None: ) dataset_from_parquet = SQLiteDataset(path, **opt) # type: ignore[arg-type] - dataset = SQLiteDataset(get_file_path("sqlite"), **opt) # type: ignore[arg-type] + dataset = SQLiteDataset( + get_file_path("sqlite"), **opt # type: ignore[arg-type] + ) assert len(dataset_from_parquet) == len(dataset) for ix in range(len(dataset)): diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index a3e0a0921..dd2e03abf 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -64,7 +64,8 @@ def dataset_setup(dataset_ref: pytest.FixtureRequest) -> tuple: dataset_ref: The dataset reference. Returns: - A tuple with the dataset reference, dataset kwargs, and dataloader kwargs. + A tuple with the dataset reference, + dataset kwargs, and dataloader kwargs. """ # Grab public dataset paths data_path = ( @@ -127,10 +128,12 @@ def test_single_dataset_without_selections( """Verify GraphNeTDataModule behavior when no test selection is provided. Args: - dataset_setup: Tuple with dataset reference, dataset arguments, and dataloader arguments. + dataset_setup: Tuple with dataset reference, + dataset arguments, and dataloader arguments. Raises: - Exception: If the test dataloader is accessed without providing a test selection. + Exception: If the test dataloader is accessed + without providing a test selection. """ dataset_ref, dataset_kwargs, dataloader_kwargs = dataset_setup @@ -165,8 +168,9 @@ def test_single_dataset_with_selections( """Test that selection functionality of DataModule behaves as expected. Args: - dataset_setup (Tuple[Any, Dict[str, Any], Dict[str, int]]): A tuple containing the dataset reference, - dataset arguments, and dataloader arguments. + dataset_setup (Tuple[Any, Dict[str, Any], Dict[str, int]]): A tuple + containing the dataset reference, dataset arguments, + and dataloader arguments. Returns: None @@ -200,7 +204,9 @@ def test_single_dataset_with_selections( if isinstance(dataset_ref, SQLiteDataset): a = len(train_dataloader.dataset) + len(val_dataloader.dataset) assert a == len(train_val_selection) # type: ignore - assert len(test_dataloader.dataset) == len(test_selection) # type: ignore + assert len(test_dataloader.dataset) == len( + test_selection + ) # noqa: E501 # type: ignore elif isinstance(dataset_ref, ParquetDataset): # Parquet dataset selection is batches not events a = train_dataloader.dataset._indices + val_dataloader.dataset._indices @@ -225,8 +231,9 @@ def test_dataloader_args( """Test that arguments to dataloaders are propagated correctly. Args: - dataset_setup (Tuple[Any, Dict[str, Any], Dict[str, int]]): A tuple containing the dataset reference, - dataset keyword arguments, and dataloader keyword arguments. + dataset_setup (Tuple[Any, Dict[str, Any], Dict[str, int]]): A tuple + containing the dataset reference, dataset keyword arguments, + and dataloader keyword arguments. Returns: None @@ -265,8 +272,9 @@ def test_ensemble_dataset_without_selections( """Test ensemble dataset functionality without selections. Args: - dataset_setup (Tuple[Any, Dict[str, Any], Dict[str, int]]): A tuple containing the dataset reference, - dataset keyword arguments, and dataloader keyword arguments. + dataset_setup (Tuple[Any, Dict[str, Any], Dict[str, int]]): A tuple + containing the dataset reference, dataset keyword arguments, + and dataloader keyword arguments. Returns: None @@ -303,8 +311,9 @@ def test_ensemble_dataset_with_selections( """Test ensemble dataset functionality with selections. Args: - dataset_setup (Tuple[Any, Dict[str, Any], Dict[str, int]]): A tuple containing the dataset reference, - dataset keyword arguments, and dataloader keyword arguments. + dataset_setup (Tuple[Any, Dict[str, Any], Dict[str, int]]): A tuple + containing the dataset reference, dataset keyword arguments, + and dataloader keyword arguments. Returns: None diff --git a/tests/deployment/queso_test.py b/tests/deployment/queso_test.py index cf761bf52..25bd2435c 100644 --- a/tests/deployment/queso_test.py +++ b/tests/deployment/queso_test.py @@ -165,7 +165,7 @@ def test_deployment() -> None: features = FEATURES.UPGRADE input_folders = [f"{TEST_DATA_DIR}/i3/upgrade_genie_step4_140028_000998"] output_folder = f"{TEST_DATA_DIR}/output/QUESO_test" - gcd_file = f"{TEST_DATA_DIR}/i3/upgrade_genie_step4_140028_000998/GeoCalibDetectorStatus_ICUpgrade.v58.mixed.V0.i3.bz2" + gcd_file = f"{TEST_DATA_DIR}/i3/upgrade_genie_step4_140028_000998/GeoCalibDetectorStatus_ICUpgrade.v58.mixed.V0.i3.bz2" # noqa: E501 features = FEATURES.UPGRADE input_files = [] for folder in input_folders: @@ -214,11 +214,12 @@ def verify_QUESO_integrity() -> None: original_predictions[frame][model], equal_nan=True, ) - except AssertionError as e: - print( - f"Mismatch found in {model}: {new_predictions[frame][model]} vs. {original_predictions[frame][model]}" + except AssertionError: + raise AssertionError( + f"Mismatch found in {model}: " + f"{new_predictions[frame][model]} vs. " + f"{original_predictions[frame][model]}" ) - raise e return diff --git a/tests/examples/01_icetray/test_icetray_examples.py b/tests/examples/01_icetray/test_icetray_examples.py index cbdf470fd..5c1942698 100644 --- a/tests/examples/01_icetray/test_icetray_examples.py +++ b/tests/examples/01_icetray/test_icetray_examples.py @@ -1,4 +1,5 @@ """Test for examples in 01_icetray.""" + import runpy import os import pytest diff --git a/tests/examples/02_data/test_data_examples.py b/tests/examples/02_data/test_data_examples.py index 8faaca5c0..ea30f57d5 100644 --- a/tests/examples/02_data/test_data_examples.py +++ b/tests/examples/02_data/test_data_examples.py @@ -1,4 +1,5 @@ """Tests for examples in 02_data.""" + import runpy import os import pytest diff --git a/tests/examples/03_weights/test_weights_examples.py b/tests/examples/03_weights/test_weights_examples.py index 5dddee264..029f6d9fe 100644 --- a/tests/examples/03_weights/test_weights_examples.py +++ b/tests/examples/03_weights/test_weights_examples.py @@ -1,4 +1,5 @@ """Test for examples in 03_weights.""" + import runpy import os from graphnet.constants import GRAPHNET_ROOT_DIR diff --git a/tests/examples/04_training/test_training_examples.py b/tests/examples/04_training/test_training_examples.py index e97b4bb3c..c6c33dc80 100644 --- a/tests/examples/04_training/test_training_examples.py +++ b/tests/examples/04_training/test_training_examples.py @@ -1,4 +1,5 @@ """Test for examples in 04_training.""" + import runpy import os from glob import glob diff --git a/tests/examples/05_liquido/test_liquido_examples.py b/tests/examples/05_liquido/test_liquido_examples.py index 9738ed861..8103d855a 100644 --- a/tests/examples/05_liquido/test_liquido_examples.py +++ b/tests/examples/05_liquido/test_liquido_examples.py @@ -1,4 +1,5 @@ """Tests for examples in 05_liquido.""" + import runpy import os import pytest diff --git a/tests/examples/06_prometheus/test_prometheus.py b/tests/examples/06_prometheus/test_prometheus.py index 3805edca7..147cfb6a6 100644 --- a/tests/examples/06_prometheus/test_prometheus.py +++ b/tests/examples/06_prometheus/test_prometheus.py @@ -1,4 +1,5 @@ """Tests for examples in 06_prometheus.""" + import runpy import os import pytest diff --git a/tests/models/test_graph_definition.py b/tests/models/test_graph_definition.py index 091cc0bac..f2538fdd7 100644 --- a/tests/models/test_graph_definition.py +++ b/tests/models/test_graph_definition.py @@ -72,7 +72,7 @@ def get_event( with sqlite3.connect(database) as con: query = f"SELECT event_no FROM {truth_table} limit 1" event_no = pd.read_sql(query, con) - query = f'SELECT {query_features} FROM {pulsemap} WHERE event_no = {str(event_no["event_no"][0])}' + query = f'SELECT {query_features} FROM {pulsemap} WHERE event_no = {str(event_no["event_no"][0])}' # noqa: E501 df = pd.read_sql(query, con) return np.array(df) @@ -86,11 +86,11 @@ def test_geometry_tables() -> None: ), "IceCube86": os.path.join( TEST_DATA_DIR, - "sqlite/oscNext_genie_level7_v02/oscNext_genie_level7_v02_first_5_frames.db", + "sqlite/oscNext_genie_level7_v02/oscNext_genie_level7_v02_first_5_frames.db", # noqa: E501 ), "IceCubeUpgrade": os.path.join( TEST_DATA_DIR, - "sqlite/upgrade_genie_step4_140028_000998_first_5_frames/upgrade_genie_step4_140028_000998_first_5_frames.db", + "sqlite/upgrade_genie_step4_140028_000998_first_5_frames/upgrade_genie_step4_140028_000998_first_5_frames.db", # noqa: E501 ), } meta = { diff --git a/tests/models/test_minkowski.py b/tests/models/test_minkowski.py index 98fed817c..9929f304d 100644 --- a/tests/models/test_minkowski.py +++ b/tests/models/test_minkowski.py @@ -1,5 +1,5 @@ """Unit tests for minkowski based edges.""" -import pytest + import torch from torch_geometric.data.data import Data diff --git a/tests/models/test_node_definition.py b/tests/models/test_node_definition.py index 4c199abd6..b849fef94 100644 --- a/tests/models/test_node_definition.py +++ b/tests/models/test_node_definition.py @@ -1,4 +1,5 @@ """Unit tests for node definitions.""" + import numpy as np import pandas as pd import sqlite3 @@ -19,7 +20,7 @@ def test_percentile_cluster() -> None: with sqlite3.connect(database) as con: query = "select event_no from mc_truth limit 1" event_no = pd.read_sql(query, con) - query = f'select sensor_pos_x, sensor_pos_y, sensor_pos_z, t from total where event_no = {str(event_no["event_no"][0])}' + query = f'select sensor_pos_x, sensor_pos_y, sensor_pos_z, t from total where event_no = {str(event_no["event_no"][0])}' # noqa: E501 df = pd.read_sql(query, con) # Save original feature names, create variables. diff --git a/tests/training/test_dataloader_utilities.py b/tests/training/test_dataloader_utilities.py index 423b2f34b..d5b9d55df 100644 --- a/tests/training/test_dataloader_utilities.py +++ b/tests/training/test_dataloader_utilities.py @@ -29,7 +29,10 @@ # Unit test(s) def test_none_selection() -> None: """Test agreement of the two ways to calculate this loss.""" - (train_dataloader, test_dataloader,) = make_train_validation_dataloader( + ( + train_dataloader, + test_dataloader, + ) = make_train_validation_dataloader( db=TEST_SQLITE_DATA, graph_definition=graph_definition, selection=None, diff --git a/tests/training/test_loss_functions.py b/tests/training/test_loss_functions.py index 373a50964..6d978ed3b 100644 --- a/tests/training/test_loss_functions.py +++ b/tests/training/test_loss_functions.py @@ -27,7 +27,11 @@ def _compute_elementwise_gradient(outputs: Tensor, inputs: Tensor) -> Tensor: nb_elements = inputs.size(dim=0) elementwise_gradients = torch.stack( [ - grad(outputs=outputs[ix], inputs=inputs, retain_graph=True,)[ + grad( + outputs=outputs[ix], + inputs=inputs, + retain_graph=True, + )[ 0 ][ix] for ix in range(nb_elements) @@ -51,8 +55,8 @@ def test_log_cosh(dtype: torch.dtype = torch.float32) -> None: # (1) Loss functions should not return `inf` losses, even for large # differences between prediction and target. This is not necessarily - # true for the directly calculated loss (reference) where cosh(x) may go - # to `inf` for x >~ 100. + # true for the directly calculated loss (reference) where cosh(x) + # may go to `inf` for x >~ 100. assert torch.all(torch.isfinite(losses)) # (2) For the inputs where the reference loss _is_ valid, the two