diff --git a/examples/01_icetray/03_i3_deployer_example.py b/examples/01_icetray/03_i3_deployer_example.py index 28d73c00d..bd2de9f43 100644 --- a/examples/01_icetray/03_i3_deployer_example.py +++ b/examples/01_icetray/03_i3_deployer_example.py @@ -50,7 +50,7 @@ def main() -> None: model_config = f"{base_path}/{model_name}/{model_name}_config.yml" state_dict = f"{base_path}/{model_name}/{model_name}_state_dict.pth" output_folder = f"{EXAMPLE_OUTPUT_DIR}/i3_deployment/upgrade_03_04" - gcd_file = f"{TEST_DATA_DIR}/i3/upgrade_genie_step4_140028_000998/GeoCalibDetectorStatus_ICUpgrade.v58.mixed.V0.i3.bz2" + gcd_file = f"{TEST_DATA_DIR}/i3/upgrade_genie_step4_140028_000998/GeoCalibDetectorStatus_ICUpgrade.v58.mixed.V0.i3.bz2" # noqa: E501 input_files = [] for folder in input_folders: input_files.extend(glob(join(folder, "*.i3.gz"))) diff --git a/examples/01_icetray/04_i3_module_in_native_icetray_example.py b/examples/01_icetray/04_i3_module_in_native_icetray_example.py index 09e9b358e..c4f797b02 100644 --- a/examples/01_icetray/04_i3_module_in_native_icetray_example.py +++ b/examples/01_icetray/04_i3_module_in_native_icetray_example.py @@ -86,7 +86,7 @@ def main() -> None: model_config = f"{base_path}/{model_name}/{model_name}_config.yml" state_dict = f"{base_path}/{model_name}/{model_name}_state_dict.pth" output_folder = f"{EXAMPLE_OUTPUT_DIR}/i3_deployment/upgrade" - gcd_file = f"{TEST_DATA_DIR}/i3/upgrade_genie_step4_140028_000998/GeoCalibDetectorStatus_ICUpgrade.v58.mixed.V0.i3.bz2" + gcd_file = f"{TEST_DATA_DIR}/i3/upgrade_genie_step4_140028_000998/GeoCalibDetectorStatus_ICUpgrade.v58.mixed.V0.i3.bz2" # noqa:E501 features = FEATURES.UPGRADE input_files = [] for folder in input_folders: diff --git a/examples/04_training/07_train_normalizing_flow.py b/examples/04_training/07_train_normalizing_flow.py index baa3eec85..4048e9923 100644 --- a/examples/04_training/07_train_normalizing_flow.py +++ b/examples/04_training/07_train_normalizing_flow.py @@ -4,7 +4,6 @@ from typing import Any, Dict, List, Optional from pytorch_lightning.loggers import WandbLogger -import torch from torch.optim.adam import Adam from graphnet.constants import EXAMPLE_DATA_DIR, EXAMPLE_OUTPUT_DIR diff --git a/setup.cfg b/setup.cfg index 7bddbf477..476acd91c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,7 +18,9 @@ omit = exclude = versioneer.py # Ignore unused imports in __init__ files -per-file-ignores=__init__.py:F401 +per-file-ignores= + __init__.py:F401 + src/graphnet/utilities/imports.py:F401 ignore=E203,W503 [docformatter] diff --git a/src/graphnet/__init__.py b/src/graphnet/__init__.py index 34f46afc3..28e4f21f6 100644 --- a/src/graphnet/__init__.py +++ b/src/graphnet/__init__.py @@ -32,4 +32,6 @@ from . import _version -__version__ = _version.get_versions()["version"] # type: ignore[no-untyped-call] +__version__ = _version.get_versions()[ # type: ignore[no-untyped-call] + "version" +] diff --git a/src/graphnet/data/curated_datamodule.py b/src/graphnet/data/curated_datamodule.py index 63b691c9d..a206783bc 100644 --- a/src/graphnet/data/curated_datamodule.py +++ b/src/graphnet/data/curated_datamodule.py @@ -31,8 +31,8 @@ def __init__( features: Optional[List[str]] = None, backend: str = "parquet", train_dataloader_kwargs: Optional[Dict[str, Any]] = None, - validation_dataloader_kwargs: Dict[str, Any] = None, - test_dataloader_kwargs: Dict[str, Any] = None, + validation_dataloader_kwargs: Optional[Dict[str, Any]] = None, + test_dataloader_kwargs: Optional[Dict[str, Any]] = None, ) -> None: """Construct CuratedDataset. diff --git a/src/graphnet/data/dataset/samplers.py b/src/graphnet/data/dataset/samplers.py index c43455447..726508805 100644 --- a/src/graphnet/data/dataset/samplers.py +++ b/src/graphnet/data/dataset/samplers.py @@ -34,7 +34,7 @@ ) from collections import defaultdict -from multiprocessing import Pool, cpu_count, get_context +from multiprocessing import get_context import numpy as np import torch diff --git a/src/graphnet/data/extractors/icecube/i3extractor.py b/src/graphnet/data/extractors/icecube/i3extractor.py index 3f2fc92d2..07bf38030 100644 --- a/src/graphnet/data/extractors/icecube/i3extractor.py +++ b/src/graphnet/data/extractors/icecube/i3extractor.py @@ -24,9 +24,9 @@ def __init__(self, extractor_name: str): """Construct I3Extractor. Args: - extractor_name: Name of the `I3Extractor` instance. Used to keep track of the - provenance of different data, and to name tables to which this - data is saved. + extractor_name: Name of the `I3Extractor` instance. Used to keep + track of the provenance of different data, and to name tables + to which this data is saved. """ # Member variable(s) self._i3_file: str = "" diff --git a/src/graphnet/data/pre_configured/dataconverters.py b/src/graphnet/data/pre_configured/dataconverters.py index 1194da3c2..a7e8f324f 100644 --- a/src/graphnet/data/pre_configured/dataconverters.py +++ b/src/graphnet/data/pre_configured/dataconverters.py @@ -25,10 +25,12 @@ def __init__( """Convert I3 files to Parquet. Args: - gcd_rescue: gcd_rescue: Path to a GCD file that will be used if no GCD file is - found in subfolder. `I3Reader` will recursively search - the input directory for I3-GCD file pairs. By IceCube - convention, a folder containing i3 files will have an + gcd_rescue: gcd_rescue: Path to a GCD file that will be used if + no GCD file is found in subfolder. `I3Reader` will + recursively search the input directory for I3-GCD file + pairs. + By IceCube convention, + a folder containing i3 files will have an accompanying GCD file. However, in some cases, this convention is broken. In cases where a folder contains i3 files but no GCD file, the `gcd_rescue` is used @@ -70,10 +72,11 @@ def __init__( """Convert I3 files to SQLite. Args: - gcd_rescue: gcd_rescue: Path to a GCD file that will be used if no GCD file is - found in subfolder. `I3Reader` will recursively search - the input directory for I3-GCD file pairs. By IceCube - convention, a folder containing i3 files will have an + gcd_rescue: gcd_rescue: Path to a GCD file that will be used if + no GCD file is found in subfolder. `I3Reader` will + recursively search the input directory for I3-GCD file + pairs. By IceCube convention, + a folder containing i3 files will have an accompanying GCD file. However, in some cases, this convention is broken. In cases where a folder contains i3 files but no GCD file, the `gcd_rescue` is used diff --git a/src/graphnet/data/readers/graphnet_file_reader.py b/src/graphnet/data/readers/graphnet_file_reader.py index f30a19fcf..6d145f1c0 100644 --- a/src/graphnet/data/readers/graphnet_file_reader.py +++ b/src/graphnet/data/readers/graphnet_file_reader.py @@ -125,7 +125,9 @@ def _validate_extractors( ) -> None: for extractor in extractors: try: - assert isinstance(extractor, tuple(self.accepted_extractors)) # type: ignore + assert isinstance( + extractor, tuple(self.accepted_extractors) # type: ignore + ) except AssertionError as e: self.error( f"{extractor.__class__.__name__}" @@ -164,5 +166,7 @@ def _validate_file(self, file: str) -> None: assert file.lower().endswith(tuple(self.accepted_file_extensions)) except AssertionError: self.error( - f'{self.__class__.__name__} accepts {self.accepted_file_extensions} but {file.split("/")[-1]} has extension {os.path.splitext(file)[1]}.' + f"{self.__class__.__name__} accepts " + f'{self.accepted_file_extensions} but {file.split("/")[-1]} ' + f"has extension {os.path.splitext(file)[1]}." ) diff --git a/src/graphnet/data/readers/i3reader.py b/src/graphnet/data/readers/i3reader.py index a2d1d0c83..39a090eb3 100644 --- a/src/graphnet/data/readers/i3reader.py +++ b/src/graphnet/data/readers/i3reader.py @@ -1,6 +1,6 @@ """Module containing different I3Reader.""" -from typing import List, Union, OrderedDict +from typing import List, Union, OrderedDict, Optional from graphnet.utilities.imports import has_icecube_package from graphnet.data.extractors.icecube.utilities.i3_filters import ( @@ -27,7 +27,7 @@ class I3Reader(GraphNeTFileReader): def __init__( self, gcd_rescue: str, - i3_filters: Union[I3Filter, List[I3Filter]] = None, + i3_filters: Optional[Union[I3Filter, List[I3Filter]]] = None, icetray_verbose: int = 0, ): """Initialize `I3Reader`. @@ -65,7 +65,9 @@ def __init__( # Base class constructor super().__init__(name=__name__, class_name=self.__class__.__name__) - def __call__(self, file_path: I3FileSet) -> List[OrderedDict]: # type: ignore + def __call__( + self, file_path: I3FileSet + ) -> List[OrderedDict]: # noqa: E501 # type: ignore """Extract data from single I3 file. Args: diff --git a/src/graphnet/data/sqlite/deprecated_methods.py b/src/graphnet/data/sqlite/deprecated_methods.py index 30b563c59..4d7e0e69f 100644 --- a/src/graphnet/data/sqlite/deprecated_methods.py +++ b/src/graphnet/data/sqlite/deprecated_methods.py @@ -3,13 +3,10 @@ This code will be removed in GraphNeT 2.0. """ -from typing import List, Union, Type +from typing import List, Union from graphnet.data.extractors.icecube import I3Extractor -from graphnet.data.extractors.icecube.utilities.i3_filters import ( - I3Filter, - NullSplitI3Filter, -) +from graphnet.data.extractors.icecube.utilities.i3_filters import I3Filter from graphnet.data import I3ToSQLiteConverter @@ -28,10 +25,11 @@ def __init__( """Convert I3 files to Parquet. Args: - gcd_rescue: gcd_rescue: Path to a GCD file that will be used if no GCD file is - found in subfolder. `I3Reader` will recursively search - the input directory for I3-GCD file pairs. By IceCube - convention, a folder containing i3 files will have an + gcd_rescue: gcd_rescue: Path to a GCD file that will be used if no + GCD file is found in subfolder. `I3Reader` will + recursively search the input directory for I3-GCD file + pairs. By IceCube convention, + a folder containing i3 files will have an accompanying GCD file. However, in some cases, this convention is broken. In cases where a folder contains i3 files but no GCD file, the `gcd_rescue` is used diff --git a/src/graphnet/data/utilities/parquet_to_sqlite.py b/src/graphnet/data/utilities/parquet_to_sqlite.py index f429d5e6b..07bc86ecb 100644 --- a/src/graphnet/data/utilities/parquet_to_sqlite.py +++ b/src/graphnet/data/utilities/parquet_to_sqlite.py @@ -1,3 +1,3 @@ """Utilities for converting files from Parquet to SQLite.""" -from graphnet.data.pre_configured import ParquetToSQLiteConverter +from graphnet.data.pre_configured import ParquetToSQLiteConverter # noqa: F401 diff --git a/src/graphnet/data/utilities/random.py b/src/graphnet/data/utilities/random.py index 1084f782c..7aaa8c4be 100644 --- a/src/graphnet/data/utilities/random.py +++ b/src/graphnet/data/utilities/random.py @@ -9,7 +9,8 @@ def pairwise_shuffle( ) -> Tuple[List[str], List[str]]: """Shuffle the I3 file list and the correponding gcd file list. - This is handy because it ensures a more even extraction load for each worker. + This is handy because it ensures a more even extraction load for each + worker. Args: files_list: List of I3 file paths. diff --git a/src/graphnet/data/utilities/sqlite_utilities.py b/src/graphnet/data/utilities/sqlite_utilities.py index 95755ef44..3458f4485 100644 --- a/src/graphnet/data/utilities/sqlite_utilities.py +++ b/src/graphnet/data/utilities/sqlite_utilities.py @@ -1,7 +1,7 @@ """SQLite-specific utility functions for use in `graphnet.data`.""" import os.path -from typing import List, Dict, Tuple +from typing import List, Dict, Tuple, Union import pandas as pd import sqlalchemy @@ -30,7 +30,9 @@ def query_database(database: str, query: str) -> pd.DataFrame: return pd.read_sql(query, conn) -def get_primary_keys(database: str) -> Tuple[Dict[str, str], str]: +def get_primary_keys( + database: str, +) -> Tuple[Dict[str, Union[str, None]], Union[str, None]]: """Get name of primary key column for each table in database. Args: @@ -50,7 +52,7 @@ def get_primary_keys(database: str) -> Tuple[Dict[str, str], str]: integer_primary_key = {} for table in table_names: - query = f"SELECT l.name FROM pragma_table_info('{table}') as l WHERE l.pk = 1;" + query = f"SELECT l.name FROM pragma_table_info('{table}') as l WHERE l.pk = 1;" # noqa: E501 first_primary_key = [ key[0] for key in conn.execute(query).fetchall() ] @@ -78,7 +80,7 @@ def database_table_exists(database_path: str, table_name: str) -> bool: """Check whether `table_name` exists in database at `database_path`.""" if not database_exists(database_path): return False - query = f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table_name}';" + query = f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table_name}';" # noqa: E501 with sqlite3.connect(database_path) as conn: result = pd.read_sql(query, conn) return len(result) == 1 diff --git a/src/graphnet/deployment/i3modules/__init__.py b/src/graphnet/deployment/i3modules/__init__.py index e2fd05a43..0fe5115da 100644 --- a/src/graphnet/deployment/i3modules/__init__.py +++ b/src/graphnet/deployment/i3modules/__init__.py @@ -5,5 +5,5 @@ detector configurations. """ -from .deprecated_methods import * +from .deprecated_methods import * # noqa: F403 from graphnet.deployment.icecube import I3InferenceModule, I3PulseCleanerModule diff --git a/src/graphnet/models/coarsening.py b/src/graphnet/models/coarsening.py index d40f0c009..d481a3d62 100644 --- a/src/graphnet/models/coarsening.py +++ b/src/graphnet/models/coarsening.py @@ -8,7 +8,6 @@ from torch_geometric.data import Data, Batch from sklearn.cluster import DBSCAN -# from torch_geometric.utils import unbatch_edge_index from graphnet.models.components.pool import ( group_by, avg_pool, @@ -28,7 +27,7 @@ # NOTE: From [https://github.com/pyg-team/pytorch_geometric/pull/4903] # TODO: Remove once bumping to torch_geometric>=2.1.0 -# See [https://github.com/pyg-team/pytorch_geometric/blob/master/CHANGELOG.md] +# See [https://github.com/pyg-team/pytorch_geometric/blob/master/CHANGELOG.md] # noqa: E501 def unbatch_edge_index(edge_index: Tensor, batch: Tensor) -> List[Tensor]: @@ -170,15 +169,18 @@ def _reconstruct_batch(self, original: Data, pooled: Data) -> Data: return pooled def _add_slice_dict(self, original: Data, pooled: Data) -> Data: - # Copy original slice_dict and count nodes in each graph in pooled batch + # Copy original slice_dict and count nodes in each + # graph in pooled batch slice_dict = deepcopy(original._slice_dict) _, counts = torch.unique_consecutive(pooled.batch, return_counts=True) - # Reconstruct the entry in slice_dict for pulsemaps - only these are affected by pooling + # Reconstruct the entry in slice_dict for pulsemaps - + # only these are affected by pooling pulsemap_slice = [0] for i in range(len(counts)): pulsemap_slice.append(pulsemap_slice[i] + counts[i].item()) - # Identifies pulsemap entries in slice_dict and set them to pulsemap_slice + # Identifies pulsemap entries in slice_dict and + # set them to pulsemap_slice for field in slice_dict.keys(): if (original._num_graphs) == slice_dict[field][-1]: pass # not pulsemap, so skip diff --git a/src/graphnet/models/components/layers.py b/src/graphnet/models/components/layers.py index bd2ed8daf..4f04067f9 100644 --- a/src/graphnet/models/components/layers.py +++ b/src/graphnet/models/components/layers.py @@ -9,7 +9,6 @@ from torch_geometric.typing import Adj, PairTensor from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.inits import reset -from torch_geometric.data import Data import torch.nn as nn from torch.nn.functional import linear from torch.nn.modules import TransformerEncoder, TransformerEncoderLayer diff --git a/src/graphnet/models/components/pool.py b/src/graphnet/models/components/pool.py index 3af15990c..592c2fce2 100644 --- a/src/graphnet/models/components/pool.py +++ b/src/graphnet/models/components/pool.py @@ -9,11 +9,11 @@ from torch_geometric.nn.pool.pool import pool_edge, pool_batch, pool_pos from torch_scatter import scatter, scatter_std -from torch_geometric.nn.pool import ( - avg_pool, +from torch_geometric.nn.pool import ( # noqa:F401 max_pool, - avg_pool_x, max_pool_x, + avg_pool_x, + avg_pool, ) @@ -90,8 +90,8 @@ def group_by(data: Union[Data, Batch], keys: List[str]) -> LongTensor: This grouping is done with in each event in case of batching. This allows for, e.g., assigning the same index to all pulses on the same PMT or DOM in the same event. This can be used for coarsening graphs, e.g., from pulse- - level to DOM-level by aggregating feature across each group returned by this - method. + level to DOM-level by aggregating feature across each group returned by + this method. Example: Given: @@ -140,7 +140,7 @@ def sum_pool_x( batch: LongTensor, size: Optional[int] = None, ) -> Tensor: - r"""Sum-pool node features according to the clustering defined in `cluster`. + r"""Sum-pool node features according to the cluster defined in `cluster`. Args: cluster: Cluster vector :math:`\mathbf{c} \in \{ 0, @@ -172,7 +172,7 @@ def std_pool_x( batch: LongTensor, size: Optional[int] = None, ) -> Tensor: - r"""Std-pool node features according to the clustering defined in `cluster`. + r"""Std-pool node features according to the cluster defined in `cluster`. Args: cluster: Cluster vector :math:`\mathbf{c} \in \{ 0, @@ -201,7 +201,7 @@ def std_pool_x( def sum_pool( cluster: LongTensor, data: Data, transform: Optional[Callable] = None ) -> Data: - r"""Pool and coarsen graph according to the clustering defined in `cluster`. + r"""Pool and coarsen graph according to the cluster defined in `cluster`. All nodes within the same cluster will be represented as one node. Final node features are defined by the *sum* of features of all nodes @@ -235,7 +235,7 @@ def sum_pool( def std_pool( cluster: LongTensor, data: Data, transform: Optional[Callable] = None ) -> Data: - r"""Pool and coarsen graph according to the clustering defined in `cluster`. + r"""Pool and coarsen graph according to the cluster defined in `cluster`. All nodes within the same cluster will be represented as one node. Final node features are defined by the *std* of features of all nodes diff --git a/src/graphnet/models/detector/detector.py b/src/graphnet/models/detector/detector.py index 9b9fc61b0..df28f3191 100644 --- a/src/graphnet/models/detector/detector.py +++ b/src/graphnet/models/detector/detector.py @@ -66,7 +66,9 @@ def _standardize( ) -> Data: for idx, feature in enumerate(input_feature_names): try: - input_features[:, idx] = self.feature_map()[feature]( # type: ignore + input_features[:, idx] = self.feature_map()[ + feature + ]( # noqa: E501 # type: ignore input_features[:, idx] ) except KeyError as e: diff --git a/src/graphnet/models/graphs/graphs.py b/src/graphnet/models/graphs/graphs.py index c0f0cf2db..e642ed06c 100644 --- a/src/graphnet/models/graphs/graphs.py +++ b/src/graphnet/models/graphs/graphs.py @@ -16,7 +16,7 @@ class KNNGraph(GraphDefinition): def __init__( self, detector: Detector, - node_definition: NodeDefinition = None, + node_definition: Optional[NodeDefinition] = None, input_feature_names: Optional[List[str]] = None, dtype: Optional[torch.dtype] = torch.float, perturbation_dict: Optional[Dict[str, float]] = None, @@ -37,10 +37,11 @@ def __init__( feature should be randomly perturbed. Defaults to None. seed: seed or Generator used to randomly sample perturbations. - Defaults to None. - nb_nearest_neighbours: Number of edges for each node. Defaults to 8. - columns: node feature columns used for distance calculation - . Defaults to [0, 1, 2]. + Defaults to None. + nb_nearest_neighbours: Number of edges for each node. + Defaults to 8. + columns: node feature columns used for distance calculation. + Defaults to [0, 1, 2]. """ # Base class constructor super().__init__( @@ -67,7 +68,7 @@ class EdgelessGraph(GraphDefinition): def __init__( self, detector: Detector, - node_definition: NodeDefinition = None, + node_definition: Optional[NodeDefinition] = None, input_feature_names: Optional[List[str]] = None, dtype: Optional[torch.dtype] = torch.float, perturbation_dict: Optional[Dict[str, float]] = None, diff --git a/src/graphnet/models/graphs/nodes/nodes.py b/src/graphnet/models/graphs/nodes/nodes.py index 4e094e6be..558ec96f4 100644 --- a/src/graphnet/models/graphs/nodes/nodes.py +++ b/src/graphnet/models/graphs/nodes/nodes.py @@ -241,7 +241,8 @@ def __init__( id_columns: List of columns that uniquely identify a DOM. time_column: Name of time column. charge_column: Name of charge column. - max_activations: Maximum number of activations to include in the time series. + max_activations: Maximum number of activations to include in + the time series. """ self._keys = keys super().__init__(input_feature_names=self._keys) @@ -251,9 +252,8 @@ def __init__( self._charge_index: Optional[int] = self._keys.index(charge_column) except ValueError: self.warning( - "Charge column with name {} not found. Running without.".format( - charge_column - ) + "Charge column with name {charge_column} not found. " + "Running without." ) self._charge_index = None @@ -271,7 +271,8 @@ def _construct_nodes(self, x: torch.Tensor) -> Data: x = x.numpy() if x.shape[0] == 0: return Data(x=torch.tensor(np.column_stack([x, []]))) - # if there is no charge column add a dummy column of zeros with the same shape as the time column + # if there is no charge column add a dummy column + # of zeros with the same shape as the time column if self._charge_index is None: charge_index: int = len(self._keys) x = np.insert(x, charge_index, np.zeros(x.shape[0]), axis=1) diff --git a/src/graphnet/models/model.py b/src/graphnet/models/model.py index 7c51c7952..650da29a9 100644 --- a/src/graphnet/models/model.py +++ b/src/graphnet/models/model.py @@ -50,7 +50,8 @@ def save_state_dict(self, path: str) -> None: """Save model `state_dict` to `path`.""" if not path.endswith(".pth"): self.info( - "It is recommended to use the .pth suffix for state_dict files." + "It is recommended to use the .pth suffix " + "for state_dict files." ) state_dict = self.state_dict() for key, value in state_dict.items(): @@ -74,7 +75,8 @@ def load_state_dict( # type: ignore[override] ) if state_dict_altered: self.warning( - "DeprecationWarning: State dicts with `_gnn` entries will be deprecated in GraphNeT 2.0" + "DeprecationWarning: State dicts with `_gnn`" + " entries will be deprecated in GraphNeT 2.0" ) return super().load_state_dict(state_dict, **kargs) diff --git a/src/graphnet/models/normalizing_flow.py b/src/graphnet/models/normalizing_flow.py index d62cf7c42..403208790 100644 --- a/src/graphnet/models/normalizing_flow.py +++ b/src/graphnet/models/normalizing_flow.py @@ -1,6 +1,6 @@ """Standard model class(es).""" -from typing import Any, Dict, List, Optional, Union, Type +from typing import Dict, List, Optional, Union, Type import torch from torch import Tensor from torch_geometric.data import Data @@ -26,7 +26,7 @@ def __init__( self, graph_definition: GraphDefinition, target_labels: str, - backbone: GNN = None, + backbone: Optional[GNN] = None, condition_on: Union[str, List[str], None] = None, flow_layers: str = "gggt", optimizer_class: Type[torch.optim.Optimizer] = Adam, diff --git a/src/graphnet/models/standard_model.py b/src/graphnet/models/standard_model.py index bef153097..1d77cfa33 100644 --- a/src/graphnet/models/standard_model.py +++ b/src/graphnet/models/standard_model.py @@ -1,6 +1,6 @@ """Standard model class(es).""" -from typing import Any, Dict, List, Optional, Union, Type +from typing import Dict, List, Optional, Union, Type import torch from torch import Tensor from torch_geometric.data import Data @@ -26,7 +26,7 @@ def __init__( self, graph_definition: GraphDefinition, tasks: Union[StandardLearnedTask, List[StandardLearnedTask]], - backbone: Model = None, + backbone: Optional[Model] = None, gnn: Optional[GNN] = None, optimizer_class: Type[torch.optim.Optimizer] = Adam, optimizer_kwargs: Optional[Dict] = None, diff --git a/src/graphnet/models/task/reconstruction.py b/src/graphnet/models/task/reconstruction.py index e9b3cdaa5..5408aa5b9 100644 --- a/src/graphnet/models/task/reconstruction.py +++ b/src/graphnet/models/task/reconstruction.py @@ -50,9 +50,8 @@ class DirectionReconstructionWithKappa(StandardLearnedTask): """Reconstructs direction with kappa from the 3D-vMF distribution.""" # Requires three features: untransformed points in (x,y,z)-space. - default_target_labels = [ - "direction" - ] # contains dir_x, dir_y, dir_z see https://github.com/graphnet-team/graphnet/blob/95309556cfd46a4046bc4bd7609888aab649e295/src/graphnet/training/labels.py#L29 + default_target_labels = ["direction"] # contains dir_x, dir_y, dir_z + # see Direction label in /src/graphnet/training/labels.py default_prediction_labels = [ "dir_x_pred", "dir_y_pred", @@ -86,7 +85,8 @@ def _forward(self, x: Tensor) -> Tensor: class ZenithReconstructionWithKappa(ZenithReconstruction): """Reconstructs zenith angle and associated kappa (1/var).""" - # Requires one feature in addition to `ZenithReconstruction`: kappa (unceratinty; 1/variance). + # Requires one feature in addition to `ZenithReconstruction`: + # kappa (unceratinty; 1/variance). default_target_labels = ["zenith"] default_prediction_labels = ["zenith_pred", "zenith_kappa"] nb_inputs = 2 @@ -148,7 +148,8 @@ def _forward(self, x: Tensor) -> Tensor: class EnergyReconstructionWithUncertainty(EnergyReconstruction): """Reconstructs energy and associated uncertainty (log(var)).""" - # Requires one feature in addition to `EnergyReconstruction`: log-variance (uncertainty). + # Requires one feature in addition to `EnergyReconstruction`: + # log-variance (uncertainty). default_target_labels = ["energy"] default_prediction_labels = ["energy_pred", "energy_sigma"] nb_inputs = 2 diff --git a/src/graphnet/training/callbacks.py b/src/graphnet/training/callbacks.py index c319dccc1..c8eb1f0b3 100644 --- a/src/graphnet/training/callbacks.py +++ b/src/graphnet/training/callbacks.py @@ -6,7 +6,6 @@ import warnings import numpy as np -import torch from tqdm.std import Bar from pytorch_lightning import LightningModule, Trainer @@ -55,7 +54,8 @@ def __init__( raise ValueError("Milestones must be increasing") if len(milestones) != len(factors): raise ValueError( - "Only multiplicative factor must be specified for each milestone." + "Only multiplicative factor must be specified" + " for each milestone." ) self.milestones = milestones diff --git a/src/graphnet/training/utils.py b/src/graphnet/training/utils.py index 4e6f7c64d..b6e705846 100644 --- a/src/graphnet/training/utils.py +++ b/src/graphnet/training/utils.py @@ -79,10 +79,10 @@ def make_dataloader( selection: Optional[List[int]] = None, num_workers: int = 10, persistent_workers: bool = True, - node_truth: List[str] = None, + node_truth: Optional[List[str]] = None, truth_table: str = "truth", node_truth_table: Optional[str] = None, - string_selection: List[int] = None, + string_selection: Optional[List[int]] = None, loss_weight_table: Optional[str] = None, loss_weight_column: Optional[str] = None, index_column: str = "event_no", diff --git a/src/graphnet/training/weight_fitting.py b/src/graphnet/training/weight_fitting.py index e66c2d4c5..9e4b66040 100644 --- a/src/graphnet/training/weight_fitting.py +++ b/src/graphnet/training/weight_fitting.py @@ -39,9 +39,9 @@ def _get_truth( ) -> pd.DataFrame: """Return truth `variable`, optionally only for `selection` events.""" if selection is None: - query = f"select {self._index_column}, {variable} from {self._truth_table}" + query = f"select {self._index_column}, {variable} from {self._truth_table}" # noqa: E501 else: - query = f"select {self._index_column}, {variable} from {self._truth_table} where {self._index_column} in {str(tuple(selection))}" + query = f"select {self._index_column}, {variable} from {self._truth_table} where {self._index_column} in {str(tuple(selection))}" # noqa: E501 with sqlite3.connect(self._database_path) as con: data = pd.read_sql(query, con) return data @@ -160,10 +160,12 @@ def _fit_weights(self, truth: pd.DataFrame) -> pd.DataFrame: # Histogram `truth_values` bin_counts, _ = np.histogram(truth[self._variable], bins=self._bins) - # Get reweighting for each bin to achieve uniformity. (NB: No normalisation applied.) + # Get reweighting for each bin to achieve uniformity. + # (NB: No normalisation applied.) bin_weights = 1.0 / np.where(bin_counts == 0, np.nan, bin_counts) - # For each sample in `truth_values`, get the weight in the corresponding bin + # For each sample in `truth_values`, get the weight in + # the corresponding bin ix = np.digitize(truth[self._variable], bins=self._bins) - 1 sample_weights = bin_weights[ix] sample_weights = sample_weights / sample_weights.mean() @@ -207,10 +209,12 @@ def _fit_weights( # type: ignore[override] # Histogram `truth_values` bin_counts, _ = np.histogram(truth[self._variable], bins=self._bins) - # Get reweighting for each bin to achieve uniformity. (NB: No normalisation applied.) + # Get reweighting for each bin to achieve uniformity. + # (NB: No normalisation applied.) bin_weights = 1.0 / np.where(bin_counts == 0, np.nan, bin_counts) - # For each sample in `truth_values`, get the weight in the corresponding bin + # For each sample in `truth_values`, + # get the weight in the corresponding bin ix = np.digitize(truth[self._variable], bins=self._bins) - 1 sample_weights = bin_weights[ix] sample_weights = sample_weights / sample_weights.mean() diff --git a/src/graphnet/utilities/config/base_config.py b/src/graphnet/utilities/config/base_config.py index 6867317e1..0bc826ad7 100644 --- a/src/graphnet/utilities/config/base_config.py +++ b/src/graphnet/utilities/config/base_config.py @@ -1,6 +1,5 @@ """Base config class(es).""" -from abc import abstractmethod from collections import OrderedDict import inspect import sys diff --git a/src/graphnet/utilities/config/configurable.py b/src/graphnet/utilities/config/configurable.py index 3a0976123..e11c0e695 100644 --- a/src/graphnet/utilities/config/configurable.py +++ b/src/graphnet/utilities/config/configurable.py @@ -1,6 +1,6 @@ """Bases for all configurable classes in `graphnet`.""" -from abc import ABC, abstractclassmethod +from abc import ABC, abstractmethod from typing import Any, Union from graphnet.utilities.config.base_config import BaseConfig @@ -34,6 +34,6 @@ def save_config(self, path: str) -> None: """Save Config to `path` as YAML file.""" self.config.dump(path) - @abstractclassmethod + @abstractmethod # type: ignore def from_config(cls, source: Union[BaseConfig, str]) -> Any: """Construct instance from `source` configuration.""" diff --git a/src/graphnet/utilities/imports.py b/src/graphnet/utilities/imports.py index 1c143280a..5ce5e45d0 100644 --- a/src/graphnet/utilities/imports.py +++ b/src/graphnet/utilities/imports.py @@ -56,7 +56,8 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: return test_function(*args, **kwargs) else: Logger(log_folder=None).info( - f"Function `{test_function.__name__}` not used since `icecube` isn't available." + f"Function `{test_function.__name__}` " + "not used since `icecube` isn't available." ) return diff --git a/tests/data/test_dataconverters_and_datasets.py b/tests/data/test_dataconverters_and_datasets.py index f370dbe17..229fc91df 100644 --- a/tests/data/test_dataconverters_and_datasets.py +++ b/tests/data/test_dataconverters_and_datasets.py @@ -21,15 +21,10 @@ from graphnet.data.dataset import ParquetDataset, SQLiteDataset from graphnet.data.sqlite import SQLiteDataConverter from graphnet.data.utilities.parquet_to_sqlite import ParquetToSQLiteConverter -from graphnet.utilities.imports import has_icecube_package from graphnet.models.graphs import KNNGraph from graphnet.models.graphs.nodes import NodesAsPulses from graphnet.models.detector import IceCubeDeepCore -if has_icecube_package(): - from icecube import dataio # pyright: reportMissingImports=false - - # Global variable(s) TEST_DATA_DIR = os.path.join( graphnet.constants.TEST_DATA_DIR, "i3", "oscNext_genie_level7_v02" @@ -199,7 +194,9 @@ def test_parquet_to_sqlite_converter() -> None: ) dataset_from_parquet = SQLiteDataset(path, **opt) # type: ignore[arg-type] - dataset = SQLiteDataset(get_file_path("sqlite"), **opt) # type: ignore[arg-type] + dataset = SQLiteDataset( + get_file_path("sqlite"), **opt # type: ignore[arg-type] + ) assert len(dataset_from_parquet) == len(dataset) for ix in range(len(dataset)): diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index a3e0a0921..dd2e03abf 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -64,7 +64,8 @@ def dataset_setup(dataset_ref: pytest.FixtureRequest) -> tuple: dataset_ref: The dataset reference. Returns: - A tuple with the dataset reference, dataset kwargs, and dataloader kwargs. + A tuple with the dataset reference, + dataset kwargs, and dataloader kwargs. """ # Grab public dataset paths data_path = ( @@ -127,10 +128,12 @@ def test_single_dataset_without_selections( """Verify GraphNeTDataModule behavior when no test selection is provided. Args: - dataset_setup: Tuple with dataset reference, dataset arguments, and dataloader arguments. + dataset_setup: Tuple with dataset reference, + dataset arguments, and dataloader arguments. Raises: - Exception: If the test dataloader is accessed without providing a test selection. + Exception: If the test dataloader is accessed + without providing a test selection. """ dataset_ref, dataset_kwargs, dataloader_kwargs = dataset_setup @@ -165,8 +168,9 @@ def test_single_dataset_with_selections( """Test that selection functionality of DataModule behaves as expected. Args: - dataset_setup (Tuple[Any, Dict[str, Any], Dict[str, int]]): A tuple containing the dataset reference, - dataset arguments, and dataloader arguments. + dataset_setup (Tuple[Any, Dict[str, Any], Dict[str, int]]): A tuple + containing the dataset reference, dataset arguments, + and dataloader arguments. Returns: None @@ -200,7 +204,9 @@ def test_single_dataset_with_selections( if isinstance(dataset_ref, SQLiteDataset): a = len(train_dataloader.dataset) + len(val_dataloader.dataset) assert a == len(train_val_selection) # type: ignore - assert len(test_dataloader.dataset) == len(test_selection) # type: ignore + assert len(test_dataloader.dataset) == len( + test_selection + ) # noqa: E501 # type: ignore elif isinstance(dataset_ref, ParquetDataset): # Parquet dataset selection is batches not events a = train_dataloader.dataset._indices + val_dataloader.dataset._indices @@ -225,8 +231,9 @@ def test_dataloader_args( """Test that arguments to dataloaders are propagated correctly. Args: - dataset_setup (Tuple[Any, Dict[str, Any], Dict[str, int]]): A tuple containing the dataset reference, - dataset keyword arguments, and dataloader keyword arguments. + dataset_setup (Tuple[Any, Dict[str, Any], Dict[str, int]]): A tuple + containing the dataset reference, dataset keyword arguments, + and dataloader keyword arguments. Returns: None @@ -265,8 +272,9 @@ def test_ensemble_dataset_without_selections( """Test ensemble dataset functionality without selections. Args: - dataset_setup (Tuple[Any, Dict[str, Any], Dict[str, int]]): A tuple containing the dataset reference, - dataset keyword arguments, and dataloader keyword arguments. + dataset_setup (Tuple[Any, Dict[str, Any], Dict[str, int]]): A tuple + containing the dataset reference, dataset keyword arguments, + and dataloader keyword arguments. Returns: None @@ -303,8 +311,9 @@ def test_ensemble_dataset_with_selections( """Test ensemble dataset functionality with selections. Args: - dataset_setup (Tuple[Any, Dict[str, Any], Dict[str, int]]): A tuple containing the dataset reference, - dataset keyword arguments, and dataloader keyword arguments. + dataset_setup (Tuple[Any, Dict[str, Any], Dict[str, int]]): A tuple + containing the dataset reference, dataset keyword arguments, + and dataloader keyword arguments. Returns: None diff --git a/tests/deployment/queso_test.py b/tests/deployment/queso_test.py index cf761bf52..25bd2435c 100644 --- a/tests/deployment/queso_test.py +++ b/tests/deployment/queso_test.py @@ -165,7 +165,7 @@ def test_deployment() -> None: features = FEATURES.UPGRADE input_folders = [f"{TEST_DATA_DIR}/i3/upgrade_genie_step4_140028_000998"] output_folder = f"{TEST_DATA_DIR}/output/QUESO_test" - gcd_file = f"{TEST_DATA_DIR}/i3/upgrade_genie_step4_140028_000998/GeoCalibDetectorStatus_ICUpgrade.v58.mixed.V0.i3.bz2" + gcd_file = f"{TEST_DATA_DIR}/i3/upgrade_genie_step4_140028_000998/GeoCalibDetectorStatus_ICUpgrade.v58.mixed.V0.i3.bz2" # noqa: E501 features = FEATURES.UPGRADE input_files = [] for folder in input_folders: @@ -214,11 +214,12 @@ def verify_QUESO_integrity() -> None: original_predictions[frame][model], equal_nan=True, ) - except AssertionError as e: - print( - f"Mismatch found in {model}: {new_predictions[frame][model]} vs. {original_predictions[frame][model]}" + except AssertionError: + raise AssertionError( + f"Mismatch found in {model}: " + f"{new_predictions[frame][model]} vs. " + f"{original_predictions[frame][model]}" ) - raise e return diff --git a/tests/models/test_graph_definition.py b/tests/models/test_graph_definition.py index 091cc0bac..f2538fdd7 100644 --- a/tests/models/test_graph_definition.py +++ b/tests/models/test_graph_definition.py @@ -72,7 +72,7 @@ def get_event( with sqlite3.connect(database) as con: query = f"SELECT event_no FROM {truth_table} limit 1" event_no = pd.read_sql(query, con) - query = f'SELECT {query_features} FROM {pulsemap} WHERE event_no = {str(event_no["event_no"][0])}' + query = f'SELECT {query_features} FROM {pulsemap} WHERE event_no = {str(event_no["event_no"][0])}' # noqa: E501 df = pd.read_sql(query, con) return np.array(df) @@ -86,11 +86,11 @@ def test_geometry_tables() -> None: ), "IceCube86": os.path.join( TEST_DATA_DIR, - "sqlite/oscNext_genie_level7_v02/oscNext_genie_level7_v02_first_5_frames.db", + "sqlite/oscNext_genie_level7_v02/oscNext_genie_level7_v02_first_5_frames.db", # noqa: E501 ), "IceCubeUpgrade": os.path.join( TEST_DATA_DIR, - "sqlite/upgrade_genie_step4_140028_000998_first_5_frames/upgrade_genie_step4_140028_000998_first_5_frames.db", + "sqlite/upgrade_genie_step4_140028_000998_first_5_frames/upgrade_genie_step4_140028_000998_first_5_frames.db", # noqa: E501 ), } meta = {