diff --git a/src/graphnet/data/__init__.py b/src/graphnet/data/__init__.py index 1eca4f6cd..fbb1ee095 100644 --- a/src/graphnet/data/__init__.py +++ b/src/graphnet/data/__init__.py @@ -3,3 +3,4 @@ `graphnet.data` enables converting domain-specific data to industry-standard, intermediate file formats and reading this data. """ +from .filters import I3Filter, I3FilterMask diff --git a/src/graphnet/data/dataconverter.py b/src/graphnet/data/dataconverter.py index dc0deabd0..41cec5eec 100644 --- a/src/graphnet/data/dataconverter.py +++ b/src/graphnet/data/dataconverter.py @@ -39,6 +39,7 @@ from graphnet.utilities.filesys import find_i3_files from graphnet.utilities.imports import has_icecube_package from graphnet.utilities.logging import Logger +from graphnet.data.filters import I3Filter, NullSplitI3Filter if has_icecube_package(): from icecube import icetray, dataio # pyright: reportMissingImports=false @@ -107,6 +108,7 @@ def __init__( workers: int = 1, index_column: str = "event_no", icetray_verbose: int = 0, + i3_filters: List[I3Filter] = [], ): """Construct DataConverter. @@ -167,6 +169,14 @@ def __init__( self._input_file_batch_pattern = input_file_batch_pattern self._workers = workers + # I3Filters (NullSplitI3Filter is always included) + self._i3filters = [NullSplitI3Filter()] + i3_filters + + for filter in self._i3filters: + assert isinstance( + filter, I3Filter + ), f"{type(filter)} is not a subclass of I3Filter" + # Create I3Extractors self._extractors = I3ExtractorCollection(*extractors) @@ -433,6 +443,7 @@ def _extract_data(self, fileset: FileSet) -> List[OrderedDict]: except Exception as e: if "I3" in str(e): continue + # check if frame should be skipped if self._skip_frame(frame): continue @@ -555,14 +566,15 @@ def _get_output_file(self, input_file: str) -> str: return output_file def _skip_frame(self, frame: "icetray.I3Frame") -> bool: - """Check if frame should be skipped. - - Args: - frame: I3Frame to check. + """Check the user defined filters. Returns: - True if frame is a null split frame, else False. + bool: True if frame should be skipped, False otherwise. """ - if frame["I3EventHeader"].sub_event_stream == "NullSplit": - return True - return False + if self._i3filters is None: + return False # No filters defined, so we keep the frame + + for filter in self._i3filters: + if not filter(frame): + return True # keep_frame call false, skip the frame. + return False # All filter keep_frame calls true, keep the frame. diff --git a/src/graphnet/data/filters.py b/src/graphnet/data/filters.py new file mode 100644 index 000000000..ca83f4217 --- /dev/null +++ b/src/graphnet/data/filters.py @@ -0,0 +1,128 @@ +"""Filter classes for filtering I3-frames when converting I3-files.""" +from abc import abstractmethod +from graphnet.utilities.logging import Logger +from typing import List + +from graphnet.utilities.imports import has_icecube_package + +if has_icecube_package(): + from icecube import icetray + + +class I3Filter(Logger): + """A generic filter for I3-frames.""" + + @abstractmethod + def _keep_frame(self, frame: "icetray.I3Frame") -> bool: + """Return True if the frame is kept, False otherwise. + + Args: + frame: I3-frame + The I3-frame to check. + + Returns: + bool: True if the frame is kept, False otherwise. + """ + raise NotImplementedError + + def __call__(self, frame: "icetray.I3Frame") -> bool: + """Return True if the frame passes the filter, False otherwise. + + Args: + frame: I3-frame + The I3-frame to check. + + Returns: + bool: True if the frame passes the filter, False otherwise. + """ + pass_flag = self._keep_frame(frame) + try: + assert isinstance(pass_flag, bool) + except AssertionError: + raise TypeError( + f"Expected _pass_frame to return bool, got {type(pass_flag)}." + ) + return pass_flag + + +class NullSplitI3Filter(I3Filter): + """A filter that skips all null-split frames.""" + + def _keep_frame(self, frame: "icetray.I3Frame") -> bool: + """Check that frame is not a null-split frame. + + returns False if the frame is a null-split frame, True otherwise. + + Args: + frame: I3-frame + The I3-frame to check. + """ + if frame.Has("I3EventHeader"): + if frame["I3EventHeader"].sub_event_stream == "NullSplit": + return False + return True + + +class I3FilterMask(I3Filter): + """checks list of filters from the FilterMask in I3 frames.""" + + def __init__(self, filter_names: List[str], filter_any: bool = True): + """Initialize I3FilterMask. + + Args: + filter_names: List[str] + A list of filter names to check for. + filter_any: bool + standard: True + If True, the frame is kept if any of the filter names are present. + If False, the frame is kept if all of the filter names are present. + """ + self._filter_names = filter_names + self._filter_any = filter_any + + def _keep_frame(self, frame: "icetray.I3Frame") -> bool: + """Check if current frame should be kept. + + Args: + frame: I3-frame + The I3-frame to check. + """ + if "FilterMask" in frame: + if ( + self._filter_any is True + ): # Require any of the filters to pass to keep the frame + bool_list = [] + for filter_name in self._filter_names: + if filter_name not in frame["FilterMask"]: + self.warning_once( + f"FilterMask {filter_name} not found in frame. skipping filter." + ) + continue + elif frame["FilterMask"][filter].condition_passed is True: + bool_list.append(True) + else: + bool_list.append(False) + if len(bool_list) == 0: + self.warning_once( + "None of the FilterMask filters found in frame, FilterMask filters will not be applied." + ) + return any(bool_list) or len(bool_list) == 0 + else: # Require all filters to pass in order to keep the frame. + for filter_name in self._filter_names: + if filter_name not in frame["FilterMask"]: + self.warning_once( + f"FilterMask {filter_name} not found in frame, skipping filter." + ) + continue + elif frame["FilterMask"][filter].condition_passed is True: + continue # current filter passed, continue to next filter + else: + return ( + False # current filter failed so frame is skipped. + ) + return True + else: + self.warning_once( + "FilterMask not found in frame, FilterMask filters will not be applied." + ) + return True