diff --git a/docker/gnn-benchmarking/apply.py b/docker/gnn-benchmarking/apply.py index 87744b94a..909025bac 100644 --- a/docker/gnn-benchmarking/apply.py +++ b/docker/gnn-benchmarking/apply.py @@ -1,6 +1,5 @@ """Script for applying GraphNeTModule in IceTray chain.""" - import argparse from glob import glob from os import makedirs @@ -37,9 +36,7 @@ def main( # Get GCD file gcd_pattern = "GeoCalibDetector" gcd_candidates = [p for p in input_files if gcd_pattern in p] - assert ( - len(gcd_candidates) == 1 - ), f"Did not get exactly one GCD-file candidate in `{dirname(input_files[0])}: {gcd_candidates}" + assert len(gcd_candidates) == 1, "Did not get exactly one GCD-file " gcd_file = gcd_candidates[0] # Get all input I3-files @@ -78,8 +75,10 @@ def main( """The main function must get an input folder and output folder! Args: - input_folder (str): The input folder where i3 files of a given dataset are located. - output_folder (str): The output folder where processed i3 files will be saved. + input_folder (str): The input folder where i3 files of a + given dataset are located. + output_folder (str): The output folder where processed i3 + files will be saved. """ parser = argparse.ArgumentParser() diff --git a/examples/04_training/01_train_dynedge.py b/examples/04_training/01_train_dynedge.py index c61df789b..6ee6e0223 100644 --- a/examples/04_training/01_train_dynedge.py +++ b/examples/04_training/01_train_dynedge.py @@ -70,9 +70,9 @@ def main( "gpus": gpus, "max_epochs": max_epochs, }, - "dataset_reference": SQLiteDataset - if path.endswith(".db") - else ParquetDataset, + "dataset_reference": ( + SQLiteDataset if path.endswith(".db") else ParquetDataset + ), } archive = os.path.join(EXAMPLE_OUTPUT_DIR, "train_model_without_configs") diff --git a/examples/04_training/02_train_tito_model.py b/examples/04_training/02_train_tito_model.py index 50c5f7fd6..991ed7744 100644 --- a/examples/04_training/02_train_tito_model.py +++ b/examples/04_training/02_train_tito_model.py @@ -72,9 +72,9 @@ def main( "gpus": gpus, "max_epochs": max_epochs, }, - "dataset_reference": SQLiteDataset - if path.endswith(".db") - else ParquetDataset, + "dataset_reference": ( + SQLiteDataset if path.endswith(".db") else ParquetDataset + ), } graph_definition = KNNGraph(detector=Prometheus()) diff --git a/examples/04_training/05_train_RNN_TITO.py b/examples/04_training/05_train_RNN_TITO.py index 8e521bd1d..6f75d0364 100644 --- a/examples/04_training/05_train_RNN_TITO.py +++ b/examples/04_training/05_train_RNN_TITO.py @@ -76,9 +76,9 @@ def main( "gpus": gpus, "max_epochs": max_epochs, }, - "dataset_reference": SQLiteDataset - if path.endswith(".db") - else ParquetDataset, + "dataset_reference": ( + SQLiteDataset if path.endswith(".db") else ParquetDataset + ), } graph_definition = KNNGraph( diff --git a/examples/05_liquido/01_convert_h5.py b/examples/05_liquido/01_convert_h5.py index ed1569c51..eaa5610fd 100644 --- a/examples/05_liquido/01_convert_h5.py +++ b/examples/05_liquido/01_convert_h5.py @@ -1,4 +1,5 @@ """Example of converting H5 files from LiquidO to SQLite and Parquet.""" + import os from graphnet.constants import EXAMPLE_OUTPUT_DIR, TEST_DATA_DIR diff --git a/examples/06_prometheus/01_convert_prometheus.py b/examples/06_prometheus/01_convert_prometheus.py index 1a6c818df..63b071372 100644 --- a/examples/06_prometheus/01_convert_prometheus.py +++ b/examples/06_prometheus/01_convert_prometheus.py @@ -1,4 +1,5 @@ """Example of converting files from Prometheus to SQLite and Parquet.""" + import os from graphnet.constants import EXAMPLE_OUTPUT_DIR, TEST_DATA_DIR diff --git a/setup.cfg b/setup.cfg index 8873542f4..7bddbf477 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,6 +17,9 @@ omit = [flake8] exclude = versioneer.py +# Ignore unused imports in __init__ files +per-file-ignores=__init__.py:F401 +ignore=E203,W503 [docformatter] wrap-summaries = 79 diff --git a/src/graphnet/data/__init__.py b/src/graphnet/data/__init__.py index 512d414d9..bd95b4a2a 100644 --- a/src/graphnet/data/__init__.py +++ b/src/graphnet/data/__init__.py @@ -3,6 +3,7 @@ `graphnet.data` enables converting domain-specific data to industry-standard, intermediate file formats and reading this data. """ + from .extractors.icecube.utilities.i3_filters import I3Filter, I3FilterMask from .dataconverter import DataConverter from .pre_configured import I3ToParquetConverter diff --git a/src/graphnet/data/datamodule.py b/src/graphnet/data/datamodule.py index 802a64a7d..978cdf52a 100644 --- a/src/graphnet/data/datamodule.py +++ b/src/graphnet/data/datamodule.py @@ -1,4 +1,5 @@ """Base `Dataloader` class(es) used in `graphnet`.""" + from typing import Dict, Any, Optional, List, Tuple, Union, Type import pytorch_lightning as pl from copy import deepcopy @@ -26,9 +27,9 @@ def __init__( dataset_args: Dict[str, Any], selection: Optional[Union[List[int], List[List[int]]]] = None, test_selection: Optional[Union[List[int], List[List[int]]]] = None, - train_dataloader_kwargs: Dict[str, Any] = None, - validation_dataloader_kwargs: Dict[str, Any] = None, - test_dataloader_kwargs: Dict[str, Any] = None, + train_dataloader_kwargs: Optional[Dict[str, Any]] = None, + validation_dataloader_kwargs: Optional[Dict[str, Any]] = None, + test_dataloader_kwargs: Optional[Dict[str, Any]] = None, train_val_split: Optional[List[float]] = [0.9, 0.10], split_seed: int = 42, ) -> None: diff --git a/src/graphnet/data/dataset/__init__.py b/src/graphnet/data/dataset/__init__.py index ed1c55ef5..c5d2bbc86 100644 --- a/src/graphnet/data/dataset/__init__.py +++ b/src/graphnet/data/dataset/__init__.py @@ -1,4 +1,5 @@ """Dataset classes for training in GraphNeT.""" + # Configuration from graphnet.utilities.imports import has_torch_package diff --git a/src/graphnet/data/dataset/parquet/__init__.py b/src/graphnet/data/dataset/parquet/__init__.py index edfc62b4e..f1f61b6a6 100644 --- a/src/graphnet/data/dataset/parquet/__init__.py +++ b/src/graphnet/data/dataset/parquet/__init__.py @@ -1,4 +1,5 @@ """Datasets using parquet backend.""" + # Configuration from graphnet.utilities.imports import has_torch_package diff --git a/src/graphnet/data/dataset/sqlite/__init__.py b/src/graphnet/data/dataset/sqlite/__init__.py index c44d66184..3ea73a068 100644 --- a/src/graphnet/data/dataset/sqlite/__init__.py +++ b/src/graphnet/data/dataset/sqlite/__init__.py @@ -1,4 +1,5 @@ """Datasets using SQLite backend.""" + from graphnet.utilities.imports import has_torch_package if has_torch_package(): diff --git a/src/graphnet/data/extractors/__init__.py b/src/graphnet/data/extractors/__init__.py index ad76d2742..21c5d7e31 100644 --- a/src/graphnet/data/extractors/__init__.py +++ b/src/graphnet/data/extractors/__init__.py @@ -1,3 +1,4 @@ """Module containing data-specific extractor modules.""" + from .extractor import Extractor from .combine_extractors import CombinedExtractor diff --git a/src/graphnet/data/extractors/combine_extractors.py b/src/graphnet/data/extractors/combine_extractors.py index a978e9b72..af89defe2 100644 --- a/src/graphnet/data/extractors/combine_extractors.py +++ b/src/graphnet/data/extractors/combine_extractors.py @@ -1,4 +1,5 @@ """Module for combining multiple extractors into a single extractor.""" + from typing import TYPE_CHECKING from graphnet.utilities.imports import has_icecube_package @@ -20,7 +21,11 @@ def __init__(self, extractors: List[I3Extractor], extractor_name: str): """Construct CombinedExtractor. Args: - extractors: List of extractors to combine. The extractors must all return data on the same level; e.g. all event-level data or pulse-level data. Mixing tables that contain event-level and pulse-level information will fail. + extractors: List of extractors to combine. + The extractors must all return data on the same level; + e.g. all event-level data or pulse-level data. + Mixing tables that contain event-level and + pulse-level information will fail. extractor_name: Name of the new extractor. """ super().__init__(extractor_name=extractor_name) diff --git a/src/graphnet/data/extractors/extractor.py b/src/graphnet/data/extractors/extractor.py index abd00668f..159524fbf 100644 --- a/src/graphnet/data/extractors/extractor.py +++ b/src/graphnet/data/extractors/extractor.py @@ -1,4 +1,5 @@ """Base I3Extractor class(es).""" + from typing import Any, Union from abc import ABC, abstractmethod import pandas as pd @@ -26,9 +27,10 @@ def __init__(self, extractor_name: str): """Construct Extractor. Args: - extractor_name: Name of the `Extractor` instance. Used to keep track of the - provenance of different data, and to name tables to which this - data is saved. E.g. "mc_truth". + extractor_name: Name of the `Extractor` instance. + Used to keep track of the provenance of different + data, and to name tables to which this data is + saved. E.g. "mc_truth". """ # Member variable(s) self._extractor_name: str = extractor_name diff --git a/src/graphnet/data/extractors/icecube/i3truthextractor.py b/src/graphnet/data/extractors/icecube/i3truthextractor.py index b715e57ab..4db330fc0 100644 --- a/src/graphnet/data/extractors/icecube/i3truthextractor.py +++ b/src/graphnet/data/extractors/icecube/i3truthextractor.py @@ -121,9 +121,10 @@ def __call__( "L7_oscNext_bool": padding_value, } - # Only InIceSplit P frames contain ML appropriate I3RecoPulseSeriesMap etc. - # At low levels i3files contain several other P frame splits (e.g NullSplit), - # we remove those here. + # Only InIceSplit P frames contain ML appropriate + # for example I3RecoPulseSeriesMap, etc. + # At low levels i3 files contain several other P frame splits + # (e.g NullSplit). We remove those here. if frame["I3EventHeader"].sub_event_stream not in [ "InIceSplit", "Final", @@ -181,7 +182,10 @@ def __call__( energy_cascade, inelasticity, ) = self._get_primary_track_energy_and_inelasticity(frame) - except RuntimeError: # track energy fails on northeren tracks with ""Hadrons" has no mass implemented. Cannot get total energy." + except ( + RuntimeError + ): # track energy fails on northeren tracks with ""Hadrons" + # has no mass implemented. Cannot get total energy." energy_track, energy_cascade, inelasticity = ( padding_value, padding_value, @@ -216,9 +220,10 @@ def __call__( muon_final = self._muon_stopped(output, self._borders) output.update( { - "position_x": muon_final[ - "x" - ], # position_xyz has no meaning for muons. These will now be updated to muon final position, given track length/azimuth/zenith + "position_x": muon_final["x"], + # position_xyz has no meaning for muons. + # These will now be updated to muon final position, + # given track length/azimuth/zenith "position_y": muon_final["y"], "position_z": muon_final["z"], "stopped_muon": muon_final["stopped"], @@ -362,10 +367,11 @@ def _get_primary_particle_interaction_type_and_elasticity( MCInIcePrimary = frame[self._mctree][0] if ( MCInIcePrimary.energy != MCInIcePrimary.energy - ): # This is a nan check. Only happens for some muons where second item in MCTree is primary. Weird! - MCInIcePrimary = frame[self._mctree][ - 1 - ] # For some strange reason the second entry is identical in all variables and has no nans (always muon) + ): # This is a nan check. Only happens for some muons + # where second item in MCTree is primary. Weird! + MCInIcePrimary = frame[self._mctree][1] + # For some strange reason the second entry is identical in + # all variables and has no nans (always muon) else: MCInIcePrimary = None try: diff --git a/src/graphnet/data/extractors/internal/parquet_extractor.py b/src/graphnet/data/extractors/internal/parquet_extractor.py index ec8bb5db2..0f9bd8e6b 100644 --- a/src/graphnet/data/extractors/internal/parquet_extractor.py +++ b/src/graphnet/data/extractors/internal/parquet_extractor.py @@ -1,4 +1,5 @@ """Parquet Extractor for conversion from internal parquet format.""" + import polars as pol import pandas as pd diff --git a/src/graphnet/data/extractors/liquido/__init__.py b/src/graphnet/data/extractors/liquido/__init__.py index 023c55e12..9ead11289 100644 --- a/src/graphnet/data/extractors/liquido/__init__.py +++ b/src/graphnet/data/extractors/liquido/__init__.py @@ -1,2 +1,3 @@ """Module containing different extractors for LiquidO files.""" + from .h5_extractor import H5Extractor, H5HitExtractor, H5TruthExtractor diff --git a/src/graphnet/data/extractors/liquido/h5_extractor.py b/src/graphnet/data/extractors/liquido/h5_extractor.py index 3075a9cde..231e3349f 100644 --- a/src/graphnet/data/extractors/liquido/h5_extractor.py +++ b/src/graphnet/data/extractors/liquido/h5_extractor.py @@ -1,4 +1,5 @@ """H5 Extractor for LiquidO data files.""" + from typing import List import numpy as np import pandas as pd diff --git a/src/graphnet/data/extractors/prometheus/__init__.py b/src/graphnet/data/extractors/prometheus/__init__.py index 09f6a9ea3..da31b2a2f 100644 --- a/src/graphnet/data/extractors/prometheus/__init__.py +++ b/src/graphnet/data/extractors/prometheus/__init__.py @@ -1,4 +1,5 @@ """Extractors for extracting data from parquet files Prometheus.""" + from .prometheus_extractor import ( PrometheusExtractor, PrometheusTruthExtractor, diff --git a/src/graphnet/data/extractors/prometheus/prometheus_extractor.py b/src/graphnet/data/extractors/prometheus/prometheus_extractor.py index 9e4d02973..a65d315b4 100644 --- a/src/graphnet/data/extractors/prometheus/prometheus_extractor.py +++ b/src/graphnet/data/extractors/prometheus/prometheus_extractor.py @@ -1,4 +1,5 @@ """Parquet Extractor for conversion of simulation files from PROMETHEUS.""" + from typing import List import pandas as pd import numpy as np diff --git a/src/graphnet/data/parquet/__init__.py b/src/graphnet/data/parquet/__init__.py index 2c41ca75d..b0725ea3c 100644 --- a/src/graphnet/data/parquet/__init__.py +++ b/src/graphnet/data/parquet/__init__.py @@ -1,2 +1,3 @@ """Module for deprecated parquet methods.""" + from .deprecated_methods import ParquetDataConverter diff --git a/src/graphnet/data/parquet/deprecated_methods.py b/src/graphnet/data/parquet/deprecated_methods.py index ae2593813..62b95aeee 100644 --- a/src/graphnet/data/parquet/deprecated_methods.py +++ b/src/graphnet/data/parquet/deprecated_methods.py @@ -2,6 +2,7 @@ This code will be removed in GraphNeT 2.0. """ + from typing import List, Union from graphnet.data.extractors.icecube import I3Extractor @@ -26,8 +27,9 @@ def __init__( """Convert I3 files to Parquet. Args: - gcd_rescue: gcd_rescue: Path to a GCD file that will be used if no GCD file is - found in subfolder. `I3Reader` will recursively search + gcd_rescue: gcd_rescue: Path to a GCD file that will be used if no + GCD file is found in subfolder. + `I3Reader` will recursively search the input directory for I3-GCD file pairs. By IceCube convention, a folder containing i3 files will have an accompanying GCD file. However, in some cases, this diff --git a/src/graphnet/data/pre_configured/__init__.py b/src/graphnet/data/pre_configured/__init__.py index c19679325..8214fdcc5 100644 --- a/src/graphnet/data/pre_configured/__init__.py +++ b/src/graphnet/data/pre_configured/__init__.py @@ -1,4 +1,5 @@ """Module for pre-configured converter modules.""" + from .dataconverters import ( I3ToParquetConverter, I3ToSQLiteConverter, diff --git a/src/graphnet/data/readers/__init__.py b/src/graphnet/data/readers/__init__.py index 658da0b0c..911a00255 100644 --- a/src/graphnet/data/readers/__init__.py +++ b/src/graphnet/data/readers/__init__.py @@ -1,4 +1,5 @@ """Modules for reading experiment-specific data and applying Extractors.""" + from .graphnet_file_reader import GraphNeTFileReader from .i3reader import I3Reader from .internal_parquet_reader import ParquetReader diff --git a/src/graphnet/data/sqlite/__init__.py b/src/graphnet/data/sqlite/__init__.py index 436a86f2d..8fbfcfcb3 100644 --- a/src/graphnet/data/sqlite/__init__.py +++ b/src/graphnet/data/sqlite/__init__.py @@ -1,2 +1,3 @@ """Module for deprecated methods using sqlite.""" + from .deprecated_methods import SQLiteDataConverter diff --git a/src/graphnet/data/utilities/__init__.py b/src/graphnet/data/utilities/__init__.py index ad4f0c7db..fb5d6fad6 100644 --- a/src/graphnet/data/utilities/__init__.py +++ b/src/graphnet/data/utilities/__init__.py @@ -1,4 +1,5 @@ """Utilities for use across `graphnet.data`.""" + from .sqlite_utilities import create_table_and_save_to_sql from .sqlite_utilities import get_primary_keys from .sqlite_utilities import query_database diff --git a/src/graphnet/data/writers/__init__.py b/src/graphnet/data/writers/__init__.py index ad3e2748e..cf70660ae 100644 --- a/src/graphnet/data/writers/__init__.py +++ b/src/graphnet/data/writers/__init__.py @@ -1,4 +1,5 @@ """Modules for saving interim dataformat to various data backends.""" + from .graphnet_writer import GraphNeTWriter from .parquet_writer import ParquetWriter from .sqlite_writer import SQLiteWriter diff --git a/src/graphnet/datasets/__init__.py b/src/graphnet/datasets/__init__.py index 74893bb83..c01f50bc3 100644 --- a/src/graphnet/datasets/__init__.py +++ b/src/graphnet/datasets/__init__.py @@ -1,3 +1,4 @@ """Contains pre-converted datasets ready for training.""" + from .test_dataset import TestDataset from .prometheus_datasets import TRIDENTSmall, BaikalGVDSmall, PONESmall diff --git a/src/graphnet/datasets/prometheus_datasets.py b/src/graphnet/datasets/prometheus_datasets.py index 14cb2bc78..8dc028278 100644 --- a/src/graphnet/datasets/prometheus_datasets.py +++ b/src/graphnet/datasets/prometheus_datasets.py @@ -1,4 +1,5 @@ """Public datasets from Prometheus Simulation.""" + from typing import Dict, Any, List, Tuple, Union import os from sklearn.model_selection import train_test_split diff --git a/src/graphnet/datasets/test_dataset.py b/src/graphnet/datasets/test_dataset.py index 0a30a962d..1c3d2b679 100644 --- a/src/graphnet/datasets/test_dataset.py +++ b/src/graphnet/datasets/test_dataset.py @@ -1,4 +1,5 @@ """A CuratedDataset for unit tests.""" + from typing import Dict, Any, List, Tuple, Union import os diff --git a/src/graphnet/deployment/__init__.py b/src/graphnet/deployment/__init__.py index c26dfc1b7..b0d0fc429 100644 --- a/src/graphnet/deployment/__init__.py +++ b/src/graphnet/deployment/__init__.py @@ -3,5 +3,6 @@ `graphnet.deployment` allows for using trained models for inference in domain- specific reconstruction chains. """ + from .deployer import Deployer from .deployment_module import DeploymentModule diff --git a/src/graphnet/deployment/deployer.py b/src/graphnet/deployment/deployer.py index 4ac67f4d8..37473d195 100644 --- a/src/graphnet/deployment/deployer.py +++ b/src/graphnet/deployment/deployer.py @@ -1,4 +1,5 @@ """Contains the graphnet deployment module.""" + import random from abc import abstractmethod, ABC import multiprocessing @@ -9,7 +10,7 @@ from .deployment_module import DeploymentModule from graphnet.utilities.logging import Logger -if has_torch_package or TYPE_CHECKING: +if has_torch_package() or TYPE_CHECKING: import torch diff --git a/src/graphnet/deployment/deployment_module.py b/src/graphnet/deployment/deployment_module.py index b91490ad0..72026a21b 100644 --- a/src/graphnet/deployment/deployment_module.py +++ b/src/graphnet/deployment/deployment_module.py @@ -1,4 +1,5 @@ """Class(es) for deploying GraphNeT models in icetray as I3Modules.""" + from abc import abstractmethod from typing import Any, List, Union, Dict diff --git a/src/graphnet/deployment/i3modules/deprecated_methods.py b/src/graphnet/deployment/i3modules/deprecated_methods.py index 6acdc8d33..987dcdef1 100644 --- a/src/graphnet/deployment/i3modules/deprecated_methods.py +++ b/src/graphnet/deployment/i3modules/deprecated_methods.py @@ -1,4 +1,5 @@ """Contains deprecated methods.""" + from typing import Union, Sequence # from graphnet.deployment.icecube import I3Deployer, I3InferenceModule diff --git a/src/graphnet/deployment/icecube/__init__.py b/src/graphnet/deployment/icecube/__init__.py index 15c7485ef..184d29ef9 100644 --- a/src/graphnet/deployment/icecube/__init__.py +++ b/src/graphnet/deployment/icecube/__init__.py @@ -1,4 +1,5 @@ """Deployment modules specific to IceCube.""" + from .inference_module import I3InferenceModule from .cleaning_module import I3PulseCleanerModule from .i3deployer import I3Deployer diff --git a/src/graphnet/deployment/icecube/cleaning_module.py b/src/graphnet/deployment/icecube/cleaning_module.py index 27e5f2260..e2fee80cd 100644 --- a/src/graphnet/deployment/icecube/cleaning_module.py +++ b/src/graphnet/deployment/icecube/cleaning_module.py @@ -2,6 +2,7 @@ Contains functionality for writing model predictions to i3 files. """ + from typing import List, Union, TYPE_CHECKING, Dict, Any, Tuple import numpy as np @@ -117,13 +118,13 @@ def __call__(self, frame: I3Frame) -> bool: # checking the prediction for each pulse # (Adds the actual pulsemap to dictionary) if self._total_pulsemap_name not in frame.keys(): - data_dict[ - self._total_pulsemap_name - ] = dataclasses.I3RecoPulseSeriesMapMask( - frame, - self._pulsemap, - lambda om_key, index, pulse: predictions_map[om_key][index] - >= self._threshold, + data_dict[self._total_pulsemap_name] = ( + dataclasses.I3RecoPulseSeriesMapMask( + frame, + self._pulsemap, + lambda om_key, index, pulse: predictions_map[om_key][index] + >= self._threshold, + ) ) # Submit predictions and general pulsemap @@ -138,19 +139,19 @@ def __call__(self, frame: I3Frame) -> bool: ) if f"{self._total_pulsemap_name}_mDOMs_Only" not in frame.keys(): - data[ - f"{self._total_pulsemap_name}_mDOMs_Only" - ] = dataclasses.I3RecoPulseSeriesMap(mDOMMap) + data[f"{self._total_pulsemap_name}_mDOMs_Only"] = ( + dataclasses.I3RecoPulseSeriesMap(mDOMMap) + ) if f"{self._total_pulsemap_name}_dEggs_Only" not in frame.keys(): - data[ - f"{self._total_pulsemap_name}_dEggs_Only" - ] = dataclasses.I3RecoPulseSeriesMap(DEggMap) + data[f"{self._total_pulsemap_name}_dEggs_Only"] = ( + dataclasses.I3RecoPulseSeriesMap(DEggMap) + ) if f"{self._total_pulsemap_name}_pDOMs_Only" not in frame.keys(): - data[ - f"{self._total_pulsemap_name}_pDOMs_Only" - ] = dataclasses.I3RecoPulseSeriesMap(IceCubeMap) + data[f"{self._total_pulsemap_name}_pDOMs_Only"] = ( + dataclasses.I3RecoPulseSeriesMap(IceCubeMap) + ) # Submits the additional pulsemaps to the frame frame = self._add_to_frame(frame=frame, data=data) @@ -211,7 +212,7 @@ def _construct_prediction_map( for om_key, pulses in pulsemap.items(): num_pulses = len(pulses) predictions_map[om_key] = predictions[ - idx : idx + num_pulses + idx : idx + num_pulses # noqa: E203 ].tolist() idx += num_pulses diff --git a/src/graphnet/deployment/icecube/inference_module.py b/src/graphnet/deployment/icecube/inference_module.py index 9631cc3e0..e963850a5 100644 --- a/src/graphnet/deployment/icecube/inference_module.py +++ b/src/graphnet/deployment/icecube/inference_module.py @@ -2,6 +2,7 @@ Contains functionality for writing model predictions to i3 files. """ + from typing import List, Union, Optional, TYPE_CHECKING, Dict, Any import numpy as np @@ -119,13 +120,13 @@ def _create_dictionary( for i in range(dim): try: assert len(predictions[:, i]) == 1 - data[ - self.model_name + "_" + self.prediction_columns[i] - ] = I3Double(float(predictions[:, i][0])) + data[self.model_name + "_" + self.prediction_columns[i]] = ( + I3Double(float(predictions[:, i][0])) + ) except IndexError: - data[ - self.model_name + "_" + self.prediction_columns[i] - ] = I3Double(predictions[0]) + data[self.model_name + "_" + self.prediction_columns[i]] = ( + I3Double(predictions[0]) + ) return data def _apply_model(self, data: Data) -> np.ndarray: diff --git a/src/graphnet/exceptions/__init__.py b/src/graphnet/exceptions/__init__.py index 9d7796808..71fd44001 100644 --- a/src/graphnet/exceptions/__init__.py +++ b/src/graphnet/exceptions/__init__.py @@ -1,2 +1,3 @@ """Custom Exceptions for GraphNeT.""" + from .exceptions import ColumnMissingException diff --git a/src/graphnet/models/__init__.py b/src/graphnet/models/__init__.py index 12d4cbcc5..e3e4f23b9 100644 --- a/src/graphnet/models/__init__.py +++ b/src/graphnet/models/__init__.py @@ -6,6 +6,7 @@ existing, purpose-built components and chain them together to form a complete GNN """ + from graphnet.utilities.imports import has_jammy_flows_package from .model import Model from .standard_model import StandardModel diff --git a/src/graphnet/models/components/embedding.py b/src/graphnet/models/components/embedding.py index 1b49cd901..08d699931 100644 --- a/src/graphnet/models/components/embedding.py +++ b/src/graphnet/models/components/embedding.py @@ -1,4 +1,5 @@ """Classes for performing embedding of input data.""" + import torch import torch.nn as nn from torch.functional import Tensor @@ -33,9 +34,7 @@ def __init__( super().__init__() if dim % 2 != 0: raise ValueError(f"dim has to be even. Got: {dim}") - self.scale = ( - nn.Parameter(torch.ones(1) * dim**-0.5) if scaled else 1.0 - ) + self.scale = nn.Parameter(torch.ones(1) * dim**-0.5) if scaled else 1.0 self.dim = dim self.n_freq = torch.Tensor([n_freq]) diff --git a/src/graphnet/models/gnn/RNN_tito.py b/src/graphnet/models/gnn/RNN_tito.py index 8b9157c58..26539eb72 100644 --- a/src/graphnet/models/gnn/RNN_tito.py +++ b/src/graphnet/models/gnn/RNN_tito.py @@ -1,4 +1,5 @@ """RNN_DynEdge model implementation.""" + from typing import List, Optional, Tuple import torch @@ -43,24 +44,42 @@ def __init__( Args: nb_inputs (int): Number of input features. - time_series_columns (List[int]): The indices of the input data that should be treated as time series data. The first index should be the charge column. + time_series_columns (List[int]): The indices of the input data that + should be treated as time series data. + The first index should be the charge column. nb_neighbours (int, optional): Number of neighbours to consider. Defaults to 8. rnn_layers (int, optional): Number of RNN layers. Defaults to 1. - rnn_hidden_size (int, optional): Size of the hidden state of the RNN. Also determines the size of the output of the RNN. + rnn_hidden_size (int, optional): Size of the hidden state of the + RNN. Also determines the size of the output of the RNN. Defaults to 64. - rnn_dropout (float, optional): Dropout to use in the RNN. Defaults to 0.5. + rnn_dropout (float, optional): Dropout to use in the RNN. + Defaults to 0.5. features_subset (List[int], optional): The subset of latent - features on each node that are used as metric dimensions when performing the k-nearest neighbours clustering. Defaults to [0,1,2,3] - dyntrans_layer_sizes (List[Tuple[int, ...]], optional): List of tuples representing the sizes of the hidden layers of the DynTrans model. - post_processing_layer_sizes (List[int], optional): List of integers representing the sizes of the hidden layers of the post-processing model. - readout_layer_sizes (List[int], optional): List of integers representing the sizes of the hidden layers of the readout model. - global_pooling_schemes (Union[str, List[str]], optional): Pooling schemes to use. Defaults to None. - embedding_dim (int, optional): Embedding dimension of the RNN. Defaults to None ie. no embedding. - n_head (int, optional): Number of heads to use in the DynTrans model. Defaults to 16. - use_global_features (bool, optional): Whether to use global features after pooling. Defaults to True. - use_post_processing_layers (bool, optional): Whether to use post-processing layers after the DynTrans layers. Defaults to True. + features on each node that are used as metric dimensions when + performing the k-nearest neighbours clustering. + Defaults to [0,1,2,3] + dyntrans_layer_sizes (List[Tuple[int, ...]], optional): List of + tuples representing the sizes of the hidden layers of + the DynTrans model. + post_processing_layer_sizes (List[int], optional): List of + integers representing the sizes of the hidden layers of the + post-processing model. + readout_layer_sizes (List[int], optional): List of integers + representing the sizes of the hidden layers of the + readout model. + global_pooling_schemes (Union[str, List[str]], optional): Pooling + schemes to use. Defaults to None. + embedding_dim (int, optional): Embedding dimension of the RNN. + Defaults to None ie. no embedding. + n_head (int, optional): Number of heads to use in the DynTrans + model. Defaults to 16. + use_global_features (bool, optional): Whether to use global + features after pooling. Defaults to True. + use_post_processing_layers (bool, optional): Whether to use + post-processing layers after the DynTrans layers. + Defaults to True. """ self._nb_neighbours = nb_neighbours self._nb_inputs = nb_inputs diff --git a/src/graphnet/models/gnn/dynedge.py b/src/graphnet/models/gnn/dynedge.py index a80ce82ce..aabfabb88 100644 --- a/src/graphnet/models/gnn/dynedge.py +++ b/src/graphnet/models/gnn/dynedge.py @@ -1,4 +1,5 @@ """Implementation of the DynEdge GNN model architecture.""" + from typing import List, Optional, Tuple, Union import torch diff --git a/src/graphnet/models/gnn/dynedge_kaggle_tito.py b/src/graphnet/models/gnn/dynedge_kaggle_tito.py index 88d4c8811..dded5616c 100644 --- a/src/graphnet/models/gnn/dynedge_kaggle_tito.py +++ b/src/graphnet/models/gnn/dynedge_kaggle_tito.py @@ -34,12 +34,12 @@ class DynEdgeTITO(GNN): def __init__( self, nb_inputs: int, - features_subset: List[int] = None, + features_subset: Optional[List[int]] = None, dyntrans_layer_sizes: Optional[List[Tuple[int, ...]]] = None, global_pooling_schemes: List[str] = ["max"], use_global_features: bool = True, use_post_processing_layers: bool = True, - post_processing_layer_sizes: List[int] = None, + post_processing_layer_sizes: Optional[List[int]] = None, readout_layer_sizes: Optional[List[int]] = None, n_head: int = 8, nb_neighbours: int = 8, @@ -57,9 +57,12 @@ def __init__( global_pooling_schemes: The list global pooling schemes to use. Options are: "min", "max", "mean", and "sum". use_global_features: Whether to use global features after pooling. - use_post_processing_layers: Whether to use post-processing layers after the `DynTrans` layers. - post_processing_layer_sizes: The layer sizes used in the post-processing layers. Defaults to [336, 256]. - readout_layer_sizes: The layer sizes used in the readout layers. Defaults to [256, 128]. + use_post_processing_layers: Whether to use post-processing layers + after the `DynTrans` layers. + post_processing_layer_sizes: The layer sizes used in the + post-processing layers. Defaults to [336, 256]. + readout_layer_sizes: The layer sizes used in the readout layers. + Defaults to [256, 128]. n_head: The number of heads to use in the `DynTrans` layer. nb_neighbours: The number of neighbours to use in the `DynTrans` layer. diff --git a/src/graphnet/models/gnn/icemix.py b/src/graphnet/models/gnn/icemix.py index a073e3ca8..6e7ebe647 100644 --- a/src/graphnet/models/gnn/icemix.py +++ b/src/graphnet/models/gnn/icemix.py @@ -1,15 +1,14 @@ -"""Implementation of IceMix architecture used in. - - IceCube - Neutrinos in Deep Ice -Reconstruct the direction of neutrinos from the Universe to the South Pole +"""Implementation of IceMix. +This method was a solution submitted to the IceCube - Neutrinos in Deep Ice Kaggle competition. Solution by DrHB: https://github.com/DrHB/icecube-2nd-place """ + import torch import torch.nn as nn -from typing import Set, Dict, Any +from typing import Set, Dict, Any, Optional from graphnet.models.components.layers import ( Block_rel, @@ -42,7 +41,7 @@ def __init__( n_rel: int = 1, scaled_emb: bool = False, include_dynedge: bool = False, - dynedge_args: Dict[str, Any] = None, + dynedge_args: Optional[Dict[str, Any]] = None, n_features: int = 6, ): """Construct `DeepIce`. diff --git a/src/graphnet/models/gnn/particlenet.py b/src/graphnet/models/gnn/particlenet.py index cf2d00998..02d060372 100644 --- a/src/graphnet/models/gnn/particlenet.py +++ b/src/graphnet/models/gnn/particlenet.py @@ -1,5 +1,6 @@ """Implementation of the ParticleNet GNN model architecture.""" -from typing import List, Optional, Callable, Tuple, Union + +from typing import List, Optional, Tuple, Union import torch from torch import Tensor, LongTensor diff --git a/src/graphnet/models/graphs/__init__.py b/src/graphnet/models/graphs/__init__.py index a07d1308d..e5db7d735 100644 --- a/src/graphnet/models/graphs/__init__.py +++ b/src/graphnet/models/graphs/__init__.py @@ -5,6 +5,5 @@ and their features. """ - from .graph_definition import GraphDefinition from .graphs import KNNGraph, EdgelessGraph diff --git a/src/graphnet/models/graphs/edges/__init__.py b/src/graphnet/models/graphs/edges/__init__.py index 40c8bbeab..ad11df041 100644 --- a/src/graphnet/models/graphs/edges/__init__.py +++ b/src/graphnet/models/graphs/edges/__init__.py @@ -4,5 +4,6 @@ graph-manipulation.´EdgeDefinition´ defines how edges are drawn between nodes and their features. """ + from .edges import EdgeDefinition, KNNEdges, RadialEdges, EuclideanEdges from .minkowski import MinkowskiKNNEdges diff --git a/src/graphnet/models/graphs/edges/minkowski.py b/src/graphnet/models/graphs/edges/minkowski.py index 2526de1cb..9e454c4fd 100644 --- a/src/graphnet/models/graphs/edges/minkowski.py +++ b/src/graphnet/models/graphs/edges/minkowski.py @@ -1,4 +1,5 @@ """Module containing EdgeDefinitions based on the Minkowski Metric.""" + from typing import Optional, List import torch diff --git a/src/graphnet/models/graphs/graph_definition.py b/src/graphnet/models/graphs/graph_definition.py index 0338225b8..67c5065a2 100644 --- a/src/graphnet/models/graphs/graph_definition.py +++ b/src/graphnet/models/graphs/graph_definition.py @@ -5,7 +5,6 @@ passed to dataloaders during training and deployment. """ - from typing import Any, List, Optional, Dict, Callable, Union import torch from torch_geometric.data import Data @@ -24,7 +23,7 @@ class GraphDefinition(Model): def __init__( self, detector: Detector, - node_definition: NodeDefinition = None, + node_definition: Optional[NodeDefinition] = None, edge_definition: Optional[EdgeDefinition] = None, input_feature_names: Optional[List[str]] = None, dtype: Optional[torch.dtype] = torch.float, @@ -33,7 +32,7 @@ def __init__( add_inactive_sensors: bool = False, sensor_mask: Optional[List[int]] = None, string_mask: Optional[List[int]] = None, - sort_by: str = None, + sort_by: Optional[str] = None, repeat_labels: bool = False, ): """Construct ´GraphDefinition´. The ´detector´ holds. @@ -97,7 +96,9 @@ def __init__( if input_feature_names is None: # Assume all features in Detector is used. - input_feature_names = list(self._detector.feature_map().keys()) # type: ignore + input_feature_names = list( + self._detector.feature_map().keys() + ) # noqa: E501 # type: ignore self._input_feature_names = input_feature_names # Set input data column names for node definition @@ -110,10 +111,13 @@ def __init__( if sort_by is not None: assert isinstance(sort_by, str) try: - sort_by = self.output_feature_names.index(sort_by) # type: ignore + sort_by = self.output_feature_names.index( # type: ignore + sort_by + ) # type: ignore except ValueError as e: self.error( - f"{sort_by} not in node features {self.output_feature_names}." + f"{sort_by} not in node " + f"features {self.output_feature_names}." ) raise e self._sort_by = sort_by @@ -159,10 +163,11 @@ def forward( # type: ignore """Construct graph as ´Data´ object. Args: - input_features: Input features for graph construction. Shape ´[num_rows, d]´ + input_features: Input features for graph construction. + Shape ´[num_rows, d]´ input_feature_names: name of each column. Shape ´[,d]´. truth_dicts: Dictionary containing truth labels. - custom_label_functions: Custom label functions. See https://github.com/graphnet-team/graphnet/blob/main/GETTING_STARTED.md#adding-custom-truth-labels. + custom_label_functions: Custom label functions. loss_weight_column: Name of column that holds loss weight. Defaults to None. loss_weight: Loss weight associated with event. Defaults to None. @@ -253,7 +258,7 @@ def _resolve_masks(self) -> None: if self._string_mask is not None: assert ( 1 == 2 - ), """Got arguments for both `sensor_mask`and `string_mask`. Please specify only one. """ + ), "Please specify only one of `sensor_mask`and `string_mask`." if (self._sensor_mask is None) & (self._string_mask is not None): self._sensor_mask = self._convert_string_to_sensor_mask() diff --git a/src/graphnet/models/graphs/utils.py b/src/graphnet/models/graphs/utils.py index 9dd21ee60..9385bd33b 100644 --- a/src/graphnet/models/graphs/utils.py +++ b/src/graphnet/models/graphs/utils.py @@ -1,6 +1,6 @@ """Utility functions for construction of graphs.""" -from typing import List, Tuple +from typing import List, Tuple, Optional import os import numpy as np import pandas as pd @@ -129,7 +129,8 @@ def cluster_summarize_with_percentiles( then each row in the returned array will correspond to a DOM, and the time and charge for each DOM will be summarized by percentiles. Returned output array has dimensions - `[n_clusters, len(percentiles)*len(summarization_indices) + len(cluster_indices)]` + `[n_clusters, len(percentiles)*len(summarization_indices) + + len(cluster_indices)]` Args: x: Array to be clustered @@ -166,7 +167,7 @@ def cluster_summarize_with_percentiles( def ice_transparency( - z_offset: float = None, z_scaling: float = None + z_offset: Optional[float] = None, z_scaling: Optional[float] = None ) -> Tuple[interp1d, interp1d]: """Return interpolation functions for optical properties of IceCube. @@ -193,9 +194,9 @@ def ice_transparency( z_scaling = z_scaling or 500.0 df["z_norm"] = (df["depth"] + z_offset) / z_scaling - df[ - ["scattering_len_norm", "absorption_len_norm"] - ] = RobustScaler().fit_transform(df[["scattering_len", "absorption_len"]]) + df[["scattering_len_norm", "absorption_len_norm"]] = ( + RobustScaler().fit_transform(df[["scattering_len", "absorption_len"]]) + ) f_scattering = interp1d(df["z_norm"], df["scattering_len_norm"]) f_absorption = interp1d(df["z_norm"], df["absorption_len_norm"]) diff --git a/src/graphnet/models/rnn/node_rnn.py b/src/graphnet/models/rnn/node_rnn.py index 45f7d643d..ebdf52041 100644 --- a/src/graphnet/models/rnn/node_rnn.py +++ b/src/graphnet/models/rnn/node_rnn.py @@ -2,6 +2,7 @@ (cannot be used as a standalone model) """ + import torch from graphnet.models.gnn.gnn import GNN @@ -42,13 +43,21 @@ def __init__( Args: nb_inputs: Number of features in the input data. - hidden_size: Number of features for the RNN output and hidden layers. + hidden_size: Number of features for the RNN output and hidden + layers. num_layers: Number of layers in the RNN. - time_series_columns: The indices of the input data that should be treated as time series data. The first index should be the charge column. - nb_neighbours: Number of neighbours to use when reconstructing the graph representation. Defaults to 8. - features_subset: The subset of latent features on each node that are used as metric dimensions when performing the k-nearest neighbours clustering. Defaults to [0,1,2,3] + time_series_columns: The indices of the input data that should be + treated as time series data. The first index should be + the charge column. + nb_neighbours: Number of neighbours to use when reconstructing the + graph representation. Defaults to 8. + features_subset: The subset of latent features on each node that + are used as metric dimensions when performing the k-nearest + neighbours clustering. Defaults to [0,1,2,3] dropout: Dropout fraction to use in the RNN. Defaults to 0.5. - embedding_dim: Embedding dimension of the RNN. Defaults to no embedding. + embedding_dim: Embedding dimension of the RNN. + Defaults to no embedding. + embedding_dim: Dimension of the embedding. Defaults to 0. """ self._hidden_size = hidden_size self._num_layers = num_layers diff --git a/src/graphnet/models/standard_averaged_model.py b/src/graphnet/models/standard_averaged_model.py index 362e5c2b2..5d4b7cc43 100644 --- a/src/graphnet/models/standard_averaged_model.py +++ b/src/graphnet/models/standard_averaged_model.py @@ -1,4 +1,5 @@ """Averaged Standard model class(es).""" + from typing import Any, Callable, Dict, List, Optional, Union, Type from collections import OrderedDict diff --git a/src/graphnet/models/task/task.py b/src/graphnet/models/task/task.py index 0b9101107..604fc601b 100644 --- a/src/graphnet/models/task/task.py +++ b/src/graphnet/models/task/task.py @@ -13,7 +13,9 @@ if TYPE_CHECKING: # Avoid cyclic dependency - from graphnet.training.loss_functions import LossFunction # type: ignore[attr-defined] + from graphnet.training.loss_functions import ( + LossFunction, + ) # noqa: E501 # type: ignore[attr-defined] from graphnet.models import Model from graphnet.utilities.decorators import final @@ -108,12 +110,12 @@ def __init__( self._inference = False self._loss_weight = loss_weight - self._transform_prediction_training: Callable[ - [Tensor], Tensor - ] = lambda x: x - self._transform_prediction_inference: Callable[ - [Tensor], Tensor - ] = lambda x: x + self._transform_prediction_training: Callable[[Tensor], Tensor] = ( + lambda x: x + ) + self._transform_prediction_inference: Callable[[Tensor], Tensor] = ( + lambda x: x + ) self._transform_target: Callable[[Tensor], Tensor] = lambda x: x self._validate_and_set_transforms( transform_prediction_and_target, @@ -158,10 +160,10 @@ def _validate_and_set_transforms( assert not ( (transform_prediction_and_target is not None) and (transform_target is not None) - ), "Please specify at most one of `transform_prediction_and_target` and `transform_target`" + ), "Please specify at most one of `transform_prediction_and_target` and `transform_target`" # noqa: E501 if (transform_target is not None) != (transform_inference is not None): self.warning( - "Setting one of `transform_target` and `transform_inference`, but not " + "Setting one of `transform_target` and `transform_inference`, but not " # noqa: E501 "the other." ) @@ -434,7 +436,9 @@ def nb_inputs(self) -> Union[int, None]: # type: ignore """Return number of conditional inputs assumed by task.""" return self._hidden_size - def _forward(self, x: Optional[Tensor], y: Tensor) -> Tensor: # type: ignore + def _forward( + self, x: Optional[Tensor], y: Tensor + ) -> Tensor: # noqa: E501 # type: ignore y = y / self._norm if x is not None: if x.shape[0] != y.shape[0]: diff --git a/src/graphnet/utilities/config/dataset_config.py b/src/graphnet/utilities/config/dataset_config.py index 42f68dff1..79e03c184 100644 --- a/src/graphnet/utilities/config/dataset_config.py +++ b/src/graphnet/utilities/config/dataset_config.py @@ -1,4 +1,5 @@ """Config classes for the `graphnet.data.dataset` module.""" + import warnings from abc import ABCMeta from functools import wraps diff --git a/src/graphnet/utilities/config/model_config.py b/src/graphnet/utilities/config/model_config.py index 18b06def9..5f32f559a 100644 --- a/src/graphnet/utilities/config/model_config.py +++ b/src/graphnet/utilities/config/model_config.py @@ -1,4 +1,5 @@ """Config classes for the `graphnet.models` module.""" + from abc import ABCMeta from functools import wraps import inspect @@ -68,9 +69,9 @@ def __init__(self, **data: Any) -> None: value = data["arguments"][arg] if isinstance(value, (tuple, list)): for ix, elem in enumerate(value): - data["arguments"][arg][ - ix - ] = self._parse_if_model_config_entry(elem) + data["arguments"][arg][ix] = ( + self._parse_if_model_config_entry(elem) + ) else: data["arguments"][arg] = self._parse_if_model_config_entry( value diff --git a/src/graphnet/utilities/deprecation_tools.py b/src/graphnet/utilities/deprecation_tools.py index 3ba051aba..778a67236 100644 --- a/src/graphnet/utilities/deprecation_tools.py +++ b/src/graphnet/utilities/deprecation_tools.py @@ -1,4 +1,5 @@ """Utility functions for handling deprecation transitions.""" + from typing import Dict, Tuple from copy import deepcopy from torch import Tensor diff --git a/tests/examples/01_icetray/test_icetray_examples.py b/tests/examples/01_icetray/test_icetray_examples.py index cbdf470fd..5c1942698 100644 --- a/tests/examples/01_icetray/test_icetray_examples.py +++ b/tests/examples/01_icetray/test_icetray_examples.py @@ -1,4 +1,5 @@ """Test for examples in 01_icetray.""" + import runpy import os import pytest diff --git a/tests/examples/02_data/test_data_examples.py b/tests/examples/02_data/test_data_examples.py index 8faaca5c0..ea30f57d5 100644 --- a/tests/examples/02_data/test_data_examples.py +++ b/tests/examples/02_data/test_data_examples.py @@ -1,4 +1,5 @@ """Tests for examples in 02_data.""" + import runpy import os import pytest diff --git a/tests/examples/03_weights/test_weights_examples.py b/tests/examples/03_weights/test_weights_examples.py index 5dddee264..029f6d9fe 100644 --- a/tests/examples/03_weights/test_weights_examples.py +++ b/tests/examples/03_weights/test_weights_examples.py @@ -1,4 +1,5 @@ """Test for examples in 03_weights.""" + import runpy import os from graphnet.constants import GRAPHNET_ROOT_DIR diff --git a/tests/examples/04_training/test_training_examples.py b/tests/examples/04_training/test_training_examples.py index e97b4bb3c..c6c33dc80 100644 --- a/tests/examples/04_training/test_training_examples.py +++ b/tests/examples/04_training/test_training_examples.py @@ -1,4 +1,5 @@ """Test for examples in 04_training.""" + import runpy import os from glob import glob diff --git a/tests/examples/05_liquido/test_liquido_examples.py b/tests/examples/05_liquido/test_liquido_examples.py index 9738ed861..8103d855a 100644 --- a/tests/examples/05_liquido/test_liquido_examples.py +++ b/tests/examples/05_liquido/test_liquido_examples.py @@ -1,4 +1,5 @@ """Tests for examples in 05_liquido.""" + import runpy import os import pytest diff --git a/tests/examples/06_prometheus/test_prometheus.py b/tests/examples/06_prometheus/test_prometheus.py index 3805edca7..147cfb6a6 100644 --- a/tests/examples/06_prometheus/test_prometheus.py +++ b/tests/examples/06_prometheus/test_prometheus.py @@ -1,4 +1,5 @@ """Tests for examples in 06_prometheus.""" + import runpy import os import pytest diff --git a/tests/models/test_minkowski.py b/tests/models/test_minkowski.py index 98fed817c..9929f304d 100644 --- a/tests/models/test_minkowski.py +++ b/tests/models/test_minkowski.py @@ -1,5 +1,5 @@ """Unit tests for minkowski based edges.""" -import pytest + import torch from torch_geometric.data.data import Data diff --git a/tests/models/test_node_definition.py b/tests/models/test_node_definition.py index 4c199abd6..b849fef94 100644 --- a/tests/models/test_node_definition.py +++ b/tests/models/test_node_definition.py @@ -1,4 +1,5 @@ """Unit tests for node definitions.""" + import numpy as np import pandas as pd import sqlite3 @@ -19,7 +20,7 @@ def test_percentile_cluster() -> None: with sqlite3.connect(database) as con: query = "select event_no from mc_truth limit 1" event_no = pd.read_sql(query, con) - query = f'select sensor_pos_x, sensor_pos_y, sensor_pos_z, t from total where event_no = {str(event_no["event_no"][0])}' + query = f'select sensor_pos_x, sensor_pos_y, sensor_pos_z, t from total where event_no = {str(event_no["event_no"][0])}' # noqa: E501 df = pd.read_sql(query, con) # Save original feature names, create variables. diff --git a/tests/training/test_dataloader_utilities.py b/tests/training/test_dataloader_utilities.py index 423b2f34b..d5b9d55df 100644 --- a/tests/training/test_dataloader_utilities.py +++ b/tests/training/test_dataloader_utilities.py @@ -29,7 +29,10 @@ # Unit test(s) def test_none_selection() -> None: """Test agreement of the two ways to calculate this loss.""" - (train_dataloader, test_dataloader,) = make_train_validation_dataloader( + ( + train_dataloader, + test_dataloader, + ) = make_train_validation_dataloader( db=TEST_SQLITE_DATA, graph_definition=graph_definition, selection=None, diff --git a/tests/training/test_loss_functions.py b/tests/training/test_loss_functions.py index 373a50964..6d978ed3b 100644 --- a/tests/training/test_loss_functions.py +++ b/tests/training/test_loss_functions.py @@ -27,7 +27,11 @@ def _compute_elementwise_gradient(outputs: Tensor, inputs: Tensor) -> Tensor: nb_elements = inputs.size(dim=0) elementwise_gradients = torch.stack( [ - grad(outputs=outputs[ix], inputs=inputs, retain_graph=True,)[ + grad( + outputs=outputs[ix], + inputs=inputs, + retain_graph=True, + )[ 0 ][ix] for ix in range(nb_elements) @@ -51,8 +55,8 @@ def test_log_cosh(dtype: torch.dtype = torch.float32) -> None: # (1) Loss functions should not return `inf` losses, even for large # differences between prediction and target. This is not necessarily - # true for the directly calculated loss (reference) where cosh(x) may go - # to `inf` for x >~ 100. + # true for the directly calculated loss (reference) where cosh(x) + # may go to `inf` for x >~ 100. assert torch.all(torch.isfinite(losses)) # (2) For the inputs where the reference loss _is_ valid, the two