diff --git a/src/graphnet/data/dataconverter.py b/src/graphnet/data/dataconverter.py index 16b18f0d5..807ef21f9 100644 --- a/src/graphnet/data/dataconverter.py +++ b/src/graphnet/data/dataconverter.py @@ -27,14 +27,11 @@ def final(f): # Identity decorator I3TruthExtractor, ) from graphnet.utilities.filesys import find_i3_files -from graphnet.utilities.logging import LoggerMixin, get_logger +from graphnet.utilities.imports import has_icecube_package +from graphnet.utilities.logging import LoggerMixin -logger = get_logger() - -try: +if has_icecube_package(): from icecube import icetray, dataio # pyright: reportMissingImports=false -except ImportError: - logger.warning("icecube package not available.") SAVE_STRATEGIES = [ diff --git a/src/graphnet/data/extractors/i3extractor.py b/src/graphnet/data/extractors/i3extractor.py index 959963642..3a613887d 100644 --- a/src/graphnet/data/extractors/i3extractor.py +++ b/src/graphnet/data/extractors/i3extractor.py @@ -1,17 +1,11 @@ from abc import ABC, abstractmethod from typing import List -from graphnet.utilities.logging import LoggerMixin, get_logger +from graphnet.utilities.imports import has_icecube_package +from graphnet.utilities.logging import LoggerMixin -logger = get_logger() - -try: - from icecube import ( - icetray, - dataio, - ) # pyright: reportMissingImports=false -except ImportError: - logger.warning("icecube package not available.") +if has_icecube_package(): + from icecube import icetray, dataio # pyright: reportMissingImports=false class I3Extractor(ABC, LoggerMixin): diff --git a/src/graphnet/data/extractors/i3featureextractor.py b/src/graphnet/data/extractors/i3featureextractor.py index e46d963d9..a384bd2f7 100644 --- a/src/graphnet/data/extractors/i3featureextractor.py +++ b/src/graphnet/data/extractors/i3featureextractor.py @@ -1,13 +1,8 @@ from graphnet.data.extractors.i3extractor import I3Extractor -from graphnet.utilities.logging import get_logger - -logger = get_logger() -try: - from icecube import ( - dataclasses, - ) # pyright: reportMissingImports=false -except ImportError: - logger.warning("icecube package not available.") +from graphnet.utilities.imports import has_icecube_package + +if has_icecube_package(): + from icecube import dataclasses # pyright: reportMissingImports=false class I3FeatureExtractor(I3Extractor): diff --git a/src/graphnet/data/extractors/i3truthextractor.py b/src/graphnet/data/extractors/i3truthextractor.py index 470ada0bd..917e0f6a5 100644 --- a/src/graphnet/data/extractors/i3truthextractor.py +++ b/src/graphnet/data/extractors/i3truthextractor.py @@ -7,18 +7,14 @@ frame_is_montecarlo, frame_is_noise, ) -from graphnet.utilities.logging import get_logger +from graphnet.utilities.imports import has_icecube_package -logger = get_logger() - -try: +if has_icecube_package(): from icecube import ( dataclasses, icetray, phys_services, ) # pyright: reportMissingImports=false -except ImportError: - logger.warning("icecube package not available.") class I3TruthExtractor(I3Extractor): @@ -385,5 +381,5 @@ def _find_data_type(self, mc, input_file): if "L2" in input_file: # not robust sim_type = "dbang" if sim_type == "lol": - logger.info("SIM TYPE NOT FOUND!") + self.logger.info("SIM TYPE NOT FOUND!") return sim_type diff --git a/src/graphnet/data/parquet/__init__.py b/src/graphnet/data/parquet/__init__.py index 94f52fc52..05dfbe0c4 100644 --- a/src/graphnet/data/parquet/__init__.py +++ b/src/graphnet/data/parquet/__init__.py @@ -1,2 +1,8 @@ +from graphnet.utilities.imports import has_torch_package + from .parquet_dataconverter import ParquetDataConverter -from .parquet_dataset import ParquetDataset + +if has_torch_package(): + from .parquet_dataset import ParquetDataset + +del has_torch_package diff --git a/src/graphnet/data/sqlite/__init__.py b/src/graphnet/data/sqlite/__init__.py index ac461ff65..9f1d84d23 100644 --- a/src/graphnet/data/sqlite/__init__.py +++ b/src/graphnet/data/sqlite/__init__.py @@ -1,4 +1,10 @@ +from graphnet.utilities.imports import has_torch_package + from .sqlite_dataconverter import SQLiteDataConverter -from .sqlite_dataset import SQLiteDataset -from .sqlite_dataset_perturbed import SQLiteDatasetPerturbed from .sqlite_utilities import run_sql_code, save_to_sql + +if has_torch_package(): + from .sqlite_dataset import SQLiteDataset + from .sqlite_dataset_perturbed import SQLiteDatasetPerturbed + +del has_torch_package diff --git a/src/graphnet/utilities/imports.py b/src/graphnet/utilities/imports.py index fac29d20e..20a1fd5cd 100644 --- a/src/graphnet/utilities/imports.py +++ b/src/graphnet/utilities/imports.py @@ -2,7 +2,7 @@ from functools import wraps -from graphnet.utilities.logging import get_logger +from graphnet.utilities.logging import get_logger, warn_once logger = get_logger() @@ -15,6 +15,23 @@ def has_icecube_package() -> bool: return True except ImportError: + warn_once( + logger, + "`icecube` not available. Some functionality may be missing.", + ) + return False + + +def has_torch_package() -> bool: + """Check whether the `torch` package is available.""" + try: + import torch + + return True + except ImportError: + warn_once( + logger, "`torch` not available. Some functionality may be missing." + ) return False diff --git a/src/graphnet/utilities/logging.py b/src/graphnet/utilities/logging.py index dc198fa43..b8b93449e 100644 --- a/src/graphnet/utilities/logging.py +++ b/src/graphnet/utilities/logging.py @@ -1,6 +1,7 @@ """Consistent and configurable logging across the project.""" from collections import Counter +from functools import lru_cache import re from typing import Optional import colorlog @@ -53,6 +54,12 @@ def get_formatters() -> Tuple[logging.Formatter, colorlog.ColoredFormatter]: return basic_formatter, colored_formatter +@lru_cache(1) +def warn_once(logger: logging.Logger, message: str): + """Print `message` as warning exactly once.""" + logger.warn(message) + + class RepeatFilter(object): """Filter out repeat messages."""