Skip to content

Commit

Permalink
black --all-files
Browse files Browse the repository at this point in the history
  • Loading branch information
RasmusOrsoe committed Nov 2, 2024
1 parent 4200a43 commit ddf8b45
Show file tree
Hide file tree
Showing 67 changed files with 234 additions and 126 deletions.
11 changes: 5 additions & 6 deletions docker/gnn-benchmarking/apply.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Script for applying GraphNeTModule in IceTray chain."""


import argparse
from glob import glob
from os import makedirs
Expand Down Expand Up @@ -37,9 +36,7 @@ def main(
# Get GCD file
gcd_pattern = "GeoCalibDetector"
gcd_candidates = [p for p in input_files if gcd_pattern in p]
assert (
len(gcd_candidates) == 1
), f"Did not get exactly one GCD-file candidate in `{dirname(input_files[0])}: {gcd_candidates}"
assert len(gcd_candidates) == 1, "Did not get exactly one GCD-file "
gcd_file = gcd_candidates[0]

# Get all input I3-files
Expand Down Expand Up @@ -78,8 +75,10 @@ def main(
"""The main function must get an input folder and output folder!
Args:
input_folder (str): The input folder where i3 files of a given dataset are located.
output_folder (str): The output folder where processed i3 files will be saved.
input_folder (str): The input folder where i3 files of a
given dataset are located.
output_folder (str): The output folder where processed i3
files will be saved.
"""
parser = argparse.ArgumentParser()

Expand Down
6 changes: 3 additions & 3 deletions examples/04_training/01_train_dynedge.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ def main(
"gpus": gpus,
"max_epochs": max_epochs,
},
"dataset_reference": SQLiteDataset
if path.endswith(".db")
else ParquetDataset,
"dataset_reference": (
SQLiteDataset if path.endswith(".db") else ParquetDataset
),
}

archive = os.path.join(EXAMPLE_OUTPUT_DIR, "train_model_without_configs")
Expand Down
6 changes: 3 additions & 3 deletions examples/04_training/02_train_tito_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ def main(
"gpus": gpus,
"max_epochs": max_epochs,
},
"dataset_reference": SQLiteDataset
if path.endswith(".db")
else ParquetDataset,
"dataset_reference": (
SQLiteDataset if path.endswith(".db") else ParquetDataset
),
}

graph_definition = KNNGraph(detector=Prometheus())
Expand Down
6 changes: 3 additions & 3 deletions examples/04_training/05_train_RNN_TITO.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def main(
"gpus": gpus,
"max_epochs": max_epochs,
},
"dataset_reference": SQLiteDataset
if path.endswith(".db")
else ParquetDataset,
"dataset_reference": (
SQLiteDataset if path.endswith(".db") else ParquetDataset
),
}

graph_definition = KNNGraph(
Expand Down
1 change: 1 addition & 0 deletions examples/05_liquido/01_convert_h5.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Example of converting H5 files from LiquidO to SQLite and Parquet."""

import os

from graphnet.constants import EXAMPLE_OUTPUT_DIR, TEST_DATA_DIR
Expand Down
1 change: 1 addition & 0 deletions examples/06_prometheus/01_convert_prometheus.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Example of converting files from Prometheus to SQLite and Parquet."""

import os

from graphnet.constants import EXAMPLE_OUTPUT_DIR, TEST_DATA_DIR
Expand Down
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ omit =
[flake8]
exclude =
versioneer.py
# Ignore unused imports in __init__ files
per-file-ignores=__init__.py:F401
ignore=E203,W503

[docformatter]
wrap-summaries = 79
Expand Down
1 change: 1 addition & 0 deletions src/graphnet/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
`graphnet.data` enables converting domain-specific data to industry-standard,
intermediate file formats and reading this data.
"""

from .extractors.icecube.utilities.i3_filters import I3Filter, I3FilterMask
from .dataconverter import DataConverter
from .pre_configured import I3ToParquetConverter
Expand Down
7 changes: 4 additions & 3 deletions src/graphnet/data/datamodule.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Base `Dataloader` class(es) used in `graphnet`."""

from typing import Dict, Any, Optional, List, Tuple, Union, Type
import pytorch_lightning as pl
from copy import deepcopy
Expand Down Expand Up @@ -26,9 +27,9 @@ def __init__(
dataset_args: Dict[str, Any],
selection: Optional[Union[List[int], List[List[int]]]] = None,
test_selection: Optional[Union[List[int], List[List[int]]]] = None,
train_dataloader_kwargs: Dict[str, Any] = None,
validation_dataloader_kwargs: Dict[str, Any] = None,
test_dataloader_kwargs: Dict[str, Any] = None,
train_dataloader_kwargs: Optional[Dict[str, Any]] = None,
validation_dataloader_kwargs: Optional[Dict[str, Any]] = None,
test_dataloader_kwargs: Optional[Dict[str, Any]] = None,
train_val_split: Optional[List[float]] = [0.9, 0.10],
split_seed: int = 42,
) -> None:
Expand Down
1 change: 1 addition & 0 deletions src/graphnet/data/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Dataset classes for training in GraphNeT."""

# Configuration
from graphnet.utilities.imports import has_torch_package

Expand Down
1 change: 1 addition & 0 deletions src/graphnet/data/dataset/parquet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Datasets using parquet backend."""

# Configuration
from graphnet.utilities.imports import has_torch_package

Expand Down
1 change: 1 addition & 0 deletions src/graphnet/data/dataset/sqlite/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Datasets using SQLite backend."""

from graphnet.utilities.imports import has_torch_package

if has_torch_package():
Expand Down
1 change: 1 addition & 0 deletions src/graphnet/data/extractors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Module containing data-specific extractor modules."""

from .extractor import Extractor
from .combine_extractors import CombinedExtractor
7 changes: 6 additions & 1 deletion src/graphnet/data/extractors/combine_extractors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Module for combining multiple extractors into a single extractor."""

from typing import TYPE_CHECKING

from graphnet.utilities.imports import has_icecube_package
Expand All @@ -20,7 +21,11 @@ def __init__(self, extractors: List[I3Extractor], extractor_name: str):
"""Construct CombinedExtractor.
Args:
extractors: List of extractors to combine. The extractors must all return data on the same level; e.g. all event-level data or pulse-level data. Mixing tables that contain event-level and pulse-level information will fail.
extractors: List of extractors to combine.
The extractors must all return data on the same level;
e.g. all event-level data or pulse-level data.
Mixing tables that contain event-level and
pulse-level information will fail.
extractor_name: Name of the new extractor.
"""
super().__init__(extractor_name=extractor_name)
Expand Down
8 changes: 5 additions & 3 deletions src/graphnet/data/extractors/extractor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Base I3Extractor class(es)."""

from typing import Any, Union
from abc import ABC, abstractmethod
import pandas as pd
Expand Down Expand Up @@ -26,9 +27,10 @@ def __init__(self, extractor_name: str):
"""Construct Extractor.
Args:
extractor_name: Name of the `Extractor` instance. Used to keep track of the
provenance of different data, and to name tables to which this
data is saved. E.g. "mc_truth".
extractor_name: Name of the `Extractor` instance.
Used to keep track of the provenance of different
data, and to name tables to which this data is
saved. E.g. "mc_truth".
"""
# Member variable(s)
self._extractor_name: str = extractor_name
Expand Down
28 changes: 17 additions & 11 deletions src/graphnet/data/extractors/icecube/i3truthextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,10 @@ def __call__(
"L7_oscNext_bool": padding_value,
}

# Only InIceSplit P frames contain ML appropriate I3RecoPulseSeriesMap etc.
# At low levels i3files contain several other P frame splits (e.g NullSplit),
# we remove those here.
# Only InIceSplit P frames contain ML appropriate
# for example I3RecoPulseSeriesMap, etc.
# At low levels i3 files contain several other P frame splits
# (e.g NullSplit). We remove those here.
if frame["I3EventHeader"].sub_event_stream not in [
"InIceSplit",
"Final",
Expand Down Expand Up @@ -181,7 +182,10 @@ def __call__(
energy_cascade,
inelasticity,
) = self._get_primary_track_energy_and_inelasticity(frame)
except RuntimeError: # track energy fails on northeren tracks with ""Hadrons" has no mass implemented. Cannot get total energy."
except (
RuntimeError
): # track energy fails on northeren tracks with ""Hadrons"
# has no mass implemented. Cannot get total energy."
energy_track, energy_cascade, inelasticity = (
padding_value,
padding_value,
Expand Down Expand Up @@ -216,9 +220,10 @@ def __call__(
muon_final = self._muon_stopped(output, self._borders)
output.update(
{
"position_x": muon_final[
"x"
], # position_xyz has no meaning for muons. These will now be updated to muon final position, given track length/azimuth/zenith
"position_x": muon_final["x"],
# position_xyz has no meaning for muons.
# These will now be updated to muon final position,
# given track length/azimuth/zenith
"position_y": muon_final["y"],
"position_z": muon_final["z"],
"stopped_muon": muon_final["stopped"],
Expand Down Expand Up @@ -362,10 +367,11 @@ def _get_primary_particle_interaction_type_and_elasticity(
MCInIcePrimary = frame[self._mctree][0]
if (
MCInIcePrimary.energy != MCInIcePrimary.energy
): # This is a nan check. Only happens for some muons where second item in MCTree is primary. Weird!
MCInIcePrimary = frame[self._mctree][
1
] # For some strange reason the second entry is identical in all variables and has no nans (always muon)
): # This is a nan check. Only happens for some muons
# where second item in MCTree is primary. Weird!
MCInIcePrimary = frame[self._mctree][1]
# For some strange reason the second entry is identical in
# all variables and has no nans (always muon)
else:
MCInIcePrimary = None
try:
Expand Down
1 change: 1 addition & 0 deletions src/graphnet/data/extractors/internal/parquet_extractor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Parquet Extractor for conversion from internal parquet format."""

import polars as pol
import pandas as pd

Expand Down
1 change: 1 addition & 0 deletions src/graphnet/data/extractors/liquido/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
"""Module containing different extractors for LiquidO files."""

from .h5_extractor import H5Extractor, H5HitExtractor, H5TruthExtractor
1 change: 1 addition & 0 deletions src/graphnet/data/extractors/liquido/h5_extractor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""H5 Extractor for LiquidO data files."""

from typing import List
import numpy as np
import pandas as pd
Expand Down
1 change: 1 addition & 0 deletions src/graphnet/data/extractors/prometheus/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Extractors for extracting data from parquet files Prometheus."""

from .prometheus_extractor import (
PrometheusExtractor,
PrometheusTruthExtractor,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Parquet Extractor for conversion of simulation files from PROMETHEUS."""

from typing import List
import pandas as pd
import numpy as np
Expand Down
1 change: 1 addition & 0 deletions src/graphnet/data/parquet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
"""Module for deprecated parquet methods."""

from .deprecated_methods import ParquetDataConverter
6 changes: 4 additions & 2 deletions src/graphnet/data/parquet/deprecated_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
This code will be removed in GraphNeT 2.0.
"""

from typing import List, Union

from graphnet.data.extractors.icecube import I3Extractor
Expand All @@ -26,8 +27,9 @@ def __init__(
"""Convert I3 files to Parquet.
Args:
gcd_rescue: gcd_rescue: Path to a GCD file that will be used if no GCD file is
found in subfolder. `I3Reader` will recursively search
gcd_rescue: gcd_rescue: Path to a GCD file that will be used if no
GCD file is found in subfolder.
`I3Reader` will recursively search
the input directory for I3-GCD file pairs. By IceCube
convention, a folder containing i3 files will have an
accompanying GCD file. However, in some cases, this
Expand Down
1 change: 1 addition & 0 deletions src/graphnet/data/pre_configured/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Module for pre-configured converter modules."""

from .dataconverters import (
I3ToParquetConverter,
I3ToSQLiteConverter,
Expand Down
1 change: 1 addition & 0 deletions src/graphnet/data/readers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Modules for reading experiment-specific data and applying Extractors."""

from .graphnet_file_reader import GraphNeTFileReader
from .i3reader import I3Reader
from .internal_parquet_reader import ParquetReader
Expand Down
1 change: 1 addition & 0 deletions src/graphnet/data/sqlite/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
"""Module for deprecated methods using sqlite."""

from .deprecated_methods import SQLiteDataConverter
1 change: 1 addition & 0 deletions src/graphnet/data/utilities/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utilities for use across `graphnet.data`."""

from .sqlite_utilities import create_table_and_save_to_sql
from .sqlite_utilities import get_primary_keys
from .sqlite_utilities import query_database
1 change: 1 addition & 0 deletions src/graphnet/data/writers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Modules for saving interim dataformat to various data backends."""

from .graphnet_writer import GraphNeTWriter
from .parquet_writer import ParquetWriter
from .sqlite_writer import SQLiteWriter
1 change: 1 addition & 0 deletions src/graphnet/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Contains pre-converted datasets ready for training."""

from .test_dataset import TestDataset
from .prometheus_datasets import TRIDENTSmall, BaikalGVDSmall, PONESmall
1 change: 1 addition & 0 deletions src/graphnet/datasets/prometheus_datasets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Public datasets from Prometheus Simulation."""

from typing import Dict, Any, List, Tuple, Union
import os
from sklearn.model_selection import train_test_split
Expand Down
1 change: 1 addition & 0 deletions src/graphnet/datasets/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""A CuratedDataset for unit tests."""

from typing import Dict, Any, List, Tuple, Union
import os

Expand Down
1 change: 1 addition & 0 deletions src/graphnet/deployment/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
`graphnet.deployment` allows for using trained models for inference in domain-
specific reconstruction chains.
"""

from .deployer import Deployer
from .deployment_module import DeploymentModule
3 changes: 2 additions & 1 deletion src/graphnet/deployment/deployer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Contains the graphnet deployment module."""

import random
from abc import abstractmethod, ABC
import multiprocessing
Expand All @@ -9,7 +10,7 @@
from .deployment_module import DeploymentModule
from graphnet.utilities.logging import Logger

if has_torch_package or TYPE_CHECKING:
if has_torch_package() or TYPE_CHECKING:
import torch


Expand Down
1 change: 1 addition & 0 deletions src/graphnet/deployment/deployment_module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Class(es) for deploying GraphNeT models in icetray as I3Modules."""

from abc import abstractmethod
from typing import Any, List, Union, Dict

Expand Down
1 change: 1 addition & 0 deletions src/graphnet/deployment/i3modules/deprecated_methods.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Contains deprecated methods."""

from typing import Union, Sequence

# from graphnet.deployment.icecube import I3Deployer, I3InferenceModule
Expand Down
1 change: 1 addition & 0 deletions src/graphnet/deployment/icecube/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Deployment modules specific to IceCube."""

from .inference_module import I3InferenceModule
from .cleaning_module import I3PulseCleanerModule
from .i3deployer import I3Deployer
Loading

0 comments on commit ddf8b45

Please sign in to comment.