diff --git a/src/graphnet/data/extractors/__init__.py b/src/graphnet/data/extractors/__init__.py index c6f4f325e..ad76d2742 100644 --- a/src/graphnet/data/extractors/__init__.py +++ b/src/graphnet/data/extractors/__init__.py @@ -1,2 +1,3 @@ """Module containing data-specific extractor modules.""" from .extractor import Extractor +from .combine_extractors import CombinedExtractor diff --git a/src/graphnet/data/extractors/combine_extractors.py b/src/graphnet/data/extractors/combine_extractors.py new file mode 100644 index 000000000..a04ab818d --- /dev/null +++ b/src/graphnet/data/extractors/combine_extractors.py @@ -0,0 +1,34 @@ +"""Module for combining multiple extractors into a single extractor.""" +from graphnet.data.extractors.icecube.i3extractor import I3Extractor +from typing import List, Dict + +from icecube import icetray + + +class CombinedExtractor(I3Extractor): + """Class for combining multiple extractors. + + This class is used to combine multiple extractors into a single extractor + with a new name. + """ + + def __init__(self, extractors: List[I3Extractor], extractor_name: str): + """Construct CombinedExtractor. + + Args: + extractors: List of extractors to combine. + extractor_name: Name of the new extractor. + """ + super().__init__(extractor_name=extractor_name) + self._extractors = extractors + + def __call__(self, frame: "icetray.I3Frame") -> Dict[str, float]: + """Extract data from frame using all extractors. + + Args: + frame: I3Frame to extract data from. + """ + output = {} + for extractor in self._extractors: + output.update(extractor(frame)) + return output