From ce9302b59bcbb542706f00803601c69780b55ecf Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Sun, 28 Jan 2024 15:09:39 +0100 Subject: [PATCH 01/33] restructure --- src/graphnet/data/dataclasses.py | 10 + src/graphnet/data/dataconverter_new.py | 254 +++++++++++++++++++++ src/graphnet/data/extractors/__init__.py | 1 + src/graphnet/data/extractors/extractor.py | 107 +++++++++ src/graphnet/data/readers.py | 265 ++++++++++++++++++++++ src/graphnet/data/writers.py | 59 +++++ 6 files changed, 696 insertions(+) create mode 100644 src/graphnet/data/dataclasses.py create mode 100644 src/graphnet/data/dataconverter_new.py create mode 100644 src/graphnet/data/extractors/extractor.py create mode 100644 src/graphnet/data/readers.py create mode 100644 src/graphnet/data/writers.py diff --git a/src/graphnet/data/dataclasses.py b/src/graphnet/data/dataclasses.py new file mode 100644 index 000000000..98b837693 --- /dev/null +++ b/src/graphnet/data/dataclasses.py @@ -0,0 +1,10 @@ +"""Module containing experiment-specific dataclasses.""" + + +from dataclasses import dataclass + + +@dataclass +class I3FileSet: # noqa: D101 + i3_file: str + gcd_file: str diff --git a/src/graphnet/data/dataconverter_new.py b/src/graphnet/data/dataconverter_new.py new file mode 100644 index 000000000..0456d26c8 --- /dev/null +++ b/src/graphnet/data/dataconverter_new.py @@ -0,0 +1,254 @@ +"""Contains `DataConverter`.""" +from typing import List, Union, OrderedDict, Dict, Tuple, Any, Optional, Type +from abc import abstractmethod, ABC + +from tqdm import tqdm +import numpy as np +from multiprocessing import Manager, Pool, Value +import multiprocessing.pool +from multiprocessing.sharedctypes import Synchronized +import pandas as pd +import os + +from graphnet.utilities.decorators import final +from graphnet.utilities.logging import Logger +from .readers import GraphNeTFileReader +from .writers import GraphNeTFileSaveMethod +from .extractors import Extractor +from .dataclasses import I3FileSet + + +def init_global_index(index: Synchronized, output_files: List[str]) -> None: + """Make `global_index` available to pool workers.""" + global global_index, global_output_files # type: ignore[name-defined] + global_index, global_output_files = (index, output_files) # type: ignore[name-defined] + + +class DataConverter(ABC, Logger): + """A finalized data conversion class in GraphNeT. + + `DataConverter` provides parallel processing of file conversion and + extraction from experiment-specific file formats to graphnet-supported data + formats. This class also assigns event id's to training examples. + """ + + def __init__( + self, + file_reader: Type[GraphNeTFileReader], + save_method: Type[GraphNeTFileSaveMethod], + extractors: Union[Type[Extractor], List[Type[Extractor]]], + index_column: str = "event_no", + num_workers: int = 1, + ) -> None: + """Initialize `DataConverter`. + + Args: + file_reader: The method used for reading and applying `Extractors`. + save_method: The method used to save the interim data format to + a graphnet supported file format. + extractors: The `Extractor`(s) that will be applied to the input + files. + index_column: Name of the event id column added to the events. + Defaults to "event_no". + num_workers: The number of CPUs used for parallel processing. + Defaults to 1 (no multiprocessing). + """ + # Member Variable Assignment + self._file_reader = file_reader + self._save_method = save_method + self._num_workers = num_workers + self._index_column = index_column + self._index = 0 + self._output_files: List[str] = [] + + # Set Extractors. Will throw error if extractors are incompatible + # with reader. + self._file_reader.set_extractors(extractors) + + @final + def __call__( + self, input_dir: Union[str, List[str]], output_dir: str + ) -> None: + """Extract data from files in `input_dir` and save to disk. + + Args: + input_dir: A directory that contains the input files. + The directory will be searched recursively for files + matching the file extension. + output_dir: The directory to save the files to. Input folder + structure is not respected. + """ + # Get the file reader to produce a list of input files + # in the directory + input_files = self._file_reader.find_files(path=input_dir) # type: ignore + self._launch_jobs(input_files=input_files, output_dir=output_dir) + + @final + def _launch_jobs( + self, input_files: Union[List[str], List[I3FileSet]] + ) -> None: + """Multi Processing Logic. + + Spawns worker pool, + distributes the input files evenly across workers. + declare event_no as globally accessible variable across workers. + starts jobs. + + Will call process_file in parallel. + """ + # Get appropriate mapping function + map_fn, pool = self.get_map_function(nb_files=len(input_files)) + + # Iterate over files + for _ in map_fn( + self._process_file, + tqdm(input_files, unit="file(s)", colour="green"), + ): + self.debug("processing file.") + + self._update_shared_variables(pool) + + @final + def _process_file(self, file_path: str) -> None: + """Process a single file. + + Calls file reader to recieve extracted output, event ids + is assigned to the extracted data and is handed to save method. + + This function is called in parallel. + """ + # Read and apply extractors + data = self._file_reader(file_path=file_path) + + # Assign event_no's to each event in data + data = self._assign_event_no(data=data) + + # Create output file name + output_file_name = self._create_file_name(input_file_path=file_path) + + # Apply save method + self._save_method(data=data, file_name=output_file_name) + + @final + def _create_file_name(self, input_file_path: str) -> str: + """Convert input file path to an output file name.""" + path_without_extension = os.path.splitext(input_file_path)[0] + base_file_name = path_without_extension.split("/")[-1] + return base_file_name + self._save_method.file_extension() # type: ignore + + @final + def _assign_event_no( + self, data: List[OrderedDict[str, Any]] + ) -> Dict[str, pd.DataFrame]: + + # Request event_no's for the entire file + event_nos = self._request_event_nos(n_ids=len(data)) + + # Dict holding pd.DataFrame's + dataframe_dict: Dict = {} + # Loop through events (again..) to assign event_nos + for k in range(len(data)): + for extractor_name in data[k].keys(): + n_rows = self._count_rows( + event_dict=data[k], extractor_name=extractor_name + ) + + data[k][extractor_name][self._index_column] = np.repeat( + event_nos[k], n_rows + ).tolist() + df = pd.DataFrame( + data[k][extractor_name], index=[0] if n_rows == 1 else None + ) + if extractor_name in dataframe_dict.keys(): + dataframe_dict[extractor_name].append(df) + else: + dataframe_dict[extractor_name] = [df] + + return dataframe_dict + + @final + def _count_rows( + self, event_dict: OrderedDict[str, Any], extractor_name: str + ) -> int: + """Count number of rows that features from `extractor_name` have.""" + try: + extractor_dict = event_dict[extractor_name] + # If all features in extractor_name have the same length + # this line of code will execute without error and result + # in an array with shape [num_features, n_rows_in_feature] + n_rows = np.asarray(list(extractor_dict.values())).shape[1] + except ValueError as e: + self.error( + f"Features from {extractor_name} ({extractor_dict.keys()}) have different lengths." + ) + raise e + + return n_rows + + def _request_event_nos(self, n_ids: int) -> List[int]: + + # Get new, unique index and increment value + if self._num_workers > 1: + with global_index.get_lock(): # type: ignore[name-defined] + starting_index = global_index.value # type: ignore[name-defined] + event_nos = np.arange( + starting_index, starting_index + n_ids, 1 + ).tolist() + global_index.value += n_ids # type: ignore[name-defined] + else: + starting_index = self._index + event_nos = np.arange( + starting_index, starting_index + n_ids, 1 + ).tolist() + self._index += n_ids + + return event_nos + + @final + def get_map_function( + self, nb_files: int, unit: str = "file(s)" + ) -> Tuple[Any, Optional[multiprocessing.pool.Pool]]: + """Identify map function to use (pure python or multiprocess).""" + # Choose relevant map-function given the requested number of workers. + n_workers = min(self._num_workers, nb_files) + if n_workers > 1: + self.info( + f"Starting pool of {n_workers} workers to process {nb_files} {unit}" + ) + + manager = Manager() + index = Value("i", 0) + output_files = manager.list() + + pool = Pool( + processes=n_workers, + initializer=init_global_index, + initargs=(index, output_files), + ) + map_fn = pool.imap + + else: + self.info( + f"Processing {nb_files} {unit} in main thread (not multiprocessing)" + ) + map_fn = map # type: ignore + pool = None + + return map_fn, pool + + @final + def _update_shared_variables( + self, pool: Optional[multiprocessing.pool.Pool] + ) -> None: + """Update `self._index` and `self._output_files`. + + If `pool` is set, it means that multiprocessing was used. In this case, + the worker processes will not have been able to write directly to + `self._index` and `self._output_files`, and we need to get them synced + up. + """ + if pool: + # Extract information from shared variables to member variables. + index, output_files = pool._initargs # type: ignore + self._index += index.value + self._output_files.extend(list(sorted(output_files[:]))) diff --git a/src/graphnet/data/extractors/__init__.py b/src/graphnet/data/extractors/__init__.py index e1d4895bf..ec0ecfe5e 100644 --- a/src/graphnet/data/extractors/__init__.py +++ b/src/graphnet/data/extractors/__init__.py @@ -18,3 +18,4 @@ from .i3pisaextractor import I3PISAExtractor from .i3ntmuonlabelsextractor import I3NTMuonLabelExtractor from .i3quesoextractor import I3QUESOExtractor +from .extractor import Extractor diff --git a/src/graphnet/data/extractors/extractor.py b/src/graphnet/data/extractors/extractor.py new file mode 100644 index 000000000..795d05cf1 --- /dev/null +++ b/src/graphnet/data/extractors/extractor.py @@ -0,0 +1,107 @@ +"""Base I3Extractor class(es).""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from graphnet.utilities.imports import has_icecube_package +from graphnet.utilities.logging import Logger + +if has_icecube_package() or TYPE_CHECKING: + from icecube import icetray, dataio # pyright: reportMissingImports=false + + +class Extractor(ABC, Logger): + """Base class for extracting information from data files. + + All classes inheriting from `Extractor` should implement the `__call__` + method, and should return a pure python dictionary on the form + + output = [{'var1: .., + ... , + 'var_n': ..}] + + Variables can be scalar or array-like of shape [n, 1], where n denotes the + number of elements in the array, and 1 the number of columns. + + An extractor is used in conjunction with a specific `FileReader`. + """ + + def __init__(self, extractor_name: str): + """Construct Extractor. + + Args: + extractor_name: Name of the `Extractor` instance. Used to keep track of the + provenance of different data, and to name tables to which this + data is saved. E.g. "mc_truth". + """ + # Member variable(s) + self._extractor_name: str = extractor_name + + # Base class constructor + super().__init__(name=__name__, class_name=self.__class__.__name__) + + @abstractmethod + def __call__(self, frame: "icetray.I3Frame") -> dict: + """Extract information from frame.""" + pass + + @property + def name(self) -> str: + """Get the name of the `I3Extractor` instance.""" + return self._extractor_name + + +class I3Extractor(Extractor): + """Base class for extracting information from physics I3-frames. + + Contains functionality required to extract data from i3 files, used by + the IceCube Neutrino Observatory. + + All classes inheriting from `I3Extractor` should implement the `__call__` + method. + """ + + def __init__(self, extractor_name: str): + """Construct I3Extractor. + + Args: + extractor_name: Name of the `I3Extractor` instance. Used to keep track of the + provenance of different data, and to name tables to which this + data is saved. + """ + # Member variable(s) + self._i3_file: str = "" + self._gcd_file: str = "" + self._gcd_dict: Dict[int, Any] = {} + self._calibration: Optional["icetray.I3Frame.Calibration"] = None + + # Base class constructor + super().__init__(extractor_name=extractor_name) + + def set_gcd(self, gcd_file: str, i3_file: str) -> None: + """Load the geospatial information contained in the GCD-file.""" + # If no GCD file is provided, search the I3 file for frames containing + # geometry (G) and calibration (C) information. + gcd = dataio.I3File(gcd_file or i3_file) + + try: + g_frame = gcd.pop_frame(icetray.I3Frame.Geometry) + except RuntimeError: + self.error( + "No GCD file was provided and no G-frame was found. Exiting." + ) + raise + else: + self._gcd_dict = g_frame["I3Geometry"].omgeo + + try: + c_frame = gcd.pop_frame(icetray.I3Frame.Calibration) + except RuntimeError: + self.warning("No GCD file was provided and no C-frame was found.") + else: + self._calibration = c_frame["I3Calibration"] + + @abstractmethod + def __call__(self, frame: "icetray.I3Frame") -> dict: + """Extract information from frame.""" + pass diff --git a/src/graphnet/data/readers.py b/src/graphnet/data/readers.py new file mode 100644 index 000000000..487ca99f7 --- /dev/null +++ b/src/graphnet/data/readers.py @@ -0,0 +1,265 @@ +"""Module containing different FileReader classes in GraphNeT. + +These methods are used to open and apply `Extractors` to experiment-specific +file formats. +""" + +from typing import List, Union, OrderedDict, Type +from abc import abstractmethod, ABC +import glob +import os + +from graphnet.utilities.decorators import final +from graphnet.utilities.logging import Logger +from graphnet.utilities.imports import has_icecube_package +from graphnet.data.filters import I3Filter, NullSplitI3Filter + +from .dataclasses import I3FileSet + +from .extractors.extractor import ( + Extractor, + I3Extractor, +) # , I3GenericExtractor +from graphnet.utilities.filesys import find_i3_files + +if has_icecube_package(): + from icecube import icetray, dataio # pyright: reportMissingImports=false + + +class GraphNeTFileReader(Logger, ABC): + """A generic base class for FileReaders in GraphNeT. + + Classes inheriting from `GraphNeTFileReader` must implement a + `__call__` method that opens a file, applies `Extractor`(s) and returns + a list of ordered dictionaries. + + In addition, Classes inheriting from `GraphNeTFileReader` must set + class properties `accepted_file_extensions` and `accepted_extractors`. + """ + + @abstractmethod + def __call__(self, file_path: str) -> List[OrderedDict]: + """Open and apply extractors to a single file. + + The `output` must be a list of dictionaries, where the number of events + in the file `n_events` satisfies `len(output) = n_events`. I.e each + element in the list is a dictionary, and each field in the dictionary + is the output of a single extractor. + """ + + @property + def accepted_file_extensions(self) -> List[str]: + """Return list of accepted file extensions.""" + return self._accepted_file_extensions # type: ignore + + @property + def accepted_extractors(self) -> List[Extractor]: + """Return list of compatible `Extractor`(s).""" + return self._accepted_extractors # type: ignore + + @property + def extracor_names(self) -> List[str]: + """Return list of table names produced by extractors.""" + return [extractor.name for extractor in self._extractors] # type: ignore + + def find_files( + self, path: Union[str, List[str]] + ) -> Union[List[str], List[I3FileSet]]: + """Search directory for input files recursively. + + This method may be overwritten by custom implementations. + + Args: + path: path to directory. + + Returns: + List of files matching accepted file extensions. + """ + if isinstance(path, str): + path = [path] + files = [] + for dir in path: + for accepted_file_extension in self.accepted_file_extensions: + files.extend(glob.glob(dir + f"/*{accepted_file_extension}")) + + # Check that files are OK. + self.validate_files(files) + return files + + @final + def set_extractors(self, extractors: List[Extractor]) -> None: + """Set `Extractor`(s) as member variable. + + Args: + extractors: A list of `Extractor`(s) to set as member variable. + """ + self._validate_extractors(extractors) + self._extractors = extractors + + @final + def _validate_extractors(self, extractors: List[Extractor]) -> None: + for extractor in extractors: + try: + assert isinstance(extractor, tuple(self.accepted_extractors)) # type: ignore + except AssertionError as e: + self.error( + f"{extractor.__class__.__name__} is not supported by {self.__class__.__name__}" + ) + raise e + + @final + def validate_files( + self, input_files: Union[List[str], List[I3FileSet]] + ) -> None: + """Check that the input files are accepted by the reader. + + Args: + input_files: Path(s) to input file(s). + """ + for input_file in input_files: + # Handle filepath vs. FileSet cases + if isinstance(input_file, I3FileSet): + self._validate_file(input_file.i3_file) + self._validate_file(input_file.gcd_file) + else: + self._validate_file(input_file) + + @final + def _validate_file(self, file: str) -> None: + """Validate a single file path. + + Args: + file: path to file. + + Returns: + None + """ + try: + assert file.lower().endswith(tuple(self.accepted_file_extensions)) + except AssertionError: + self.error( + f'{self.__class__.__name__} accepts {self.accepted_file_extensions} but {file.split("/")[-1]} has extension {os.path.splitext(file)[1]}.' + ) + + +class I3Reader(GraphNeTFileReader): + """A class for reading .i3 files from the IceCube Neutrino Observatory. + + Note that this class relies on IceCube-specific software, and therefore + must be run in a software environment that contains IceTray. + """ + + def __init__( + self, + gcd_rescue: str, + i3_filters: Union[ + Type[I3Filter], List[Type[I3Filter]] + ] = NullSplitI3Filter, + icetray_verbose: int = 0, + ): + """Initialize `I3Reader`. + + Args: + gcd_rescue: Path to a GCD file that will be used if no GCD file is + found in subfolder. `I3Reader` will recursively search + the input directory for I3-GCD file pairs. By IceCube + convention, a folder containing i3 files will have an + accompanying GCD file. However, in some cases, this + convention is broken. In cases where a folder contains + i3 files but no GCD file, the `gcd_rescue` is used + instead. + i3_filters: Instances of `I3Filter` to filter PFrames. Defaults to + `NullSplitI3Filter`. + icetray_verbose: Set the level of verbosity of icetray. + Defaults to 0. + """ + # Set verbosity + if icetray_verbose == 0: + icetray.I3Logger.global_logger = icetray.I3NullLogger() + + # Set Member Variables + self._accepted_file_extensions = [".bz2", ".zst", ".gz"] + self._accepted_extractors = [I3Extractor] + self._gcd_rescue = gcd_rescue + self._i3filters = ( + i3_filters if isinstance(i3_filters, list) else [i3_filters] + ) + + # Base class constructor + super().__init__(name=__name__, class_name=self.__class__.__name__) + + def __call__(self, file_path: I3FileSet) -> List[OrderedDict]: # type: ignore + """Extract data from single I3 file. + + Args: + fileset: Path to I3 file and corresponding GCD file. + + Returns: + Extracted data. + """ + # Set I3-GCD file pair in extractor + for extractor in self._extractors: + extractor.set_files(file_path.i3_file, file_path.gcd_file) # type: ignore + + # Open I3 file + i3_file_io = dataio.I3File(file_path.i3_file, "r") + data = list() + while i3_file_io.more(): + try: + frame = i3_file_io.pop_physics() + except Exception as e: + if "I3" in str(e): + continue + # check if frame should be skipped + if self._skip_frame(frame): + continue + + # Try to extract data from I3Frame + results = [extractor(frame) for extractor in self._extractors] + + data_dict = OrderedDict(zip(self.extracor_names, results)) + + # If an I3GenericExtractor is used, we want each automatically + # parsed key to be stored as a separate table. + # for extractor in self._extractors: + # if isinstance(extractor, I3GenericExtractor): + # data_dict.update(data_dict.pop(extractor._name)) + + data.append(data_dict) + return data + + def find_files(self, path: Union[str, List[str]]) -> List[I3FileSet]: + """Recursively search directory for I3 and GCD file pairs. + + Args: + path: directory to search recursively. + + Returns: + List I3 and GCD file pairs as I3FileSets + """ + # Find all I3 and GCD files in the specified directories. + i3_files, gcd_files = find_i3_files( + path, + self._gcd_rescue, + ) + + # Pack as I3FileSets + filesets = [ + I3FileSet(i3_file, gcd_file) + for i3_file, gcd_file in zip(i3_files, gcd_files) + ] + return filesets + + def _skip_frame(self, frame: "icetray.I3Frame") -> bool: + """Check the user defined filters. + + Returns: + bool: True if frame should be skipped, False otherwise. + """ + if self._i3filters is None: + return False # No filters defined, so we keep the frame + + for filter in self._i3filters: + if not filter(frame): + return True # keep_frame call false, skip the frame. + return False # All filter keep_frame calls true, keep the frame. diff --git a/src/graphnet/data/writers.py b/src/graphnet/data/writers.py new file mode 100644 index 000000000..d02eef2b4 --- /dev/null +++ b/src/graphnet/data/writers.py @@ -0,0 +1,59 @@ +"""Module containing `GraphNeTFileSaveMethod`(s). + +These modules are used to save the interim data format from `DataConverter` to +a deep-learning friendly file format. +""" + +import os +from typing import List, Union, OrderedDict, Any +from abc import abstractmethod, ABC + +from graphnet.utilities.decorators import final +from graphnet.utilities.logging import Logger + + +class GraphNeTFileSaveMethod(Logger, ABC): + """Generic base class for saving interim data format in `DataConverter`. + + Classes inheriting from `GraphNeTFileSaveMethod` must implement the + `save_file` method, which recieves the interim data format from + from a single file. + + In addition, classes inheriting from `GraphNeTFileSaveMethod` must + set the `file_extension` property. + """ + + @abstractmethod + def _save_file( + self, data: OrderedDict[str, Any], output_file_path: str + ) -> None: + """Save the interim data format from a single input file. + + Args: + data: the interim data from a single input file. + output_file_path: output file path. + """ + return + + @final + def __call__( + self, data: OrderedDict[str, Any], file_name: str, out_dir: str + ) -> None: + """Save data. + + Args: + data: data to be saved. + file_name: name of input file. Will be used to generate output + file name. + out_dir: directory to save data to. + """ + output_file_path = os.path.join( + out_dir, file_name, self.file_extension + ) + self._save_file(data=data, output_file_path=output_file_path) + return + + @property + def file_extension(self) -> str: + """Return file extension used to store the data.""" + return self._file_extension # type: ignore From f33ae78c79b202ffd7ab7c94c8345a5c3be60f75 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Tue, 6 Feb 2024 19:26:45 +0100 Subject: [PATCH 02/33] first test --- src/graphnet/data/dataconverter_new.py | 69 +++++++++++----- .../data/extractors/i3featureextractor.py | 2 +- .../data/extractors/i3truthextractor.py | 2 +- src/graphnet/data/readers.py | 10 ++- src/graphnet/data/writers.py | 81 +++++++++++++++++-- 5 files changed, 129 insertions(+), 35 deletions(-) diff --git a/src/graphnet/data/dataconverter_new.py b/src/graphnet/data/dataconverter_new.py index 0456d26c8..eb51495d1 100644 --- a/src/graphnet/data/dataconverter_new.py +++ b/src/graphnet/data/dataconverter_new.py @@ -65,6 +65,9 @@ def __init__( # with reader. self._file_reader.set_extractors(extractors) + # Base class constructor + super().__init__(name=__name__, class_name=self.__class__.__name__) + @final def __call__( self, input_dir: Union[str, List[str]], output_dir: str @@ -78,14 +81,17 @@ def __call__( output_dir: The directory to save the files to. Input folder structure is not respected. """ + # Set outdir + self._output_dir = output_dir # Get the file reader to produce a list of input files # in the directory input_files = self._file_reader.find_files(path=input_dir) # type: ignore - self._launch_jobs(input_files=input_files, output_dir=output_dir) + self._launch_jobs(input_files=input_files) @final def _launch_jobs( - self, input_files: Union[List[str], List[I3FileSet]] + self, + input_files: Union[List[str], List[I3FileSet]], ) -> None: """Multi Processing Logic. @@ -109,7 +115,7 @@ def _launch_jobs( self._update_shared_variables(pool) @final - def _process_file(self, file_path: str) -> None: + def _process_file(self, file_path: Union[str, I3FileSet]) -> None: """Process a single file. Calls file reader to recieve extracted output, event ids @@ -119,22 +125,30 @@ def _process_file(self, file_path: str) -> None: """ # Read and apply extractors data = self._file_reader(file_path=file_path) + n_events = len(data) # type: ignore - # Assign event_no's to each event in data + # Assign event_no's to each event in data and transform to pd.DataFrame data = self._assign_event_no(data=data) # Create output file name output_file_name = self._create_file_name(input_file_path=file_path) # Apply save method - self._save_method(data=data, file_name=output_file_name) + self._save_method( + data=data, + file_name=output_file_name, + n_events=n_events, + output_dir=self._output_dir, + ) @final - def _create_file_name(self, input_file_path: str) -> str: + def _create_file_name(self, input_file_path: Union[str, I3FileSet]) -> str: """Convert input file path to an output file name.""" + if isinstance(input_file_path, I3FileSet): + input_file_path = input_file_path.i3_file path_without_extension = os.path.splitext(input_file_path)[0] base_file_name = path_without_extension.split("/")[-1] - return base_file_name + self._save_method.file_extension() # type: ignore + return base_file_name # type: ignore @final def _assign_event_no( @@ -152,18 +166,23 @@ def _assign_event_no( n_rows = self._count_rows( event_dict=data[k], extractor_name=extractor_name ) - - data[k][extractor_name][self._index_column] = np.repeat( - event_nos[k], n_rows - ).tolist() - df = pd.DataFrame( - data[k][extractor_name], index=[0] if n_rows == 1 else None - ) - if extractor_name in dataframe_dict.keys(): - dataframe_dict[extractor_name].append(df) - else: - dataframe_dict[extractor_name] = [df] - + if n_rows > 0: + data[k][extractor_name][self._index_column] = np.repeat( + event_nos[k], n_rows + ).tolist() + df = pd.DataFrame( + data[k][extractor_name], + index=[0] if n_rows == 1 else None, + ) + if extractor_name in dataframe_dict.keys(): + dataframe_dict[extractor_name].append(df) + else: + dataframe_dict[extractor_name] = [df] + # Merge each list of dataframes + for key in dataframe_dict.keys(): + dataframe_dict[key] = pd.concat( + dataframe_dict[key], axis=0 + ).reset_index(drop=True) return dataframe_dict @final @@ -171,18 +190,24 @@ def _count_rows( self, event_dict: OrderedDict[str, Any], extractor_name: str ) -> int: """Count number of rows that features from `extractor_name` have.""" + extractor_dict = event_dict[extractor_name] + try: - extractor_dict = event_dict[extractor_name] # If all features in extractor_name have the same length # this line of code will execute without error and result # in an array with shape [num_features, n_rows_in_feature] - n_rows = np.asarray(list(extractor_dict.values())).shape[1] + # unless the list is empty! + + shape = np.asarray(list(extractor_dict.values())).shape + if len(shape) > 1: + n_rows = shape[1] + else: + n_rows = 1 except ValueError as e: self.error( f"Features from {extractor_name} ({extractor_dict.keys()}) have different lengths." ) raise e - return n_rows def _request_event_nos(self, n_ids: int) -> List[int]: diff --git a/src/graphnet/data/extractors/i3featureextractor.py b/src/graphnet/data/extractors/i3featureextractor.py index f1f578453..f351f0f3a 100644 --- a/src/graphnet/data/extractors/i3featureextractor.py +++ b/src/graphnet/data/extractors/i3featureextractor.py @@ -1,7 +1,7 @@ """I3Extractor class(es) for extracting specific, reconstructed features.""" from typing import TYPE_CHECKING, Any, Dict, List -from graphnet.data.extractors.i3extractor import I3Extractor +from graphnet.data.extractors.extractor import I3Extractor from graphnet.data.extractors.utilities.frames import ( get_om_keys_and_pulseseries, ) diff --git a/src/graphnet/data/extractors/i3truthextractor.py b/src/graphnet/data/extractors/i3truthextractor.py index bcfe694c3..d04be69b2 100644 --- a/src/graphnet/data/extractors/i3truthextractor.py +++ b/src/graphnet/data/extractors/i3truthextractor.py @@ -4,7 +4,7 @@ import matplotlib.path as mpath from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple -from graphnet.data.extractors.i3extractor import I3Extractor +from graphnet.data.extractors.extractor import I3Extractor from graphnet.data.extractors.utilities.frames import ( frame_is_montecarlo, frame_is_noise, diff --git a/src/graphnet/data/readers.py b/src/graphnet/data/readers.py index 487ca99f7..6dd9bd63d 100644 --- a/src/graphnet/data/readers.py +++ b/src/graphnet/data/readers.py @@ -103,7 +103,8 @@ def _validate_extractors(self, extractors: List[Extractor]) -> None: assert isinstance(extractor, tuple(self.accepted_extractors)) # type: ignore except AssertionError as e: self.error( - f"{extractor.__class__.__name__} is not supported by {self.__class__.__name__}" + f"{extractor.__class__.__name__}" + f" is not supported by {self.__class__.__name__}" ) raise e @@ -154,7 +155,7 @@ def __init__( gcd_rescue: str, i3_filters: Union[ Type[I3Filter], List[Type[I3Filter]] - ] = NullSplitI3Filter, + ] = NullSplitI3Filter(), # type: ignore icetray_verbose: int = 0, ): """Initialize `I3Reader`. @@ -199,7 +200,10 @@ def __call__(self, file_path: I3FileSet) -> List[OrderedDict]: # type: ignore """ # Set I3-GCD file pair in extractor for extractor in self._extractors: - extractor.set_files(file_path.i3_file, file_path.gcd_file) # type: ignore + assert isinstance(extractor, I3Extractor) + extractor.set_gcd( + i3_file=file_path.i3_file, gcd_file=file_path.gcd_file + ) # type: ignore # Open I3 file i3_file_io = dataio.I3File(file_path.i3_file, "r") diff --git a/src/graphnet/data/writers.py b/src/graphnet/data/writers.py index d02eef2b4..d23b21ac8 100644 --- a/src/graphnet/data/writers.py +++ b/src/graphnet/data/writers.py @@ -5,11 +5,17 @@ """ import os -from typing import List, Union, OrderedDict, Any +from typing import List, Union, Dict, Any, OrderedDict from abc import abstractmethod, ABC from graphnet.utilities.decorators import final from graphnet.utilities.logging import Logger +from graphnet.data.sqlite.sqlite_utilities import ( + create_table, + create_table_and_save_to_sql, +) + +import pandas as pd class GraphNeTFileSaveMethod(Logger, ABC): @@ -25,19 +31,26 @@ class GraphNeTFileSaveMethod(Logger, ABC): @abstractmethod def _save_file( - self, data: OrderedDict[str, Any], output_file_path: str + self, + data: Dict[str, pd.DataFrame], + output_file_path: str, + n_events: int, ) -> None: """Save the interim data format from a single input file. Args: data: the interim data from a single input file. output_file_path: output file path. + n_events: Number of events container in `data`. """ - return @final def __call__( - self, data: OrderedDict[str, Any], file_name: str, out_dir: str + self, + data: Dict[str, pd.DataFrame], + file_name: str, + output_dir: str, + n_events: int, ) -> None: """Save data. @@ -45,15 +58,67 @@ def __call__( data: data to be saved. file_name: name of input file. Will be used to generate output file name. - out_dir: directory to save data to. + output_dir: directory to save data to. + n_events: Number of events in `data`. """ - output_file_path = os.path.join( - out_dir, file_name, self.file_extension + # make dir + os.makedirs(output_dir, exist_ok=True) + output_file_path = ( + os.path.join(output_dir, file_name) + self.file_extension + ) + + self._save_file( + data=data, output_file_path=output_file_path, n_events=n_events ) - self._save_file(data=data, output_file_path=output_file_path) return @property def file_extension(self) -> str: """Return file extension used to store the data.""" return self._file_extension # type: ignore + + +class SQLiteSaveMethod(GraphNeTFileSaveMethod): + """A method for saving GraphNeT's interim dataformat to SQLite.""" + + _file_extension = ".db" + + def _save_file( + self, + data: Dict[str, pd.DataFrame], + output_file_path: str, + n_events: int, + ) -> None: + """Save data to SQLite database.""" + # Check(s) + if os.path.exists(output_file_path): + self.warning( + f"Output file {output_file_path} already exists. Appending." + ) + + # Concatenate data + if len(data) == 0: + self.warning( + "No data was extracted from the processed I3 file(s). " + f"No data saved to {output_file_path}" + ) + return + + saved_any = False + # Save each dataframe to SQLite database + self.debug(f"Saving to {output_file_path}") + for table, df in data.items(): + if len(df) > 0: + create_table_and_save_to_sql( + df, + table, + output_file_path, + default_type="FLOAT", + integer_primary_key=len(df) <= n_events, + ) + saved_any = True + + if saved_any: + self.debug("- Done saving") + else: + self.warning(f"No data saved to {output_file_path}") From c207ec6ea53b25d3d0d6b19c6dd7a01faaa4874d Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Thu, 8 Feb 2024 15:00:45 +0100 Subject: [PATCH 03/33] restructure --- examples/01_icetray/01_convert_i3_files.py | 2 +- .../02_compare_sqlite_and_parquet.py | 2 +- src/graphnet/data/__init__.py | 5 +- src/graphnet/data/dataconverter.py | 716 ++++++------------ src/graphnet/data/dataconverter_new.py | 279 ------- src/graphnet/data/extractors/__init__.py | 21 +- src/graphnet/data/extractors/extractor.py | 56 -- src/graphnet/data/extractors/i3extractor.py | 106 --- .../data/extractors/i3particleextractor.py | 43 -- .../data/extractors/icecube/__init__.py | 20 + .../data/extractors/icecube/i3extractor.py | 66 ++ .../{ => icecube}/i3featureextractor.py | 9 +- .../{ => icecube}/i3genericextractor.py | 6 +- .../{ => icecube}/i3hybridrecoextractor.py | 2 +- .../{ => icecube}/i3ntmuonlabelsextractor.py | 2 +- .../extractors/icecube/i3particleextractor.py | 44 ++ .../{ => icecube}/i3pisaextractor.py | 2 +- .../{ => icecube}/i3quesoextractor.py | 2 +- .../{ => icecube}/i3retroextractor.py | 4 +- .../{ => icecube}/i3splinempeextractor.py | 2 +- .../{ => icecube}/i3truthextractor.py | 4 +- .../{ => icecube}/i3tumextractor.py | 2 +- .../{ => icecube}/utilities/__init__.py | 0 .../{ => icecube}/utilities/collections.py | 0 .../{ => icecube}/utilities/frames.py | 0 .../icecube/utilities/i3_filters.py} | 0 .../{ => icecube}/utilities/types.py | 4 +- src/graphnet/data/parquet/__init__.py | 2 - .../data/parquet/parquet_dataconverter.py | 52 -- src/graphnet/data/pipeline.py | 4 +- src/graphnet/data/readers/__init__.py | 3 + .../data/readers/graphnet_file_reader.py | 132 ++++ .../data/{readers.py => readers/i3reader.py} | 144 +--- src/graphnet/data/sqlite/__init__.py | 4 - .../data/sqlite/sqlite_dataconverter.py | 349 --------- src/graphnet/data/utilities/__init__.py | 3 + .../data/utilities/parquet_to_sqlite.py | 4 +- .../{sqlite => utilities}/sqlite_utilities.py | 54 +- src/graphnet/data/writers/__init__.py | 4 + .../graphnet_writer.py} | 69 +- src/graphnet/data/writers/parquet_writer.py | 34 + src/graphnet/data/writers/sqlite_writer.py | 224 ++++++ .../deployment/i3modules/graphnet_module.py | 4 +- src/graphnet/models/graphs/edges/minkowski.py | 11 +- src/graphnet/training/weight_fitting.py | 4 +- .../data/test_dataconverters_and_datasets.py | 2 +- tests/data/test_i3extractor.py | 2 +- tests/data/test_i3genericextractor.py | 8 +- 48 files changed, 879 insertions(+), 1633 deletions(-) delete mode 100644 src/graphnet/data/dataconverter_new.py delete mode 100644 src/graphnet/data/extractors/i3extractor.py delete mode 100644 src/graphnet/data/extractors/i3particleextractor.py create mode 100644 src/graphnet/data/extractors/icecube/__init__.py create mode 100644 src/graphnet/data/extractors/icecube/i3extractor.py rename src/graphnet/data/extractors/{ => icecube}/i3featureextractor.py (97%) rename src/graphnet/data/extractors/{ => icecube}/i3genericextractor.py (98%) rename src/graphnet/data/extractors/{ => icecube}/i3hybridrecoextractor.py (96%) rename src/graphnet/data/extractors/{ => icecube}/i3ntmuonlabelsextractor.py (96%) create mode 100644 src/graphnet/data/extractors/icecube/i3particleextractor.py rename src/graphnet/data/extractors/{ => icecube}/i3pisaextractor.py (94%) rename src/graphnet/data/extractors/{ => icecube}/i3quesoextractor.py (94%) rename src/graphnet/data/extractors/{ => icecube}/i3retroextractor.py (97%) rename src/graphnet/data/extractors/{ => icecube}/i3splinempeextractor.py (93%) rename src/graphnet/data/extractors/{ => icecube}/i3truthextractor.py (99%) rename src/graphnet/data/extractors/{ => icecube}/i3tumextractor.py (94%) rename src/graphnet/data/extractors/{ => icecube}/utilities/__init__.py (100%) rename src/graphnet/data/extractors/{ => icecube}/utilities/collections.py (100%) rename src/graphnet/data/extractors/{ => icecube}/utilities/frames.py (100%) rename src/graphnet/data/{filters.py => extractors/icecube/utilities/i3_filters.py} (100%) rename src/graphnet/data/extractors/{ => icecube}/utilities/types.py (98%) delete mode 100644 src/graphnet/data/parquet/__init__.py delete mode 100644 src/graphnet/data/parquet/parquet_dataconverter.py create mode 100644 src/graphnet/data/readers/__init__.py create mode 100644 src/graphnet/data/readers/graphnet_file_reader.py rename src/graphnet/data/{readers.py => readers/i3reader.py} (51%) delete mode 100644 src/graphnet/data/sqlite/__init__.py delete mode 100644 src/graphnet/data/sqlite/sqlite_dataconverter.py rename src/graphnet/data/{sqlite => utilities}/sqlite_utilities.py (72%) create mode 100644 src/graphnet/data/writers/__init__.py rename src/graphnet/data/{writers.py => writers/graphnet_writer.py} (57%) create mode 100644 src/graphnet/data/writers/parquet_writer.py create mode 100644 src/graphnet/data/writers/sqlite_writer.py diff --git a/examples/01_icetray/01_convert_i3_files.py b/examples/01_icetray/01_convert_i3_files.py index 88dcf714a..9a39f95e7 100644 --- a/examples/01_icetray/01_convert_i3_files.py +++ b/examples/01_icetray/01_convert_i3_files.py @@ -3,7 +3,7 @@ import os from graphnet.constants import EXAMPLE_OUTPUT_DIR, TEST_DATA_DIR -from graphnet.data.extractors import ( +from graphnet.data.extractors.icecube import ( I3FeatureExtractorIceCubeUpgrade, I3FeatureExtractorIceCube86, I3RetroExtractor, diff --git a/examples/01_icetray/02_compare_sqlite_and_parquet.py b/examples/01_icetray/02_compare_sqlite_and_parquet.py index 99250d4b0..d3874c5f2 100644 --- a/examples/01_icetray/02_compare_sqlite_and_parquet.py +++ b/examples/01_icetray/02_compare_sqlite_and_parquet.py @@ -7,7 +7,7 @@ from graphnet.data.sqlite import SQLiteDataConverter from graphnet.data.parquet import ParquetDataConverter from graphnet.data.dataset import SQLiteDataset, ParquetDataset -from graphnet.data.extractors import ( +from graphnet.data.extractors.icecube import ( I3FeatureExtractorIceCube86, I3TruthExtractor, I3RetroExtractor, diff --git a/src/graphnet/data/__init__.py b/src/graphnet/data/__init__.py index fbb1ee095..e7eb84ca4 100644 --- a/src/graphnet/data/__init__.py +++ b/src/graphnet/data/__init__.py @@ -1,6 +1,7 @@ """Modules for converting and ingesting data. `graphnet.data` enables converting domain-specific data to industry-standard, -intermediate file formats and reading this data. +intermediate file formats and reading this data. """ -from .filters import I3Filter, I3FilterMask +from .extractors.icecube.utilities.i3_filters import I3Filter, I3FilterMask +from .dataconverter import DataConverter diff --git a/src/graphnet/data/dataconverter.py b/src/graphnet/data/dataconverter.py index 2a67ddce9..efae14a2f 100644 --- a/src/graphnet/data/dataconverter.py +++ b/src/graphnet/data/dataconverter.py @@ -1,496 +1,251 @@ -"""Base `DataConverter` class(es) used in GraphNeT.""" -# type: ignore[name-defined] # Due to use of `init_global_index`. - -from abc import ABC, abstractmethod -from collections import OrderedDict -from dataclasses import dataclass -from functools import wraps -import itertools +"""Contains `DataConverter`.""" +from typing import List, Union, OrderedDict, Dict, Tuple, Any, Optional, Type +from abc import abstractmethod, ABC + +from tqdm import tqdm +import numpy as np from multiprocessing import Manager, Pool, Value import multiprocessing.pool from multiprocessing.sharedctypes import Synchronized +import pandas as pd import os -import re -from typing import ( - Any, - Callable, - Dict, - List, - Optional, - Tuple, - TypeVar, - Union, - cast, -) +from glob import glob -import numpy as np -import pandas as pd -from tqdm import tqdm -from graphnet.data.utilities.random import pairwise_shuffle -from graphnet.data.extractors import ( - I3Extractor, - I3ExtractorCollection, - I3FeatureExtractor, - I3TruthExtractor, - I3GenericExtractor, -) from graphnet.utilities.decorators import final -from graphnet.utilities.filesys import find_i3_files -from graphnet.utilities.imports import has_icecube_package from graphnet.utilities.logging import Logger -from graphnet.data.filters import I3Filter, NullSplitI3Filter - -if has_icecube_package(): - from icecube import icetray, dataio # pyright: reportMissingImports=false +from .readers.graphnet_file_reader import GraphNeTFileReader +from .writers.graphnet_writer import GraphNeTWriter +from .extractors import Extractor +from .dataclasses import I3FileSet -SAVE_STRATEGIES = [ - "1:1", - "sequential_batched", - "pattern_batched", -] - - -# Utility classes -@dataclass -class FileSet: # noqa: D101 - i3_file: str - gcd_file: str - - -# Utility method(s) def init_global_index(index: Synchronized, output_files: List[str]) -> None: """Make `global_index` available to pool workers.""" global global_index, global_output_files # type: ignore[name-defined] global_index, global_output_files = (index, output_files) # type: ignore[name-defined] -F = TypeVar("F", bound=Callable[..., Any]) - - -def cache_output_files(process_method: F) -> F: - """Decorate `process_method` to cache output file names.""" - - @wraps(process_method) - def wrapper(self: Any, *args: Any) -> Any: - try: - # Using multiprocessing - output_files = global_output_files # type: ignore[name-defined] - except NameError: # `global_output_files` not set - # Running on main process - output_files = self._output_files - - output_file = process_method(self, *args) - output_files.append(output_file) - return output_file - - return cast(F, wrapper) - - class DataConverter(ABC, Logger): - """Base class for converting I3-files to intermediate file format.""" + """A finalized data conversion class in GraphNeT. - @property - @abstractmethod - def file_suffix(self) -> str: - """Suffix to use on output files.""" + `DataConverter` provides parallel processing of file conversion and + extraction from experiment-specific file formats to graphnet-supported data + formats. This class also assigns event id's to training examples. + """ def __init__( self, - extractors: List[I3Extractor], - outdir: str, - gcd_rescue: Optional[str] = None, - *, - nb_files_to_batch: Optional[int] = None, - sequential_batch_pattern: Optional[str] = None, - input_file_batch_pattern: Optional[str] = None, - workers: int = 1, + file_reader: Type[GraphNeTFileReader], + save_method: Type[GraphNeTWriter], + extractors: Union[Type[Extractor], List[Type[Extractor]]], index_column: str = "event_no", - icetray_verbose: int = 0, - i3_filters: List[I3Filter] = [], - ): - """Construct DataConverter. - - When using `input_file_batch_pattern`, regular expressions are used to - group files according to their names. All files that match a certain - pattern up to wildcards are grouped into the same output file. This - output file has the same name as the input files that are group into it, - with wildcards replaced with "x". Periods (.) and wildcards (*) have a - special meaning: Periods are interpreted as literal periods, and not as - matching any character (as in standard regex); and wildcards are - interpreted as ".*" in standard regex. - - For instance, the pattern "[A-Z]{1}_[0-9]{5}*.i3.zst" will find all I3 - files whose names contain: - - one capital letter, followed by - - an underscore, followed by - - five numbers, followed by - - any string of characters ending in ".i3.zst" - - This means that, e.g., the files: - - upgrade_genie_step4_141020_A_000000.i3.zst - - upgrade_genie_step4_141020_A_000001.i3.zst - - ... - - upgrade_genie_step4_141020_A_000008.i3.zst - - upgrade_genie_step4_141020_A_000009.i3.zst - would be grouped into the output file named - "upgrade_genie_step4_141020_A_00000x." but the file - - upgrade_genie_step4_141020_A_000010.i3.zst - would end up in a separate group, named - "upgrade_genie_step4_141020_A_00001x.". - """ - # Check(s) - if not isinstance(extractors, (list, tuple)): - extractors = [extractors] - - assert ( - len(extractors) > 0 - ), "Please specify at least one argument of type I3Extractor" - - for extractor in extractors: - assert isinstance( - extractor, I3Extractor - ), f"{type(extractor)} is not a subclass of I3Extractor" - - # Infer saving strategy - save_strategy = self._infer_save_strategy( - nb_files_to_batch, - sequential_batch_pattern, - input_file_batch_pattern, - ) + num_workers: int = 1, + ) -> None: + """Initialize `DataConverter`. - # Member variables - self._outdir = outdir - self._gcd_rescue = gcd_rescue - self._save_strategy = save_strategy - self._nb_files_to_batch = nb_files_to_batch - self._sequential_batch_pattern = sequential_batch_pattern - self._input_file_batch_pattern = input_file_batch_pattern - self._workers = workers - - # I3Filters (NullSplitI3Filter is always included) - self._i3filters = [NullSplitI3Filter()] + i3_filters - - for filter in self._i3filters: - assert isinstance( - filter, I3Filter - ), f"{type(filter)} is not a subclass of I3Filter" - - # Create I3Extractors - self._extractors = I3ExtractorCollection(*extractors) - - # Create shorthand of names of all pulsemaps queried - self._table_names = [extractor.name for extractor in self._extractors] - self._pulsemaps = [ - extractor.name - for extractor in self._extractors - if isinstance(extractor, I3FeatureExtractor) - ] - - # Placeholders for keeping track of sequential event indices and output files + Args: + file_reader: The method used for reading and applying `Extractors`. + save_method: The method used to save the interim data format to + a graphnet supported file format. + extractors: The `Extractor`(s) that will be applied to the input + files. + index_column: Name of the event id column added to the events. + Defaults to "event_no". + num_workers: The number of CPUs used for parallel processing. + Defaults to 1 (no multiprocessing). + """ + # Member Variable Assignment + self._file_reader = file_reader + self._save_method = save_method + self._num_workers = num_workers self._index_column = index_column self._index = 0 self._output_files: List[str] = [] - # Set verbosity - if icetray_verbose == 0: - icetray.I3Logger.global_logger = icetray.I3NullLogger() + # Set Extractors. Will throw error if extractors are incompatible + # with reader. + self._file_reader.set_extractors(extractors) # Base class constructor super().__init__(name=__name__, class_name=self.__class__.__name__) @final def __call__( - self, - directories: Union[str, List[str]], - recursive: Optional[bool] = True, + self, input_dir: Union[str, List[str]], output_dir: str ) -> None: - """Convert I3-files in `directories. + """Extract data from files in `input_dir` and save to disk. Args: - directories: One or more directories, the I3 files within which - should be converted to an intermediate file format. - recursive: Whether or not to search the directories recursively. + input_dir: A directory that contains the input files. + The directory will be searched recursively for files + matching the file extension. + output_dir: The directory to save the files to. Input folder + structure is not respected. """ - # Find all I3 and GCD files in the specified directories. - i3_files, gcd_files = find_i3_files( - directories, self._gcd_rescue, recursive + # Set outdir + self._output_dir = output_dir + # Get the file reader to produce a list of input files + # in the directory + input_files = self._file_reader.find_files(path=input_dir) # type: ignore + self._launch_jobs(input_files=input_files) + self._output_files = glob( + os.path.join( + self._output_dir, f"*{self._save_method.file_extension}" + ) ) - if len(i3_files) == 0: - self.error(f"No files found in {directories}.") - return - - # Save a record of the found I3 files in the output directory. - self._save_filenames(i3_files) - - # Shuffle I3 files to get a more uniform load on worker nodes. - i3_files, gcd_files = pairwise_shuffle(i3_files, gcd_files) - - # Process the files - filesets = [ - FileSet(i3_file, gcd_file) - for i3_file, gcd_file in zip(i3_files, gcd_files) - ] - self.execute(filesets) @final - def execute(self, filesets: List[FileSet]) -> None: - """General method for processing a set of I3 files. - - The files are converted individually according to the inheriting class/ - intermediate file format. - - Args: - filesets: List of paths to I3 and corresponding GCD files. - """ - # Make sure output directory exists. - self.info(f"Saving results to {self._outdir}") - os.makedirs(self._outdir, exist_ok=True) - - # Iterate over batches of files. - try: - if self._save_strategy == "sequential_batched": - # Define batches - assert self._nb_files_to_batch is not None - assert self._sequential_batch_pattern is not None - batches = np.array_split( - np.asarray(filesets), - int(np.ceil(len(filesets) / self._nb_files_to_batch)), - ) - batches = [ - ( - group.tolist(), - self._sequential_batch_pattern.format(ix_batch), - ) - for ix_batch, group in enumerate(batches) - ] - self.info( - f"Will batch {len(filesets)} input files into {len(batches)} groups." - ) - - # Iterate over batches - pool = self._iterate_over_batches_of_files(batches) - - elif self._save_strategy == "pattern_batched": - # Define batches - groups: Dict[str, List[FileSet]] = OrderedDict() - for fileset in sorted(filesets, key=lambda f: f.i3_file): - group = re.sub( - self._sub_from, - self._sub_to, - os.path.basename(fileset.i3_file), - ) - if group not in groups: - groups[group] = list() - groups[group].append(fileset) - - self.info( - f"Will batch {len(filesets)} input files into {len(groups)} groups" - ) - if len(groups) <= 20: - for group, group_filesets in groups.items(): - self.info( - f"> {group}: {len(group_filesets):3d} file(s)" - ) - - batches = [ - (list(group_filesets), group) - for group, group_filesets in groups.items() - ] - - # Iterate over batches - pool = self._iterate_over_batches_of_files(batches) - - elif self._save_strategy == "1:1": - pool = self._iterate_over_individual_files(filesets) - - else: - assert False, "Shouldn't reach here." - - self._update_shared_variables(pool) - - except KeyboardInterrupt: - self.warning("[ctrl+c] Exciting gracefully.") - - @abstractmethod - def save_data(self, data: List[OrderedDict], output_file: str) -> None: - """Implementation-specific method for saving data to file. - - Args: - data: List of extracted features. - output_file: Name of output file. - """ - - @abstractmethod - def merge_files( - self, output_file: str, input_files: Optional[List[str]] = None + def _launch_jobs( + self, + input_files: Union[List[str], List[I3FileSet]], ) -> None: - """Implementation-specific method for merging output files. + """Multi Processing Logic. - Args: - output_file: Name of the output file containing the merged results. - input_files: Intermediate files to be merged, according to the - specific implementation. Default to None, meaning that all - files output by the current instance are merged. - - Raises: - NotImplementedError: If the method has not been implemented for the - backend in question. - """ + Spawns worker pool, + distributes the input files evenly across workers. + declare event_no as globally accessible variable across workers. + starts jobs. - # Internal methods - def _iterate_over_individual_files( - self, args: List[FileSet] - ) -> Optional[multiprocessing.pool.Pool]: + Will call process_file in parallel. + """ # Get appropriate mapping function - map_fn, pool = self.get_map_function(len(args)) + map_fn, pool = self.get_map_function(nb_files=len(input_files)) # Iterate over files for _ in map_fn( - self._process_file, tqdm(args, unit="file(s)", colour="green") + self._process_file, + tqdm(input_files, unit="file(s)", colour="green"), ): - self.debug( - "Saving with 1:1 strategy on the individual worker processes" - ) + self.debug("processing file.") - return pool + self._update_shared_variables(pool) - def _iterate_over_batches_of_files( - self, args: List[Tuple[List[FileSet], str]] - ) -> Optional[multiprocessing.pool.Pool]: - """Iterate over a batch of files and save results on worker process.""" - # Get appropriate mapping function - map_fn, pool = self.get_map_function(len(args), unit="batch(es)") - - # Iterate over batches of files - for _ in map_fn( - self._process_batch, tqdm(args, unit="batch(es)", colour="green") - ): - self.debug("Saving with batched strategy") - - return pool + @final + def _process_file(self, file_path: Union[str, I3FileSet]) -> None: + """Process a single file. - def _update_shared_variables( - self, pool: Optional[multiprocessing.pool.Pool] - ) -> None: - """Update `self._index` and `self._output_files`. + Calls file reader to recieve extracted output, event ids + is assigned to the extracted data and is handed to save method. - If `pool` is set, it means that multiprocessing was used. In this case, - the worker processes will not have been able to write directly to - `self._index` and `self._output_files`, and we need to get them synced - up. + This function is called in parallel. """ - if pool: - # Extract information from shared variables to member variables. - index, output_files = pool._initargs # type: ignore - self._index += index.value - self._output_files.extend(list(sorted(output_files[:]))) - - @cache_output_files - def _process_file( - self, - fileset: FileSet, - ) -> str: - - # Process individual files - data = self._extract_data(fileset) - - # Save data - output_file = self._get_output_file(fileset.i3_file) - self.save_data(data, output_file) - - return output_file - - @cache_output_files - def _process_batch(self, args: Tuple[List[FileSet], str]) -> str: - # Unpack arguments - filesets, output_file_name = args - - # Process individual files - data = list( - itertools.chain.from_iterable(map(self._extract_data, filesets)) + # Read and apply extractors + data = self._file_reader(file_path=file_path) + n_events = len(data) # type: ignore + + # Assign event_no's to each event in data and transform to pd.DataFrame + data = self._assign_event_no(data=data) + + # Create output file name + output_file_name = self._create_file_name(input_file_path=file_path) + + # Apply save method + self._save_method( + data=data, + file_name=output_file_name, + n_events=n_events, + output_dir=self._output_dir, ) - # Save batched data - output_file = self._get_output_file(output_file_name) - self.save_data(data, output_file) - - return output_file - - def _extract_data(self, fileset: FileSet) -> List[OrderedDict]: - """Extract data from single I3 file. - - If the saving strategy is 1:1 (i.e., each I3 file is converted to a - corresponding intermediate file) the data is saved to such a file, and - no data is return from the method. + @final + def _create_file_name(self, input_file_path: Union[str, I3FileSet]) -> str: + """Convert input file path to an output file name.""" + if isinstance(input_file_path, I3FileSet): + input_file_path = input_file_path.i3_file + path_without_extension = os.path.splitext(input_file_path)[0] + base_file_name = path_without_extension.split("/")[-1] + return base_file_name # type: ignore - The above distincting is to allow worker processes to save files rather - than sending it back to the main process. + @final + def _assign_event_no( + self, data: List[OrderedDict[str, Any]] + ) -> Dict[str, pd.DataFrame]: + + # Request event_no's for the entire file + event_nos = self._request_event_nos(n_ids=len(data)) + + # Dict holding pd.DataFrame's + dataframe_dict: Dict = {} + # Loop through events (again..) to assign event_nos + for k in range(len(data)): + for extractor_name in data[k].keys(): + n_rows = self._count_rows( + event_dict=data[k], extractor_name=extractor_name + ) + if n_rows > 0: + data[k][extractor_name][self._index_column] = np.repeat( + event_nos[k], n_rows + ).tolist() + df = pd.DataFrame( + data[k][extractor_name], + index=[0] if n_rows == 1 else None, + ) + if extractor_name in dataframe_dict.keys(): + dataframe_dict[extractor_name].append(df) + else: + dataframe_dict[extractor_name] = [df] + # Merge each list of dataframes + for key in dataframe_dict.keys(): + dataframe_dict[key] = pd.concat( + dataframe_dict[key], axis=0 + ).reset_index(drop=True) + return dataframe_dict - Args: - fileset: Path to I3 file and corresponding GCD file. + @final + def _count_rows( + self, event_dict: OrderedDict[str, Any], extractor_name: str + ) -> int: + """Count number of rows that features from `extractor_name` have.""" + extractor_dict = event_dict[extractor_name] - Returns: - Extracted data. - """ - # Infer whether method is being run using multiprocessing try: - global_index # type: ignore[name-defined] - multi_processing = True - except NameError: - multi_processing = False - - self._extractors.set_files(fileset.i3_file, fileset.gcd_file) - i3_file_io = dataio.I3File(fileset.i3_file, "r") - data = list() - while i3_file_io.more(): - try: - frame = i3_file_io.pop_physics() - except Exception as e: - if "I3" in str(e): - continue - # check if frame should be skipped - if self._skip_frame(frame): - continue - - # Try to extract data from I3Frame - results = self._extractors(frame) - - data_dict = OrderedDict(zip(self._table_names, results)) - - # If an I3GenericExtractor is used, we want each automatically - # parsed key to be stored as a separate table. - for extractor in self._extractors: - if isinstance(extractor, I3GenericExtractor): - data_dict.update(data_dict.pop(extractor._name)) - - # Get new, unique index and increment value - if multi_processing: - with global_index.get_lock(): # type: ignore[name-defined] - index = global_index.value # type: ignore[name-defined] - global_index.value += 1 # type: ignore[name-defined] + # If all features in extractor_name have the same length + # this line of code will execute without error and result + # in an array with shape [num_features, n_rows_in_feature] + # unless the list is empty! + + shape = np.asarray(list(extractor_dict.values())).shape + if len(shape) > 1: + n_rows = shape[1] else: - index = self._index - self._index += 1 - - # Attach index to all tables - for table in data_dict.keys(): - data_dict[table][self._index_column] = index - - data.append(data_dict) + n_rows = 1 + except ValueError as e: + self.error( + f"Features from {extractor_name} ({extractor_dict.keys()}) have different lengths." + ) + raise e + return n_rows + + def _request_event_nos(self, n_ids: int) -> List[int]: + + # Get new, unique index and increment value + if self._num_workers > 1: + with global_index.get_lock(): # type: ignore[name-defined] + starting_index = global_index.value # type: ignore[name-defined] + event_nos = np.arange( + starting_index, starting_index + n_ids, 1 + ).tolist() + global_index.value += n_ids # type: ignore[name-defined] + else: + starting_index = self._index + event_nos = np.arange( + starting_index, starting_index + n_ids, 1 + ).tolist() + self._index += n_ids - return data + return event_nos + @final def get_map_function( - self, nb_files: int, unit: str = "I3 file(s)" + self, nb_files: int, unit: str = "file(s)" ) -> Tuple[Any, Optional[multiprocessing.pool.Pool]]: """Identify map function to use (pure python or multiprocess).""" # Choose relevant map-function given the requested number of workers. - workers = min(self._workers, nb_files) - if workers > 1: + n_workers = min(self._num_workers, nb_files) + if n_workers > 1: self.info( - f"Starting pool of {workers} workers to process {nb_files} {unit}" + f"Starting pool of {n_workers} workers to process {nb_files} {unit}" ) manager = Manager() @@ -498,7 +253,7 @@ def get_map_function( output_files = manager.list() pool = Pool( - processes=workers, + processes=n_workers, initializer=init_global_index, initargs=(index, output_files), ) @@ -513,75 +268,50 @@ def get_map_function( return map_fn, pool - def _infer_save_strategy( - self, - nb_files_to_batch: Optional[int] = None, - sequential_batch_pattern: Optional[str] = None, - input_file_batch_pattern: Optional[str] = None, - ) -> str: - if input_file_batch_pattern is not None: - save_strategy = "pattern_batched" - - assert ( - "*" in input_file_batch_pattern - ), "Argument `input_file_batch_pattern` should contain at least one wildcard (*)" - - fields = [ - "(" + field + ")" - for field in input_file_batch_pattern.replace( - ".", r"\." - ).split("*") - ] - nb_fields = len(fields) - self._sub_from = ".*".join(fields) - self._sub_to = "x".join([f"\\{ix + 1}" for ix in range(nb_fields)]) - - if sequential_batch_pattern is not None: - self.warning("Argument `sequential_batch_pattern` ignored.") - if nb_files_to_batch is not None: - self.warning("Argument `nb_files_to_batch` ignored.") - - elif (nb_files_to_batch is not None) or ( - sequential_batch_pattern is not None - ): - save_strategy = "sequential_batched" + @final + def _update_shared_variables( + self, pool: Optional[multiprocessing.pool.Pool] + ) -> None: + """Update `self._index` and `self._output_files`. - assert (nb_files_to_batch is not None) and ( - sequential_batch_pattern is not None - ), "Please specify both `nb_files_to_batch` and `sequential_batch_pattern` for sequential batching." + If `pool` is set, it means that multiprocessing was used. In this case, + the worker processes will not have been able to write directly to + `self._index` and `self._output_files`, and we need to get them synced + up. + """ + if pool: + # Extract information from shared variables to member variables. + index, output_files = pool._initargs # type: ignore + self._index += index.value + self._output_files.extend(list(sorted(output_files[:]))) - else: - save_strategy = "1:1" - - return save_strategy - - def _save_filenames(self, i3_files: List[str]) -> None: - """Save I3 file names in CSV format.""" - self.debug("Saving input file names to config CSV.") - config_dir = os.path.join(self._outdir, "config") - os.makedirs(config_dir, exist_ok=True) - df_i3_files = pd.DataFrame(data=i3_files, columns=["filename"]) - df_i3_files.to_csv(os.path.join(config_dir, "i3files.csv")) - - def _get_output_file(self, input_file: str) -> str: - assert isinstance(input_file, str) - basename = os.path.basename(input_file) - output_file = os.path.join( - self._outdir, - re.sub(r"\.i3\..*", "", basename) + "." + self.file_suffix, - ) - return output_file + @final + def merge_files(self, files: Optional[List[str]] = None) -> None: + """Merge converted files. - def _skip_frame(self, frame: "icetray.I3Frame") -> bool: - """Check the user defined filters. + `DataConverter` will call the `.merge_files` method in the + `GraphNeTWriter` module that it was instantiated with. - Returns: - bool: True if frame should be skipped, False otherwise. + Args: + files: Intermediate files to be merged. """ - if self._i3filters is None: - return False # No filters defined, so we keep the frame + if (files is None) & (len(self._output_files) > 0): + # If no input files are given, but output files from conversion + # is available. + files_to_merge = self._output_files + elif files is not None: + # Proceed to merge specified by user. + files_to_merge = files + else: + # Raise error + self.error( + "This DataConverter does not have output files set," + "and you must therefore specify argument `files`." + ) + assert files is not None - for filter in self._i3filters: - if not filter(frame): - return True # keep_frame call false, skip the frame. - return False # All filter keep_frame calls true, keep the frame. + # Merge files + self._save_method.merge_files( # type:ignore + files=files_to_merge, + output_dir=os.path.join(self._output_dir, "merged"), + ) diff --git a/src/graphnet/data/dataconverter_new.py b/src/graphnet/data/dataconverter_new.py deleted file mode 100644 index eb51495d1..000000000 --- a/src/graphnet/data/dataconverter_new.py +++ /dev/null @@ -1,279 +0,0 @@ -"""Contains `DataConverter`.""" -from typing import List, Union, OrderedDict, Dict, Tuple, Any, Optional, Type -from abc import abstractmethod, ABC - -from tqdm import tqdm -import numpy as np -from multiprocessing import Manager, Pool, Value -import multiprocessing.pool -from multiprocessing.sharedctypes import Synchronized -import pandas as pd -import os - -from graphnet.utilities.decorators import final -from graphnet.utilities.logging import Logger -from .readers import GraphNeTFileReader -from .writers import GraphNeTFileSaveMethod -from .extractors import Extractor -from .dataclasses import I3FileSet - - -def init_global_index(index: Synchronized, output_files: List[str]) -> None: - """Make `global_index` available to pool workers.""" - global global_index, global_output_files # type: ignore[name-defined] - global_index, global_output_files = (index, output_files) # type: ignore[name-defined] - - -class DataConverter(ABC, Logger): - """A finalized data conversion class in GraphNeT. - - `DataConverter` provides parallel processing of file conversion and - extraction from experiment-specific file formats to graphnet-supported data - formats. This class also assigns event id's to training examples. - """ - - def __init__( - self, - file_reader: Type[GraphNeTFileReader], - save_method: Type[GraphNeTFileSaveMethod], - extractors: Union[Type[Extractor], List[Type[Extractor]]], - index_column: str = "event_no", - num_workers: int = 1, - ) -> None: - """Initialize `DataConverter`. - - Args: - file_reader: The method used for reading and applying `Extractors`. - save_method: The method used to save the interim data format to - a graphnet supported file format. - extractors: The `Extractor`(s) that will be applied to the input - files. - index_column: Name of the event id column added to the events. - Defaults to "event_no". - num_workers: The number of CPUs used for parallel processing. - Defaults to 1 (no multiprocessing). - """ - # Member Variable Assignment - self._file_reader = file_reader - self._save_method = save_method - self._num_workers = num_workers - self._index_column = index_column - self._index = 0 - self._output_files: List[str] = [] - - # Set Extractors. Will throw error if extractors are incompatible - # with reader. - self._file_reader.set_extractors(extractors) - - # Base class constructor - super().__init__(name=__name__, class_name=self.__class__.__name__) - - @final - def __call__( - self, input_dir: Union[str, List[str]], output_dir: str - ) -> None: - """Extract data from files in `input_dir` and save to disk. - - Args: - input_dir: A directory that contains the input files. - The directory will be searched recursively for files - matching the file extension. - output_dir: The directory to save the files to. Input folder - structure is not respected. - """ - # Set outdir - self._output_dir = output_dir - # Get the file reader to produce a list of input files - # in the directory - input_files = self._file_reader.find_files(path=input_dir) # type: ignore - self._launch_jobs(input_files=input_files) - - @final - def _launch_jobs( - self, - input_files: Union[List[str], List[I3FileSet]], - ) -> None: - """Multi Processing Logic. - - Spawns worker pool, - distributes the input files evenly across workers. - declare event_no as globally accessible variable across workers. - starts jobs. - - Will call process_file in parallel. - """ - # Get appropriate mapping function - map_fn, pool = self.get_map_function(nb_files=len(input_files)) - - # Iterate over files - for _ in map_fn( - self._process_file, - tqdm(input_files, unit="file(s)", colour="green"), - ): - self.debug("processing file.") - - self._update_shared_variables(pool) - - @final - def _process_file(self, file_path: Union[str, I3FileSet]) -> None: - """Process a single file. - - Calls file reader to recieve extracted output, event ids - is assigned to the extracted data and is handed to save method. - - This function is called in parallel. - """ - # Read and apply extractors - data = self._file_reader(file_path=file_path) - n_events = len(data) # type: ignore - - # Assign event_no's to each event in data and transform to pd.DataFrame - data = self._assign_event_no(data=data) - - # Create output file name - output_file_name = self._create_file_name(input_file_path=file_path) - - # Apply save method - self._save_method( - data=data, - file_name=output_file_name, - n_events=n_events, - output_dir=self._output_dir, - ) - - @final - def _create_file_name(self, input_file_path: Union[str, I3FileSet]) -> str: - """Convert input file path to an output file name.""" - if isinstance(input_file_path, I3FileSet): - input_file_path = input_file_path.i3_file - path_without_extension = os.path.splitext(input_file_path)[0] - base_file_name = path_without_extension.split("/")[-1] - return base_file_name # type: ignore - - @final - def _assign_event_no( - self, data: List[OrderedDict[str, Any]] - ) -> Dict[str, pd.DataFrame]: - - # Request event_no's for the entire file - event_nos = self._request_event_nos(n_ids=len(data)) - - # Dict holding pd.DataFrame's - dataframe_dict: Dict = {} - # Loop through events (again..) to assign event_nos - for k in range(len(data)): - for extractor_name in data[k].keys(): - n_rows = self._count_rows( - event_dict=data[k], extractor_name=extractor_name - ) - if n_rows > 0: - data[k][extractor_name][self._index_column] = np.repeat( - event_nos[k], n_rows - ).tolist() - df = pd.DataFrame( - data[k][extractor_name], - index=[0] if n_rows == 1 else None, - ) - if extractor_name in dataframe_dict.keys(): - dataframe_dict[extractor_name].append(df) - else: - dataframe_dict[extractor_name] = [df] - # Merge each list of dataframes - for key in dataframe_dict.keys(): - dataframe_dict[key] = pd.concat( - dataframe_dict[key], axis=0 - ).reset_index(drop=True) - return dataframe_dict - - @final - def _count_rows( - self, event_dict: OrderedDict[str, Any], extractor_name: str - ) -> int: - """Count number of rows that features from `extractor_name` have.""" - extractor_dict = event_dict[extractor_name] - - try: - # If all features in extractor_name have the same length - # this line of code will execute without error and result - # in an array with shape [num_features, n_rows_in_feature] - # unless the list is empty! - - shape = np.asarray(list(extractor_dict.values())).shape - if len(shape) > 1: - n_rows = shape[1] - else: - n_rows = 1 - except ValueError as e: - self.error( - f"Features from {extractor_name} ({extractor_dict.keys()}) have different lengths." - ) - raise e - return n_rows - - def _request_event_nos(self, n_ids: int) -> List[int]: - - # Get new, unique index and increment value - if self._num_workers > 1: - with global_index.get_lock(): # type: ignore[name-defined] - starting_index = global_index.value # type: ignore[name-defined] - event_nos = np.arange( - starting_index, starting_index + n_ids, 1 - ).tolist() - global_index.value += n_ids # type: ignore[name-defined] - else: - starting_index = self._index - event_nos = np.arange( - starting_index, starting_index + n_ids, 1 - ).tolist() - self._index += n_ids - - return event_nos - - @final - def get_map_function( - self, nb_files: int, unit: str = "file(s)" - ) -> Tuple[Any, Optional[multiprocessing.pool.Pool]]: - """Identify map function to use (pure python or multiprocess).""" - # Choose relevant map-function given the requested number of workers. - n_workers = min(self._num_workers, nb_files) - if n_workers > 1: - self.info( - f"Starting pool of {n_workers} workers to process {nb_files} {unit}" - ) - - manager = Manager() - index = Value("i", 0) - output_files = manager.list() - - pool = Pool( - processes=n_workers, - initializer=init_global_index, - initargs=(index, output_files), - ) - map_fn = pool.imap - - else: - self.info( - f"Processing {nb_files} {unit} in main thread (not multiprocessing)" - ) - map_fn = map # type: ignore - pool = None - - return map_fn, pool - - @final - def _update_shared_variables( - self, pool: Optional[multiprocessing.pool.Pool] - ) -> None: - """Update `self._index` and `self._output_files`. - - If `pool` is set, it means that multiprocessing was used. In this case, - the worker processes will not have been able to write directly to - `self._index` and `self._output_files`, and we need to get them synced - up. - """ - if pool: - # Extract information from shared variables to member variables. - index, output_files = pool._initargs # type: ignore - self._index += index.value - self._output_files.extend(list(sorted(output_files[:]))) diff --git a/src/graphnet/data/extractors/__init__.py b/src/graphnet/data/extractors/__init__.py index ec0ecfe5e..c6f4f325e 100644 --- a/src/graphnet/data/extractors/__init__.py +++ b/src/graphnet/data/extractors/__init__.py @@ -1,21 +1,2 @@ -"""Collection of I3Extractors, extracting pure-python data from I3Frames.""" - -from .i3extractor import I3Extractor, I3ExtractorCollection -from .i3featureextractor import ( - I3FeatureExtractor, - I3FeatureExtractorIceCube86, - I3FeatureExtractorIceCubeDeepCore, - I3FeatureExtractorIceCubeUpgrade, - I3PulseNoiseTruthFlagIceCubeUpgrade, -) -from .i3truthextractor import I3TruthExtractor -from .i3retroextractor import I3RetroExtractor -from .i3splinempeextractor import I3SplineMPEICExtractor -from .i3particleextractor import I3ParticleExtractor -from .i3tumextractor import I3TUMExtractor -from .i3hybridrecoextractor import I3GalacticPlaneHybridRecoExtractor -from .i3genericextractor import I3GenericExtractor -from .i3pisaextractor import I3PISAExtractor -from .i3ntmuonlabelsextractor import I3NTMuonLabelExtractor -from .i3quesoextractor import I3QUESOExtractor +"""Module containing data-specific extractor modules.""" from .extractor import Extractor diff --git a/src/graphnet/data/extractors/extractor.py b/src/graphnet/data/extractors/extractor.py index 795d05cf1..b5e5ed37c 100644 --- a/src/graphnet/data/extractors/extractor.py +++ b/src/graphnet/data/extractors/extractor.py @@ -49,59 +49,3 @@ def __call__(self, frame: "icetray.I3Frame") -> dict: def name(self) -> str: """Get the name of the `I3Extractor` instance.""" return self._extractor_name - - -class I3Extractor(Extractor): - """Base class for extracting information from physics I3-frames. - - Contains functionality required to extract data from i3 files, used by - the IceCube Neutrino Observatory. - - All classes inheriting from `I3Extractor` should implement the `__call__` - method. - """ - - def __init__(self, extractor_name: str): - """Construct I3Extractor. - - Args: - extractor_name: Name of the `I3Extractor` instance. Used to keep track of the - provenance of different data, and to name tables to which this - data is saved. - """ - # Member variable(s) - self._i3_file: str = "" - self._gcd_file: str = "" - self._gcd_dict: Dict[int, Any] = {} - self._calibration: Optional["icetray.I3Frame.Calibration"] = None - - # Base class constructor - super().__init__(extractor_name=extractor_name) - - def set_gcd(self, gcd_file: str, i3_file: str) -> None: - """Load the geospatial information contained in the GCD-file.""" - # If no GCD file is provided, search the I3 file for frames containing - # geometry (G) and calibration (C) information. - gcd = dataio.I3File(gcd_file or i3_file) - - try: - g_frame = gcd.pop_frame(icetray.I3Frame.Geometry) - except RuntimeError: - self.error( - "No GCD file was provided and no G-frame was found. Exiting." - ) - raise - else: - self._gcd_dict = g_frame["I3Geometry"].omgeo - - try: - c_frame = gcd.pop_frame(icetray.I3Frame.Calibration) - except RuntimeError: - self.warning("No GCD file was provided and no C-frame was found.") - else: - self._calibration = c_frame["I3Calibration"] - - @abstractmethod - def __call__(self, frame: "icetray.I3Frame") -> dict: - """Extract information from frame.""" - pass diff --git a/src/graphnet/data/extractors/i3extractor.py b/src/graphnet/data/extractors/i3extractor.py deleted file mode 100644 index 90a982387..000000000 --- a/src/graphnet/data/extractors/i3extractor.py +++ /dev/null @@ -1,106 +0,0 @@ -"""Base I3Extractor class(es).""" - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Dict, List, Optional - -from graphnet.utilities.imports import has_icecube_package -from graphnet.utilities.logging import Logger - -if has_icecube_package() or TYPE_CHECKING: - from icecube import icetray, dataio # pyright: reportMissingImports=false - - -class I3Extractor(ABC, Logger): - """Base class for extracting information from physics I3-frames. - - All classes inheriting from `I3Extractor` should implement the `__call__` - method, and can be applied directly on `icetray.I3Frame` objects to return - extracted, pure-python data. - """ - - def __init__(self, name: str): - """Construct I3Extractor. - - Args: - name: Name of the `I3Extractor` instance. Used to keep track of the - provenance of different data, and to name tables to which this - data is saved. - """ - # Member variable(s) - self._i3_file: str = "" - self._gcd_file: str = "" - self._gcd_dict: Dict[int, Any] = {} - self._calibration: Optional["icetray.I3Frame.Calibration"] = None - self._name: str = name - - # Base class constructor - super().__init__(name=__name__, class_name=self.__class__.__name__) - - def set_files(self, i3_file: str, gcd_file: str) -> None: - """Store references to the I3- and GCD-files being processed.""" - # @TODO: Is it necessary to set the `i3_file`? It is only used in one - # place in `I3TruthExtractor`, and there only in a way that might - # be solved another way. - self._i3_file = i3_file - self._gcd_file = gcd_file - self._load_gcd_data() - - def _load_gcd_data(self) -> None: - """Load the geospatial information contained in the GCD-file.""" - # If no GCD file is provided, search the I3 file for frames containing - # geometry (G) and calibration (C) information. - gcd_file = dataio.I3File(self._gcd_file or self._i3_file) - - try: - g_frame = gcd_file.pop_frame(icetray.I3Frame.Geometry) - except RuntimeError: - self.error( - "No GCD file was provided and no G-frame was found. Exiting." - ) - raise - else: - self._gcd_dict = g_frame["I3Geometry"].omgeo - - try: - c_frame = gcd_file.pop_frame(icetray.I3Frame.Calibration) - except RuntimeError: - self.warning("No GCD file was provided and no C-frame was found.") - else: - self._calibration = c_frame["I3Calibration"] - - @abstractmethod - def __call__(self, frame: "icetray.I3Frame") -> dict: - """Extract information from frame.""" - pass - - @property - def name(self) -> str: - """Get the name of the `I3Extractor` instance.""" - return self._name - - -class I3ExtractorCollection(list): - """Class to manage multiple I3Extractors.""" - - def __init__(self, *extractors: I3Extractor): - """Construct I3ExtractorCollection. - - Args: - *extractors: List of `I3Extractor`s to be treated as a single - collection. - """ - # Check(s) - for extractor in extractors: - assert isinstance(extractor, I3Extractor) - - # Base class constructor - super().__init__(extractors) - - def set_files(self, i3_file: str, gcd_file: str) -> None: - """Store references to the I3- and GCD-files being processed.""" - for extractor in self: - extractor.set_files(i3_file, gcd_file) - - def __call__(self, frame: "icetray.I3Frame") -> List[dict]: - """Extract information from frame for each member `I3Extractor`.""" - return [extractor(frame) for extractor in self] diff --git a/src/graphnet/data/extractors/i3particleextractor.py b/src/graphnet/data/extractors/i3particleextractor.py deleted file mode 100644 index bd37424d2..000000000 --- a/src/graphnet/data/extractors/i3particleextractor.py +++ /dev/null @@ -1,43 +0,0 @@ -"""I3Extractor class(es) for extracting I3Particle properties.""" - -from typing import TYPE_CHECKING, Dict - -from graphnet.data.extractors.i3extractor import I3Extractor - -if TYPE_CHECKING: - from icecube import icetray # pyright: reportMissingImports=false - - -class I3ParticleExtractor(I3Extractor): - """Class for extracting I3Particle properties. - - Can be used to extract predictions from other algorithms for comparisons - with GraphNeT. - """ - - def __init__(self, name: str): - """Construct I3ParticleExtractor.""" - # Base class constructor - super().__init__(name) - - def __call__(self, frame: "icetray.I3Frame") -> Dict[str, float]: - """Extract I3Particle properties from I3Particle in frame.""" - output = {} - if self._name in frame: - output.update( - { - "zenith_" + self._name: frame[self._name].dir.zenith, - "azimuth_" + self._name: frame[self._name].dir.azimuth, - "dir_x_" + self._name: frame[self._name].dir.x, - "dir_y_" + self._name: frame[self._name].dir.y, - "dir_z_" + self._name: frame[self._name].dir.z, - "pos_x_" + self._name: frame[self._name].pos.x, - "pos_y_" + self._name: frame[self._name].pos.y, - "pos_z_" + self._name: frame[self._name].pos.z, - "time_" + self._name: frame[self._name].time, - "speed_" + self._name: frame[self._name].speed, - "energy_" + self._name: frame[self._name].energy, - } - ) - - return output diff --git a/src/graphnet/data/extractors/icecube/__init__.py b/src/graphnet/data/extractors/icecube/__init__.py new file mode 100644 index 000000000..11befe581 --- /dev/null +++ b/src/graphnet/data/extractors/icecube/__init__.py @@ -0,0 +1,20 @@ +"""Collection of I3Extractors, extracting pure-python data from I3Frames.""" + +from .i3extractor import I3Extractor +from .i3featureextractor import ( + I3FeatureExtractor, + I3FeatureExtractorIceCube86, + I3FeatureExtractorIceCubeDeepCore, + I3FeatureExtractorIceCubeUpgrade, + I3PulseNoiseTruthFlagIceCubeUpgrade, +) +from .i3truthextractor import I3TruthExtractor +from .i3retroextractor import I3RetroExtractor +from .i3splinempeextractor import I3SplineMPEICExtractor +from .i3particleextractor import I3ParticleExtractor +from .i3tumextractor import I3TUMExtractor +from .i3hybridrecoextractor import I3GalacticPlaneHybridRecoExtractor +from .i3genericextractor import I3GenericExtractor +from .i3pisaextractor import I3PISAExtractor +from .i3ntmuonlabelsextractor import I3NTMuonLabelExtractor +from .i3quesoextractor import I3QUESOExtractor diff --git a/src/graphnet/data/extractors/icecube/i3extractor.py b/src/graphnet/data/extractors/icecube/i3extractor.py new file mode 100644 index 000000000..d997efcc4 --- /dev/null +++ b/src/graphnet/data/extractors/icecube/i3extractor.py @@ -0,0 +1,66 @@ +"""Base I3Extractor class(es).""" + +from abc import abstractmethod +from typing import TYPE_CHECKING, Any, Dict, Optional + +from graphnet.utilities.imports import has_icecube_package +from graphnet.data.extractors import Extractor + +if has_icecube_package() or TYPE_CHECKING: + from icecube import icetray, dataio # pyright: reportMissingImports=false + + +class I3Extractor(Extractor): + """Base class for extracting information from physics I3-frames. + + Contains functionality required to extract data from i3 files, used by + the IceCube Neutrino Observatory. + + All classes inheriting from `I3Extractor` should implement the `__call__` + method. + """ + + def __init__(self, extractor_name: str): + """Construct I3Extractor. + + Args: + extractor_name: Name of the `I3Extractor` instance. Used to keep track of the + provenance of different data, and to name tables to which this + data is saved. + """ + # Member variable(s) + self._i3_file: str = "" + self._gcd_file: str = "" + self._gcd_dict: Dict[int, Any] = {} + self._calibration: Optional["icetray.I3Frame.Calibration"] = None + + # Base class constructor + super().__init__(extractor_name=extractor_name) + + def set_gcd(self, gcd_file: str, i3_file: str) -> None: + """Load the geospatial information contained in the GCD-file.""" + # If no GCD file is provided, search the I3 file for frames containing + # geometry (G) and calibration (C) information. + gcd = dataio.I3File(gcd_file or i3_file) + + try: + g_frame = gcd.pop_frame(icetray.I3Frame.Geometry) + except RuntimeError: + self.error( + "No GCD file was provided and no G-frame was found. Exiting." + ) + raise + else: + self._gcd_dict = g_frame["I3Geometry"].omgeo + + try: + c_frame = gcd.pop_frame(icetray.I3Frame.Calibration) + except RuntimeError: + self.warning("No GCD file was provided and no C-frame was found.") + else: + self._calibration = c_frame["I3Calibration"] + + @abstractmethod + def __call__(self, frame: "icetray.I3Frame") -> dict: + """Extract information from frame.""" + pass diff --git a/src/graphnet/data/extractors/i3featureextractor.py b/src/graphnet/data/extractors/icecube/i3featureextractor.py similarity index 97% rename from src/graphnet/data/extractors/i3featureextractor.py rename to src/graphnet/data/extractors/icecube/i3featureextractor.py index f351f0f3a..258bb368c 100644 --- a/src/graphnet/data/extractors/i3featureextractor.py +++ b/src/graphnet/data/extractors/icecube/i3featureextractor.py @@ -1,17 +1,14 @@ """I3Extractor class(es) for extracting specific, reconstructed features.""" from typing import TYPE_CHECKING, Any, Dict, List -from graphnet.data.extractors.extractor import I3Extractor -from graphnet.data.extractors.utilities.frames import ( +from .i3extractor import I3Extractor +from graphnet.data.extractors.icecube.utilities.frames import ( get_om_keys_and_pulseseries, ) from graphnet.utilities.imports import has_icecube_package if has_icecube_package() or TYPE_CHECKING: - from icecube import ( - icetray, - dataclasses, - ) # pyright: reportMissingImports=false + from icecube import icetray # pyright: reportMissingImports=false class I3FeatureExtractor(I3Extractor): diff --git a/src/graphnet/data/extractors/i3genericextractor.py b/src/graphnet/data/extractors/icecube/i3genericextractor.py similarity index 98% rename from src/graphnet/data/extractors/i3genericextractor.py rename to src/graphnet/data/extractors/icecube/i3genericextractor.py index 6a86303e7..e907181d0 100644 --- a/src/graphnet/data/extractors/i3genericextractor.py +++ b/src/graphnet/data/extractors/icecube/i3genericextractor.py @@ -2,12 +2,12 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union -from graphnet.data.extractors.i3extractor import I3Extractor -from graphnet.data.extractors.utilities.types import ( +from graphnet.data.extractors.icecube import I3Extractor +from graphnet.data.extractors.icecube.utilities.types import ( cast_object_to_pure_python, cast_pulse_series_to_pure_python, ) -from graphnet.data.extractors.utilities.collections import ( +from graphnet.data.extractors.icecube.utilities.collections import ( transpose_list_of_dicts, serialise, flatten_nested_dictionary, diff --git a/src/graphnet/data/extractors/i3hybridrecoextractor.py b/src/graphnet/data/extractors/icecube/i3hybridrecoextractor.py similarity index 96% rename from src/graphnet/data/extractors/i3hybridrecoextractor.py rename to src/graphnet/data/extractors/icecube/i3hybridrecoextractor.py index 74f445120..90525bcab 100644 --- a/src/graphnet/data/extractors/i3hybridrecoextractor.py +++ b/src/graphnet/data/extractors/icecube/i3hybridrecoextractor.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, Dict -from graphnet.data.extractors.i3extractor import I3Extractor +from graphnet.data.extractors.icecube import I3Extractor if TYPE_CHECKING: from icecube import icetray # pyright: reportMissingImports=false diff --git a/src/graphnet/data/extractors/i3ntmuonlabelsextractor.py b/src/graphnet/data/extractors/icecube/i3ntmuonlabelsextractor.py similarity index 96% rename from src/graphnet/data/extractors/i3ntmuonlabelsextractor.py rename to src/graphnet/data/extractors/icecube/i3ntmuonlabelsextractor.py index 1ca3e8bcb..039b13cfe 100644 --- a/src/graphnet/data/extractors/i3ntmuonlabelsextractor.py +++ b/src/graphnet/data/extractors/icecube/i3ntmuonlabelsextractor.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Dict -from graphnet.data.extractors.i3extractor import I3Extractor +from graphnet.data.extractors.icecube.i3extractor import I3Extractor if TYPE_CHECKING: from icecube import icetray # pyright: reportMissingImports=false diff --git a/src/graphnet/data/extractors/icecube/i3particleextractor.py b/src/graphnet/data/extractors/icecube/i3particleextractor.py new file mode 100644 index 000000000..a50c11d21 --- /dev/null +++ b/src/graphnet/data/extractors/icecube/i3particleextractor.py @@ -0,0 +1,44 @@ +"""I3Extractor class(es) for extracting I3Particle properties.""" + +from typing import TYPE_CHECKING, Dict + +from graphnet.data.extractors.icecube import I3Extractor + +if TYPE_CHECKING: + from icecube import icetray # pyright: reportMissingImports=false + + +class I3ParticleExtractor(I3Extractor): + """Class for extracting I3Particle properties. + + Can be used to extract predictions from other algorithms for comparisons + with GraphNeT. + """ + + def __init__(self, extractor_name: str): + """Construct I3ParticleExtractor.""" + # Base class constructor + super().__init__(extractor_name=extractor_name) + + def __call__(self, frame: "icetray.I3Frame") -> Dict[str, float]: + """Extract I3Particle properties from I3Particle in frame.""" + output = {} + name = self._extractor_name + if name in frame: + output.update( + { + "zenith_" + name: frame[name].dir.zenith, + "azimuth_" + name: frame[name].dir.azimuth, + "dir_x_" + name: frame[name].dir.x, + "dir_y_" + name: frame[name].dir.y, + "dir_z_" + name: frame[name].dir.z, + "pos_x_" + name: frame[name].pos.x, + "pos_y_" + name: frame[name].pos.y, + "pos_z_" + name: frame[name].pos.z, + "time_" + name: frame[name].time, + "speed_" + name: frame[name].speed, + "energy_" + name: frame[name].energy, + } + ) + + return output diff --git a/src/graphnet/data/extractors/i3pisaextractor.py b/src/graphnet/data/extractors/icecube/i3pisaextractor.py similarity index 94% rename from src/graphnet/data/extractors/i3pisaextractor.py rename to src/graphnet/data/extractors/icecube/i3pisaextractor.py index fd5a09583..f14a8046a 100644 --- a/src/graphnet/data/extractors/i3pisaextractor.py +++ b/src/graphnet/data/extractors/icecube/i3pisaextractor.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, Dict -from graphnet.data.extractors.i3extractor import I3Extractor +from graphnet.data.extractors.icecube.i3extractor import I3Extractor if TYPE_CHECKING: from icecube import icetray # pyright: reportMissingImports=false diff --git a/src/graphnet/data/extractors/i3quesoextractor.py b/src/graphnet/data/extractors/icecube/i3quesoextractor.py similarity index 94% rename from src/graphnet/data/extractors/i3quesoextractor.py rename to src/graphnet/data/extractors/icecube/i3quesoextractor.py index b72b20046..e29c72a41 100644 --- a/src/graphnet/data/extractors/i3quesoextractor.py +++ b/src/graphnet/data/extractors/icecube/i3quesoextractor.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Dict -from graphnet.data.extractors.i3extractor import I3Extractor +from graphnet.data.extractors.icecube.i3extractor import I3Extractor if TYPE_CHECKING: from icecube import icetray # pyright: reportMissingImports=false diff --git a/src/graphnet/data/extractors/i3retroextractor.py b/src/graphnet/data/extractors/icecube/i3retroextractor.py similarity index 97% rename from src/graphnet/data/extractors/i3retroextractor.py rename to src/graphnet/data/extractors/icecube/i3retroextractor.py index cd55d01f4..aaeb773b4 100644 --- a/src/graphnet/data/extractors/i3retroextractor.py +++ b/src/graphnet/data/extractors/icecube/i3retroextractor.py @@ -2,8 +2,8 @@ from typing import TYPE_CHECKING, Any, Dict -from graphnet.data.extractors.i3extractor import I3Extractor -from graphnet.data.extractors.utilities.frames import ( +from graphnet.data.extractors.icecube import I3Extractor +from graphnet.data.extractors.icecube.utilities.frames import ( frame_is_montecarlo, frame_is_noise, ) diff --git a/src/graphnet/data/extractors/i3splinempeextractor.py b/src/graphnet/data/extractors/icecube/i3splinempeextractor.py similarity index 93% rename from src/graphnet/data/extractors/i3splinempeextractor.py rename to src/graphnet/data/extractors/icecube/i3splinempeextractor.py index e47b2e71d..1439ada51 100644 --- a/src/graphnet/data/extractors/i3splinempeextractor.py +++ b/src/graphnet/data/extractors/icecube/i3splinempeextractor.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Dict -from graphnet.data.extractors.i3extractor import I3Extractor +from graphnet.data.extractors.icecube import I3Extractor if TYPE_CHECKING: from icecube import icetray # pyright: reportMissingImports=false diff --git a/src/graphnet/data/extractors/i3truthextractor.py b/src/graphnet/data/extractors/icecube/i3truthextractor.py similarity index 99% rename from src/graphnet/data/extractors/i3truthextractor.py rename to src/graphnet/data/extractors/icecube/i3truthextractor.py index d04be69b2..b715e57ab 100644 --- a/src/graphnet/data/extractors/i3truthextractor.py +++ b/src/graphnet/data/extractors/icecube/i3truthextractor.py @@ -4,8 +4,8 @@ import matplotlib.path as mpath from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple -from graphnet.data.extractors.extractor import I3Extractor -from graphnet.data.extractors.utilities.frames import ( +from .i3extractor import I3Extractor +from .utilities.frames import ( frame_is_montecarlo, frame_is_noise, ) diff --git a/src/graphnet/data/extractors/i3tumextractor.py b/src/graphnet/data/extractors/icecube/i3tumextractor.py similarity index 94% rename from src/graphnet/data/extractors/i3tumextractor.py rename to src/graphnet/data/extractors/icecube/i3tumextractor.py index 38cbca146..685b0a78e 100644 --- a/src/graphnet/data/extractors/i3tumextractor.py +++ b/src/graphnet/data/extractors/icecube/i3tumextractor.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Dict -from graphnet.data.extractors.i3extractor import I3Extractor +from graphnet.data.extractors.icecube import I3Extractor if TYPE_CHECKING: from icecube import icetray # pyright: reportMissingImports=false diff --git a/src/graphnet/data/extractors/utilities/__init__.py b/src/graphnet/data/extractors/icecube/utilities/__init__.py similarity index 100% rename from src/graphnet/data/extractors/utilities/__init__.py rename to src/graphnet/data/extractors/icecube/utilities/__init__.py diff --git a/src/graphnet/data/extractors/utilities/collections.py b/src/graphnet/data/extractors/icecube/utilities/collections.py similarity index 100% rename from src/graphnet/data/extractors/utilities/collections.py rename to src/graphnet/data/extractors/icecube/utilities/collections.py diff --git a/src/graphnet/data/extractors/utilities/frames.py b/src/graphnet/data/extractors/icecube/utilities/frames.py similarity index 100% rename from src/graphnet/data/extractors/utilities/frames.py rename to src/graphnet/data/extractors/icecube/utilities/frames.py diff --git a/src/graphnet/data/filters.py b/src/graphnet/data/extractors/icecube/utilities/i3_filters.py similarity index 100% rename from src/graphnet/data/filters.py rename to src/graphnet/data/extractors/icecube/utilities/i3_filters.py diff --git a/src/graphnet/data/extractors/utilities/types.py b/src/graphnet/data/extractors/icecube/utilities/types.py similarity index 98% rename from src/graphnet/data/extractors/utilities/types.py rename to src/graphnet/data/extractors/icecube/utilities/types.py index cf58e8357..32ecae0ff 100644 --- a/src/graphnet/data/extractors/utilities/types.py +++ b/src/graphnet/data/extractors/icecube/utilities/types.py @@ -4,11 +4,11 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from graphnet.data.extractors.utilities.collections import ( +from graphnet.data.extractors.icecube.utilities.collections import ( transpose_list_of_dicts, flatten_nested_dictionary, ) -from graphnet.data.extractors.utilities.frames import ( +from graphnet.data.extractors.icecube.utilities.frames import ( get_om_keys_and_pulseseries, ) from graphnet.utilities.imports import has_icecube_package diff --git a/src/graphnet/data/parquet/__init__.py b/src/graphnet/data/parquet/__init__.py deleted file mode 100644 index 616d89c16..000000000 --- a/src/graphnet/data/parquet/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -"""Parquet-specific implementation of data classes.""" -from .parquet_dataconverter import ParquetDataConverter diff --git a/src/graphnet/data/parquet/parquet_dataconverter.py b/src/graphnet/data/parquet/parquet_dataconverter.py deleted file mode 100644 index 68531c8e2..000000000 --- a/src/graphnet/data/parquet/parquet_dataconverter.py +++ /dev/null @@ -1,52 +0,0 @@ -"""DataConverter for the Parquet backend.""" - -from collections import OrderedDict -import os -from typing import List, Optional - -import awkward - -from graphnet.data.dataconverter import DataConverter # type: ignore[attr-defined] - - -class ParquetDataConverter(DataConverter): - """Class for converting I3-files to Parquet format.""" - - # Class variables - file_suffix: str = "parquet" - - # Abstract method implementation(s) - def save_data(self, data: List[OrderedDict], output_file: str) -> None: - """Save data to parquet file.""" - # Check(s) - if os.path.exists(output_file): - self.warning( - f"Output file {output_file} already exists. Overwriting." - ) - - self.debug(f"Saving to {output_file}") - self.debug( - f"- Data has {len(data)} events and {len(data[0])} tables for each" - ) - - awkward.to_parquet(awkward.from_iter(data), output_file) - - self.debug("- Done saving") - self._output_files.append(output_file) - - def merge_files( - self, output_file: str, input_files: Optional[List[str]] = None - ) -> None: - """Parquet-specific method for merging output files. - - Args: - output_file: Name of the output file containing the merged results. - input_files: Intermediate files to be merged, according to the - specific implementation. Default to None, meaning that all - files output by the current instance are merged. - - Raises: - NotImplementedError: If the method has not been implemented for the - Parquet backend. - """ - raise NotImplementedError() diff --git a/src/graphnet/data/pipeline.py b/src/graphnet/data/pipeline.py index d97415bb0..9973c763f 100644 --- a/src/graphnet/data/pipeline.py +++ b/src/graphnet/data/pipeline.py @@ -13,7 +13,9 @@ import torch from torch.utils.data import DataLoader -from graphnet.data.sqlite.sqlite_utilities import create_table_and_save_to_sql +from graphnet.data.utilities.sqlite_utilities import ( + create_table_and_save_to_sql, +) from graphnet.training.utils import get_predictions, make_dataloader from graphnet.models.graphs import GraphDefinition diff --git a/src/graphnet/data/readers/__init__.py b/src/graphnet/data/readers/__init__.py new file mode 100644 index 000000000..0755bd35a --- /dev/null +++ b/src/graphnet/data/readers/__init__.py @@ -0,0 +1,3 @@ +"""Modules for reading experiment-specific data and applying Extractors.""" +from .graphnet_file_reader import GraphNeTFileReader +from .i3reader import I3Reader diff --git a/src/graphnet/data/readers/graphnet_file_reader.py b/src/graphnet/data/readers/graphnet_file_reader.py new file mode 100644 index 000000000..ab6464e13 --- /dev/null +++ b/src/graphnet/data/readers/graphnet_file_reader.py @@ -0,0 +1,132 @@ +"""Module containing different FileReader classes in GraphNeT. + +These methods are used to open and apply `Extractors` to experiment-specific +file formats. +""" + +from typing import List, Union, OrderedDict +from abc import abstractmethod, ABC +import glob +import os + +from graphnet.utilities.decorators import final +from graphnet.utilities.logging import Logger +from graphnet.data.dataclasses import I3FileSet +from graphnet.data.extractors.extractor import Extractor + + +class GraphNeTFileReader(Logger, ABC): + """A generic base class for FileReaders in GraphNeT. + + Classes inheriting from `GraphNeTFileReader` must implement a + `__call__` method that opens a file, applies `Extractor`(s) and returns + a list of ordered dictionaries. + + In addition, Classes inheriting from `GraphNeTFileReader` must set + class properties `accepted_file_extensions` and `accepted_extractors`. + """ + + @abstractmethod + def __call__(self, file_path: str) -> List[OrderedDict]: + """Open and apply extractors to a single file. + + The `output` must be a list of dictionaries, where the number of events + in the file `n_events` satisfies `len(output) = n_events`. I.e each + element in the list is a dictionary, and each field in the dictionary + is the output of a single extractor. + """ + + @property + def accepted_file_extensions(self) -> List[str]: + """Return list of accepted file extensions.""" + return self._accepted_file_extensions # type: ignore + + @property + def accepted_extractors(self) -> List[Extractor]: + """Return list of compatible `Extractor`(s).""" + return self._accepted_extractors # type: ignore + + @property + def extracor_names(self) -> List[str]: + """Return list of table names produced by extractors.""" + return [extractor.name for extractor in self._extractors] # type: ignore + + def find_files( + self, path: Union[str, List[str]] + ) -> Union[List[str], List[I3FileSet]]: + """Search directory for input files recursively. + + This method may be overwritten by custom implementations. + + Args: + path: path to directory. + + Returns: + List of files matching accepted file extensions. + """ + if isinstance(path, str): + path = [path] + files = [] + for dir in path: + for accepted_file_extension in self.accepted_file_extensions: + files.extend(glob.glob(dir + f"/*{accepted_file_extension}")) + + # Check that files are OK. + self.validate_files(files) + return files + + @final + def set_extractors(self, extractors: List[Extractor]) -> None: + """Set `Extractor`(s) as member variable. + + Args: + extractors: A list of `Extractor`(s) to set as member variable. + """ + self._validate_extractors(extractors) + self._extractors = extractors + + @final + def _validate_extractors(self, extractors: List[Extractor]) -> None: + for extractor in extractors: + try: + assert isinstance(extractor, tuple(self.accepted_extractors)) # type: ignore + except AssertionError as e: + self.error( + f"{extractor.__class__.__name__}" + f" is not supported by {self.__class__.__name__}" + ) + raise e + + @final + def validate_files( + self, input_files: Union[List[str], List[I3FileSet]] + ) -> None: + """Check that the input files are accepted by the reader. + + Args: + input_files: Path(s) to input file(s). + """ + for input_file in input_files: + # Handle filepath vs. FileSet cases + if isinstance(input_file, I3FileSet): + self._validate_file(input_file.i3_file) + self._validate_file(input_file.gcd_file) + else: + self._validate_file(input_file) + + @final + def _validate_file(self, file: str) -> None: + """Validate a single file path. + + Args: + file: path to file. + + Returns: + None + """ + try: + assert file.lower().endswith(tuple(self.accepted_file_extensions)) + except AssertionError: + self.error( + f'{self.__class__.__name__} accepts {self.accepted_file_extensions} but {file.split("/")[-1]} has extension {os.path.splitext(file)[1]}.' + ) diff --git a/src/graphnet/data/readers.py b/src/graphnet/data/readers/i3reader.py similarity index 51% rename from src/graphnet/data/readers.py rename to src/graphnet/data/readers/i3reader.py index 6dd9bd63d..926c2395a 100644 --- a/src/graphnet/data/readers.py +++ b/src/graphnet/data/readers/i3reader.py @@ -1,148 +1,22 @@ -"""Module containing different FileReader classes in GraphNeT. - -These methods are used to open and apply `Extractors` to experiment-specific -file formats. -""" +"""Module containing different I3Reader.""" from typing import List, Union, OrderedDict, Type -from abc import abstractmethod, ABC -import glob -import os -from graphnet.utilities.decorators import final -from graphnet.utilities.logging import Logger from graphnet.utilities.imports import has_icecube_package -from graphnet.data.filters import I3Filter, NullSplitI3Filter - -from .dataclasses import I3FileSet - -from .extractors.extractor import ( - Extractor, - I3Extractor, -) # , I3GenericExtractor +from graphnet.data.extractors.icecube.utilities.i3_filters import ( + I3Filter, + NullSplitI3Filter, +) +from graphnet.data.extractors.icecube import I3Extractor +from graphnet.data.dataclasses import I3FileSet from graphnet.utilities.filesys import find_i3_files +from .graphnet_file_reader import GraphNeTFileReader + if has_icecube_package(): from icecube import icetray, dataio # pyright: reportMissingImports=false -class GraphNeTFileReader(Logger, ABC): - """A generic base class for FileReaders in GraphNeT. - - Classes inheriting from `GraphNeTFileReader` must implement a - `__call__` method that opens a file, applies `Extractor`(s) and returns - a list of ordered dictionaries. - - In addition, Classes inheriting from `GraphNeTFileReader` must set - class properties `accepted_file_extensions` and `accepted_extractors`. - """ - - @abstractmethod - def __call__(self, file_path: str) -> List[OrderedDict]: - """Open and apply extractors to a single file. - - The `output` must be a list of dictionaries, where the number of events - in the file `n_events` satisfies `len(output) = n_events`. I.e each - element in the list is a dictionary, and each field in the dictionary - is the output of a single extractor. - """ - - @property - def accepted_file_extensions(self) -> List[str]: - """Return list of accepted file extensions.""" - return self._accepted_file_extensions # type: ignore - - @property - def accepted_extractors(self) -> List[Extractor]: - """Return list of compatible `Extractor`(s).""" - return self._accepted_extractors # type: ignore - - @property - def extracor_names(self) -> List[str]: - """Return list of table names produced by extractors.""" - return [extractor.name for extractor in self._extractors] # type: ignore - - def find_files( - self, path: Union[str, List[str]] - ) -> Union[List[str], List[I3FileSet]]: - """Search directory for input files recursively. - - This method may be overwritten by custom implementations. - - Args: - path: path to directory. - - Returns: - List of files matching accepted file extensions. - """ - if isinstance(path, str): - path = [path] - files = [] - for dir in path: - for accepted_file_extension in self.accepted_file_extensions: - files.extend(glob.glob(dir + f"/*{accepted_file_extension}")) - - # Check that files are OK. - self.validate_files(files) - return files - - @final - def set_extractors(self, extractors: List[Extractor]) -> None: - """Set `Extractor`(s) as member variable. - - Args: - extractors: A list of `Extractor`(s) to set as member variable. - """ - self._validate_extractors(extractors) - self._extractors = extractors - - @final - def _validate_extractors(self, extractors: List[Extractor]) -> None: - for extractor in extractors: - try: - assert isinstance(extractor, tuple(self.accepted_extractors)) # type: ignore - except AssertionError as e: - self.error( - f"{extractor.__class__.__name__}" - f" is not supported by {self.__class__.__name__}" - ) - raise e - - @final - def validate_files( - self, input_files: Union[List[str], List[I3FileSet]] - ) -> None: - """Check that the input files are accepted by the reader. - - Args: - input_files: Path(s) to input file(s). - """ - for input_file in input_files: - # Handle filepath vs. FileSet cases - if isinstance(input_file, I3FileSet): - self._validate_file(input_file.i3_file) - self._validate_file(input_file.gcd_file) - else: - self._validate_file(input_file) - - @final - def _validate_file(self, file: str) -> None: - """Validate a single file path. - - Args: - file: path to file. - - Returns: - None - """ - try: - assert file.lower().endswith(tuple(self.accepted_file_extensions)) - except AssertionError: - self.error( - f'{self.__class__.__name__} accepts {self.accepted_file_extensions} but {file.split("/")[-1]} has extension {os.path.splitext(file)[1]}.' - ) - - class I3Reader(GraphNeTFileReader): """A class for reading .i3 files from the IceCube Neutrino Observatory. diff --git a/src/graphnet/data/sqlite/__init__.py b/src/graphnet/data/sqlite/__init__.py deleted file mode 100644 index e4ac554a7..000000000 --- a/src/graphnet/data/sqlite/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -"""SQLite-specific implementation of data classes.""" -from .sqlite_dataconverter import SQLiteDataConverter -from .sqlite_utilities import create_table_and_save_to_sql -from .sqlite_utilities import run_sql_code, save_to_sql diff --git a/src/graphnet/data/sqlite/sqlite_dataconverter.py b/src/graphnet/data/sqlite/sqlite_dataconverter.py deleted file mode 100644 index 1750b7a33..000000000 --- a/src/graphnet/data/sqlite/sqlite_dataconverter.py +++ /dev/null @@ -1,349 +0,0 @@ -"""DataConverter for the SQLite backend.""" - -from collections import OrderedDict -import os -from typing import Any, Dict, List, Optional, Tuple, Union - -import pandas as pd -import sqlalchemy -import sqlite3 -from tqdm import tqdm - -from graphnet.data.dataconverter import DataConverter # type: ignore[attr-defined] -from graphnet.data.sqlite.sqlite_utilities import ( - create_table, - create_table_and_save_to_sql, -) - - -class SQLiteDataConverter(DataConverter): - """Class for converting I3-file(s) to SQLite format.""" - - # Class variables - file_suffix = "db" - - # Abstract method implementation(s) - def save_data(self, data: List[OrderedDict], output_file: str) -> None: - """Save data to SQLite database.""" - # Check(s) - if os.path.exists(output_file): - self.warning( - f"Output file {output_file} already exists. Appending." - ) - - # Concatenate data - if len(data) == 0: - self.warning( - "No data was extracted from the processed I3 file(s). " - f"No data saved to {output_file}" - ) - return - - saved_any = False - dataframe_list: OrderedDict = OrderedDict( - [(key, []) for key in data[0]] - ) - for data_dict in data: - for key, data_values in data_dict.items(): - df = construct_dataframe(data_values) - - if self.any_pulsemap_is_non_empty(data_dict) and len(df) > 0: - # only include data_dict in temp. databases if at least one pulsemap is non-empty, - # and the current extractor (df) is also non-empty (also since truth is always non-empty) - dataframe_list[key].append(df) - - dataframe = OrderedDict( - [ - ( - key, - pd.concat(dfs, ignore_index=True, sort=True) - if dfs - else pd.DataFrame(), - ) - for key, dfs in dataframe_list.items() - ] - ) - # Can delete dataframe_list here to free up memory. - - # Save each dataframe to SQLite database - self.debug(f"Saving to {output_file}") - for table, df in dataframe.items(): - if len(df) > 0: - create_table_and_save_to_sql( - df, - table, - output_file, - default_type="FLOAT", - integer_primary_key=not ( - is_pulse_map(table) or is_mc_tree(table) - ), - ) - saved_any = True - - if saved_any: - self.debug("- Done saving") - else: - self.warning(f"No data saved to {output_file}") - - def merge_files( - self, - output_file: str, - input_files: Optional[List[str]] = None, - max_table_size: Optional[int] = None, - ) -> None: - """SQLite-specific method for merging output files/databases. - - Args: - output_file: Name of the output file containing the merged results. - input_files: Intermediate files/databases to be merged, according - to the specific implementation. Default to None, meaning that - all files/databases output by the current instance are merged. - max_table_size: The maximum number of rows in any given table. - If any one table exceed this limit, a new database will be - created. - """ - if max_table_size: - self.warning( - f"Merging got max_table_size of {max_table_size}. Will attempt to create databases with a maximum row count of this size." - ) - self.max_table_size = max_table_size - self._partition_count = 1 - - if input_files is None: - self.info("Merging files output by current instance.") - self._input_files = self._output_files - else: - self._input_files = input_files - - if not output_file.endswith("." + self.file_suffix): - output_file = ".".join([output_file, self.file_suffix]) - - if os.path.exists(output_file): - self.warning( - f"Target path for merged database, {output_file}, already exists." - ) - - if len(self._input_files) > 0: - self.info(f"Merging {len(self._input_files)} database files") - # Create one empty database table for each extraction - self._merged_table_names = self._extract_table_names( - self._input_files - ) - if self.max_table_size: - output_file = self._adjust_output_file_name(output_file) - self._create_empty_tables(output_file) - self._row_counts = self._initialize_row_counts() - # Merge temporary databases into newly created one - self._merge_temporary_databases(output_file, self._input_files) - else: - self.warning("No temporary database files found!") - - # Internal methods - def _adjust_output_file_name(self, output_file: str) -> str: - if "_part_" in output_file: - root = ( - output_file.split("_part_")[0] - + output_file.split("_part_")[1][1:] - ) - else: - root = output_file - str_list = root.split(".db") - return str_list[0] + f"_part_{self._partition_count}" + ".db" - - def _update_row_counts( - self, results: "OrderedDict[str, pd.DataFrame]" - ) -> None: - for table_name, data in results.items(): - self._row_counts[table_name] += len(data) - return - - def _initialize_row_counts(self) -> Dict[str, int]: - """Build dictionary with row counts. Initialized with 0. - - Returns: - Dictionary where every field is a table name that contains - corresponding row counts. - """ - row_counts = {} - for table_name in self._merged_table_names: - row_counts[table_name] = 0 - return row_counts - - def _create_empty_tables(self, output_file: str) -> None: - """Create tables for output database. - - Args: - output_file: Path to database. - """ - for table_name in self._merged_table_names: - column_names = self._extract_column_names( - self._input_files, table_name - ) - if len(column_names) > 1: - create_table( - column_names, - table_name, - output_file, - default_type="FLOAT", - integer_primary_key=not ( - is_pulse_map(table_name) or is_mc_tree(table_name) - ), - ) - - def _get_tables_in_database(self, db: str) -> Tuple[str, ...]: - with sqlite3.connect(db) as conn: - table_names = tuple( - [ - p[0] - for p in ( - conn.execute( - "SELECT name FROM sqlite_master WHERE type='table';" - ).fetchall() - ) - ] - ) - return table_names - - def _extract_table_names( - self, db: Union[str, List[str]] - ) -> Tuple[str, ...]: - """Get the names of all tables in database `db`.""" - if isinstance(db, str): - db = [db] - results = [self._get_tables_in_database(path) for path in db] - # @TODO: Check... - if all([results[0] == r for r in results]): - return results[0] - else: - unique_tables = [] - for tables in results: - for table in tables: - if table not in unique_tables: - unique_tables.append(table) - return tuple(unique_tables) - - def _extract_column_names( - self, db_paths: List[str], table_name: str - ) -> List[str]: - for db_path in db_paths: - tables_in_database = self._get_tables_in_database(db_path) - if table_name in tables_in_database: - with sqlite3.connect(db_path) as con: - query = f"select * from {table_name} limit 1" - columns = pd.read_sql(query, con).columns - if len(columns): - return columns - return [] - - def any_pulsemap_is_non_empty(self, data_dict: Dict[str, Dict]) -> bool: - """Check whether there are non-empty pulsemaps extracted from P frame. - - Takes in the data extracted from the P frame, then retrieves the - values, if there are any, from the pulsemap key(s) (e.g - SplitInIcePulses). If at least one of the pulsemaps is non-empty then - return true. If no pulsemaps exist, i.e., if no `I3FeatureExtractor` is - called e.g. because `I3GenericExtractor` is used instead, always return - True. - """ - if len(self._pulsemaps) == 0: - return True - - pulsemap_dicts = [data_dict[pulsemap] for pulsemap in self._pulsemaps] - return any(d["dom_x"] for d in pulsemap_dicts) - - def _submit_to_database( - self, database: str, key: str, data: pd.DataFrame - ) -> None: - """Submit data to the database with specified key.""" - if len(data) == 0: - self.info(f"No data provided for {key}.") - return - engine = sqlalchemy.create_engine("sqlite:///" + database) - data.to_sql(key, engine, index=False, if_exists="append") - engine.dispose() - - def _extract_everything(self, db: str) -> "OrderedDict[str, pd.DataFrame]": - """Extract everything from the temporary database `db`. - - Args: - db: Path to temporary database. - - Returns: - Dictionary containing the data for each extracted table. - """ - results = OrderedDict() - table_names = self._extract_table_names(db) - with sqlite3.connect(db) as conn: - for table_name in table_names: - query = f"select * from {table_name}" - try: - data = pd.read_sql(query, conn) - except: # noqa: E722 - data = [] - results[table_name] = data - return results - - def _merge_temporary_databases( - self, - output_file: str, - input_files: List[str], - ) -> None: - """Merge the temporary databases. - - Args: - output_file: path to the final database - input_files: list of names of temporary databases - """ - file_count = 0 - for input_file in tqdm(input_files, colour="green"): - results = self._extract_everything(input_file) - for table_name, data in results.items(): - self._submit_to_database(output_file, table_name, data) - file_count += 1 - if (self.max_table_size is not None) & ( - file_count < len(input_files) - ): - self._update_row_counts(results) - maximum_row_count_reached = False - for table in self._row_counts.keys(): - assert self.max_table_size is not None - if self._row_counts[table] >= self.max_table_size: - maximum_row_count_reached = True - if maximum_row_count_reached: - self._partition_count += 1 - output_file = self._adjust_output_file_name(output_file) - self.info( - f"Maximum row count reached. Creating new partition at {output_file}" - ) - self._create_empty_tables(output_file) - self._row_counts = self._initialize_row_counts() - - -# Implementation-specific utility function(s) -def construct_dataframe(extraction: Dict[str, Any]) -> pd.DataFrame: - """Convert extraction to pandas.DataFrame. - - Args: - extraction: Dictionary with the extracted data. - - Returns: - Extraction as pandas.DataFrame. - """ - all_scalars = True - for value in extraction.values(): - if isinstance(value, (list, tuple, dict)): - all_scalars = False - break - - out = pd.DataFrame(extraction, index=[0] if all_scalars else None) - return out - - -def is_pulse_map(table_name: str) -> bool: - """Check whether `table_name` corresponds to a pulse map.""" - return "pulse" in table_name.lower() or "series" in table_name.lower() - - -def is_mc_tree(table_name: str) -> bool: - """Check whether `table_name` corresponds to an MC tree.""" - return "I3MCTree" in table_name diff --git a/src/graphnet/data/utilities/__init__.py b/src/graphnet/data/utilities/__init__.py index 0dd9e0600..ad4f0c7db 100644 --- a/src/graphnet/data/utilities/__init__.py +++ b/src/graphnet/data/utilities/__init__.py @@ -1 +1,4 @@ """Utilities for use across `graphnet.data`.""" +from .sqlite_utilities import create_table_and_save_to_sql +from .sqlite_utilities import get_primary_keys +from .sqlite_utilities import query_database diff --git a/src/graphnet/data/utilities/parquet_to_sqlite.py b/src/graphnet/data/utilities/parquet_to_sqlite.py index 146e69ce8..11114698e 100644 --- a/src/graphnet/data/utilities/parquet_to_sqlite.py +++ b/src/graphnet/data/utilities/parquet_to_sqlite.py @@ -9,7 +9,9 @@ import pandas as pd from tqdm.auto import trange -from graphnet.data.sqlite.sqlite_utilities import create_table_and_save_to_sql +from graphnet.data.utilities.sqlite_utilities import ( + create_table_and_save_to_sql, +) from graphnet.utilities.logging import Logger diff --git a/src/graphnet/data/sqlite/sqlite_utilities.py b/src/graphnet/data/utilities/sqlite_utilities.py similarity index 72% rename from src/graphnet/data/sqlite/sqlite_utilities.py rename to src/graphnet/data/utilities/sqlite_utilities.py index 23bae802d..cfa308ba2 100644 --- a/src/graphnet/data/sqlite/sqlite_utilities.py +++ b/src/graphnet/data/utilities/sqlite_utilities.py @@ -1,7 +1,7 @@ """SQLite-specific utility functions for use in `graphnet.data`.""" import os.path -from typing import List +from typing import List, Dict, Tuple import pandas as pd import sqlalchemy @@ -16,6 +16,58 @@ def database_exists(database_path: str) -> bool: return os.path.exists(database_path) +def query_database(database: str, query: str) -> pd.DataFrame: + """Execute query on database, and return result. + + Args: + database: path to database. + query: query to be executed. + + Returns: + DataFrame containing the result of the query. + """ + with sqlite3.connect(database) as conn: + return pd.read_sql(query, conn) + + +def get_primary_keys(database: str) -> Tuple[Dict[str, str], str]: + """Get name of primary key column for each table in database. + + Args: + database: path to database. + + Returns: + A dictionary containing the names of primary keys in each table of + `database`. E.g. {'truth': "event_no", + 'SplitInIcePulses': None} + Name of the primary key. + """ + with sqlite3.connect(database) as conn: + query = 'SELECT name FROM sqlite_master WHERE type == "table"' + table_names = [table[0] for table in conn.execute(query).fetchall()] + + integer_primary_key = {} + for table in table_names: + query = f"SELECT l.name FROM pragma_table_info('{table}') as l WHERE l.pk = 1;" + first_primary_key = [ + key[0] for key in conn.execute(query).fetchall() + ] + integer_primary_key[table] = ( + first_primary_key[0] if len(first_primary_key) else None + ) + + # Get the primary key column name + primary_key_candidates = [] + for val in set(integer_primary_key.values()): + if val is not None: + primary_key_candidates.append(val) + + # There should only be one primary key: + assert len(primary_key_candidates) == 1 + + return integer_primary_key, primary_key_candidates[0] + + def database_table_exists(database_path: str, table_name: str) -> bool: """Check whether `table_name` exists in database at `database_path`.""" if not database_exists(database_path): diff --git a/src/graphnet/data/writers/__init__.py b/src/graphnet/data/writers/__init__.py new file mode 100644 index 000000000..ad3e2748e --- /dev/null +++ b/src/graphnet/data/writers/__init__.py @@ -0,0 +1,4 @@ +"""Modules for saving interim dataformat to various data backends.""" +from .graphnet_writer import GraphNeTWriter +from .parquet_writer import ParquetWriter +from .sqlite_writer import SQLiteWriter diff --git a/src/graphnet/data/writers.py b/src/graphnet/data/writers/graphnet_writer.py similarity index 57% rename from src/graphnet/data/writers.py rename to src/graphnet/data/writers/graphnet_writer.py index d23b21ac8..04ee079f4 100644 --- a/src/graphnet/data/writers.py +++ b/src/graphnet/data/writers/graphnet_writer.py @@ -5,20 +5,16 @@ """ import os -from typing import List, Union, Dict, Any, OrderedDict +from typing import Dict, List from abc import abstractmethod, ABC from graphnet.utilities.decorators import final from graphnet.utilities.logging import Logger -from graphnet.data.sqlite.sqlite_utilities import ( - create_table, - create_table_and_save_to_sql, -) import pandas as pd -class GraphNeTFileSaveMethod(Logger, ABC): +class GraphNeTWriter(Logger, ABC): """Generic base class for saving interim data format in `DataConverter`. Classes inheriting from `GraphNeTFileSaveMethod` must implement the @@ -43,6 +39,21 @@ def _save_file( output_file_path: output file path. n_events: Number of events container in `data`. """ + raise NotImplementedError + + @abstractmethod + def merge_files( + self, + files: List[str], + output_dir: str, + ) -> None: + """Merge smaller files. + + Args: + files: Files to be merged. + output_dir: The directory to store the merged files in. + """ + raise NotImplementedError @final def __call__( @@ -76,49 +87,3 @@ def __call__( def file_extension(self) -> str: """Return file extension used to store the data.""" return self._file_extension # type: ignore - - -class SQLiteSaveMethod(GraphNeTFileSaveMethod): - """A method for saving GraphNeT's interim dataformat to SQLite.""" - - _file_extension = ".db" - - def _save_file( - self, - data: Dict[str, pd.DataFrame], - output_file_path: str, - n_events: int, - ) -> None: - """Save data to SQLite database.""" - # Check(s) - if os.path.exists(output_file_path): - self.warning( - f"Output file {output_file_path} already exists. Appending." - ) - - # Concatenate data - if len(data) == 0: - self.warning( - "No data was extracted from the processed I3 file(s). " - f"No data saved to {output_file_path}" - ) - return - - saved_any = False - # Save each dataframe to SQLite database - self.debug(f"Saving to {output_file_path}") - for table, df in data.items(): - if len(df) > 0: - create_table_and_save_to_sql( - df, - table, - output_file_path, - default_type="FLOAT", - integer_primary_key=len(df) <= n_events, - ) - saved_any = True - - if saved_any: - self.debug("- Done saving") - else: - self.warning(f"No data saved to {output_file_path}") diff --git a/src/graphnet/data/writers/parquet_writer.py b/src/graphnet/data/writers/parquet_writer.py new file mode 100644 index 000000000..a8e74f11f --- /dev/null +++ b/src/graphnet/data/writers/parquet_writer.py @@ -0,0 +1,34 @@ +"""DataConverter for the Parquet backend.""" + +import os +from typing import List, Optional, Dict + +import awkward +import pandas as pd + +from .graphnet_writer import GraphNeTWriter + + +class ParquetWriter(GraphNeTWriter): + """Class for writing interim data format to Parquet.""" + + # Class variables + file_suffix: str = ".parquet" + + # Abstract method implementation(s) + def _save_file( + self, + data: Dict[str, pd.DataFrame], + output_file_path: str, + n_events: int, + ) -> None: + """Save data to parquet file.""" + # Check(s) + if os.path.exists(output_file_path): + self.warning( + f"Output file {output_file_path} already exists. Overwriting." + ) + + self.debug(f"Saving to {output_file_path}") + awkward.to_parquet(awkward.from_iter(data), output_file_path) + self.debug("- Done saving") diff --git a/src/graphnet/data/writers/sqlite_writer.py b/src/graphnet/data/writers/sqlite_writer.py new file mode 100644 index 000000000..e9f400c53 --- /dev/null +++ b/src/graphnet/data/writers/sqlite_writer.py @@ -0,0 +1,224 @@ +"""Module containing `GraphNeTFileSaveMethod`(s). + +These modules are used to save the interim data format from `DataConverter` to +a deep-learning friendly file format. +""" + +import os +from tqdm import tqdm +from typing import List, Dict, Optional + +from graphnet.data.utilities import ( + create_table_and_save_to_sql, + get_primary_keys, + query_database, +) +import pandas as pd +from .graphnet_writer import GraphNeTWriter + + +class SQLiteWriter(GraphNeTWriter): + """A method for saving GraphNeT's interim dataformat to SQLite.""" + + def __init__( + self, + merged_database_name: str = "merged.db", + max_table_size: Optional[int] = None, + ) -> None: + """Initialize `SQLiteWriter`. + + Args: + merged_database_name: name of the database, not path, that files + will be merged into. Defaults to "merged.db". + max_table_size: The maximum number of rows in any given table. + If given, the merging proceedure splits the databases into + partitions each with a maximum table size of max_table_size. + Note that the size is approximate. This feature is useful if + you have many events, as tables exceeding + 400 million rows tend to be noticably slower to query. + Defaults to None (All events are put into a single database). + """ + # Member Variables + self._file_extension = ".db" + self._max_table_size = max_table_size + self._database_name = merged_database_name + + # Add file extension to database name if forgotten + if not self._database_name.endswith(self._file_extension): + self._database_name = self._database_name + self._file_extension + + # Base class constructor + super().__init__(name=__name__, class_name=self.__class__.__name__) + + def _save_file( + self, + data: Dict[str, pd.DataFrame], + output_file_path: str, + n_events: int, + ) -> None: + """Save data to SQLite database.""" + # Check(s) + if os.path.exists(output_file_path): + self.warning( + f"Output file {output_file_path} already exists. Appending." + ) + + # Concatenate data + if len(data) == 0: + self.warning( + "No data was extracted from the processed I3 file(s). " + f"No data saved to {output_file_path}" + ) + return + + saved_any = False + # Save each dataframe to SQLite database + self.debug(f"Saving to {output_file_path}") + for table, df in data.items(): + if len(df) > 0: + create_table_and_save_to_sql( + df, + table, + output_file_path, + default_type="FLOAT", + integer_primary_key=len(df) <= n_events, + ) + saved_any = True + + if saved_any: + self.debug("- Done saving") + else: + self.warning(f"No data saved to {output_file_path}") + + def merge_files( + self, + files: List[str], + output_dir: str, + ) -> None: + """SQLite-specific method for merging output files/databases. + + Args: + files: paths to SQLite databases that needs to be merged. + output_dir: path to store the merged database(s) in. + database_name: name, not path, of database. E.g. "my_database". + max_table_size: The maximum number of rows in any given table. + If given, the merging proceedure splits the databases into + partitions each with a maximum table size of max_table_size. + Note that the size is approximate. This feature is useful if + you have many events, as tables exceeding + 400 million rows tend to be noticably slower to query. + Defaults to None (All events are put into a single database.) + """ + # Warnings + if self._max_table_size: + self.warning( + f"Merging got max_table_size of {self._max_table_size}." + " Will attempt to create databases with a maximum row count of" + " this size." + ) + + # Set variables + self._partition_count = 1 + + # Construct full database path + database_path = os.path.join(output_dir, self._database_name) + print(database_path) + # Start merging if files are given + if len(files) > 0: + os.makedirs(output_dir, exist_ok=True) + self.info(f"Merging {len(files)} database files") + self._merge_databases(files=files, database_path=database_path) + else: + self.warning("No database files given! Exiting.") + + def _merge_databases( + self, + files: List[str], + database_path: str, + ) -> None: + """Merge the temporary databases. + + Args: + files: List of files to be merged. + database_path: Path to a database, can be an empty path, where the + databases listed in `files` will be merged into. If no database + exists at the given path, one will be created. + """ + if os.path.exists(database_path): + self.warning( + "Target path for merged database", + f"{database_path}, already exists.", + ) + + if self._max_table_size is not None: + database_path = self._adjust_output_path(database_path) + self._row_counts: Dict[str, int] = {} + self._largest_table = 0 + + # Merge temporary databases into newly created one + for file_count, input_file in tqdm(enumerate(files), colour="green"): + + # Extract table names and index column name in database + tables, primary_key = get_primary_keys(database=input_file) + + for table_name in tables.keys(): + # Extract all data in the table from the given database + df = query_database( + database=input_file, query=f"SELECT * FROM {table_name}" + ) + + # Infer whether the table was previously indexed with + # A primary key or not. len(tables[table]) = 0 if not. + integer_primary_key = ( + True if tables[table_name] is not None else False + ) + + # Submit to new database + create_table_and_save_to_sql( + df=df, + table_name=table_name, + database_path=database_path, + index_column=primary_key, + integer_primary_key=integer_primary_key, + ) + + # Update row counts if needed + if self._max_table_size is not None: + self._update_row_counts(df=df, table_name=table_name) + + if (self._max_table_size is not None) & (file_count < len(files)): + assert self._max_table_size is not None # mypy... + if self._largest_table >= self._max_table_size: + # Increment partition, reset counts, adjust output path + self._partition_count += 1 + self._row_counts = {} + self._largest_table = 0 + database_path = self._adjust_output_path(database_path) + self.info( + "Maximum row count reached." + f" Creating new partition at {database_path}" + ) + + # Internal methods + + def _adjust_output_path(self, output_file: str) -> str: + """Adjust the file path to reflect that it is a partition.""" + path_without_extension, extension = os.path.splitext(output_file) + if "_part_" in path_without_extension: + # if true, this is already a partition. + database_name = path_without_extension.split("_part_")[:-1][0] + else: + database_name = path_without_extension + # split into multiple lines to avoid one long + database_name = database_name + f"_part_{self._partition_count}" + database_name = database_name + extension + return database_name + + def _update_row_counts(self, df: pd.DataFrame, table_name: str) -> None: + if table_name in self._row_counts.keys(): + self._row_counts[table_name] += len(df) + else: + self._row_counts[table_name] = len(df) + + self._largest_table = max(self._row_counts.values()) + return diff --git a/src/graphnet/deployment/i3modules/graphnet_module.py b/src/graphnet/deployment/i3modules/graphnet_module.py index d3aa878e0..a385413b3 100644 --- a/src/graphnet/deployment/i3modules/graphnet_module.py +++ b/src/graphnet/deployment/i3modules/graphnet_module.py @@ -7,7 +7,7 @@ import torch from torch_geometric.data import Data, Batch -from graphnet.data.extractors import ( +from graphnet.data.extractors.icecube import ( I3FeatureExtractor, I3FeatureExtractorIceCubeUpgrade, ) @@ -70,7 +70,7 @@ def __init__( self._i3_extractors = [pulsemap_extractor] for i3_extractor in self._i3_extractors: - i3_extractor.set_files(i3_file="", gcd_file=self._gcd_file) + i3_extractor.set_gcd(i3_file="", gcd_file=self._gcd_file) @abstractmethod def __call__(self, frame: I3Frame) -> bool: diff --git a/src/graphnet/models/graphs/edges/minkowski.py b/src/graphnet/models/graphs/edges/minkowski.py index 5d1134ec5..2526de1cb 100644 --- a/src/graphnet/models/graphs/edges/minkowski.py +++ b/src/graphnet/models/graphs/edges/minkowski.py @@ -69,12 +69,13 @@ def _construct_edges(self, graph: Data) -> Data: row = [] col = [] for batch in range(x.shape[0]): + x_masked = x[batch][mask[batch]] distance_mat = compute_minkowski_distance_mat( - x_masked := x[batch][mask[batch]], - x_masked, - self.c, - self.space_coords, - self.time_coord, + x=x_masked, + y=x_masked, + c=self.c, + space_coords=self.space_coords, + time_coord=self.time_coord, ) num_points = x_masked.shape[0] num_edges = min(self.nb_nearest_neighbours, num_points) diff --git a/src/graphnet/training/weight_fitting.py b/src/graphnet/training/weight_fitting.py index a52c91b29..97411bbe5 100644 --- a/src/graphnet/training/weight_fitting.py +++ b/src/graphnet/training/weight_fitting.py @@ -7,7 +7,9 @@ import pandas as pd import sqlite3 -from graphnet.data.sqlite.sqlite_utilities import create_table_and_save_to_sql +from graphnet.data.utilities.sqlite_utilities import ( + create_table_and_save_to_sql, +) from graphnet.utilities.logging import Logger diff --git a/tests/data/test_dataconverters_and_datasets.py b/tests/data/test_dataconverters_and_datasets.py index 480f11d4d..53a73a5f4 100644 --- a/tests/data/test_dataconverters_and_datasets.py +++ b/tests/data/test_dataconverters_and_datasets.py @@ -11,7 +11,7 @@ from graphnet.constants import TEST_OUTPUT_DIR from graphnet.data.constants import FEATURES, TRUTH from graphnet.data.dataconverter import DataConverter -from graphnet.data.extractors import ( +from graphnet.data.extractors.icecube import ( I3FeatureExtractorIceCube86, I3TruthExtractor, I3RetroExtractor, diff --git a/tests/data/test_i3extractor.py b/tests/data/test_i3extractor.py index 3fa19f078..f1c8c3ff7 100644 --- a/tests/data/test_i3extractor.py +++ b/tests/data/test_i3extractor.py @@ -1,6 +1,6 @@ """Unit tests for I3Extractor class.""" -from graphnet.data.extractors import ( +from graphnet.data.extractors.icecube import ( I3FeatureExtractorIceCube86, I3TruthExtractor, I3RetroExtractor, diff --git a/tests/data/test_i3genericextractor.py b/tests/data/test_i3genericextractor.py index 314fa5f44..e77727eaf 100644 --- a/tests/data/test_i3genericextractor.py +++ b/tests/data/test_i3genericextractor.py @@ -5,7 +5,7 @@ import numpy as np import graphnet.constants -from graphnet.data.extractors import ( +from graphnet.data.extractors.icecube import ( I3FeatureExtractorIceCube86, I3TruthExtractor, I3GenericExtractor, @@ -40,9 +40,9 @@ def test_i3genericextractor(test_data_dir: str = TEST_DATA_DIR) -> None: i3_file = os.path.join(test_data_dir, FILE_NAME) + ".i3.gz" gcd_file = os.path.join(test_data_dir, GCD_FILE) - generic_extractor.set_files(i3_file, gcd_file) - truth_extractor.set_files(i3_file, gcd_file) - feature_extractor.set_files(i3_file, gcd_file) + generic_extractor.set_gcd(i3_file, gcd_file) + truth_extractor.set_gcd(i3_file, gcd_file) + feature_extractor.set_gcd(i3_file, gcd_file) i3_file_io = dataio.I3File(i3_file, "r") ix_test = 5 From 31d040f0e32bb6a9580e3307e1bf30baf0ecce55 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Fri, 9 Feb 2024 14:55:43 +0100 Subject: [PATCH 04/33] add NotImplementedError in parquet_writer --- src/graphnet/data/dataconverter.py | 4 ++- src/graphnet/data/writers/parquet_writer.py | 27 ++++++++++++--------- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/src/graphnet/data/dataconverter.py b/src/graphnet/data/dataconverter.py index efae14a2f..76796d844 100644 --- a/src/graphnet/data/dataconverter.py +++ b/src/graphnet/data/dataconverter.py @@ -311,7 +311,9 @@ def merge_files(self, files: Optional[List[str]] = None) -> None: assert files is not None # Merge files + merge_path = os.path.join(self._output_dir, "merged") + self.info(f"Merging files to {merge_path}") self._save_method.merge_files( # type:ignore files=files_to_merge, - output_dir=os.path.join(self._output_dir, "merged"), + output_dir=merge_path, ) diff --git a/src/graphnet/data/writers/parquet_writer.py b/src/graphnet/data/writers/parquet_writer.py index a8e74f11f..fa07d266d 100644 --- a/src/graphnet/data/writers/parquet_writer.py +++ b/src/graphnet/data/writers/parquet_writer.py @@ -13,7 +13,7 @@ class ParquetWriter(GraphNeTWriter): """Class for writing interim data format to Parquet.""" # Class variables - file_suffix: str = ".parquet" + _file_extension = ".parquet" # Abstract method implementation(s) def _save_file( @@ -22,13 +22,18 @@ def _save_file( output_file_path: str, n_events: int, ) -> None: - """Save data to parquet file.""" - # Check(s) - if os.path.exists(output_file_path): - self.warning( - f"Output file {output_file_path} already exists. Overwriting." - ) - - self.debug(f"Saving to {output_file_path}") - awkward.to_parquet(awkward.from_iter(data), output_file_path) - self.debug("- Done saving") + """Save data to parquet.""" + raise NotImplementedError + + def merge_files(self, files: List[str], output_dir: str) -> None: + """Merge parquet files. + + Args: + files: input files for merging. + output_dir: directory to store merged file(s) in. + + Raises: + NotImplementedError: _description_ + """ + self.error(f"{self.__class__.__name__} does not have a merge method.") + raise NotImplementedError From b787a0783ca3ab2a14419318300a369d7a2f4684 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Fri, 9 Feb 2024 14:55:55 +0100 Subject: [PATCH 05/33] docstring --- src/graphnet/data/writers/parquet_writer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/graphnet/data/writers/parquet_writer.py b/src/graphnet/data/writers/parquet_writer.py index fa07d266d..df64dc91a 100644 --- a/src/graphnet/data/writers/parquet_writer.py +++ b/src/graphnet/data/writers/parquet_writer.py @@ -33,7 +33,7 @@ def merge_files(self, files: List[str], output_dir: str) -> None: output_dir: directory to store merged file(s) in. Raises: - NotImplementedError: _description_ + NotImplementedError """ self.error(f"{self.__class__.__name__} does not have a merge method.") raise NotImplementedError From 56ccf038bf281d265203b5e4775b73de4dbe43ac Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Fri, 9 Feb 2024 23:24:45 +0100 Subject: [PATCH 06/33] deprecation warnings --- src/graphnet/data/__init__.py | 2 + src/graphnet/data/dataconverter.py | 11 +- src/graphnet/data/parquet/__init__.py | 2 + .../data/parquet/deprecated_methods.py | 63 +++++++++++ src/graphnet/data/pre_configured/__init__.py | 2 + .../data/pre_configured/dataconverters.py | 106 ++++++++++++++++++ src/graphnet/data/sqlite/__init__.py | 2 + .../data/sqlite/deprecated_methods.py | 64 +++++++++++ tests/deployment/queso_test.py | 2 +- 9 files changed, 246 insertions(+), 8 deletions(-) create mode 100644 src/graphnet/data/parquet/__init__.py create mode 100644 src/graphnet/data/parquet/deprecated_methods.py create mode 100644 src/graphnet/data/pre_configured/__init__.py create mode 100644 src/graphnet/data/pre_configured/dataconverters.py create mode 100644 src/graphnet/data/sqlite/__init__.py create mode 100644 src/graphnet/data/sqlite/deprecated_methods.py diff --git a/src/graphnet/data/__init__.py b/src/graphnet/data/__init__.py index e7eb84ca4..77cbc1af8 100644 --- a/src/graphnet/data/__init__.py +++ b/src/graphnet/data/__init__.py @@ -5,3 +5,5 @@ """ from .extractors.icecube.utilities.i3_filters import I3Filter, I3FilterMask from .dataconverter import DataConverter +from .pre_configured import I3ToParquetConverter +from .pre_configured import I3ToSQLiteConverter diff --git a/src/graphnet/data/dataconverter.py b/src/graphnet/data/dataconverter.py index 76796d844..bdba2a733 100644 --- a/src/graphnet/data/dataconverter.py +++ b/src/graphnet/data/dataconverter.py @@ -38,6 +38,7 @@ def __init__( self, file_reader: Type[GraphNeTFileReader], save_method: Type[GraphNeTWriter], + outdir: str, extractors: Union[Type[Extractor], List[Type[Extractor]]], index_column: str = "event_no", num_workers: int = 1, @@ -48,6 +49,7 @@ def __init__( file_reader: The method used for reading and applying `Extractors`. save_method: The method used to save the interim data format to a graphnet supported file format. + outdir: The directory to save the files in. extractors: The `Extractor`(s) that will be applied to the input files. index_column: Name of the event id column added to the events. @@ -61,6 +63,7 @@ def __init__( self._num_workers = num_workers self._index_column = index_column self._index = 0 + self._output_dir = outdir self._output_files: List[str] = [] # Set Extractors. Will throw error if extractors are incompatible @@ -71,20 +74,14 @@ def __init__( super().__init__(name=__name__, class_name=self.__class__.__name__) @final - def __call__( - self, input_dir: Union[str, List[str]], output_dir: str - ) -> None: + def __call__(self, input_dir: Union[str, List[str]]) -> None: """Extract data from files in `input_dir` and save to disk. Args: input_dir: A directory that contains the input files. The directory will be searched recursively for files matching the file extension. - output_dir: The directory to save the files to. Input folder - structure is not respected. """ - # Set outdir - self._output_dir = output_dir # Get the file reader to produce a list of input files # in the directory input_files = self._file_reader.find_files(path=input_dir) # type: ignore diff --git a/src/graphnet/data/parquet/__init__.py b/src/graphnet/data/parquet/__init__.py new file mode 100644 index 000000000..2c41ca75d --- /dev/null +++ b/src/graphnet/data/parquet/__init__.py @@ -0,0 +1,2 @@ +"""Module for deprecated parquet methods.""" +from .deprecated_methods import ParquetDataConverter diff --git a/src/graphnet/data/parquet/deprecated_methods.py b/src/graphnet/data/parquet/deprecated_methods.py new file mode 100644 index 000000000..299cdbae8 --- /dev/null +++ b/src/graphnet/data/parquet/deprecated_methods.py @@ -0,0 +1,63 @@ +"""Module containing deprecated data conversion code. + +This code will be removed in GraphNeT 2.0. +""" +from typing import List, Union, Type + +from graphnet.data.extractors.icecube import I3Extractor +from graphnet.data.extractors.icecube.utilities.i3_filters import ( + I3Filter, + NullSplitI3Filter, +) +from graphnet.data import I3ToParquetConverter + + +class ParquetDataConverter(I3ToParquetConverter): + """Method for converting i3 files to parquet files.""" + + def __init__( + self, + gcd_rescue: str, + extractors: Union[Type[I3Extractor], List[Type[I3Extractor]]], + outdir: str, + index_column: str = "event_no", + num_workers: int = 1, + i3_filters: Union[ + Type[I3Filter], List[Type[I3Filter]] + ] = NullSplitI3Filter(), # type: ignore + ): + """Convert I3 files to Parquet. + + Args: + gcd_rescue: gcd_rescue: Path to a GCD file that will be used if no GCD file is + found in subfolder. `I3Reader` will recursively search + the input directory for I3-GCD file pairs. By IceCube + convention, a folder containing i3 files will have an + accompanying GCD file. However, in some cases, this + convention is broken. In cases where a folder contains + i3 files but no GCD file, the `gcd_rescue` is used + instead. + extractors: The `Extractor`(s) that will be applied to the input + files. + outdir: The directory to save the files in. + icetray_verbose: Set the level of verbosity of icetray. + Defaults to 0. + index_column: Name of the event id column added to the events. + Defaults to "event_no". + num_workers: The number of CPUs used for parallel processing. + Defaults to 1 (no multiprocessing). + i3_filters: Instances of `I3Filter` to filter PFrames. Defaults to + `NullSplitI3Filter`. + """ + self.warning( + f"{self.__class__.__name__} will be deprecated in " + "GraphNeT 2.0. Please use I3ToParquetConverter instead." + ) + super().__init__( + extractors=extractors, + num_workers=num_workers, + index_column=index_column, + i3_filters=i3_filters, + outdir=outdir, + gcd_rescue=gcd_rescue, + ) diff --git a/src/graphnet/data/pre_configured/__init__.py b/src/graphnet/data/pre_configured/__init__.py new file mode 100644 index 000000000..f56f0de18 --- /dev/null +++ b/src/graphnet/data/pre_configured/__init__.py @@ -0,0 +1,2 @@ +"""Module for pre-configured converter modules.""" +from .dataconverters import I3ToParquetConverter, I3ToSQLiteConverter diff --git a/src/graphnet/data/pre_configured/dataconverters.py b/src/graphnet/data/pre_configured/dataconverters.py new file mode 100644 index 000000000..fcd26fd49 --- /dev/null +++ b/src/graphnet/data/pre_configured/dataconverters.py @@ -0,0 +1,106 @@ +"""Pre-configured combinations of writers and readers.""" + +from typing import List, Union, Type + +from graphnet.data import DataConverter +from graphnet.data.readers import I3Reader +from graphnet.data.writers import ParquetWriter, SQLiteWriter +from graphnet.data.extractors.icecube import I3Extractor +from graphnet.data.extractors.icecube.utilities.i3_filters import ( + I3Filter, + NullSplitI3Filter, +) + + +class I3ToParquetConverter(DataConverter): + """Preconfigured DataConverter for converting i3 files to parquet files.""" + + def __init__( + self, + gcd_rescue: str, + extractors: Union[Type[I3Extractor], List[Type[I3Extractor]]], + outdir: str, + index_column: str = "event_no", + num_workers: int = 1, + i3_filters: Union[ + Type[I3Filter], List[Type[I3Filter]] + ] = NullSplitI3Filter(), # type: ignore + ): + """Convert I3 files to Parquet. + + Args: + gcd_rescue: gcd_rescue: Path to a GCD file that will be used if no GCD file is + found in subfolder. `I3Reader` will recursively search + the input directory for I3-GCD file pairs. By IceCube + convention, a folder containing i3 files will have an + accompanying GCD file. However, in some cases, this + convention is broken. In cases where a folder contains + i3 files but no GCD file, the `gcd_rescue` is used + instead. + extractors: The `Extractor`(s) that will be applied to the input + files. + outdir: The directory to save the files in. + icetray_verbose: Set the level of verbosity of icetray. + Defaults to 0. + index_column: Name of the event id column added to the events. + Defaults to "event_no". + num_workers: The number of CPUs used for parallel processing. + Defaults to 1 (no multiprocessing). + i3_filters: Instances of `I3Filter` to filter PFrames. Defaults to + `NullSplitI3Filter`. + """ + super().__init__( + file_reader=I3Reader(gcd_rescue=gcd_rescue, i3_filters=i3_filters), # type: ignore + save_method=ParquetWriter(), # type: ignore + extractors=extractors, # type: ignore + num_workers=num_workers, + index_column=index_column, + outdir=outdir, + ) + + +class I3ToSQLiteConverter(DataConverter): + """Preconfigured DataConverter for converting i3 files to SQLite files.""" + + def __init__( + self, + gcd_rescue: str, + extractors: Union[Type[I3Extractor], List[Type[I3Extractor]]], + outdir: str, + index_column: str = "event_no", + num_workers: int = 1, + i3_filters: Union[ + Type[I3Filter], List[Type[I3Filter]] + ] = NullSplitI3Filter(), # type: ignore + ): + """Convert I3 files to Parquet. + + Args: + gcd_rescue: gcd_rescue: Path to a GCD file that will be used if no GCD file is + found in subfolder. `I3Reader` will recursively search + the input directory for I3-GCD file pairs. By IceCube + convention, a folder containing i3 files will have an + accompanying GCD file. However, in some cases, this + convention is broken. In cases where a folder contains + i3 files but no GCD file, the `gcd_rescue` is used + instead. + extractors: The `Extractor`(s) that will be applied to the input + files. + outdir: The directory to save the files in. + icetray_verbose: Set the level of verbosity of icetray. + Defaults to 0. + index_column: Name of the event id column added to the events. + Defaults to "event_no". + num_workers: The number of CPUs used for parallel processing. + Defaults to 1 (no multiprocessing). + i3_filters: Instances of `I3Filter` to filter PFrames. Defaults to + `NullSplitI3Filter`. + """ + super().__init__( + file_reader=I3Reader(gcd_rescue=gcd_rescue, i3_filters=i3_filters), # type: ignore + save_method=SQLiteWriter(), # type: ignore + extractors=extractors, # type: ignore + num_workers=num_workers, + index_column=index_column, + outdir=outdir, + ) diff --git a/src/graphnet/data/sqlite/__init__.py b/src/graphnet/data/sqlite/__init__.py new file mode 100644 index 000000000..436a86f2d --- /dev/null +++ b/src/graphnet/data/sqlite/__init__.py @@ -0,0 +1,2 @@ +"""Module for deprecated methods using sqlite.""" +from .deprecated_methods import SQLiteDataConverter diff --git a/src/graphnet/data/sqlite/deprecated_methods.py b/src/graphnet/data/sqlite/deprecated_methods.py new file mode 100644 index 000000000..3dd1e04b5 --- /dev/null +++ b/src/graphnet/data/sqlite/deprecated_methods.py @@ -0,0 +1,64 @@ +"""Module containing deprecated data conversion code. + +This code will be removed in GraphNeT 2.0. +""" + +from typing import List, Union, Type + +from graphnet.data.extractors.icecube import I3Extractor +from graphnet.data.extractors.icecube.utilities.i3_filters import ( + I3Filter, + NullSplitI3Filter, +) +from graphnet.data import I3ToSQLiteConverter + + +class SQLiteDataConverter(I3ToSQLiteConverter): + """Method for converting i3 files to SQLite files.""" + + def __init__( + self, + gcd_rescue: str, + extractors: Union[Type[I3Extractor], List[Type[I3Extractor]]], + outdir: str, + index_column: str = "event_no", + num_workers: int = 1, + i3_filters: Union[ + Type[I3Filter], List[Type[I3Filter]] + ] = NullSplitI3Filter(), # type: ignore + ): + """Convert I3 files to Parquet. + + Args: + gcd_rescue: gcd_rescue: Path to a GCD file that will be used if no GCD file is + found in subfolder. `I3Reader` will recursively search + the input directory for I3-GCD file pairs. By IceCube + convention, a folder containing i3 files will have an + accompanying GCD file. However, in some cases, this + convention is broken. In cases where a folder contains + i3 files but no GCD file, the `gcd_rescue` is used + instead. + extractors: The `Extractor`(s) that will be applied to the input + files. + outdir: The directory to save the files in. + icetray_verbose: Set the level of verbosity of icetray. + Defaults to 0. + index_column: Name of the event id column added to the events. + Defaults to "event_no". + num_workers: The number of CPUs used for parallel processing. + Defaults to 1 (no multiprocessing). + i3_filters: Instances of `I3Filter` to filter PFrames. Defaults to + `NullSplitI3Filter`. + """ + self.warning( + f"{self.__class__.__name__} will be deprecated in " + "GraphNeT 2.0. Please use I3ToSQLiteConverter instead." + ) + super().__init__( + extractors=extractors, + num_workers=num_workers, + index_column=index_column, + i3_filters=i3_filters, + outdir=outdir, + gcd_rescue=gcd_rescue, + ) diff --git a/tests/deployment/queso_test.py b/tests/deployment/queso_test.py index d1258ed89..5c0088f5d 100644 --- a/tests/deployment/queso_test.py +++ b/tests/deployment/queso_test.py @@ -8,7 +8,7 @@ import pytest from graphnet.data.constants import FEATURES -from graphnet.data.extractors.i3featureextractor import ( +from graphnet.data.extractors.icecube import ( I3FeatureExtractorIceCubeUpgrade, ) from graphnet.constants import ( From 92834439402a95b7b7c42c4792a21d10c8d316b3 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Sun, 11 Feb 2024 11:39:38 +0100 Subject: [PATCH 07/33] add legacy parquet writer --- src/graphnet/data/dataconverter.py | 12 +++++++----- src/graphnet/data/writers/graphnet_writer.py | 5 +++++ src/graphnet/data/writers/parquet_writer.py | 14 +++++++++++++- src/graphnet/data/writers/sqlite_writer.py | 1 + 4 files changed, 26 insertions(+), 6 deletions(-) diff --git a/src/graphnet/data/dataconverter.py b/src/graphnet/data/dataconverter.py index bdba2a733..7160a0c2e 100644 --- a/src/graphnet/data/dataconverter.py +++ b/src/graphnet/data/dataconverter.py @@ -182,11 +182,13 @@ def _assign_event_no( dataframe_dict[extractor_name].append(df) else: dataframe_dict[extractor_name] = [df] - # Merge each list of dataframes - for key in dataframe_dict.keys(): - dataframe_dict[key] = pd.concat( - dataframe_dict[key], axis=0 - ).reset_index(drop=True) + + # Merge each list of dataframes if wanted by writer + if self._save_method.expects_merged_dataframes: + for key in dataframe_dict.keys(): + dataframe_dict[key] = pd.concat( + dataframe_dict[key], axis=0 + ).reset_index(drop=True) return dataframe_dict @final diff --git a/src/graphnet/data/writers/graphnet_writer.py b/src/graphnet/data/writers/graphnet_writer.py index 04ee079f4..330a3d868 100644 --- a/src/graphnet/data/writers/graphnet_writer.py +++ b/src/graphnet/data/writers/graphnet_writer.py @@ -87,3 +87,8 @@ def __call__( def file_extension(self) -> str: """Return file extension used to store the data.""" return self._file_extension # type: ignore + + @property + def expects_merged_dataframes(self) -> bool: + """Return if writer expects input to be merged dataframes or not.""" + return self._merge_dataframes # type: ignore diff --git a/src/graphnet/data/writers/parquet_writer.py b/src/graphnet/data/writers/parquet_writer.py index df64dc91a..755a829c1 100644 --- a/src/graphnet/data/writers/parquet_writer.py +++ b/src/graphnet/data/writers/parquet_writer.py @@ -14,6 +14,7 @@ class ParquetWriter(GraphNeTWriter): # Class variables _file_extension = ".parquet" + _merge_dataframes = False # Abstract method implementation(s) def _save_file( @@ -23,7 +24,18 @@ def _save_file( n_events: int, ) -> None: """Save data to parquet.""" - raise NotImplementedError + # Check(s) + + if n_events > 0: + events = [] + for k in range(n_events): + event = {} + for table in data.keys(): + event[table] = data[table][k].to_dict(orient="list") + + events.append(event) + + awkward.to_parquet(awkward.from_iter(events), output_file_path) def merge_files(self, files: List[str], output_dir: str) -> None: """Merge parquet files. diff --git a/src/graphnet/data/writers/sqlite_writer.py b/src/graphnet/data/writers/sqlite_writer.py index e9f400c53..d7cc48297 100644 --- a/src/graphnet/data/writers/sqlite_writer.py +++ b/src/graphnet/data/writers/sqlite_writer.py @@ -40,6 +40,7 @@ def __init__( """ # Member Variables self._file_extension = ".db" + self._merge_dataframes = True self._max_table_size = max_table_size self._database_name = merged_database_name From 198dce4f1204a4ec1189684990f18253db9d151c Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Sun, 11 Feb 2024 11:42:09 +0100 Subject: [PATCH 08/33] remove is_pulse_map unit test --- tests/data/test_dataconverters_and_datasets.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/tests/data/test_dataconverters_and_datasets.py b/tests/data/test_dataconverters_and_datasets.py index 53a73a5f4..e1d9e773b 100644 --- a/tests/data/test_dataconverters_and_datasets.py +++ b/tests/data/test_dataconverters_and_datasets.py @@ -19,7 +19,6 @@ from graphnet.data.parquet import ParquetDataConverter from graphnet.data.dataset import ParquetDataset, SQLiteDataset from graphnet.data.sqlite import SQLiteDataConverter -from graphnet.data.sqlite.sqlite_dataconverter import is_pulse_map from graphnet.data.utilities.parquet_to_sqlite import ParquetToSQLiteConverter from graphnet.utilities.imports import has_icecube_package from graphnet.models.graphs import KNNGraph @@ -52,17 +51,6 @@ def get_file_path(backend: str) -> str: return path -# Unit test(s) -def test_is_pulsemap_check() -> None: - """Test behaviour of `is_pulsemap_check`.""" - assert is_pulse_map("SplitInIcePulses") is True - assert is_pulse_map("SRTInIcePulses") is True - assert is_pulse_map("InIceDSTPulses") is True - assert is_pulse_map("RTTWOfflinePulses") is True - assert is_pulse_map("truth") is False - assert is_pulse_map("retro") is False - - @pytest.mark.order(1) @pytest.mark.parametrize("backend", ["sqlite", "parquet"]) def test_dataconverter( From 9cf58916d96b97a0353a2b12deef7f1710cd670c Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Sun, 11 Feb 2024 12:12:32 +0100 Subject: [PATCH 09/33] change num_workers -> workers in deprecated methods. Adjust output file name in dataconverter --- src/graphnet/data/dataconverter.py | 7 ++++--- src/graphnet/data/parquet/deprecated_methods.py | 14 +++++++------- src/graphnet/data/sqlite/deprecated_methods.py | 14 +++++++------- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/src/graphnet/data/dataconverter.py b/src/graphnet/data/dataconverter.py index 7160a0c2e..f50912e4e 100644 --- a/src/graphnet/data/dataconverter.py +++ b/src/graphnet/data/dataconverter.py @@ -150,9 +150,10 @@ def _create_file_name(self, input_file_path: Union[str, I3FileSet]) -> str: """Convert input file path to an output file name.""" if isinstance(input_file_path, I3FileSet): input_file_path = input_file_path.i3_file - path_without_extension = os.path.splitext(input_file_path)[0] - base_file_name = path_without_extension.split("/")[-1] - return base_file_name # type: ignore + file_name = os.path.basename(input_file_path) + index_of_dot = file_name.index(".") + file_name_without_extension = file_name[:index_of_dot] + return file_name_without_extension # type: ignore @final def _assign_event_no( diff --git a/src/graphnet/data/parquet/deprecated_methods.py b/src/graphnet/data/parquet/deprecated_methods.py index 299cdbae8..717e798bb 100644 --- a/src/graphnet/data/parquet/deprecated_methods.py +++ b/src/graphnet/data/parquet/deprecated_methods.py @@ -21,7 +21,7 @@ def __init__( extractors: Union[Type[I3Extractor], List[Type[I3Extractor]]], outdir: str, index_column: str = "event_no", - num_workers: int = 1, + workers: int = 1, i3_filters: Union[ Type[I3Filter], List[Type[I3Filter]] ] = NullSplitI3Filter(), # type: ignore @@ -44,20 +44,20 @@ def __init__( Defaults to 0. index_column: Name of the event id column added to the events. Defaults to "event_no". - num_workers: The number of CPUs used for parallel processing. + workers: The number of CPUs used for parallel processing. Defaults to 1 (no multiprocessing). i3_filters: Instances of `I3Filter` to filter PFrames. Defaults to `NullSplitI3Filter`. """ - self.warning( - f"{self.__class__.__name__} will be deprecated in " - "GraphNeT 2.0. Please use I3ToParquetConverter instead." - ) super().__init__( extractors=extractors, - num_workers=num_workers, + num_workers=workers, index_column=index_column, i3_filters=i3_filters, outdir=outdir, gcd_rescue=gcd_rescue, ) + self.warning( + f"{self.__class__.__name__} will be deprecated in " + "GraphNeT 2.0. Please use I3ToParquetConverter instead." + ) diff --git a/src/graphnet/data/sqlite/deprecated_methods.py b/src/graphnet/data/sqlite/deprecated_methods.py index 3dd1e04b5..f3da0d10f 100644 --- a/src/graphnet/data/sqlite/deprecated_methods.py +++ b/src/graphnet/data/sqlite/deprecated_methods.py @@ -22,7 +22,7 @@ def __init__( extractors: Union[Type[I3Extractor], List[Type[I3Extractor]]], outdir: str, index_column: str = "event_no", - num_workers: int = 1, + workers: int = 1, i3_filters: Union[ Type[I3Filter], List[Type[I3Filter]] ] = NullSplitI3Filter(), # type: ignore @@ -45,20 +45,20 @@ def __init__( Defaults to 0. index_column: Name of the event id column added to the events. Defaults to "event_no". - num_workers: The number of CPUs used for parallel processing. + workers: The number of CPUs used for parallel processing. Defaults to 1 (no multiprocessing). i3_filters: Instances of `I3Filter` to filter PFrames. Defaults to `NullSplitI3Filter`. """ - self.warning( - f"{self.__class__.__name__} will be deprecated in " - "GraphNeT 2.0. Please use I3ToSQLiteConverter instead." - ) super().__init__( extractors=extractors, - num_workers=num_workers, + num_workers=workers, index_column=index_column, i3_filters=i3_filters, outdir=outdir, gcd_rescue=gcd_rescue, ) + self.warning( + f"{self.__class__.__name__} will be deprecated in " + "GraphNeT 2.0. Please use I3ToSQLiteConverter instead." + ) From e38ded5cb763988d8b7e6fb9e09f87db3e96bbc5 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Sun, 11 Feb 2024 14:14:48 +0100 Subject: [PATCH 10/33] fix examples --- examples/01_icetray/01_convert_i3_files.py | 3 --- examples/01_icetray/03_i3_deployer_example.py | 2 +- examples/01_icetray/04_i3_module_in_native_icetray_example.py | 2 +- 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/examples/01_icetray/01_convert_i3_files.py b/examples/01_icetray/01_convert_i3_files.py index 9a39f95e7..279adf6e0 100644 --- a/examples/01_icetray/01_convert_i3_files.py +++ b/examples/01_icetray/01_convert_i3_files.py @@ -72,9 +72,6 @@ def main_icecube_upgrade(backend: str) -> None: ], outdir, workers=workers, - # nb_files_to_batch=10, - # sequential_batch_pattern="temp_{:03d}", - # input_file_batch_pattern="[A-Z]{1}_[0-9]{5}*.i3.zst", icetray_verbose=1, ) converter(inputs) diff --git a/examples/01_icetray/03_i3_deployer_example.py b/examples/01_icetray/03_i3_deployer_example.py index f55aa769c..28d73c00d 100644 --- a/examples/01_icetray/03_i3_deployer_example.py +++ b/examples/01_icetray/03_i3_deployer_example.py @@ -10,7 +10,7 @@ PRETRAINED_MODEL_DIR, ) from graphnet.data.constants import FEATURES, TRUTH -from graphnet.data.extractors.i3featureextractor import ( +from graphnet.data.extractors.icecube import ( I3FeatureExtractorIceCubeUpgrade, ) from graphnet.utilities.argparse import ArgumentParser diff --git a/examples/01_icetray/04_i3_module_in_native_icetray_example.py b/examples/01_icetray/04_i3_module_in_native_icetray_example.py index 74da5e499..ab8c1b58c 100644 --- a/examples/01_icetray/04_i3_module_in_native_icetray_example.py +++ b/examples/01_icetray/04_i3_module_in_native_icetray_example.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, List, Sequence from graphnet.data.constants import FEATURES -from graphnet.data.extractors.i3featureextractor import ( +from graphnet.data.extractors.icecube import ( I3FeatureExtractorIceCubeUpgrade, ) from graphnet.constants import ( From 19ce2392f017b59c15ba2f052658f84cb1cd5d59 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Sun, 11 Feb 2024 14:17:28 +0100 Subject: [PATCH 11/33] remove unused imports in extractor.py --- src/graphnet/data/extractors/extractor.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/graphnet/data/extractors/extractor.py b/src/graphnet/data/extractors/extractor.py index b5e5ed37c..3e9a8f715 100644 --- a/src/graphnet/data/extractors/extractor.py +++ b/src/graphnet/data/extractors/extractor.py @@ -1,13 +1,12 @@ """Base I3Extractor class(es).""" - +from typing import TYPE_CHECKING from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Dict, List, Optional from graphnet.utilities.imports import has_icecube_package from graphnet.utilities.logging import Logger if has_icecube_package() or TYPE_CHECKING: - from icecube import icetray, dataio # pyright: reportMissingImports=false + from icecube import icetray # pyright: reportMissingImports=false class Extractor(ABC, Logger): From b6bf74a62829b28db67bfc6ed3930fd090398405 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Sun, 11 Feb 2024 14:40:29 +0100 Subject: [PATCH 12/33] Type hint test --- src/graphnet/data/dataconverter.py | 8 ++++---- src/graphnet/data/readers/graphnet_file_reader.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/graphnet/data/dataconverter.py b/src/graphnet/data/dataconverter.py index f50912e4e..12926e510 100644 --- a/src/graphnet/data/dataconverter.py +++ b/src/graphnet/data/dataconverter.py @@ -36,10 +36,10 @@ class DataConverter(ABC, Logger): def __init__( self, - file_reader: Type[GraphNeTFileReader], - save_method: Type[GraphNeTWriter], + file_reader: GraphNeTFileReader, + save_method: GraphNeTWriter, outdir: str, - extractors: Union[Type[Extractor], List[Type[Extractor]]], + extractors: Union[Extractor, List[Extractor]], index_column: str = "event_no", num_workers: int = 1, ) -> None: @@ -68,7 +68,7 @@ def __init__( # Set Extractors. Will throw error if extractors are incompatible # with reader. - self._file_reader.set_extractors(extractors) + self._file_reader.set_extractors(extractors=extractors) # Base class constructor super().__init__(name=__name__, class_name=self.__class__.__name__) diff --git a/src/graphnet/data/readers/graphnet_file_reader.py b/src/graphnet/data/readers/graphnet_file_reader.py index ab6464e13..c9d859335 100644 --- a/src/graphnet/data/readers/graphnet_file_reader.py +++ b/src/graphnet/data/readers/graphnet_file_reader.py @@ -27,7 +27,7 @@ class properties `accepted_file_extensions` and `accepted_extractors`. """ @abstractmethod - def __call__(self, file_path: str) -> List[OrderedDict]: + def __call__(self, file_path: Union[str, I3FileSet]) -> List[OrderedDict]: """Open and apply extractors to a single file. The `output` must be a list of dictionaries, where the number of events From a10b8334fc12f105d0cea20673e0a60b65245630 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Sun, 11 Feb 2024 14:41:31 +0100 Subject: [PATCH 13/33] type hint test --- src/graphnet/data/dataconverter.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/graphnet/data/dataconverter.py b/src/graphnet/data/dataconverter.py index 12926e510..77c4839c1 100644 --- a/src/graphnet/data/dataconverter.py +++ b/src/graphnet/data/dataconverter.py @@ -84,7 +84,7 @@ def __call__(self, input_dir: Union[str, List[str]]) -> None: """ # Get the file reader to produce a list of input files # in the directory - input_files = self._file_reader.find_files(path=input_dir) # type: ignore + input_files = self._file_reader.find_files(path=input_dir) self._launch_jobs(input_files=input_files) self._output_files = glob( os.path.join( @@ -129,7 +129,7 @@ def _process_file(self, file_path: Union[str, I3FileSet]) -> None: """ # Read and apply extractors data = self._file_reader(file_path=file_path) - n_events = len(data) # type: ignore + n_events = len(data) # Assign event_no's to each event in data and transform to pd.DataFrame data = self._assign_event_no(data=data) @@ -153,7 +153,7 @@ def _create_file_name(self, input_file_path: Union[str, I3FileSet]) -> str: file_name = os.path.basename(input_file_path) index_of_dot = file_name.index(".") file_name_without_extension = file_name[:index_of_dot] - return file_name_without_extension # type: ignore + return file_name_without_extension @final def _assign_event_no( @@ -313,7 +313,7 @@ def merge_files(self, files: Optional[List[str]] = None) -> None: # Merge files merge_path = os.path.join(self._output_dir, "merged") self.info(f"Merging files to {merge_path}") - self._save_method.merge_files( # type:ignore + self._save_method.merge_files( files=files_to_merge, output_dir=merge_path, ) From 9a3288c989c4d016218b91ae27031b013cd6c402 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Sun, 11 Feb 2024 14:54:37 +0100 Subject: [PATCH 14/33] type hints --- src/graphnet/data/readers/graphnet_file_reader.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/graphnet/data/readers/graphnet_file_reader.py b/src/graphnet/data/readers/graphnet_file_reader.py index c9d859335..c24ab12a4 100644 --- a/src/graphnet/data/readers/graphnet_file_reader.py +++ b/src/graphnet/data/readers/graphnet_file_reader.py @@ -26,6 +26,9 @@ class GraphNeTFileReader(Logger, ABC): class properties `accepted_file_extensions` and `accepted_extractors`. """ + _accepted_file_extensions: List[str] = [] + _accepted_extractors: List[Extractor] = [] + @abstractmethod def __call__(self, file_path: Union[str, I3FileSet]) -> List[OrderedDict]: """Open and apply extractors to a single file. @@ -39,17 +42,17 @@ def __call__(self, file_path: Union[str, I3FileSet]) -> List[OrderedDict]: @property def accepted_file_extensions(self) -> List[str]: """Return list of accepted file extensions.""" - return self._accepted_file_extensions # type: ignore + return self._accepted_file_extensions @property def accepted_extractors(self) -> List[Extractor]: """Return list of compatible `Extractor`(s).""" - return self._accepted_extractors # type: ignore + return self._accepted_extractors @property def extracor_names(self) -> List[str]: """Return list of table names produced by extractors.""" - return [extractor.name for extractor in self._extractors] # type: ignore + return [extractor.name for extractor in self._extractors] def find_files( self, path: Union[str, List[str]] From b619017fa151a4b7e408224bbf953ac7b300e2dc Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Sun, 11 Feb 2024 14:56:05 +0100 Subject: [PATCH 15/33] delete I3GenericExtractor unit test --- tests/data/test_i3genericextractor.py | 97 --------------------------- 1 file changed, 97 deletions(-) delete mode 100644 tests/data/test_i3genericextractor.py diff --git a/tests/data/test_i3genericextractor.py b/tests/data/test_i3genericextractor.py deleted file mode 100644 index e77727eaf..000000000 --- a/tests/data/test_i3genericextractor.py +++ /dev/null @@ -1,97 +0,0 @@ -"""Unit tests for I3GenericExtractor class.""" - -import os - -import numpy as np - -import graphnet.constants -from graphnet.data.extractors.icecube import ( - I3FeatureExtractorIceCube86, - I3TruthExtractor, - I3GenericExtractor, -) -from graphnet.utilities.imports import has_icecube_package - -if has_icecube_package(): - from icecube import dataio # pyright: reportMissingImports=false - -# Global variable(s) -TEST_DATA_DIR = os.path.join( - graphnet.constants.TEST_DATA_DIR, "i3", "oscNext_genie_level7_v02" -) -FILE_NAME = "oscNext_genie_level7_v02_first_5_frames" -GCD_FILE = ( - "GeoCalibDetectorStatus_AVG_55697-57531_PASS2_SPE_withScaledNoise.i3.gz" -) - - -# Unit test(s) -def test_i3genericextractor(test_data_dir: str = TEST_DATA_DIR) -> None: - """Test the implementation of `I3GenericExtractor`.""" - # Constants(s) - mc_tree = "I3MCTree" - pulse_series = "SRTInIcePulses" - - # Constructor I3Extractor instance(s) - generic_extractor = I3GenericExtractor(keys=[mc_tree, pulse_series]) - truth_extractor = I3TruthExtractor() - feature_extractor = I3FeatureExtractorIceCube86(pulse_series) - - i3_file = os.path.join(test_data_dir, FILE_NAME) + ".i3.gz" - gcd_file = os.path.join(test_data_dir, GCD_FILE) - - generic_extractor.set_gcd(i3_file, gcd_file) - truth_extractor.set_gcd(i3_file, gcd_file) - feature_extractor.set_gcd(i3_file, gcd_file) - - i3_file_io = dataio.I3File(i3_file, "r") - ix_test = 5 - while i3_file_io.more(): - try: - frame = i3_file_io.pop_physics() - except: # noqa: E722 - continue - - generic_data = generic_extractor(frame) - truth_data = truth_extractor(frame) - feature_data = feature_extractor(frame) - - if ix_test == 5: - print(list(generic_data[pulse_series].keys())) - print(list(truth_data.keys())) - print(list(feature_data.keys())) - - # Truth vs. generic - key_pairs = [ - ("energy", "energy"), - ("zenith", "dir__zenith"), - ("azimuth", "dir__azimuth"), - ("pid", "pdg_encoding"), - ] - - for truth_key, generic_key in key_pairs: - assert ( - truth_data[truth_key] - == generic_data[f"{mc_tree}__primaries"][generic_key][0] - ) - - # Reco vs. generic - key_pairs = [ - ("charge", "charge"), - ("dom_time", "time"), - ("dom_x", "position__x"), - ("dom_y", "position__y"), - ("dom_z", "position__z"), - ("width", "width"), - ("pmt_area", "area"), - ("rde", "relative_dom_eff"), - ] - - for reco_key, generic_key in key_pairs: - assert np.allclose( - feature_data[reco_key], generic_data[pulse_series][generic_key] - ) - - ix_test -= 1 - if ix_test == 0: - break From 636e116db6ea2a5bf1481e19536684dcdd4df5aa Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Sun, 11 Feb 2024 20:40:38 +0100 Subject: [PATCH 16/33] test --- tests/data/test_i3extractor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/data/test_i3extractor.py b/tests/data/test_i3extractor.py index f1c8c3ff7..ce40626c0 100644 --- a/tests/data/test_i3extractor.py +++ b/tests/data/test_i3extractor.py @@ -1,4 +1,4 @@ -"""Unit tests for I3Extractor class.""" +"""Unit tests for I3Extractor.""" from graphnet.data.extractors.icecube import ( I3FeatureExtractorIceCube86, From 31a448659611db1e2b0d7c4704783fc35b6f7e29 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Mon, 12 Feb 2024 08:36:49 +0100 Subject: [PATCH 17/33] test --- src/graphnet/data/dataconverter.py | 4 ++-- src/graphnet/data/readers/graphnet_file_reader.py | 10 +++++++--- src/graphnet/data/readers/i3reader.py | 12 +++--------- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/src/graphnet/data/dataconverter.py b/src/graphnet/data/dataconverter.py index 77c4839c1..aabf465ae 100644 --- a/src/graphnet/data/dataconverter.py +++ b/src/graphnet/data/dataconverter.py @@ -157,8 +157,8 @@ def _create_file_name(self, input_file_path: Union[str, I3FileSet]) -> str: @final def _assign_event_no( - self, data: List[OrderedDict[str, Any]] - ) -> Dict[str, pd.DataFrame]: + self, data: List[OrderedDict] + ) -> Union[Dict[str, pd.DataFrame], Dict[str, List[pd.DataFrame]]]: # Request event_no's for the entire file event_nos = self._request_event_nos(n_ids=len(data)) diff --git a/src/graphnet/data/readers/graphnet_file_reader.py b/src/graphnet/data/readers/graphnet_file_reader.py index c24ab12a4..87f829fac 100644 --- a/src/graphnet/data/readers/graphnet_file_reader.py +++ b/src/graphnet/data/readers/graphnet_file_reader.py @@ -4,7 +4,7 @@ file formats. """ -from typing import List, Union, OrderedDict +from typing import List, Union, OrderedDict, Any from abc import abstractmethod, ABC import glob import os @@ -27,7 +27,7 @@ class properties `accepted_file_extensions` and `accepted_extractors`. """ _accepted_file_extensions: List[str] = [] - _accepted_extractors: List[Extractor] = [] + _accepted_extractors: List[Any] = [] @abstractmethod def __call__(self, file_path: Union[str, I3FileSet]) -> List[OrderedDict]: @@ -79,13 +79,17 @@ def find_files( return files @final - def set_extractors(self, extractors: List[Extractor]) -> None: + def set_extractors( + self, extractors: Union[Extractor, List[Extractor]] + ) -> None: """Set `Extractor`(s) as member variable. Args: extractors: A list of `Extractor`(s) to set as member variable. """ self._validate_extractors(extractors) + if not isinstance(extractors, list): + extractors = [extractors] self._extractors = extractors @final diff --git a/src/graphnet/data/readers/i3reader.py b/src/graphnet/data/readers/i3reader.py index 926c2395a..523367943 100644 --- a/src/graphnet/data/readers/i3reader.py +++ b/src/graphnet/data/readers/i3reader.py @@ -27,9 +27,7 @@ class I3Reader(GraphNeTFileReader): def __init__( self, gcd_rescue: str, - i3_filters: Union[ - Type[I3Filter], List[Type[I3Filter]] - ] = NullSplitI3Filter(), # type: ignore + i3_filters: Union[I3Filter, List[I3Filter]] = None, icetray_verbose: int = 0, ): """Initialize `I3Reader`. @@ -52,6 +50,8 @@ def __init__( if icetray_verbose == 0: icetray.I3Logger.global_logger = icetray.I3NullLogger() + if i3_filters is None: + i3_filters = [NullSplitI3Filter()] # Set Member Variables self._accepted_file_extensions = [".bz2", ".zst", ".gz"] self._accepted_extractors = [I3Extractor] @@ -97,12 +97,6 @@ def __call__(self, file_path: I3FileSet) -> List[OrderedDict]: # type: ignore data_dict = OrderedDict(zip(self.extracor_names, results)) - # If an I3GenericExtractor is used, we want each automatically - # parsed key to be stored as a separate table. - # for extractor in self._extractors: - # if isinstance(extractor, I3GenericExtractor): - # data_dict.update(data_dict.pop(extractor._name)) - data.append(data_dict) return data From dec554cc7535e949ad7be1201ea7e892e78695af Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Mon, 12 Feb 2024 08:37:47 +0100 Subject: [PATCH 18/33] update import in graphnet.pisa --- src/graphnet/pisa/fitting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/graphnet/pisa/fitting.py b/src/graphnet/pisa/fitting.py index dfcc20a37..5408f9bfc 100644 --- a/src/graphnet/pisa/fitting.py +++ b/src/graphnet/pisa/fitting.py @@ -23,7 +23,7 @@ from pisa.analysis.analysis import Analysis from pisa import ureg -from graphnet.data.sqlite import create_table_and_save_to_sql +from graphnet.data.utilities import create_table_and_save_to_sql mpl.use("pdf") plt.rc("font", family="serif") From 4c428140d36cdcc64b7103bbd7806487bc8dac8f Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Mon, 12 Feb 2024 08:40:35 +0100 Subject: [PATCH 19/33] polish 01-01 --- examples/01_icetray/01_convert_i3_files.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/01_icetray/01_convert_i3_files.py b/examples/01_icetray/01_convert_i3_files.py index 279adf6e0..6a9d010ec 100644 --- a/examples/01_icetray/01_convert_i3_files.py +++ b/examples/01_icetray/01_convert_i3_files.py @@ -42,12 +42,12 @@ def main_icecube86(backend: str) -> None: inputs = [f"{TEST_DATA_DIR}/i3/oscNext_genie_level7_v02"] outdir = f"{EXAMPLE_OUTPUT_DIR}/convert_i3_files/ic86" - converter: DataConverter = CONVERTER_CLASS[backend]( - [ + converter = CONVERTER_CLASS[backend]( + extractors=[ I3FeatureExtractorIceCube86("SRTInIcePulses"), I3TruthExtractor(), ], - outdir, + outdir=outdir, ) converter(inputs) if backend == "sqlite": From 10f4b36c6976078193ff667bdb2461b643b3fc2d Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Mon, 12 Feb 2024 08:47:31 +0100 Subject: [PATCH 20/33] mypy --- src/graphnet/data/readers/graphnet_file_reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/graphnet/data/readers/graphnet_file_reader.py b/src/graphnet/data/readers/graphnet_file_reader.py index 87f829fac..d168ba09d 100644 --- a/src/graphnet/data/readers/graphnet_file_reader.py +++ b/src/graphnet/data/readers/graphnet_file_reader.py @@ -87,9 +87,9 @@ def set_extractors( Args: extractors: A list of `Extractor`(s) to set as member variable. """ - self._validate_extractors(extractors) if not isinstance(extractors, list): extractors = [extractors] + self._validate_extractors(extractors) self._extractors = extractors @final From 4c22db923be333f78504324031ebc428555ad61d Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Mon, 12 Feb 2024 09:18:37 +0100 Subject: [PATCH 21/33] polish 01-01 --- examples/01_icetray/01_convert_i3_files.py | 15 ++++++++-- src/graphnet/data/dataconverter.py | 7 +++-- .../data/parquet/deprecated_methods.py | 7 ++--- .../data/pre_configured/dataconverters.py | 29 +++++++------------ .../data/sqlite/deprecated_methods.py | 6 ++-- 5 files changed, 32 insertions(+), 32 deletions(-) diff --git a/examples/01_icetray/01_convert_i3_files.py b/examples/01_icetray/01_convert_i3_files.py index 6a9d010ec..870fd09f4 100644 --- a/examples/01_icetray/01_convert_i3_files.py +++ b/examples/01_icetray/01_convert_i3_files.py @@ -1,6 +1,7 @@ """Example of converting I3-files to SQLite and Parquet.""" import os +from glob import glob from graphnet.constants import EXAMPLE_OUTPUT_DIR, TEST_DATA_DIR from graphnet.data.extractors.icecube import ( @@ -41,6 +42,9 @@ def main_icecube86(backend: str) -> None: inputs = [f"{TEST_DATA_DIR}/i3/oscNext_genie_level7_v02"] outdir = f"{EXAMPLE_OUTPUT_DIR}/convert_i3_files/ic86" + gcd_rescue = glob( + "{TEST_DATA_DIR}/i3/oscNext_genie_level7_v02/*GeoCalib*" + )[0] converter = CONVERTER_CLASS[backend]( extractors=[ @@ -48,6 +52,8 @@ def main_icecube86(backend: str) -> None: I3TruthExtractor(), ], outdir=outdir, + gcd_rescue=gcd_rescue, + workers=1, ) converter(inputs) if backend == "sqlite": @@ -61,18 +67,21 @@ def main_icecube_upgrade(backend: str) -> None: inputs = [f"{TEST_DATA_DIR}/i3/upgrade_genie_step4_140028_000998"] outdir = f"{EXAMPLE_OUTPUT_DIR}/convert_i3_files/upgrade" + gcd_rescue = glob( + "{TEST_DATA_DIR}/i3/upgrade_genie_step4_140028_000998/*GeoCalib*" + )[0] workers = 1 converter: DataConverter = CONVERTER_CLASS[backend]( - [ + extractors=[ I3TruthExtractor(), I3RetroExtractor(), I3FeatureExtractorIceCubeUpgrade("I3RecoPulseSeriesMap_mDOM"), I3FeatureExtractorIceCubeUpgrade("I3RecoPulseSeriesMap_DEgg"), ], - outdir, + outdir=outdir, workers=workers, - icetray_verbose=1, + gcd_rescue=gcd_rescue, ) converter(inputs) if backend == "sqlite": diff --git a/src/graphnet/data/dataconverter.py b/src/graphnet/data/dataconverter.py index aabf465ae..828efa199 100644 --- a/src/graphnet/data/dataconverter.py +++ b/src/graphnet/data/dataconverter.py @@ -17,6 +17,7 @@ from .readers.graphnet_file_reader import GraphNeTFileReader from .writers.graphnet_writer import GraphNeTWriter from .extractors import Extractor +from .extractors.icecube import I3Extractor from .dataclasses import I3FileSet @@ -39,7 +40,7 @@ def __init__( file_reader: GraphNeTFileReader, save_method: GraphNeTWriter, outdir: str, - extractors: Union[Extractor, List[Extractor]], + extractors: Union[List[Extractor], List[I3Extractor]], index_column: str = "event_no", num_workers: int = 1, ) -> None: @@ -68,6 +69,8 @@ def __init__( # Set Extractors. Will throw error if extractors are incompatible # with reader. + if not isinstance(extractors, list): + extractors = [extractors] self._file_reader.set_extractors(extractors=extractors) # Base class constructor @@ -132,7 +135,7 @@ def _process_file(self, file_path: Union[str, I3FileSet]) -> None: n_events = len(data) # Assign event_no's to each event in data and transform to pd.DataFrame - data = self._assign_event_no(data=data) + data = self._assign_event_no(data=data) # type: ignore # Create output file name output_file_name = self._create_file_name(input_file_path=file_path) diff --git a/src/graphnet/data/parquet/deprecated_methods.py b/src/graphnet/data/parquet/deprecated_methods.py index 717e798bb..423e1aa00 100644 --- a/src/graphnet/data/parquet/deprecated_methods.py +++ b/src/graphnet/data/parquet/deprecated_methods.py @@ -7,7 +7,6 @@ from graphnet.data.extractors.icecube import I3Extractor from graphnet.data.extractors.icecube.utilities.i3_filters import ( I3Filter, - NullSplitI3Filter, ) from graphnet.data import I3ToParquetConverter @@ -18,13 +17,11 @@ class ParquetDataConverter(I3ToParquetConverter): def __init__( self, gcd_rescue: str, - extractors: Union[Type[I3Extractor], List[Type[I3Extractor]]], + extractors: List[I3Extractor], outdir: str, index_column: str = "event_no", workers: int = 1, - i3_filters: Union[ - Type[I3Filter], List[Type[I3Filter]] - ] = NullSplitI3Filter(), # type: ignore + i3_filters: Union[I3Filter, List[I3Filter]] = None, # type: ignore ): """Convert I3 files to Parquet. diff --git a/src/graphnet/data/pre_configured/dataconverters.py b/src/graphnet/data/pre_configured/dataconverters.py index fcd26fd49..63d8e61ab 100644 --- a/src/graphnet/data/pre_configured/dataconverters.py +++ b/src/graphnet/data/pre_configured/dataconverters.py @@ -6,10 +6,7 @@ from graphnet.data.readers import I3Reader from graphnet.data.writers import ParquetWriter, SQLiteWriter from graphnet.data.extractors.icecube import I3Extractor -from graphnet.data.extractors.icecube.utilities.i3_filters import ( - I3Filter, - NullSplitI3Filter, -) +from graphnet.data.extractors.icecube.utilities.i3_filters import I3Filter class I3ToParquetConverter(DataConverter): @@ -18,13 +15,11 @@ class I3ToParquetConverter(DataConverter): def __init__( self, gcd_rescue: str, - extractors: Union[Type[I3Extractor], List[Type[I3Extractor]]], + extractors: List[I3Extractor], outdir: str, index_column: str = "event_no", num_workers: int = 1, - i3_filters: Union[ - Type[I3Filter], List[Type[I3Filter]] - ] = NullSplitI3Filter(), # type: ignore + i3_filters: Union[I3Filter, List[I3Filter]] = None, # type: ignore ): """Convert I3 files to Parquet. @@ -50,9 +45,9 @@ def __init__( `NullSplitI3Filter`. """ super().__init__( - file_reader=I3Reader(gcd_rescue=gcd_rescue, i3_filters=i3_filters), # type: ignore - save_method=ParquetWriter(), # type: ignore - extractors=extractors, # type: ignore + file_reader=I3Reader(gcd_rescue=gcd_rescue, i3_filters=i3_filters), + save_method=ParquetWriter(), + extractors=extractors, num_workers=num_workers, index_column=index_column, outdir=outdir, @@ -65,13 +60,11 @@ class I3ToSQLiteConverter(DataConverter): def __init__( self, gcd_rescue: str, - extractors: Union[Type[I3Extractor], List[Type[I3Extractor]]], + extractors: List[I3Extractor], outdir: str, index_column: str = "event_no", num_workers: int = 1, - i3_filters: Union[ - Type[I3Filter], List[Type[I3Filter]] - ] = NullSplitI3Filter(), # type: ignore + i3_filters: Union[I3Filter, List[I3Filter]] = None, # type: ignore ): """Convert I3 files to Parquet. @@ -97,9 +90,9 @@ def __init__( `NullSplitI3Filter`. """ super().__init__( - file_reader=I3Reader(gcd_rescue=gcd_rescue, i3_filters=i3_filters), # type: ignore - save_method=SQLiteWriter(), # type: ignore - extractors=extractors, # type: ignore + file_reader=I3Reader(gcd_rescue=gcd_rescue, i3_filters=i3_filters), + save_method=SQLiteWriter(), + extractors=extractors, num_workers=num_workers, index_column=index_column, outdir=outdir, diff --git a/src/graphnet/data/sqlite/deprecated_methods.py b/src/graphnet/data/sqlite/deprecated_methods.py index f3da0d10f..30b563c59 100644 --- a/src/graphnet/data/sqlite/deprecated_methods.py +++ b/src/graphnet/data/sqlite/deprecated_methods.py @@ -19,13 +19,11 @@ class SQLiteDataConverter(I3ToSQLiteConverter): def __init__( self, gcd_rescue: str, - extractors: Union[Type[I3Extractor], List[Type[I3Extractor]]], + extractors: List[I3Extractor], outdir: str, index_column: str = "event_no", workers: int = 1, - i3_filters: Union[ - Type[I3Filter], List[Type[I3Filter]] - ] = NullSplitI3Filter(), # type: ignore + i3_filters: Union[I3Filter, List[I3Filter]] = None, # type: ignore ): """Convert I3 files to Parquet. From 664189bf97bda05d359658388d86b83fc32384be Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Mon, 12 Feb 2024 10:22:15 +0100 Subject: [PATCH 22/33] mypy.. --- src/graphnet/data/readers/graphnet_file_reader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/graphnet/data/readers/graphnet_file_reader.py b/src/graphnet/data/readers/graphnet_file_reader.py index d168ba09d..13a01faf9 100644 --- a/src/graphnet/data/readers/graphnet_file_reader.py +++ b/src/graphnet/data/readers/graphnet_file_reader.py @@ -13,6 +13,7 @@ from graphnet.utilities.logging import Logger from graphnet.data.dataclasses import I3FileSet from graphnet.data.extractors.extractor import Extractor +from graphnet.data.extractors.icecube import I3Extractor class GraphNeTFileReader(Logger, ABC): @@ -80,7 +81,7 @@ def find_files( @final def set_extractors( - self, extractors: Union[Extractor, List[Extractor]] + self, extractors: Union[List[Extractor], List[I3Extractor]] ) -> None: """Set `Extractor`(s) as member variable. From 83a011c91b9cf89cbca5e20a4baa4c825b33503b Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Mon, 12 Feb 2024 10:45:29 +0100 Subject: [PATCH 23/33] mypy --- src/graphnet/data/writers/graphnet_writer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/graphnet/data/writers/graphnet_writer.py b/src/graphnet/data/writers/graphnet_writer.py index 330a3d868..518823b5a 100644 --- a/src/graphnet/data/writers/graphnet_writer.py +++ b/src/graphnet/data/writers/graphnet_writer.py @@ -5,7 +5,7 @@ """ import os -from typing import Dict, List +from typing import Dict, List, Union from abc import abstractmethod, ABC from graphnet.utilities.decorators import final @@ -58,7 +58,7 @@ def merge_files( @final def __call__( self, - data: Dict[str, pd.DataFrame], + data: Union[Dict[str, pd.DataFrame], Dict[str, List[pd.DataFrame]]], file_name: str, output_dir: str, n_events: int, From 6aca5d8013e6c0128eb69e8138fc3e82f6631a14 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Mon, 12 Feb 2024 10:50:42 +0100 Subject: [PATCH 24/33] polish 01-01 --- examples/01_icetray/01_convert_i3_files.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/01_icetray/01_convert_i3_files.py b/examples/01_icetray/01_convert_i3_files.py index 870fd09f4..9f0795cb1 100644 --- a/examples/01_icetray/01_convert_i3_files.py +++ b/examples/01_icetray/01_convert_i3_files.py @@ -57,7 +57,7 @@ def main_icecube86(backend: str) -> None: ) converter(inputs) if backend == "sqlite": - converter.merge_files(os.path.join(outdir, "merged")) + converter.merge_files() def main_icecube_upgrade(backend: str) -> None: @@ -85,7 +85,7 @@ def main_icecube_upgrade(backend: str) -> None: ) converter(inputs) if backend == "sqlite": - converter.merge_files(os.path.join(outdir, "merged")) + converter.merge_files() if __name__ == "__main__": From 398c27ac9c1f3a85c4e309274d3781efa648d604 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Mon, 12 Feb 2024 11:03:35 +0100 Subject: [PATCH 25/33] mypy --- src/graphnet/data/readers/graphnet_file_reader.py | 4 +++- src/graphnet/data/writers/graphnet_writer.py | 2 +- src/graphnet/data/writers/parquet_writer.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/graphnet/data/readers/graphnet_file_reader.py b/src/graphnet/data/readers/graphnet_file_reader.py index 13a01faf9..c590c6424 100644 --- a/src/graphnet/data/readers/graphnet_file_reader.py +++ b/src/graphnet/data/readers/graphnet_file_reader.py @@ -94,7 +94,9 @@ def set_extractors( self._extractors = extractors @final - def _validate_extractors(self, extractors: List[Extractor]) -> None: + def _validate_extractors( + self, extractors: Union[List[Extractor], List[I3Extractor]] + ) -> None: for extractor in extractors: try: assert isinstance(extractor, tuple(self.accepted_extractors)) # type: ignore diff --git a/src/graphnet/data/writers/graphnet_writer.py b/src/graphnet/data/writers/graphnet_writer.py index 518823b5a..f6ec03029 100644 --- a/src/graphnet/data/writers/graphnet_writer.py +++ b/src/graphnet/data/writers/graphnet_writer.py @@ -28,7 +28,7 @@ class GraphNeTWriter(Logger, ABC): @abstractmethod def _save_file( self, - data: Dict[str, pd.DataFrame], + data: Union[Dict[str, pd.DataFrame], Dict[str, List[pd.DataFrame]]], output_file_path: str, n_events: int, ) -> None: diff --git a/src/graphnet/data/writers/parquet_writer.py b/src/graphnet/data/writers/parquet_writer.py index 755a829c1..18e524ca9 100644 --- a/src/graphnet/data/writers/parquet_writer.py +++ b/src/graphnet/data/writers/parquet_writer.py @@ -19,7 +19,7 @@ class ParquetWriter(GraphNeTWriter): # Abstract method implementation(s) def _save_file( self, - data: Dict[str, pd.DataFrame], + data: Dict[str, List[pd.DataFrame]], output_file_path: str, n_events: int, ) -> None: From bd627a3e794dd7d3e5e55208f570c79650107918 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Mon, 12 Feb 2024 11:15:47 +0100 Subject: [PATCH 26/33] mypy please --- src/graphnet/data/dataconverter.py | 4 ++-- src/graphnet/data/readers/i3reader.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/graphnet/data/dataconverter.py b/src/graphnet/data/dataconverter.py index 828efa199..70d5ae89e 100644 --- a/src/graphnet/data/dataconverter.py +++ b/src/graphnet/data/dataconverter.py @@ -131,11 +131,11 @@ def _process_file(self, file_path: Union[str, I3FileSet]) -> None: This function is called in parallel. """ # Read and apply extractors - data = self._file_reader(file_path=file_path) + data: List[OrderedDict] = self._file_reader(file_path=file_path) n_events = len(data) # Assign event_no's to each event in data and transform to pd.DataFrame - data = self._assign_event_no(data=data) # type: ignore + data: Union[Dict[str, pd.DataFrame], Dict[str, List[pd.DataFrame]]] = self._assign_event_no(data=data) # type: ignore # Create output file name output_file_name = self._create_file_name(input_file_path=file_path) diff --git a/src/graphnet/data/readers/i3reader.py b/src/graphnet/data/readers/i3reader.py index 523367943..ed5fd7c1f 100644 --- a/src/graphnet/data/readers/i3reader.py +++ b/src/graphnet/data/readers/i3reader.py @@ -77,7 +77,7 @@ def __call__(self, file_path: I3FileSet) -> List[OrderedDict]: # type: ignore assert isinstance(extractor, I3Extractor) extractor.set_gcd( i3_file=file_path.i3_file, gcd_file=file_path.gcd_file - ) # type: ignore + ) # Open I3 file i3_file_io = dataio.I3File(file_path.i3_file, "r") From e0b4ba4e3bf719bb599fe6545b30022ce84ff0c5 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Mon, 12 Feb 2024 11:21:02 +0100 Subject: [PATCH 27/33] mypy... --- src/graphnet/data/dataconverter.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/graphnet/data/dataconverter.py b/src/graphnet/data/dataconverter.py index 70d5ae89e..43929cedd 100644 --- a/src/graphnet/data/dataconverter.py +++ b/src/graphnet/data/dataconverter.py @@ -135,14 +135,17 @@ def _process_file(self, file_path: Union[str, I3FileSet]) -> None: n_events = len(data) # Assign event_no's to each event in data and transform to pd.DataFrame - data: Union[Dict[str, pd.DataFrame], Dict[str, List[pd.DataFrame]]] = self._assign_event_no(data=data) # type: ignore + dataframes = self._assign_event_no(data=data) + + # Delete `data` to save memory + del data # Create output file name output_file_name = self._create_file_name(input_file_path=file_path) # Apply save method self._save_method( - data=data, + data=dataframes, file_name=output_file_name, n_events=n_events, output_dir=self._output_dir, From 23f0a9c945b2b8b24b5e9ea7ccfe3cf0fc8eea41 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Mon, 12 Feb 2024 11:28:40 +0100 Subject: [PATCH 28/33] add comment in dataconverter --- src/graphnet/data/dataconverter.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/graphnet/data/dataconverter.py b/src/graphnet/data/dataconverter.py index 43929cedd..69d13be50 100644 --- a/src/graphnet/data/dataconverter.py +++ b/src/graphnet/data/dataconverter.py @@ -132,6 +132,8 @@ def _process_file(self, file_path: Union[str, I3FileSet]) -> None: """ # Read and apply extractors data: List[OrderedDict] = self._file_reader(file_path=file_path) + + # Count number of events n_events = len(data) # Assign event_no's to each event in data and transform to pd.DataFrame From f3720caaafb8c0f05794f1b6b401910285d43f14 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Mon, 12 Feb 2024 12:13:22 +0100 Subject: [PATCH 29/33] polish extractor.py --- src/graphnet/data/extractors/extractor.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/graphnet/data/extractors/extractor.py b/src/graphnet/data/extractors/extractor.py index 3e9a8f715..d03419870 100644 --- a/src/graphnet/data/extractors/extractor.py +++ b/src/graphnet/data/extractors/extractor.py @@ -1,13 +1,9 @@ """Base I3Extractor class(es).""" -from typing import TYPE_CHECKING +from typing import Any from abc import ABC, abstractmethod -from graphnet.utilities.imports import has_icecube_package from graphnet.utilities.logging import Logger -if has_icecube_package() or TYPE_CHECKING: - from icecube import icetray # pyright: reportMissingImports=false - class Extractor(ABC, Logger): """Base class for extracting information from data files. @@ -40,11 +36,11 @@ def __init__(self, extractor_name: str): super().__init__(name=__name__, class_name=self.__class__.__name__) @abstractmethod - def __call__(self, frame: "icetray.I3Frame") -> dict: - """Extract information from frame.""" + def __call__(self, data: Any) -> dict: + """Extract information from data.""" pass @property def name(self) -> str: - """Get the name of the `I3Extractor` instance.""" + """Get the name of the `Extractor` instance.""" return self._extractor_name From c1a8ae2fafef732c2ca0fe6289e1ee16a942b478 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Wed, 14 Feb 2024 10:28:02 +0100 Subject: [PATCH 30/33] update `extractor.py` docstring --- src/graphnet/data/extractors/extractor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/graphnet/data/extractors/extractor.py b/src/graphnet/data/extractors/extractor.py index d03419870..ce743f63d 100644 --- a/src/graphnet/data/extractors/extractor.py +++ b/src/graphnet/data/extractors/extractor.py @@ -11,9 +11,9 @@ class Extractor(ABC, Logger): All classes inheriting from `Extractor` should implement the `__call__` method, and should return a pure python dictionary on the form - output = [{'var1: .., + output = {'var1: .., ... , - 'var_n': ..}] + 'var_n': ..} Variables can be scalar or array-like of shape [n, 1], where n denotes the number of elements in the array, and 1 the number of columns. From 4abd266a4c3ab905242e967488a1338dd309ce3a Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Wed, 14 Feb 2024 10:54:17 +0100 Subject: [PATCH 31/33] refactor `set_gcd` to make it more readable --- .../data/extractors/icecube/i3extractor.py | 52 ++++++++++++++----- 1 file changed, 38 insertions(+), 14 deletions(-) diff --git a/src/graphnet/data/extractors/icecube/i3extractor.py b/src/graphnet/data/extractors/icecube/i3extractor.py index d997efcc4..e03f3e9f1 100644 --- a/src/graphnet/data/extractors/icecube/i3extractor.py +++ b/src/graphnet/data/extractors/icecube/i3extractor.py @@ -37,28 +37,52 @@ def __init__(self, extractor_name: str): # Base class constructor super().__init__(extractor_name=extractor_name) - def set_gcd(self, gcd_file: str, i3_file: str) -> None: - """Load the geospatial information contained in the GCD-file.""" - # If no GCD file is provided, search the I3 file for frames containing - # geometry (G) and calibration (C) information. - gcd = dataio.I3File(gcd_file or i3_file) + def set_gcd(self, i3_file: str, gcd_file: Optional[str] = None) -> None: + """Extract GFrame and CFrame from i3/gcd-file pair. + Information from these frames will be set as member variables of + `I3Extractor.` + + Args: + i3_file: Path to i3 file that is being converted. + gcd_file: Path to GCD file. Defaults to None. If no GCD file is + given, the method will attempt to find C and G frames in + the i3 file instead. If either one of those are not + present, `RuntimeErrors` will be raised. + """ + if gcd_file is None: + # If no GCD file is provided, search the I3 file for frames + # containing geometry (GFrame) and calibration (CFrame) + gcd = dataio.I3File(i3_file) + else: + # Ideally ends here + gcd = dataio.I3File(gcd_file) + + # Get GFrame try: g_frame = gcd.pop_frame(icetray.I3Frame.Geometry) - except RuntimeError: + # If the line above fails, it means that no gcd file was given + # and that the i3 file does not have a G-Frame in it. + except RuntimeError as e: self.error( - "No GCD file was provided and no G-frame was found. Exiting." + "No GCD file was provided " + f"and no G-frame was found in {i3_file.split('/')[-1]}." ) - raise - else: - self._gcd_dict = g_frame["I3Geometry"].omgeo + raise e + # Get CFrame try: c_frame = gcd.pop_frame(icetray.I3Frame.Calibration) - except RuntimeError: - self.warning("No GCD file was provided and no C-frame was found.") - else: - self._calibration = c_frame["I3Calibration"] + except RuntimeError as e: + self.warning( + "No GCD file was provided and no C-frame " + f"was found in {i3_file.split('/')[-1]}." + ) + raise e + + # Save information as member variables of I3Extractor + self._gcd_dict = g_frame["I3Geometry"].omgeo + self._calibration = c_frame["I3Calibration"] @abstractmethod def __call__(self, frame: "icetray.I3Frame") -> dict: From 3f407b07ee7f595c04b1ab91c1ce59314010e5e2 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Wed, 14 Feb 2024 10:55:23 +0100 Subject: [PATCH 32/33] add another comment in `set_gcd` --- src/graphnet/data/extractors/icecube/i3extractor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/graphnet/data/extractors/icecube/i3extractor.py b/src/graphnet/data/extractors/icecube/i3extractor.py index e03f3e9f1..3f2fc92d2 100644 --- a/src/graphnet/data/extractors/icecube/i3extractor.py +++ b/src/graphnet/data/extractors/icecube/i3extractor.py @@ -73,6 +73,8 @@ def set_gcd(self, i3_file: str, gcd_file: Optional[str] = None) -> None: # Get CFrame try: c_frame = gcd.pop_frame(icetray.I3Frame.Calibration) + # If the line above fails, it means that no gcd file was given + # and that the i3 file does not have a C-Frame in it. except RuntimeError as e: self.warning( "No GCD file was provided and no C-frame " From 852efcb4789c1e00f14fcc1456735aa78b960838 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Thu, 15 Feb 2024 09:54:09 +0100 Subject: [PATCH 33/33] Update docstring in `I3ToSQLiteConverter` --- src/graphnet/data/pre_configured/dataconverters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/graphnet/data/pre_configured/dataconverters.py b/src/graphnet/data/pre_configured/dataconverters.py index 63d8e61ab..6db89c46e 100644 --- a/src/graphnet/data/pre_configured/dataconverters.py +++ b/src/graphnet/data/pre_configured/dataconverters.py @@ -66,7 +66,7 @@ def __init__( num_workers: int = 1, i3_filters: Union[I3Filter, List[I3Filter]] = None, # type: ignore ): - """Convert I3 files to Parquet. + """Convert I3 files to SQLite. Args: gcd_rescue: gcd_rescue: Path to a GCD file that will be used if no GCD file is