Skip to content

Commit

Permalink
run flake8 --all-files
Browse files Browse the repository at this point in the history
  • Loading branch information
RasmusOrsoe committed Nov 2, 2024
1 parent ddf8b45 commit bb99a3e
Show file tree
Hide file tree
Showing 36 changed files with 150 additions and 118 deletions.
2 changes: 1 addition & 1 deletion examples/01_icetray/03_i3_deployer_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def main() -> None:
model_config = f"{base_path}/{model_name}/{model_name}_config.yml"
state_dict = f"{base_path}/{model_name}/{model_name}_state_dict.pth"
output_folder = f"{EXAMPLE_OUTPUT_DIR}/i3_deployment/upgrade_03_04"
gcd_file = f"{TEST_DATA_DIR}/i3/upgrade_genie_step4_140028_000998/GeoCalibDetectorStatus_ICUpgrade.v58.mixed.V0.i3.bz2"
gcd_file = f"{TEST_DATA_DIR}/i3/upgrade_genie_step4_140028_000998/GeoCalibDetectorStatus_ICUpgrade.v58.mixed.V0.i3.bz2" # noqa: E501
input_files = []
for folder in input_folders:
input_files.extend(glob(join(folder, "*.i3.gz")))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def main() -> None:
model_config = f"{base_path}/{model_name}/{model_name}_config.yml"
state_dict = f"{base_path}/{model_name}/{model_name}_state_dict.pth"
output_folder = f"{EXAMPLE_OUTPUT_DIR}/i3_deployment/upgrade"
gcd_file = f"{TEST_DATA_DIR}/i3/upgrade_genie_step4_140028_000998/GeoCalibDetectorStatus_ICUpgrade.v58.mixed.V0.i3.bz2"
gcd_file = f"{TEST_DATA_DIR}/i3/upgrade_genie_step4_140028_000998/GeoCalibDetectorStatus_ICUpgrade.v58.mixed.V0.i3.bz2" # noqa:E501
features = FEATURES.UPGRADE
input_files = []
for folder in input_folders:
Expand Down
1 change: 0 additions & 1 deletion examples/04_training/07_train_normalizing_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Any, Dict, List, Optional

from pytorch_lightning.loggers import WandbLogger
import torch
from torch.optim.adam import Adam

from graphnet.constants import EXAMPLE_DATA_DIR, EXAMPLE_OUTPUT_DIR
Expand Down
4 changes: 3 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ omit =
exclude =
versioneer.py
# Ignore unused imports in __init__ files
per-file-ignores=__init__.py:F401
per-file-ignores=
__init__.py:F401
src/graphnet/utilities/imports.py:F401
ignore=E203,W503

[docformatter]
Expand Down
4 changes: 3 additions & 1 deletion src/graphnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,6 @@

from . import _version

__version__ = _version.get_versions()["version"] # type: ignore[no-untyped-call]
__version__ = _version.get_versions()[ # type: ignore[no-untyped-call]
"version"
]
4 changes: 2 additions & 2 deletions src/graphnet/data/curated_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def __init__(
features: Optional[List[str]] = None,
backend: str = "parquet",
train_dataloader_kwargs: Optional[Dict[str, Any]] = None,
validation_dataloader_kwargs: Dict[str, Any] = None,
test_dataloader_kwargs: Dict[str, Any] = None,
validation_dataloader_kwargs: Optional[Dict[str, Any]] = None,
test_dataloader_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
"""Construct CuratedDataset.
Expand Down
2 changes: 1 addition & 1 deletion src/graphnet/data/dataset/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
)

from collections import defaultdict
from multiprocessing import Pool, cpu_count, get_context
from multiprocessing import get_context

import numpy as np
import torch
Expand Down
6 changes: 3 additions & 3 deletions src/graphnet/data/extractors/icecube/i3extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ def __init__(self, extractor_name: str):
"""Construct I3Extractor.
Args:
extractor_name: Name of the `I3Extractor` instance. Used to keep track of the
provenance of different data, and to name tables to which this
data is saved.
extractor_name: Name of the `I3Extractor` instance. Used to keep
track of the provenance of different data, and to name tables
to which this data is saved.
"""
# Member variable(s)
self._i3_file: str = ""
Expand Down
19 changes: 11 additions & 8 deletions src/graphnet/data/pre_configured/dataconverters.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ def __init__(
"""Convert I3 files to Parquet.
Args:
gcd_rescue: gcd_rescue: Path to a GCD file that will be used if no GCD file is
found in subfolder. `I3Reader` will recursively search
the input directory for I3-GCD file pairs. By IceCube
convention, a folder containing i3 files will have an
gcd_rescue: gcd_rescue: Path to a GCD file that will be used if
no GCD file is found in subfolder. `I3Reader` will
recursively search the input directory for I3-GCD file
pairs.
By IceCube convention,
a folder containing i3 files will have an
accompanying GCD file. However, in some cases, this
convention is broken. In cases where a folder contains
i3 files but no GCD file, the `gcd_rescue` is used
Expand Down Expand Up @@ -70,10 +72,11 @@ def __init__(
"""Convert I3 files to SQLite.
Args:
gcd_rescue: gcd_rescue: Path to a GCD file that will be used if no GCD file is
found in subfolder. `I3Reader` will recursively search
the input directory for I3-GCD file pairs. By IceCube
convention, a folder containing i3 files will have an
gcd_rescue: gcd_rescue: Path to a GCD file that will be used if
no GCD file is found in subfolder. `I3Reader` will
recursively search the input directory for I3-GCD file
pairs. By IceCube convention,
a folder containing i3 files will have an
accompanying GCD file. However, in some cases, this
convention is broken. In cases where a folder contains
i3 files but no GCD file, the `gcd_rescue` is used
Expand Down
8 changes: 6 additions & 2 deletions src/graphnet/data/readers/graphnet_file_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ def _validate_extractors(
) -> None:
for extractor in extractors:
try:
assert isinstance(extractor, tuple(self.accepted_extractors)) # type: ignore
assert isinstance(
extractor, tuple(self.accepted_extractors) # type: ignore
)
except AssertionError as e:
self.error(
f"{extractor.__class__.__name__}"
Expand Down Expand Up @@ -164,5 +166,7 @@ def _validate_file(self, file: str) -> None:
assert file.lower().endswith(tuple(self.accepted_file_extensions))
except AssertionError:
self.error(
f'{self.__class__.__name__} accepts {self.accepted_file_extensions} but {file.split("/")[-1]} has extension {os.path.splitext(file)[1]}.'
f"{self.__class__.__name__} accepts "
f'{self.accepted_file_extensions} but {file.split("/")[-1]} '
f"has extension {os.path.splitext(file)[1]}."
)
8 changes: 5 additions & 3 deletions src/graphnet/data/readers/i3reader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Module containing different I3Reader."""

from typing import List, Union, OrderedDict
from typing import List, Union, OrderedDict, Optional

from graphnet.utilities.imports import has_icecube_package
from graphnet.data.extractors.icecube.utilities.i3_filters import (
Expand All @@ -27,7 +27,7 @@ class I3Reader(GraphNeTFileReader):
def __init__(
self,
gcd_rescue: str,
i3_filters: Union[I3Filter, List[I3Filter]] = None,
i3_filters: Optional[Union[I3Filter, List[I3Filter]]] = None,
icetray_verbose: int = 0,
):
"""Initialize `I3Reader`.
Expand Down Expand Up @@ -65,7 +65,9 @@ def __init__(
# Base class constructor
super().__init__(name=__name__, class_name=self.__class__.__name__)

def __call__(self, file_path: I3FileSet) -> List[OrderedDict]: # type: ignore
def __call__(
self, file_path: I3FileSet
) -> List[OrderedDict]: # noqa: E501 # type: ignore
"""Extract data from single I3 file.
Args:
Expand Down
16 changes: 7 additions & 9 deletions src/graphnet/data/sqlite/deprecated_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,10 @@
This code will be removed in GraphNeT 2.0.
"""

from typing import List, Union, Type
from typing import List, Union

from graphnet.data.extractors.icecube import I3Extractor
from graphnet.data.extractors.icecube.utilities.i3_filters import (
I3Filter,
NullSplitI3Filter,
)
from graphnet.data.extractors.icecube.utilities.i3_filters import I3Filter
from graphnet.data import I3ToSQLiteConverter


Expand All @@ -28,10 +25,11 @@ def __init__(
"""Convert I3 files to Parquet.
Args:
gcd_rescue: gcd_rescue: Path to a GCD file that will be used if no GCD file is
found in subfolder. `I3Reader` will recursively search
the input directory for I3-GCD file pairs. By IceCube
convention, a folder containing i3 files will have an
gcd_rescue: gcd_rescue: Path to a GCD file that will be used if no
GCD file is found in subfolder. `I3Reader` will
recursively search the input directory for I3-GCD file
pairs. By IceCube convention,
a folder containing i3 files will have an
accompanying GCD file. However, in some cases, this
convention is broken. In cases where a folder contains
i3 files but no GCD file, the `gcd_rescue` is used
Expand Down
2 changes: 1 addition & 1 deletion src/graphnet/data/utilities/parquet_to_sqlite.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Utilities for converting files from Parquet to SQLite."""

from graphnet.data.pre_configured import ParquetToSQLiteConverter
from graphnet.data.pre_configured import ParquetToSQLiteConverter # noqa: F401
3 changes: 2 additions & 1 deletion src/graphnet/data/utilities/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ def pairwise_shuffle(
) -> Tuple[List[str], List[str]]:
"""Shuffle the I3 file list and the correponding gcd file list.
This is handy because it ensures a more even extraction load for each worker.
This is handy because it ensures a more even extraction load for each
worker.
Args:
files_list: List of I3 file paths.
Expand Down
10 changes: 6 additions & 4 deletions src/graphnet/data/utilities/sqlite_utilities.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""SQLite-specific utility functions for use in `graphnet.data`."""

import os.path
from typing import List, Dict, Tuple
from typing import List, Dict, Tuple, Union

import pandas as pd
import sqlalchemy
Expand Down Expand Up @@ -30,7 +30,9 @@ def query_database(database: str, query: str) -> pd.DataFrame:
return pd.read_sql(query, conn)


def get_primary_keys(database: str) -> Tuple[Dict[str, str], str]:
def get_primary_keys(
database: str,
) -> Tuple[Dict[str, Union[str, None]], Union[str, None]]:
"""Get name of primary key column for each table in database.
Args:
Expand All @@ -50,7 +52,7 @@ def get_primary_keys(database: str) -> Tuple[Dict[str, str], str]:

integer_primary_key = {}
for table in table_names:
query = f"SELECT l.name FROM pragma_table_info('{table}') as l WHERE l.pk = 1;"
query = f"SELECT l.name FROM pragma_table_info('{table}') as l WHERE l.pk = 1;" # noqa: E501
first_primary_key = [
key[0] for key in conn.execute(query).fetchall()
]
Expand Down Expand Up @@ -78,7 +80,7 @@ def database_table_exists(database_path: str, table_name: str) -> bool:
"""Check whether `table_name` exists in database at `database_path`."""
if not database_exists(database_path):
return False
query = f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table_name}';"
query = f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table_name}';" # noqa: E501
with sqlite3.connect(database_path) as conn:
result = pd.read_sql(query, conn)
return len(result) == 1
Expand Down
2 changes: 1 addition & 1 deletion src/graphnet/deployment/i3modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
detector configurations.
"""

from .deprecated_methods import *
from .deprecated_methods import * # noqa: F403
from graphnet.deployment.icecube import I3InferenceModule, I3PulseCleanerModule
12 changes: 7 additions & 5 deletions src/graphnet/models/coarsening.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from torch_geometric.data import Data, Batch
from sklearn.cluster import DBSCAN

# from torch_geometric.utils import unbatch_edge_index
from graphnet.models.components.pool import (
group_by,
avg_pool,
Expand All @@ -28,7 +27,7 @@

# NOTE: From [https://github.com/pyg-team/pytorch_geometric/pull/4903]
# TODO: Remove once bumping to torch_geometric>=2.1.0
# See [https://github.com/pyg-team/pytorch_geometric/blob/master/CHANGELOG.md]
# See [https://github.com/pyg-team/pytorch_geometric/blob/master/CHANGELOG.md] # noqa: E501


def unbatch_edge_index(edge_index: Tensor, batch: Tensor) -> List[Tensor]:
Expand Down Expand Up @@ -170,15 +169,18 @@ def _reconstruct_batch(self, original: Data, pooled: Data) -> Data:
return pooled

def _add_slice_dict(self, original: Data, pooled: Data) -> Data:
# Copy original slice_dict and count nodes in each graph in pooled batch
# Copy original slice_dict and count nodes in each
# graph in pooled batch
slice_dict = deepcopy(original._slice_dict)
_, counts = torch.unique_consecutive(pooled.batch, return_counts=True)
# Reconstruct the entry in slice_dict for pulsemaps - only these are affected by pooling
# Reconstruct the entry in slice_dict for pulsemaps -
# only these are affected by pooling
pulsemap_slice = [0]
for i in range(len(counts)):
pulsemap_slice.append(pulsemap_slice[i] + counts[i].item())

# Identifies pulsemap entries in slice_dict and set them to pulsemap_slice
# Identifies pulsemap entries in slice_dict and
# set them to pulsemap_slice
for field in slice_dict.keys():
if (original._num_graphs) == slice_dict[field][-1]:
pass # not pulsemap, so skip
Expand Down
1 change: 0 additions & 1 deletion src/graphnet/models/components/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from torch_geometric.typing import Adj, PairTensor
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import reset
from torch_geometric.data import Data
import torch.nn as nn
from torch.nn.functional import linear
from torch.nn.modules import TransformerEncoder, TransformerEncoderLayer
Expand Down
18 changes: 9 additions & 9 deletions src/graphnet/models/components/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from torch_geometric.nn.pool.pool import pool_edge, pool_batch, pool_pos
from torch_scatter import scatter, scatter_std

from torch_geometric.nn.pool import (
avg_pool,
from torch_geometric.nn.pool import ( # noqa:F401
max_pool,
avg_pool_x,
max_pool_x,
avg_pool_x,
avg_pool,
)


Expand Down Expand Up @@ -90,8 +90,8 @@ def group_by(data: Union[Data, Batch], keys: List[str]) -> LongTensor:
This grouping is done with in each event in case of batching. This allows
for, e.g., assigning the same index to all pulses on the same PMT or DOM in
the same event. This can be used for coarsening graphs, e.g., from pulse-
level to DOM-level by aggregating feature across each group returned by this
method.
level to DOM-level by aggregating feature across each group returned by
this method.
Example:
Given:
Expand Down Expand Up @@ -140,7 +140,7 @@ def sum_pool_x(
batch: LongTensor,
size: Optional[int] = None,
) -> Tensor:
r"""Sum-pool node features according to the clustering defined in `cluster`.
r"""Sum-pool node features according to the cluster defined in `cluster`.
Args:
cluster: Cluster vector :math:`\mathbf{c} \in \{ 0,
Expand Down Expand Up @@ -172,7 +172,7 @@ def std_pool_x(
batch: LongTensor,
size: Optional[int] = None,
) -> Tensor:
r"""Std-pool node features according to the clustering defined in `cluster`.
r"""Std-pool node features according to the cluster defined in `cluster`.
Args:
cluster: Cluster vector :math:`\mathbf{c} \in \{ 0,
Expand Down Expand Up @@ -201,7 +201,7 @@ def std_pool_x(
def sum_pool(
cluster: LongTensor, data: Data, transform: Optional[Callable] = None
) -> Data:
r"""Pool and coarsen graph according to the clustering defined in `cluster`.
r"""Pool and coarsen graph according to the cluster defined in `cluster`.
All nodes within the same cluster will be represented as one node.
Final node features are defined by the *sum* of features of all nodes
Expand Down Expand Up @@ -235,7 +235,7 @@ def sum_pool(
def std_pool(
cluster: LongTensor, data: Data, transform: Optional[Callable] = None
) -> Data:
r"""Pool and coarsen graph according to the clustering defined in `cluster`.
r"""Pool and coarsen graph according to the cluster defined in `cluster`.
All nodes within the same cluster will be represented as one node.
Final node features are defined by the *std* of features of all nodes
Expand Down
4 changes: 3 additions & 1 deletion src/graphnet/models/detector/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def _standardize(
) -> Data:
for idx, feature in enumerate(input_feature_names):
try:
input_features[:, idx] = self.feature_map()[feature]( # type: ignore
input_features[:, idx] = self.feature_map()[
feature
]( # noqa: E501 # type: ignore
input_features[:, idx]
)
except KeyError as e:
Expand Down
Loading

0 comments on commit bb99a3e

Please sign in to comment.