Skip to content

Commit

Permalink
Merge pull request #273 from asogaard/fix-torch-in-data
Browse files Browse the repository at this point in the history
Fix torch import in graphnet.data
  • Loading branch information
asogaard authored Aug 31, 2022
2 parents dd390bd + 3ddce57 commit 2671f2d
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 36 deletions.
9 changes: 3 additions & 6 deletions src/graphnet/data/dataconverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,11 @@ def final(f): # Identity decorator
I3TruthExtractor,
)
from graphnet.utilities.filesys import find_i3_files
from graphnet.utilities.logging import LoggerMixin, get_logger
from graphnet.utilities.imports import has_icecube_package
from graphnet.utilities.logging import LoggerMixin

logger = get_logger()

try:
if has_icecube_package():
from icecube import icetray, dataio # pyright: reportMissingImports=false
except ImportError:
logger.warning("icecube package not available.")


SAVE_STRATEGIES = [
Expand Down
14 changes: 4 additions & 10 deletions src/graphnet/data/extractors/i3extractor.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
from abc import ABC, abstractmethod
from typing import List

from graphnet.utilities.logging import LoggerMixin, get_logger
from graphnet.utilities.imports import has_icecube_package
from graphnet.utilities.logging import LoggerMixin

logger = get_logger()

try:
from icecube import (
icetray,
dataio,
) # pyright: reportMissingImports=false
except ImportError:
logger.warning("icecube package not available.")
if has_icecube_package():
from icecube import icetray, dataio # pyright: reportMissingImports=false


class I3Extractor(ABC, LoggerMixin):
Expand Down
13 changes: 4 additions & 9 deletions src/graphnet/data/extractors/i3featureextractor.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
from graphnet.data.extractors.i3extractor import I3Extractor
from graphnet.utilities.logging import get_logger

logger = get_logger()
try:
from icecube import (
dataclasses,
) # pyright: reportMissingImports=false
except ImportError:
logger.warning("icecube package not available.")
from graphnet.utilities.imports import has_icecube_package

if has_icecube_package():
from icecube import dataclasses # pyright: reportMissingImports=false


class I3FeatureExtractor(I3Extractor):
Expand Down
10 changes: 3 additions & 7 deletions src/graphnet/data/extractors/i3truthextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,14 @@
frame_is_montecarlo,
frame_is_noise,
)
from graphnet.utilities.logging import get_logger
from graphnet.utilities.imports import has_icecube_package

logger = get_logger()

try:
if has_icecube_package():
from icecube import (
dataclasses,
icetray,
phys_services,
) # pyright: reportMissingImports=false
except ImportError:
logger.warning("icecube package not available.")


class I3TruthExtractor(I3Extractor):
Expand Down Expand Up @@ -385,5 +381,5 @@ def _find_data_type(self, mc, input_file):
if "L2" in input_file: # not robust
sim_type = "dbang"
if sim_type == "lol":
logger.info("SIM TYPE NOT FOUND!")
self.logger.info("SIM TYPE NOT FOUND!")
return sim_type
8 changes: 7 additions & 1 deletion src/graphnet/data/parquet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,8 @@
from graphnet.utilities.imports import has_torch_package

from .parquet_dataconverter import ParquetDataConverter
from .parquet_dataset import ParquetDataset

if has_torch_package():
from .parquet_dataset import ParquetDataset

del has_torch_package
10 changes: 8 additions & 2 deletions src/graphnet/data/sqlite/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from graphnet.utilities.imports import has_torch_package

from .sqlite_dataconverter import SQLiteDataConverter
from .sqlite_dataset import SQLiteDataset
from .sqlite_dataset_perturbed import SQLiteDatasetPerturbed
from .sqlite_utilities import run_sql_code, save_to_sql

if has_torch_package():
from .sqlite_dataset import SQLiteDataset
from .sqlite_dataset_perturbed import SQLiteDatasetPerturbed

del has_torch_package
19 changes: 18 additions & 1 deletion src/graphnet/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from functools import wraps

from graphnet.utilities.logging import get_logger
from graphnet.utilities.logging import get_logger, warn_once


logger = get_logger()
Expand All @@ -15,6 +15,23 @@ def has_icecube_package() -> bool:

return True
except ImportError:
warn_once(
logger,
"`icecube` not available. Some functionality may be missing.",
)
return False


def has_torch_package() -> bool:
"""Check whether the `torch` package is available."""
try:
import torch

return True
except ImportError:
warn_once(
logger, "`torch` not available. Some functionality may be missing."
)
return False


Expand Down
7 changes: 7 additions & 0 deletions src/graphnet/utilities/logging.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Consistent and configurable logging across the project."""

from collections import Counter
from functools import lru_cache
import re
from typing import Optional
import colorlog
Expand Down Expand Up @@ -53,6 +54,12 @@ def get_formatters() -> Tuple[logging.Formatter, colorlog.ColoredFormatter]:
return basic_formatter, colored_formatter


@lru_cache(1)
def warn_once(logger: logging.Logger, message: str):
"""Print `message` as warning exactly once."""
logger.warn(message)


class RepeatFilter(object):
"""Filter out repeat messages."""

Expand Down

0 comments on commit 2671f2d

Please sign in to comment.