Skip to content

Commit

Permalink
Merge branch 'main' into train_example_comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Aske-Rosted committed Nov 6, 2024
2 parents 3cde04d + 8903e35 commit af85b89
Show file tree
Hide file tree
Showing 111 changed files with 454 additions and 310 deletions.
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
exclude: '^(versioneer.py|src/graphnet/_version.py|docs/)'
repos:
- repo: https://github.com/psf/black
rev: 22.3.0
rev: 24.10.0
hooks:
- id: black
language_version: python3
args: [--config=black.toml]
- repo: https://github.com/pycqa/flake8
rev: 4.0.1
rev: 7.1.1
hooks:
- id: flake8
language_version: python3
- repo: https://github.com/pycqa/docformatter
rev: v1.5.0
rev: v1.7.5
hooks:
- id: docformatter
language_version: python3
- repo: https://github.com/pycqa/pydocstyle
rev: 6.1.1
rev: 6.3.0
hooks:
- id: pydocstyle
language_version: python3
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.982
rev: v1.13.0
hooks:
- id: mypy
args: [--follow-imports=silent, --disallow-untyped-defs, --disallow-incomplete-defs, --disallow-untyped-calls]
Expand Down
11 changes: 5 additions & 6 deletions docker/gnn-benchmarking/apply.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Script for applying GraphNeTModule in IceTray chain."""


import argparse
from glob import glob
from os import makedirs
Expand Down Expand Up @@ -37,9 +36,7 @@ def main(
# Get GCD file
gcd_pattern = "GeoCalibDetector"
gcd_candidates = [p for p in input_files if gcd_pattern in p]
assert (
len(gcd_candidates) == 1
), f"Did not get exactly one GCD-file candidate in `{dirname(input_files[0])}: {gcd_candidates}"
assert len(gcd_candidates) == 1, "Did not get exactly one GCD-file "
gcd_file = gcd_candidates[0]

# Get all input I3-files
Expand Down Expand Up @@ -78,8 +75,10 @@ def main(
"""The main function must get an input folder and output folder!
Args:
input_folder (str): The input folder where i3 files of a given dataset are located.
output_folder (str): The output folder where processed i3 files will be saved.
input_folder (str): The input folder where i3 files of a
given dataset are located.
output_folder (str): The output folder where processed i3
files will be saved.
"""
parser = argparse.ArgumentParser()

Expand Down
2 changes: 1 addition & 1 deletion docs/source/models/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ Below is a snippet that defines a :code:`Model` that reconstructs the zenith ang
nb_nearest_neighbours=8,
)
backbone = DynEdge(
nb_inputs=detector.nb_outputs,
nb_inputs=graph_definition.nb_outputs,
global_pooling_schemes=["min", "max", "mean"],
)
task = ZenithReconstructionWithKappa(
Expand Down
2 changes: 1 addition & 1 deletion examples/01_icetray/03_i3_deployer_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def main() -> None:
model_config = f"{base_path}/{model_name}/{model_name}_config.yml"
state_dict = f"{base_path}/{model_name}/{model_name}_state_dict.pth"
output_folder = f"{EXAMPLE_OUTPUT_DIR}/i3_deployment/upgrade_03_04"
gcd_file = f"{TEST_DATA_DIR}/i3/upgrade_genie_step4_140028_000998/GeoCalibDetectorStatus_ICUpgrade.v58.mixed.V0.i3.bz2"
gcd_file = f"{TEST_DATA_DIR}/i3/upgrade_genie_step4_140028_000998/GeoCalibDetectorStatus_ICUpgrade.v58.mixed.V0.i3.bz2" # noqa: E501
input_files = []
for folder in input_folders:
input_files.extend(glob(join(folder, "*.i3.gz")))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def main() -> None:
model_config = f"{base_path}/{model_name}/{model_name}_config.yml"
state_dict = f"{base_path}/{model_name}/{model_name}_state_dict.pth"
output_folder = f"{EXAMPLE_OUTPUT_DIR}/i3_deployment/upgrade"
gcd_file = f"{TEST_DATA_DIR}/i3/upgrade_genie_step4_140028_000998/GeoCalibDetectorStatus_ICUpgrade.v58.mixed.V0.i3.bz2"
gcd_file = f"{TEST_DATA_DIR}/i3/upgrade_genie_step4_140028_000998/GeoCalibDetectorStatus_ICUpgrade.v58.mixed.V0.i3.bz2" # noqa:E501
features = FEATURES.UPGRADE
input_files = []
for folder in input_folders:
Expand Down
6 changes: 3 additions & 3 deletions examples/04_training/01_train_dynedge.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ def main(
"gpus": gpus,
"max_epochs": max_epochs,
},
"dataset_reference": SQLiteDataset
if path.endswith(".db")
else ParquetDataset,
"dataset_reference": (
SQLiteDataset if path.endswith(".db") else ParquetDataset
),
}

archive = os.path.join(EXAMPLE_OUTPUT_DIR, "train_model_without_configs")
Expand Down
6 changes: 3 additions & 3 deletions examples/04_training/02_train_tito_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ def main(
"gpus": gpus,
"max_epochs": max_epochs,
},
"dataset_reference": SQLiteDataset
if path.endswith(".db")
else ParquetDataset,
"dataset_reference": (
SQLiteDataset if path.endswith(".db") else ParquetDataset
),
}

graph_definition = KNNGraph(detector=Prometheus())
Expand Down
6 changes: 3 additions & 3 deletions examples/04_training/05_train_RNN_TITO.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def main(
"gpus": gpus,
"max_epochs": max_epochs,
},
"dataset_reference": SQLiteDataset
if path.endswith(".db")
else ParquetDataset,
"dataset_reference": (
SQLiteDataset if path.endswith(".db") else ParquetDataset
),
}

graph_definition = KNNGraph(
Expand Down
9 changes: 5 additions & 4 deletions examples/04_training/06_train_icemix_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Example of training Model.
This example is based on Icemix solution proposed in
https://github.com/DrHB/icecube-2nd-place.git (2nd place solution).
https://github.com/DrHB/icecube-2nd-place.git
(2nd place solution).
"""

import os
Expand Down Expand Up @@ -78,9 +79,9 @@ def main(
"max_epochs": max_epochs,
"distribution_strategy": "ddp_find_unused_parameters_true",
},
"dataset_reference": SQLiteDataset
if path.endswith(".db")
else ParquetDataset,
"dataset_reference": (
SQLiteDataset if path.endswith(".db") else ParquetDataset
),
}

graph_definition = KNNGraph(
Expand Down
1 change: 0 additions & 1 deletion examples/04_training/07_train_normalizing_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Any, Dict, List, Optional

from pytorch_lightning.loggers import WandbLogger
import torch
from torch.optim.adam import Adam

from graphnet.constants import EXAMPLE_DATA_DIR, EXAMPLE_OUTPUT_DIR
Expand Down
1 change: 1 addition & 0 deletions examples/05_liquido/01_convert_h5.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Example of converting H5 files from LiquidO to SQLite and Parquet."""

import os

from graphnet.constants import EXAMPLE_OUTPUT_DIR, TEST_DATA_DIR
Expand Down
1 change: 1 addition & 0 deletions examples/06_prometheus/01_convert_prometheus.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Example of converting files from Prometheus to SQLite and Parquet."""

import os

from graphnet.constants import EXAMPLE_OUTPUT_DIR, TEST_DATA_DIR
Expand Down
5 changes: 5 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ omit =
[flake8]
exclude =
versioneer.py
# Ignore unused imports in __init__ files
per-file-ignores=
__init__.py:F401
src/graphnet/utilities/imports.py:F401
ignore=E203,W503

[docformatter]
wrap-summaries = 79
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# type: ignore[no-untyped-call]
# mypy: disable-error-code="no-untyped-call"
"""Setup script for the GraphNeT package."""

from setuptools import setup, find_packages
Expand Down Expand Up @@ -39,7 +39,7 @@
"MarkupSafe<=2.1",
"mypy",
"myst-parser",
"pre-commit",
"pre-commit<4.0",
"pydocstyle",
"pylint",
"pytest",
Expand Down
4 changes: 3 additions & 1 deletion src/graphnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,6 @@

from . import _version

__version__ = _version.get_versions()["version"] # type: ignore[no-untyped-call]
__version__ = _version.get_versions()[ # type: ignore[no-untyped-call]
"version"
]
1 change: 1 addition & 0 deletions src/graphnet/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
`graphnet.data` enables converting domain-specific data to industry-standard,
intermediate file formats and reading this data.
"""

from .extractors.icecube.utilities.i3_filters import I3Filter, I3FilterMask
from .dataconverter import DataConverter
from .pre_configured import I3ToParquetConverter
Expand Down
4 changes: 2 additions & 2 deletions src/graphnet/data/curated_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def __init__(
features: Optional[List[str]] = None,
backend: str = "parquet",
train_dataloader_kwargs: Optional[Dict[str, Any]] = None,
validation_dataloader_kwargs: Dict[str, Any] = None,
test_dataloader_kwargs: Dict[str, Any] = None,
validation_dataloader_kwargs: Optional[Dict[str, Any]] = None,
test_dataloader_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
"""Construct CuratedDataset.
Expand Down
25 changes: 14 additions & 11 deletions src/graphnet/data/dataconverter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Contains `DataConverter`."""

from typing import List, Union, OrderedDict, Dict, Tuple, Any, Optional, Type
from typing import List, Union, OrderedDict, Dict, Tuple, Any, Optional
from abc import ABC

from tqdm import tqdm
Expand Down Expand Up @@ -28,7 +28,8 @@
def init_global_index(index: Synchronized, output_files: List[str]) -> None:
"""Make `global_index` available to pool workers."""
global global_index, global_output_files # type: ignore[name-defined]
global_index, global_output_files = (index, output_files) # type: ignore[name-defined]
global_index = index # type: ignore[name-defined]
global_output_files = output_files # type: ignore[name-defined]


class DataConverter(ABC, Logger):
Expand Down Expand Up @@ -116,10 +117,9 @@ def _launch_jobs(
) -> None:
"""Multi Processing Logic.
Spawns worker pool,
distributes the input files evenly across workers.
declare event_no as globally accessible variable across workers.
starts jobs.
Spawns worker pool, distributes the input files evenly across workers.
declare event_no as globally accessible variable across workers. starts
jobs.
Will call process_file in parallel.
"""
Expand All @@ -138,8 +138,8 @@ def _launch_jobs(
def _process_file(self, file_path: Union[str, I3FileSet]) -> None:
"""Process a single file.
Calls file reader to recieve extracted output, event ids
is assigned to the extracted data and is handed to save method.
Calls file reader to recieve extracted output, event ids is assigned to
the extracted data and is handed to save method.
This function is called in parallel.
"""
Expand Down Expand Up @@ -247,7 +247,8 @@ def _count_rows(
n_rows = 1
except ValueError as e:
self.error(
f"Features from {extractor_name} ({extractor_dict.keys()}) have different lengths."
f"Features from {extractor_name} ({extractor_dict.keys()}) "
"have different lengths."
)
raise e
return n_rows
Expand Down Expand Up @@ -276,7 +277,8 @@ def get_map_function(
n_workers = min(self._num_workers, nb_files)
if n_workers > 1:
self.info(
f"Starting pool of {n_workers} workers to process {nb_files} {unit}"
f"Starting pool of {n_workers} workers to process"
" {nb_files} {unit}"
)

manager = Manager()
Expand All @@ -292,7 +294,8 @@ def get_map_function(

else:
self.info(
f"Processing {nb_files} {unit} in main thread (not multiprocessing)"
f"Processing {nb_files} {unit} in main thread"
"(not multiprocessing)"
)
map_fn = map # type: ignore
pool = None
Expand Down
7 changes: 4 additions & 3 deletions src/graphnet/data/datamodule.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Base `Dataloader` class(es) used in `graphnet`."""

from typing import Dict, Any, Optional, List, Tuple, Union, Type
import pytorch_lightning as pl
from copy import deepcopy
Expand Down Expand Up @@ -26,9 +27,9 @@ def __init__(
dataset_args: Dict[str, Any],
selection: Optional[Union[List[int], List[List[int]]]] = None,
test_selection: Optional[Union[List[int], List[List[int]]]] = None,
train_dataloader_kwargs: Dict[str, Any] = None,
validation_dataloader_kwargs: Dict[str, Any] = None,
test_dataloader_kwargs: Dict[str, Any] = None,
train_dataloader_kwargs: Optional[Dict[str, Any]] = None,
validation_dataloader_kwargs: Optional[Dict[str, Any]] = None,
test_dataloader_kwargs: Optional[Dict[str, Any]] = None,
train_val_split: Optional[List[float]] = [0.9, 0.10],
split_seed: int = 42,
) -> None:
Expand Down
1 change: 1 addition & 0 deletions src/graphnet/data/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Dataset classes for training in GraphNeT."""

# Configuration
from graphnet.utilities.imports import has_torch_package

Expand Down
1 change: 1 addition & 0 deletions src/graphnet/data/dataset/parquet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Datasets using parquet backend."""

# Configuration
from graphnet.utilities.imports import has_torch_package

Expand Down
2 changes: 1 addition & 1 deletion src/graphnet/data/dataset/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
)

from collections import defaultdict
from multiprocessing import Pool, cpu_count, get_context
from multiprocessing import get_context

import numpy as np
import torch
Expand Down
1 change: 1 addition & 0 deletions src/graphnet/data/dataset/sqlite/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Datasets using SQLite backend."""

from graphnet.utilities.imports import has_torch_package

if has_torch_package():
Expand Down
1 change: 1 addition & 0 deletions src/graphnet/data/extractors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Module containing data-specific extractor modules."""

from .extractor import Extractor
from .combine_extractors import CombinedExtractor
7 changes: 6 additions & 1 deletion src/graphnet/data/extractors/combine_extractors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Module for combining multiple extractors into a single extractor."""

from typing import TYPE_CHECKING

from graphnet.utilities.imports import has_icecube_package
Expand All @@ -20,7 +21,11 @@ def __init__(self, extractors: List[I3Extractor], extractor_name: str):
"""Construct CombinedExtractor.
Args:
extractors: List of extractors to combine. The extractors must all return data on the same level; e.g. all event-level data or pulse-level data. Mixing tables that contain event-level and pulse-level information will fail.
extractors: List of extractors to combine.
The extractors must all return data on the same level;
e.g. all event-level data or pulse-level data.
Mixing tables that contain event-level and
pulse-level information will fail.
extractor_name: Name of the new extractor.
"""
super().__init__(extractor_name=extractor_name)
Expand Down
8 changes: 5 additions & 3 deletions src/graphnet/data/extractors/extractor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Base I3Extractor class(es)."""

from typing import Any, Union
from abc import ABC, abstractmethod
import pandas as pd
Expand Down Expand Up @@ -26,9 +27,10 @@ def __init__(self, extractor_name: str):
"""Construct Extractor.
Args:
extractor_name: Name of the `Extractor` instance. Used to keep track of the
provenance of different data, and to name tables to which this
data is saved. E.g. "mc_truth".
extractor_name: Name of the `Extractor` instance.
Used to keep track of the provenance of different
data, and to name tables to which this data is
saved. E.g. "mc_truth".
"""
# Member variable(s)
self._extractor_name: str = extractor_name
Expand Down
Loading

0 comments on commit af85b89

Please sign in to comment.