diff --git a/examples/01_icetray/01_convert_i3_files.py b/examples/01_icetray/01_convert_i3_files.py index 88dcf714a..9f0795cb1 100644 --- a/examples/01_icetray/01_convert_i3_files.py +++ b/examples/01_icetray/01_convert_i3_files.py @@ -1,9 +1,10 @@ """Example of converting I3-files to SQLite and Parquet.""" import os +from glob import glob from graphnet.constants import EXAMPLE_OUTPUT_DIR, TEST_DATA_DIR -from graphnet.data.extractors import ( +from graphnet.data.extractors.icecube import ( I3FeatureExtractorIceCubeUpgrade, I3FeatureExtractorIceCube86, I3RetroExtractor, @@ -41,17 +42,22 @@ def main_icecube86(backend: str) -> None: inputs = [f"{TEST_DATA_DIR}/i3/oscNext_genie_level7_v02"] outdir = f"{EXAMPLE_OUTPUT_DIR}/convert_i3_files/ic86" + gcd_rescue = glob( + "{TEST_DATA_DIR}/i3/oscNext_genie_level7_v02/*GeoCalib*" + )[0] - converter: DataConverter = CONVERTER_CLASS[backend]( - [ + converter = CONVERTER_CLASS[backend]( + extractors=[ I3FeatureExtractorIceCube86("SRTInIcePulses"), I3TruthExtractor(), ], - outdir, + outdir=outdir, + gcd_rescue=gcd_rescue, + workers=1, ) converter(inputs) if backend == "sqlite": - converter.merge_files(os.path.join(outdir, "merged")) + converter.merge_files() def main_icecube_upgrade(backend: str) -> None: @@ -61,25 +67,25 @@ def main_icecube_upgrade(backend: str) -> None: inputs = [f"{TEST_DATA_DIR}/i3/upgrade_genie_step4_140028_000998"] outdir = f"{EXAMPLE_OUTPUT_DIR}/convert_i3_files/upgrade" + gcd_rescue = glob( + "{TEST_DATA_DIR}/i3/upgrade_genie_step4_140028_000998/*GeoCalib*" + )[0] workers = 1 converter: DataConverter = CONVERTER_CLASS[backend]( - [ + extractors=[ I3TruthExtractor(), I3RetroExtractor(), I3FeatureExtractorIceCubeUpgrade("I3RecoPulseSeriesMap_mDOM"), I3FeatureExtractorIceCubeUpgrade("I3RecoPulseSeriesMap_DEgg"), ], - outdir, + outdir=outdir, workers=workers, - # nb_files_to_batch=10, - # sequential_batch_pattern="temp_{:03d}", - # input_file_batch_pattern="[A-Z]{1}_[0-9]{5}*.i3.zst", - icetray_verbose=1, + gcd_rescue=gcd_rescue, ) converter(inputs) if backend == "sqlite": - converter.merge_files(os.path.join(outdir, "merged")) + converter.merge_files() if __name__ == "__main__": diff --git a/examples/01_icetray/02_compare_sqlite_and_parquet.py b/examples/01_icetray/02_compare_sqlite_and_parquet.py index 99250d4b0..d3874c5f2 100644 --- a/examples/01_icetray/02_compare_sqlite_and_parquet.py +++ b/examples/01_icetray/02_compare_sqlite_and_parquet.py @@ -7,7 +7,7 @@ from graphnet.data.sqlite import SQLiteDataConverter from graphnet.data.parquet import ParquetDataConverter from graphnet.data.dataset import SQLiteDataset, ParquetDataset -from graphnet.data.extractors import ( +from graphnet.data.extractors.icecube import ( I3FeatureExtractorIceCube86, I3TruthExtractor, I3RetroExtractor, diff --git a/examples/01_icetray/03_i3_deployer_example.py b/examples/01_icetray/03_i3_deployer_example.py index f55aa769c..28d73c00d 100644 --- a/examples/01_icetray/03_i3_deployer_example.py +++ b/examples/01_icetray/03_i3_deployer_example.py @@ -10,7 +10,7 @@ PRETRAINED_MODEL_DIR, ) from graphnet.data.constants import FEATURES, TRUTH -from graphnet.data.extractors.i3featureextractor import ( +from graphnet.data.extractors.icecube import ( I3FeatureExtractorIceCubeUpgrade, ) from graphnet.utilities.argparse import ArgumentParser diff --git a/examples/01_icetray/04_i3_module_in_native_icetray_example.py b/examples/01_icetray/04_i3_module_in_native_icetray_example.py index 74da5e499..ab8c1b58c 100644 --- a/examples/01_icetray/04_i3_module_in_native_icetray_example.py +++ b/examples/01_icetray/04_i3_module_in_native_icetray_example.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, List, Sequence from graphnet.data.constants import FEATURES -from graphnet.data.extractors.i3featureextractor import ( +from graphnet.data.extractors.icecube import ( I3FeatureExtractorIceCubeUpgrade, ) from graphnet.constants import ( diff --git a/src/graphnet/data/__init__.py b/src/graphnet/data/__init__.py index fbb1ee095..77cbc1af8 100644 --- a/src/graphnet/data/__init__.py +++ b/src/graphnet/data/__init__.py @@ -1,6 +1,9 @@ """Modules for converting and ingesting data. `graphnet.data` enables converting domain-specific data to industry-standard, -intermediate file formats and reading this data. +intermediate file formats and reading this data. """ -from .filters import I3Filter, I3FilterMask +from .extractors.icecube.utilities.i3_filters import I3Filter, I3FilterMask +from .dataconverter import DataConverter +from .pre_configured import I3ToParquetConverter +from .pre_configured import I3ToSQLiteConverter diff --git a/src/graphnet/data/dataclasses.py b/src/graphnet/data/dataclasses.py new file mode 100644 index 000000000..98b837693 --- /dev/null +++ b/src/graphnet/data/dataclasses.py @@ -0,0 +1,10 @@ +"""Module containing experiment-specific dataclasses.""" + + +from dataclasses import dataclass + + +@dataclass +class I3FileSet: # noqa: D101 + i3_file: str + gcd_file: str diff --git a/src/graphnet/data/dataconverter.py b/src/graphnet/data/dataconverter.py index 2a67ddce9..69d13be50 100644 --- a/src/graphnet/data/dataconverter.py +++ b/src/graphnet/data/dataconverter.py @@ -1,496 +1,259 @@ -"""Base `DataConverter` class(es) used in GraphNeT.""" -# type: ignore[name-defined] # Due to use of `init_global_index`. - -from abc import ABC, abstractmethod -from collections import OrderedDict -from dataclasses import dataclass -from functools import wraps -import itertools +"""Contains `DataConverter`.""" +from typing import List, Union, OrderedDict, Dict, Tuple, Any, Optional, Type +from abc import abstractmethod, ABC + +from tqdm import tqdm +import numpy as np from multiprocessing import Manager, Pool, Value import multiprocessing.pool from multiprocessing.sharedctypes import Synchronized +import pandas as pd import os -import re -from typing import ( - Any, - Callable, - Dict, - List, - Optional, - Tuple, - TypeVar, - Union, - cast, -) +from glob import glob -import numpy as np -import pandas as pd -from tqdm import tqdm -from graphnet.data.utilities.random import pairwise_shuffle -from graphnet.data.extractors import ( - I3Extractor, - I3ExtractorCollection, - I3FeatureExtractor, - I3TruthExtractor, - I3GenericExtractor, -) from graphnet.utilities.decorators import final -from graphnet.utilities.filesys import find_i3_files -from graphnet.utilities.imports import has_icecube_package from graphnet.utilities.logging import Logger -from graphnet.data.filters import I3Filter, NullSplitI3Filter - -if has_icecube_package(): - from icecube import icetray, dataio # pyright: reportMissingImports=false +from .readers.graphnet_file_reader import GraphNeTFileReader +from .writers.graphnet_writer import GraphNeTWriter +from .extractors import Extractor +from .extractors.icecube import I3Extractor +from .dataclasses import I3FileSet -SAVE_STRATEGIES = [ - "1:1", - "sequential_batched", - "pattern_batched", -] - - -# Utility classes -@dataclass -class FileSet: # noqa: D101 - i3_file: str - gcd_file: str - - -# Utility method(s) def init_global_index(index: Synchronized, output_files: List[str]) -> None: """Make `global_index` available to pool workers.""" global global_index, global_output_files # type: ignore[name-defined] global_index, global_output_files = (index, output_files) # type: ignore[name-defined] -F = TypeVar("F", bound=Callable[..., Any]) - - -def cache_output_files(process_method: F) -> F: - """Decorate `process_method` to cache output file names.""" - - @wraps(process_method) - def wrapper(self: Any, *args: Any) -> Any: - try: - # Using multiprocessing - output_files = global_output_files # type: ignore[name-defined] - except NameError: # `global_output_files` not set - # Running on main process - output_files = self._output_files - - output_file = process_method(self, *args) - output_files.append(output_file) - return output_file - - return cast(F, wrapper) - - class DataConverter(ABC, Logger): - """Base class for converting I3-files to intermediate file format.""" + """A finalized data conversion class in GraphNeT. - @property - @abstractmethod - def file_suffix(self) -> str: - """Suffix to use on output files.""" + `DataConverter` provides parallel processing of file conversion and + extraction from experiment-specific file formats to graphnet-supported data + formats. This class also assigns event id's to training examples. + """ def __init__( self, - extractors: List[I3Extractor], + file_reader: GraphNeTFileReader, + save_method: GraphNeTWriter, outdir: str, - gcd_rescue: Optional[str] = None, - *, - nb_files_to_batch: Optional[int] = None, - sequential_batch_pattern: Optional[str] = None, - input_file_batch_pattern: Optional[str] = None, - workers: int = 1, + extractors: Union[List[Extractor], List[I3Extractor]], index_column: str = "event_no", - icetray_verbose: int = 0, - i3_filters: List[I3Filter] = [], - ): - """Construct DataConverter. - - When using `input_file_batch_pattern`, regular expressions are used to - group files according to their names. All files that match a certain - pattern up to wildcards are grouped into the same output file. This - output file has the same name as the input files that are group into it, - with wildcards replaced with "x". Periods (.) and wildcards (*) have a - special meaning: Periods are interpreted as literal periods, and not as - matching any character (as in standard regex); and wildcards are - interpreted as ".*" in standard regex. - - For instance, the pattern "[A-Z]{1}_[0-9]{5}*.i3.zst" will find all I3 - files whose names contain: - - one capital letter, followed by - - an underscore, followed by - - five numbers, followed by - - any string of characters ending in ".i3.zst" - - This means that, e.g., the files: - - upgrade_genie_step4_141020_A_000000.i3.zst - - upgrade_genie_step4_141020_A_000001.i3.zst - - ... - - upgrade_genie_step4_141020_A_000008.i3.zst - - upgrade_genie_step4_141020_A_000009.i3.zst - would be grouped into the output file named - "upgrade_genie_step4_141020_A_00000x." but the file - - upgrade_genie_step4_141020_A_000010.i3.zst - would end up in a separate group, named - "upgrade_genie_step4_141020_A_00001x.". - """ - # Check(s) - if not isinstance(extractors, (list, tuple)): - extractors = [extractors] - - assert ( - len(extractors) > 0 - ), "Please specify at least one argument of type I3Extractor" - - for extractor in extractors: - assert isinstance( - extractor, I3Extractor - ), f"{type(extractor)} is not a subclass of I3Extractor" - - # Infer saving strategy - save_strategy = self._infer_save_strategy( - nb_files_to_batch, - sequential_batch_pattern, - input_file_batch_pattern, - ) + num_workers: int = 1, + ) -> None: + """Initialize `DataConverter`. - # Member variables - self._outdir = outdir - self._gcd_rescue = gcd_rescue - self._save_strategy = save_strategy - self._nb_files_to_batch = nb_files_to_batch - self._sequential_batch_pattern = sequential_batch_pattern - self._input_file_batch_pattern = input_file_batch_pattern - self._workers = workers - - # I3Filters (NullSplitI3Filter is always included) - self._i3filters = [NullSplitI3Filter()] + i3_filters - - for filter in self._i3filters: - assert isinstance( - filter, I3Filter - ), f"{type(filter)} is not a subclass of I3Filter" - - # Create I3Extractors - self._extractors = I3ExtractorCollection(*extractors) - - # Create shorthand of names of all pulsemaps queried - self._table_names = [extractor.name for extractor in self._extractors] - self._pulsemaps = [ - extractor.name - for extractor in self._extractors - if isinstance(extractor, I3FeatureExtractor) - ] - - # Placeholders for keeping track of sequential event indices and output files + Args: + file_reader: The method used for reading and applying `Extractors`. + save_method: The method used to save the interim data format to + a graphnet supported file format. + outdir: The directory to save the files in. + extractors: The `Extractor`(s) that will be applied to the input + files. + index_column: Name of the event id column added to the events. + Defaults to "event_no". + num_workers: The number of CPUs used for parallel processing. + Defaults to 1 (no multiprocessing). + """ + # Member Variable Assignment + self._file_reader = file_reader + self._save_method = save_method + self._num_workers = num_workers self._index_column = index_column self._index = 0 + self._output_dir = outdir self._output_files: List[str] = [] - # Set verbosity - if icetray_verbose == 0: - icetray.I3Logger.global_logger = icetray.I3NullLogger() + # Set Extractors. Will throw error if extractors are incompatible + # with reader. + if not isinstance(extractors, list): + extractors = [extractors] + self._file_reader.set_extractors(extractors=extractors) # Base class constructor super().__init__(name=__name__, class_name=self.__class__.__name__) @final - def __call__( - self, - directories: Union[str, List[str]], - recursive: Optional[bool] = True, - ) -> None: - """Convert I3-files in `directories. + def __call__(self, input_dir: Union[str, List[str]]) -> None: + """Extract data from files in `input_dir` and save to disk. Args: - directories: One or more directories, the I3 files within which - should be converted to an intermediate file format. - recursive: Whether or not to search the directories recursively. + input_dir: A directory that contains the input files. + The directory will be searched recursively for files + matching the file extension. """ - # Find all I3 and GCD files in the specified directories. - i3_files, gcd_files = find_i3_files( - directories, self._gcd_rescue, recursive + # Get the file reader to produce a list of input files + # in the directory + input_files = self._file_reader.find_files(path=input_dir) + self._launch_jobs(input_files=input_files) + self._output_files = glob( + os.path.join( + self._output_dir, f"*{self._save_method.file_extension}" + ) ) - if len(i3_files) == 0: - self.error(f"No files found in {directories}.") - return - - # Save a record of the found I3 files in the output directory. - self._save_filenames(i3_files) - - # Shuffle I3 files to get a more uniform load on worker nodes. - i3_files, gcd_files = pairwise_shuffle(i3_files, gcd_files) - - # Process the files - filesets = [ - FileSet(i3_file, gcd_file) - for i3_file, gcd_file in zip(i3_files, gcd_files) - ] - self.execute(filesets) @final - def execute(self, filesets: List[FileSet]) -> None: - """General method for processing a set of I3 files. - - The files are converted individually according to the inheriting class/ - intermediate file format. - - Args: - filesets: List of paths to I3 and corresponding GCD files. - """ - # Make sure output directory exists. - self.info(f"Saving results to {self._outdir}") - os.makedirs(self._outdir, exist_ok=True) - - # Iterate over batches of files. - try: - if self._save_strategy == "sequential_batched": - # Define batches - assert self._nb_files_to_batch is not None - assert self._sequential_batch_pattern is not None - batches = np.array_split( - np.asarray(filesets), - int(np.ceil(len(filesets) / self._nb_files_to_batch)), - ) - batches = [ - ( - group.tolist(), - self._sequential_batch_pattern.format(ix_batch), - ) - for ix_batch, group in enumerate(batches) - ] - self.info( - f"Will batch {len(filesets)} input files into {len(batches)} groups." - ) - - # Iterate over batches - pool = self._iterate_over_batches_of_files(batches) - - elif self._save_strategy == "pattern_batched": - # Define batches - groups: Dict[str, List[FileSet]] = OrderedDict() - for fileset in sorted(filesets, key=lambda f: f.i3_file): - group = re.sub( - self._sub_from, - self._sub_to, - os.path.basename(fileset.i3_file), - ) - if group not in groups: - groups[group] = list() - groups[group].append(fileset) - - self.info( - f"Will batch {len(filesets)} input files into {len(groups)} groups" - ) - if len(groups) <= 20: - for group, group_filesets in groups.items(): - self.info( - f"> {group}: {len(group_filesets):3d} file(s)" - ) - - batches = [ - (list(group_filesets), group) - for group, group_filesets in groups.items() - ] - - # Iterate over batches - pool = self._iterate_over_batches_of_files(batches) - - elif self._save_strategy == "1:1": - pool = self._iterate_over_individual_files(filesets) - - else: - assert False, "Shouldn't reach here." - - self._update_shared_variables(pool) - - except KeyboardInterrupt: - self.warning("[ctrl+c] Exciting gracefully.") - - @abstractmethod - def save_data(self, data: List[OrderedDict], output_file: str) -> None: - """Implementation-specific method for saving data to file. - - Args: - data: List of extracted features. - output_file: Name of output file. - """ - - @abstractmethod - def merge_files( - self, output_file: str, input_files: Optional[List[str]] = None + def _launch_jobs( + self, + input_files: Union[List[str], List[I3FileSet]], ) -> None: - """Implementation-specific method for merging output files. + """Multi Processing Logic. - Args: - output_file: Name of the output file containing the merged results. - input_files: Intermediate files to be merged, according to the - specific implementation. Default to None, meaning that all - files output by the current instance are merged. - - Raises: - NotImplementedError: If the method has not been implemented for the - backend in question. - """ + Spawns worker pool, + distributes the input files evenly across workers. + declare event_no as globally accessible variable across workers. + starts jobs. - # Internal methods - def _iterate_over_individual_files( - self, args: List[FileSet] - ) -> Optional[multiprocessing.pool.Pool]: + Will call process_file in parallel. + """ # Get appropriate mapping function - map_fn, pool = self.get_map_function(len(args)) + map_fn, pool = self.get_map_function(nb_files=len(input_files)) # Iterate over files for _ in map_fn( - self._process_file, tqdm(args, unit="file(s)", colour="green") + self._process_file, + tqdm(input_files, unit="file(s)", colour="green"), ): - self.debug( - "Saving with 1:1 strategy on the individual worker processes" - ) - - return pool - - def _iterate_over_batches_of_files( - self, args: List[Tuple[List[FileSet], str]] - ) -> Optional[multiprocessing.pool.Pool]: - """Iterate over a batch of files and save results on worker process.""" - # Get appropriate mapping function - map_fn, pool = self.get_map_function(len(args), unit="batch(es)") + self.debug("processing file.") - # Iterate over batches of files - for _ in map_fn( - self._process_batch, tqdm(args, unit="batch(es)", colour="green") - ): - self.debug("Saving with batched strategy") + self._update_shared_variables(pool) - return pool + @final + def _process_file(self, file_path: Union[str, I3FileSet]) -> None: + """Process a single file. - def _update_shared_variables( - self, pool: Optional[multiprocessing.pool.Pool] - ) -> None: - """Update `self._index` and `self._output_files`. + Calls file reader to recieve extracted output, event ids + is assigned to the extracted data and is handed to save method. - If `pool` is set, it means that multiprocessing was used. In this case, - the worker processes will not have been able to write directly to - `self._index` and `self._output_files`, and we need to get them synced - up. + This function is called in parallel. """ - if pool: - # Extract information from shared variables to member variables. - index, output_files = pool._initargs # type: ignore - self._index += index.value - self._output_files.extend(list(sorted(output_files[:]))) - - @cache_output_files - def _process_file( - self, - fileset: FileSet, - ) -> str: + # Read and apply extractors + data: List[OrderedDict] = self._file_reader(file_path=file_path) - # Process individual files - data = self._extract_data(fileset) + # Count number of events + n_events = len(data) - # Save data - output_file = self._get_output_file(fileset.i3_file) - self.save_data(data, output_file) + # Assign event_no's to each event in data and transform to pd.DataFrame + dataframes = self._assign_event_no(data=data) - return output_file + # Delete `data` to save memory + del data - @cache_output_files - def _process_batch(self, args: Tuple[List[FileSet], str]) -> str: - # Unpack arguments - filesets, output_file_name = args + # Create output file name + output_file_name = self._create_file_name(input_file_path=file_path) - # Process individual files - data = list( - itertools.chain.from_iterable(map(self._extract_data, filesets)) + # Apply save method + self._save_method( + data=dataframes, + file_name=output_file_name, + n_events=n_events, + output_dir=self._output_dir, ) - # Save batched data - output_file = self._get_output_file(output_file_name) - self.save_data(data, output_file) - - return output_file - - def _extract_data(self, fileset: FileSet) -> List[OrderedDict]: - """Extract data from single I3 file. - - If the saving strategy is 1:1 (i.e., each I3 file is converted to a - corresponding intermediate file) the data is saved to such a file, and - no data is return from the method. + @final + def _create_file_name(self, input_file_path: Union[str, I3FileSet]) -> str: + """Convert input file path to an output file name.""" + if isinstance(input_file_path, I3FileSet): + input_file_path = input_file_path.i3_file + file_name = os.path.basename(input_file_path) + index_of_dot = file_name.index(".") + file_name_without_extension = file_name[:index_of_dot] + return file_name_without_extension - The above distincting is to allow worker processes to save files rather - than sending it back to the main process. + @final + def _assign_event_no( + self, data: List[OrderedDict] + ) -> Union[Dict[str, pd.DataFrame], Dict[str, List[pd.DataFrame]]]: + + # Request event_no's for the entire file + event_nos = self._request_event_nos(n_ids=len(data)) + + # Dict holding pd.DataFrame's + dataframe_dict: Dict = {} + # Loop through events (again..) to assign event_nos + for k in range(len(data)): + for extractor_name in data[k].keys(): + n_rows = self._count_rows( + event_dict=data[k], extractor_name=extractor_name + ) + if n_rows > 0: + data[k][extractor_name][self._index_column] = np.repeat( + event_nos[k], n_rows + ).tolist() + df = pd.DataFrame( + data[k][extractor_name], + index=[0] if n_rows == 1 else None, + ) + if extractor_name in dataframe_dict.keys(): + dataframe_dict[extractor_name].append(df) + else: + dataframe_dict[extractor_name] = [df] + + # Merge each list of dataframes if wanted by writer + if self._save_method.expects_merged_dataframes: + for key in dataframe_dict.keys(): + dataframe_dict[key] = pd.concat( + dataframe_dict[key], axis=0 + ).reset_index(drop=True) + return dataframe_dict - Args: - fileset: Path to I3 file and corresponding GCD file. + @final + def _count_rows( + self, event_dict: OrderedDict[str, Any], extractor_name: str + ) -> int: + """Count number of rows that features from `extractor_name` have.""" + extractor_dict = event_dict[extractor_name] - Returns: - Extracted data. - """ - # Infer whether method is being run using multiprocessing try: - global_index # type: ignore[name-defined] - multi_processing = True - except NameError: - multi_processing = False - - self._extractors.set_files(fileset.i3_file, fileset.gcd_file) - i3_file_io = dataio.I3File(fileset.i3_file, "r") - data = list() - while i3_file_io.more(): - try: - frame = i3_file_io.pop_physics() - except Exception as e: - if "I3" in str(e): - continue - # check if frame should be skipped - if self._skip_frame(frame): - continue - - # Try to extract data from I3Frame - results = self._extractors(frame) - - data_dict = OrderedDict(zip(self._table_names, results)) - - # If an I3GenericExtractor is used, we want each automatically - # parsed key to be stored as a separate table. - for extractor in self._extractors: - if isinstance(extractor, I3GenericExtractor): - data_dict.update(data_dict.pop(extractor._name)) - - # Get new, unique index and increment value - if multi_processing: - with global_index.get_lock(): # type: ignore[name-defined] - index = global_index.value # type: ignore[name-defined] - global_index.value += 1 # type: ignore[name-defined] + # If all features in extractor_name have the same length + # this line of code will execute without error and result + # in an array with shape [num_features, n_rows_in_feature] + # unless the list is empty! + + shape = np.asarray(list(extractor_dict.values())).shape + if len(shape) > 1: + n_rows = shape[1] else: - index = self._index - self._index += 1 - - # Attach index to all tables - for table in data_dict.keys(): - data_dict[table][self._index_column] = index - - data.append(data_dict) + n_rows = 1 + except ValueError as e: + self.error( + f"Features from {extractor_name} ({extractor_dict.keys()}) have different lengths." + ) + raise e + return n_rows + + def _request_event_nos(self, n_ids: int) -> List[int]: + + # Get new, unique index and increment value + if self._num_workers > 1: + with global_index.get_lock(): # type: ignore[name-defined] + starting_index = global_index.value # type: ignore[name-defined] + event_nos = np.arange( + starting_index, starting_index + n_ids, 1 + ).tolist() + global_index.value += n_ids # type: ignore[name-defined] + else: + starting_index = self._index + event_nos = np.arange( + starting_index, starting_index + n_ids, 1 + ).tolist() + self._index += n_ids - return data + return event_nos + @final def get_map_function( - self, nb_files: int, unit: str = "I3 file(s)" + self, nb_files: int, unit: str = "file(s)" ) -> Tuple[Any, Optional[multiprocessing.pool.Pool]]: """Identify map function to use (pure python or multiprocess).""" # Choose relevant map-function given the requested number of workers. - workers = min(self._workers, nb_files) - if workers > 1: + n_workers = min(self._num_workers, nb_files) + if n_workers > 1: self.info( - f"Starting pool of {workers} workers to process {nb_files} {unit}" + f"Starting pool of {n_workers} workers to process {nb_files} {unit}" ) manager = Manager() @@ -498,7 +261,7 @@ def get_map_function( output_files = manager.list() pool = Pool( - processes=workers, + processes=n_workers, initializer=init_global_index, initargs=(index, output_files), ) @@ -513,75 +276,52 @@ def get_map_function( return map_fn, pool - def _infer_save_strategy( - self, - nb_files_to_batch: Optional[int] = None, - sequential_batch_pattern: Optional[str] = None, - input_file_batch_pattern: Optional[str] = None, - ) -> str: - if input_file_batch_pattern is not None: - save_strategy = "pattern_batched" - - assert ( - "*" in input_file_batch_pattern - ), "Argument `input_file_batch_pattern` should contain at least one wildcard (*)" - - fields = [ - "(" + field + ")" - for field in input_file_batch_pattern.replace( - ".", r"\." - ).split("*") - ] - nb_fields = len(fields) - self._sub_from = ".*".join(fields) - self._sub_to = "x".join([f"\\{ix + 1}" for ix in range(nb_fields)]) - - if sequential_batch_pattern is not None: - self.warning("Argument `sequential_batch_pattern` ignored.") - if nb_files_to_batch is not None: - self.warning("Argument `nb_files_to_batch` ignored.") - - elif (nb_files_to_batch is not None) or ( - sequential_batch_pattern is not None - ): - save_strategy = "sequential_batched" + @final + def _update_shared_variables( + self, pool: Optional[multiprocessing.pool.Pool] + ) -> None: + """Update `self._index` and `self._output_files`. - assert (nb_files_to_batch is not None) and ( - sequential_batch_pattern is not None - ), "Please specify both `nb_files_to_batch` and `sequential_batch_pattern` for sequential batching." + If `pool` is set, it means that multiprocessing was used. In this case, + the worker processes will not have been able to write directly to + `self._index` and `self._output_files`, and we need to get them synced + up. + """ + if pool: + # Extract information from shared variables to member variables. + index, output_files = pool._initargs # type: ignore + self._index += index.value + self._output_files.extend(list(sorted(output_files[:]))) - else: - save_strategy = "1:1" - - return save_strategy - - def _save_filenames(self, i3_files: List[str]) -> None: - """Save I3 file names in CSV format.""" - self.debug("Saving input file names to config CSV.") - config_dir = os.path.join(self._outdir, "config") - os.makedirs(config_dir, exist_ok=True) - df_i3_files = pd.DataFrame(data=i3_files, columns=["filename"]) - df_i3_files.to_csv(os.path.join(config_dir, "i3files.csv")) - - def _get_output_file(self, input_file: str) -> str: - assert isinstance(input_file, str) - basename = os.path.basename(input_file) - output_file = os.path.join( - self._outdir, - re.sub(r"\.i3\..*", "", basename) + "." + self.file_suffix, - ) - return output_file + @final + def merge_files(self, files: Optional[List[str]] = None) -> None: + """Merge converted files. - def _skip_frame(self, frame: "icetray.I3Frame") -> bool: - """Check the user defined filters. + `DataConverter` will call the `.merge_files` method in the + `GraphNeTWriter` module that it was instantiated with. - Returns: - bool: True if frame should be skipped, False otherwise. + Args: + files: Intermediate files to be merged. """ - if self._i3filters is None: - return False # No filters defined, so we keep the frame - - for filter in self._i3filters: - if not filter(frame): - return True # keep_frame call false, skip the frame. - return False # All filter keep_frame calls true, keep the frame. + if (files is None) & (len(self._output_files) > 0): + # If no input files are given, but output files from conversion + # is available. + files_to_merge = self._output_files + elif files is not None: + # Proceed to merge specified by user. + files_to_merge = files + else: + # Raise error + self.error( + "This DataConverter does not have output files set," + "and you must therefore specify argument `files`." + ) + assert files is not None + + # Merge files + merge_path = os.path.join(self._output_dir, "merged") + self.info(f"Merging files to {merge_path}") + self._save_method.merge_files( + files=files_to_merge, + output_dir=merge_path, + ) diff --git a/src/graphnet/data/extractors/__init__.py b/src/graphnet/data/extractors/__init__.py index e1d4895bf..c6f4f325e 100644 --- a/src/graphnet/data/extractors/__init__.py +++ b/src/graphnet/data/extractors/__init__.py @@ -1,20 +1,2 @@ -"""Collection of I3Extractors, extracting pure-python data from I3Frames.""" - -from .i3extractor import I3Extractor, I3ExtractorCollection -from .i3featureextractor import ( - I3FeatureExtractor, - I3FeatureExtractorIceCube86, - I3FeatureExtractorIceCubeDeepCore, - I3FeatureExtractorIceCubeUpgrade, - I3PulseNoiseTruthFlagIceCubeUpgrade, -) -from .i3truthextractor import I3TruthExtractor -from .i3retroextractor import I3RetroExtractor -from .i3splinempeextractor import I3SplineMPEICExtractor -from .i3particleextractor import I3ParticleExtractor -from .i3tumextractor import I3TUMExtractor -from .i3hybridrecoextractor import I3GalacticPlaneHybridRecoExtractor -from .i3genericextractor import I3GenericExtractor -from .i3pisaextractor import I3PISAExtractor -from .i3ntmuonlabelsextractor import I3NTMuonLabelExtractor -from .i3quesoextractor import I3QUESOExtractor +"""Module containing data-specific extractor modules.""" +from .extractor import Extractor diff --git a/src/graphnet/data/extractors/extractor.py b/src/graphnet/data/extractors/extractor.py new file mode 100644 index 000000000..ce743f63d --- /dev/null +++ b/src/graphnet/data/extractors/extractor.py @@ -0,0 +1,46 @@ +"""Base I3Extractor class(es).""" +from typing import Any +from abc import ABC, abstractmethod + +from graphnet.utilities.logging import Logger + + +class Extractor(ABC, Logger): + """Base class for extracting information from data files. + + All classes inheriting from `Extractor` should implement the `__call__` + method, and should return a pure python dictionary on the form + + output = {'var1: .., + ... , + 'var_n': ..} + + Variables can be scalar or array-like of shape [n, 1], where n denotes the + number of elements in the array, and 1 the number of columns. + + An extractor is used in conjunction with a specific `FileReader`. + """ + + def __init__(self, extractor_name: str): + """Construct Extractor. + + Args: + extractor_name: Name of the `Extractor` instance. Used to keep track of the + provenance of different data, and to name tables to which this + data is saved. E.g. "mc_truth". + """ + # Member variable(s) + self._extractor_name: str = extractor_name + + # Base class constructor + super().__init__(name=__name__, class_name=self.__class__.__name__) + + @abstractmethod + def __call__(self, data: Any) -> dict: + """Extract information from data.""" + pass + + @property + def name(self) -> str: + """Get the name of the `Extractor` instance.""" + return self._extractor_name diff --git a/src/graphnet/data/extractors/i3extractor.py b/src/graphnet/data/extractors/i3extractor.py deleted file mode 100644 index 90a982387..000000000 --- a/src/graphnet/data/extractors/i3extractor.py +++ /dev/null @@ -1,106 +0,0 @@ -"""Base I3Extractor class(es).""" - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Dict, List, Optional - -from graphnet.utilities.imports import has_icecube_package -from graphnet.utilities.logging import Logger - -if has_icecube_package() or TYPE_CHECKING: - from icecube import icetray, dataio # pyright: reportMissingImports=false - - -class I3Extractor(ABC, Logger): - """Base class for extracting information from physics I3-frames. - - All classes inheriting from `I3Extractor` should implement the `__call__` - method, and can be applied directly on `icetray.I3Frame` objects to return - extracted, pure-python data. - """ - - def __init__(self, name: str): - """Construct I3Extractor. - - Args: - name: Name of the `I3Extractor` instance. Used to keep track of the - provenance of different data, and to name tables to which this - data is saved. - """ - # Member variable(s) - self._i3_file: str = "" - self._gcd_file: str = "" - self._gcd_dict: Dict[int, Any] = {} - self._calibration: Optional["icetray.I3Frame.Calibration"] = None - self._name: str = name - - # Base class constructor - super().__init__(name=__name__, class_name=self.__class__.__name__) - - def set_files(self, i3_file: str, gcd_file: str) -> None: - """Store references to the I3- and GCD-files being processed.""" - # @TODO: Is it necessary to set the `i3_file`? It is only used in one - # place in `I3TruthExtractor`, and there only in a way that might - # be solved another way. - self._i3_file = i3_file - self._gcd_file = gcd_file - self._load_gcd_data() - - def _load_gcd_data(self) -> None: - """Load the geospatial information contained in the GCD-file.""" - # If no GCD file is provided, search the I3 file for frames containing - # geometry (G) and calibration (C) information. - gcd_file = dataio.I3File(self._gcd_file or self._i3_file) - - try: - g_frame = gcd_file.pop_frame(icetray.I3Frame.Geometry) - except RuntimeError: - self.error( - "No GCD file was provided and no G-frame was found. Exiting." - ) - raise - else: - self._gcd_dict = g_frame["I3Geometry"].omgeo - - try: - c_frame = gcd_file.pop_frame(icetray.I3Frame.Calibration) - except RuntimeError: - self.warning("No GCD file was provided and no C-frame was found.") - else: - self._calibration = c_frame["I3Calibration"] - - @abstractmethod - def __call__(self, frame: "icetray.I3Frame") -> dict: - """Extract information from frame.""" - pass - - @property - def name(self) -> str: - """Get the name of the `I3Extractor` instance.""" - return self._name - - -class I3ExtractorCollection(list): - """Class to manage multiple I3Extractors.""" - - def __init__(self, *extractors: I3Extractor): - """Construct I3ExtractorCollection. - - Args: - *extractors: List of `I3Extractor`s to be treated as a single - collection. - """ - # Check(s) - for extractor in extractors: - assert isinstance(extractor, I3Extractor) - - # Base class constructor - super().__init__(extractors) - - def set_files(self, i3_file: str, gcd_file: str) -> None: - """Store references to the I3- and GCD-files being processed.""" - for extractor in self: - extractor.set_files(i3_file, gcd_file) - - def __call__(self, frame: "icetray.I3Frame") -> List[dict]: - """Extract information from frame for each member `I3Extractor`.""" - return [extractor(frame) for extractor in self] diff --git a/src/graphnet/data/extractors/i3particleextractor.py b/src/graphnet/data/extractors/i3particleextractor.py deleted file mode 100644 index bd37424d2..000000000 --- a/src/graphnet/data/extractors/i3particleextractor.py +++ /dev/null @@ -1,43 +0,0 @@ -"""I3Extractor class(es) for extracting I3Particle properties.""" - -from typing import TYPE_CHECKING, Dict - -from graphnet.data.extractors.i3extractor import I3Extractor - -if TYPE_CHECKING: - from icecube import icetray # pyright: reportMissingImports=false - - -class I3ParticleExtractor(I3Extractor): - """Class for extracting I3Particle properties. - - Can be used to extract predictions from other algorithms for comparisons - with GraphNeT. - """ - - def __init__(self, name: str): - """Construct I3ParticleExtractor.""" - # Base class constructor - super().__init__(name) - - def __call__(self, frame: "icetray.I3Frame") -> Dict[str, float]: - """Extract I3Particle properties from I3Particle in frame.""" - output = {} - if self._name in frame: - output.update( - { - "zenith_" + self._name: frame[self._name].dir.zenith, - "azimuth_" + self._name: frame[self._name].dir.azimuth, - "dir_x_" + self._name: frame[self._name].dir.x, - "dir_y_" + self._name: frame[self._name].dir.y, - "dir_z_" + self._name: frame[self._name].dir.z, - "pos_x_" + self._name: frame[self._name].pos.x, - "pos_y_" + self._name: frame[self._name].pos.y, - "pos_z_" + self._name: frame[self._name].pos.z, - "time_" + self._name: frame[self._name].time, - "speed_" + self._name: frame[self._name].speed, - "energy_" + self._name: frame[self._name].energy, - } - ) - - return output diff --git a/src/graphnet/data/extractors/icecube/__init__.py b/src/graphnet/data/extractors/icecube/__init__.py new file mode 100644 index 000000000..11befe581 --- /dev/null +++ b/src/graphnet/data/extractors/icecube/__init__.py @@ -0,0 +1,20 @@ +"""Collection of I3Extractors, extracting pure-python data from I3Frames.""" + +from .i3extractor import I3Extractor +from .i3featureextractor import ( + I3FeatureExtractor, + I3FeatureExtractorIceCube86, + I3FeatureExtractorIceCubeDeepCore, + I3FeatureExtractorIceCubeUpgrade, + I3PulseNoiseTruthFlagIceCubeUpgrade, +) +from .i3truthextractor import I3TruthExtractor +from .i3retroextractor import I3RetroExtractor +from .i3splinempeextractor import I3SplineMPEICExtractor +from .i3particleextractor import I3ParticleExtractor +from .i3tumextractor import I3TUMExtractor +from .i3hybridrecoextractor import I3GalacticPlaneHybridRecoExtractor +from .i3genericextractor import I3GenericExtractor +from .i3pisaextractor import I3PISAExtractor +from .i3ntmuonlabelsextractor import I3NTMuonLabelExtractor +from .i3quesoextractor import I3QUESOExtractor diff --git a/src/graphnet/data/extractors/icecube/i3extractor.py b/src/graphnet/data/extractors/icecube/i3extractor.py new file mode 100644 index 000000000..3f2fc92d2 --- /dev/null +++ b/src/graphnet/data/extractors/icecube/i3extractor.py @@ -0,0 +1,92 @@ +"""Base I3Extractor class(es).""" + +from abc import abstractmethod +from typing import TYPE_CHECKING, Any, Dict, Optional + +from graphnet.utilities.imports import has_icecube_package +from graphnet.data.extractors import Extractor + +if has_icecube_package() or TYPE_CHECKING: + from icecube import icetray, dataio # pyright: reportMissingImports=false + + +class I3Extractor(Extractor): + """Base class for extracting information from physics I3-frames. + + Contains functionality required to extract data from i3 files, used by + the IceCube Neutrino Observatory. + + All classes inheriting from `I3Extractor` should implement the `__call__` + method. + """ + + def __init__(self, extractor_name: str): + """Construct I3Extractor. + + Args: + extractor_name: Name of the `I3Extractor` instance. Used to keep track of the + provenance of different data, and to name tables to which this + data is saved. + """ + # Member variable(s) + self._i3_file: str = "" + self._gcd_file: str = "" + self._gcd_dict: Dict[int, Any] = {} + self._calibration: Optional["icetray.I3Frame.Calibration"] = None + + # Base class constructor + super().__init__(extractor_name=extractor_name) + + def set_gcd(self, i3_file: str, gcd_file: Optional[str] = None) -> None: + """Extract GFrame and CFrame from i3/gcd-file pair. + + Information from these frames will be set as member variables of + `I3Extractor.` + + Args: + i3_file: Path to i3 file that is being converted. + gcd_file: Path to GCD file. Defaults to None. If no GCD file is + given, the method will attempt to find C and G frames in + the i3 file instead. If either one of those are not + present, `RuntimeErrors` will be raised. + """ + if gcd_file is None: + # If no GCD file is provided, search the I3 file for frames + # containing geometry (GFrame) and calibration (CFrame) + gcd = dataio.I3File(i3_file) + else: + # Ideally ends here + gcd = dataio.I3File(gcd_file) + + # Get GFrame + try: + g_frame = gcd.pop_frame(icetray.I3Frame.Geometry) + # If the line above fails, it means that no gcd file was given + # and that the i3 file does not have a G-Frame in it. + except RuntimeError as e: + self.error( + "No GCD file was provided " + f"and no G-frame was found in {i3_file.split('/')[-1]}." + ) + raise e + + # Get CFrame + try: + c_frame = gcd.pop_frame(icetray.I3Frame.Calibration) + # If the line above fails, it means that no gcd file was given + # and that the i3 file does not have a C-Frame in it. + except RuntimeError as e: + self.warning( + "No GCD file was provided and no C-frame " + f"was found in {i3_file.split('/')[-1]}." + ) + raise e + + # Save information as member variables of I3Extractor + self._gcd_dict = g_frame["I3Geometry"].omgeo + self._calibration = c_frame["I3Calibration"] + + @abstractmethod + def __call__(self, frame: "icetray.I3Frame") -> dict: + """Extract information from frame.""" + pass diff --git a/src/graphnet/data/extractors/i3featureextractor.py b/src/graphnet/data/extractors/icecube/i3featureextractor.py similarity index 97% rename from src/graphnet/data/extractors/i3featureextractor.py rename to src/graphnet/data/extractors/icecube/i3featureextractor.py index f1f578453..258bb368c 100644 --- a/src/graphnet/data/extractors/i3featureextractor.py +++ b/src/graphnet/data/extractors/icecube/i3featureextractor.py @@ -1,17 +1,14 @@ """I3Extractor class(es) for extracting specific, reconstructed features.""" from typing import TYPE_CHECKING, Any, Dict, List -from graphnet.data.extractors.i3extractor import I3Extractor -from graphnet.data.extractors.utilities.frames import ( +from .i3extractor import I3Extractor +from graphnet.data.extractors.icecube.utilities.frames import ( get_om_keys_and_pulseseries, ) from graphnet.utilities.imports import has_icecube_package if has_icecube_package() or TYPE_CHECKING: - from icecube import ( - icetray, - dataclasses, - ) # pyright: reportMissingImports=false + from icecube import icetray # pyright: reportMissingImports=false class I3FeatureExtractor(I3Extractor): diff --git a/src/graphnet/data/extractors/i3genericextractor.py b/src/graphnet/data/extractors/icecube/i3genericextractor.py similarity index 98% rename from src/graphnet/data/extractors/i3genericextractor.py rename to src/graphnet/data/extractors/icecube/i3genericextractor.py index 6a86303e7..e907181d0 100644 --- a/src/graphnet/data/extractors/i3genericextractor.py +++ b/src/graphnet/data/extractors/icecube/i3genericextractor.py @@ -2,12 +2,12 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union -from graphnet.data.extractors.i3extractor import I3Extractor -from graphnet.data.extractors.utilities.types import ( +from graphnet.data.extractors.icecube import I3Extractor +from graphnet.data.extractors.icecube.utilities.types import ( cast_object_to_pure_python, cast_pulse_series_to_pure_python, ) -from graphnet.data.extractors.utilities.collections import ( +from graphnet.data.extractors.icecube.utilities.collections import ( transpose_list_of_dicts, serialise, flatten_nested_dictionary, diff --git a/src/graphnet/data/extractors/i3hybridrecoextractor.py b/src/graphnet/data/extractors/icecube/i3hybridrecoextractor.py similarity index 96% rename from src/graphnet/data/extractors/i3hybridrecoextractor.py rename to src/graphnet/data/extractors/icecube/i3hybridrecoextractor.py index 74f445120..90525bcab 100644 --- a/src/graphnet/data/extractors/i3hybridrecoextractor.py +++ b/src/graphnet/data/extractors/icecube/i3hybridrecoextractor.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, Dict -from graphnet.data.extractors.i3extractor import I3Extractor +from graphnet.data.extractors.icecube import I3Extractor if TYPE_CHECKING: from icecube import icetray # pyright: reportMissingImports=false diff --git a/src/graphnet/data/extractors/i3ntmuonlabelsextractor.py b/src/graphnet/data/extractors/icecube/i3ntmuonlabelsextractor.py similarity index 96% rename from src/graphnet/data/extractors/i3ntmuonlabelsextractor.py rename to src/graphnet/data/extractors/icecube/i3ntmuonlabelsextractor.py index 1ca3e8bcb..039b13cfe 100644 --- a/src/graphnet/data/extractors/i3ntmuonlabelsextractor.py +++ b/src/graphnet/data/extractors/icecube/i3ntmuonlabelsextractor.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Dict -from graphnet.data.extractors.i3extractor import I3Extractor +from graphnet.data.extractors.icecube.i3extractor import I3Extractor if TYPE_CHECKING: from icecube import icetray # pyright: reportMissingImports=false diff --git a/src/graphnet/data/extractors/icecube/i3particleextractor.py b/src/graphnet/data/extractors/icecube/i3particleextractor.py new file mode 100644 index 000000000..a50c11d21 --- /dev/null +++ b/src/graphnet/data/extractors/icecube/i3particleextractor.py @@ -0,0 +1,44 @@ +"""I3Extractor class(es) for extracting I3Particle properties.""" + +from typing import TYPE_CHECKING, Dict + +from graphnet.data.extractors.icecube import I3Extractor + +if TYPE_CHECKING: + from icecube import icetray # pyright: reportMissingImports=false + + +class I3ParticleExtractor(I3Extractor): + """Class for extracting I3Particle properties. + + Can be used to extract predictions from other algorithms for comparisons + with GraphNeT. + """ + + def __init__(self, extractor_name: str): + """Construct I3ParticleExtractor.""" + # Base class constructor + super().__init__(extractor_name=extractor_name) + + def __call__(self, frame: "icetray.I3Frame") -> Dict[str, float]: + """Extract I3Particle properties from I3Particle in frame.""" + output = {} + name = self._extractor_name + if name in frame: + output.update( + { + "zenith_" + name: frame[name].dir.zenith, + "azimuth_" + name: frame[name].dir.azimuth, + "dir_x_" + name: frame[name].dir.x, + "dir_y_" + name: frame[name].dir.y, + "dir_z_" + name: frame[name].dir.z, + "pos_x_" + name: frame[name].pos.x, + "pos_y_" + name: frame[name].pos.y, + "pos_z_" + name: frame[name].pos.z, + "time_" + name: frame[name].time, + "speed_" + name: frame[name].speed, + "energy_" + name: frame[name].energy, + } + ) + + return output diff --git a/src/graphnet/data/extractors/i3pisaextractor.py b/src/graphnet/data/extractors/icecube/i3pisaextractor.py similarity index 94% rename from src/graphnet/data/extractors/i3pisaextractor.py rename to src/graphnet/data/extractors/icecube/i3pisaextractor.py index fd5a09583..f14a8046a 100644 --- a/src/graphnet/data/extractors/i3pisaextractor.py +++ b/src/graphnet/data/extractors/icecube/i3pisaextractor.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, Dict -from graphnet.data.extractors.i3extractor import I3Extractor +from graphnet.data.extractors.icecube.i3extractor import I3Extractor if TYPE_CHECKING: from icecube import icetray # pyright: reportMissingImports=false diff --git a/src/graphnet/data/extractors/i3quesoextractor.py b/src/graphnet/data/extractors/icecube/i3quesoextractor.py similarity index 94% rename from src/graphnet/data/extractors/i3quesoextractor.py rename to src/graphnet/data/extractors/icecube/i3quesoextractor.py index b72b20046..e29c72a41 100644 --- a/src/graphnet/data/extractors/i3quesoextractor.py +++ b/src/graphnet/data/extractors/icecube/i3quesoextractor.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Dict -from graphnet.data.extractors.i3extractor import I3Extractor +from graphnet.data.extractors.icecube.i3extractor import I3Extractor if TYPE_CHECKING: from icecube import icetray # pyright: reportMissingImports=false diff --git a/src/graphnet/data/extractors/i3retroextractor.py b/src/graphnet/data/extractors/icecube/i3retroextractor.py similarity index 97% rename from src/graphnet/data/extractors/i3retroextractor.py rename to src/graphnet/data/extractors/icecube/i3retroextractor.py index cd55d01f4..aaeb773b4 100644 --- a/src/graphnet/data/extractors/i3retroextractor.py +++ b/src/graphnet/data/extractors/icecube/i3retroextractor.py @@ -2,8 +2,8 @@ from typing import TYPE_CHECKING, Any, Dict -from graphnet.data.extractors.i3extractor import I3Extractor -from graphnet.data.extractors.utilities.frames import ( +from graphnet.data.extractors.icecube import I3Extractor +from graphnet.data.extractors.icecube.utilities.frames import ( frame_is_montecarlo, frame_is_noise, ) diff --git a/src/graphnet/data/extractors/i3splinempeextractor.py b/src/graphnet/data/extractors/icecube/i3splinempeextractor.py similarity index 93% rename from src/graphnet/data/extractors/i3splinempeextractor.py rename to src/graphnet/data/extractors/icecube/i3splinempeextractor.py index e47b2e71d..1439ada51 100644 --- a/src/graphnet/data/extractors/i3splinempeextractor.py +++ b/src/graphnet/data/extractors/icecube/i3splinempeextractor.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Dict -from graphnet.data.extractors.i3extractor import I3Extractor +from graphnet.data.extractors.icecube import I3Extractor if TYPE_CHECKING: from icecube import icetray # pyright: reportMissingImports=false diff --git a/src/graphnet/data/extractors/i3truthextractor.py b/src/graphnet/data/extractors/icecube/i3truthextractor.py similarity index 99% rename from src/graphnet/data/extractors/i3truthextractor.py rename to src/graphnet/data/extractors/icecube/i3truthextractor.py index bcfe694c3..b715e57ab 100644 --- a/src/graphnet/data/extractors/i3truthextractor.py +++ b/src/graphnet/data/extractors/icecube/i3truthextractor.py @@ -4,8 +4,8 @@ import matplotlib.path as mpath from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple -from graphnet.data.extractors.i3extractor import I3Extractor -from graphnet.data.extractors.utilities.frames import ( +from .i3extractor import I3Extractor +from .utilities.frames import ( frame_is_montecarlo, frame_is_noise, ) diff --git a/src/graphnet/data/extractors/i3tumextractor.py b/src/graphnet/data/extractors/icecube/i3tumextractor.py similarity index 94% rename from src/graphnet/data/extractors/i3tumextractor.py rename to src/graphnet/data/extractors/icecube/i3tumextractor.py index 38cbca146..685b0a78e 100644 --- a/src/graphnet/data/extractors/i3tumextractor.py +++ b/src/graphnet/data/extractors/icecube/i3tumextractor.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Dict -from graphnet.data.extractors.i3extractor import I3Extractor +from graphnet.data.extractors.icecube import I3Extractor if TYPE_CHECKING: from icecube import icetray # pyright: reportMissingImports=false diff --git a/src/graphnet/data/extractors/utilities/__init__.py b/src/graphnet/data/extractors/icecube/utilities/__init__.py similarity index 100% rename from src/graphnet/data/extractors/utilities/__init__.py rename to src/graphnet/data/extractors/icecube/utilities/__init__.py diff --git a/src/graphnet/data/extractors/utilities/collections.py b/src/graphnet/data/extractors/icecube/utilities/collections.py similarity index 100% rename from src/graphnet/data/extractors/utilities/collections.py rename to src/graphnet/data/extractors/icecube/utilities/collections.py diff --git a/src/graphnet/data/extractors/utilities/frames.py b/src/graphnet/data/extractors/icecube/utilities/frames.py similarity index 100% rename from src/graphnet/data/extractors/utilities/frames.py rename to src/graphnet/data/extractors/icecube/utilities/frames.py diff --git a/src/graphnet/data/filters.py b/src/graphnet/data/extractors/icecube/utilities/i3_filters.py similarity index 100% rename from src/graphnet/data/filters.py rename to src/graphnet/data/extractors/icecube/utilities/i3_filters.py diff --git a/src/graphnet/data/extractors/utilities/types.py b/src/graphnet/data/extractors/icecube/utilities/types.py similarity index 98% rename from src/graphnet/data/extractors/utilities/types.py rename to src/graphnet/data/extractors/icecube/utilities/types.py index cf58e8357..32ecae0ff 100644 --- a/src/graphnet/data/extractors/utilities/types.py +++ b/src/graphnet/data/extractors/icecube/utilities/types.py @@ -4,11 +4,11 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from graphnet.data.extractors.utilities.collections import ( +from graphnet.data.extractors.icecube.utilities.collections import ( transpose_list_of_dicts, flatten_nested_dictionary, ) -from graphnet.data.extractors.utilities.frames import ( +from graphnet.data.extractors.icecube.utilities.frames import ( get_om_keys_and_pulseseries, ) from graphnet.utilities.imports import has_icecube_package diff --git a/src/graphnet/data/parquet/__init__.py b/src/graphnet/data/parquet/__init__.py index 616d89c16..2c41ca75d 100644 --- a/src/graphnet/data/parquet/__init__.py +++ b/src/graphnet/data/parquet/__init__.py @@ -1,2 +1,2 @@ -"""Parquet-specific implementation of data classes.""" -from .parquet_dataconverter import ParquetDataConverter +"""Module for deprecated parquet methods.""" +from .deprecated_methods import ParquetDataConverter diff --git a/src/graphnet/data/parquet/deprecated_methods.py b/src/graphnet/data/parquet/deprecated_methods.py new file mode 100644 index 000000000..423e1aa00 --- /dev/null +++ b/src/graphnet/data/parquet/deprecated_methods.py @@ -0,0 +1,60 @@ +"""Module containing deprecated data conversion code. + +This code will be removed in GraphNeT 2.0. +""" +from typing import List, Union, Type + +from graphnet.data.extractors.icecube import I3Extractor +from graphnet.data.extractors.icecube.utilities.i3_filters import ( + I3Filter, +) +from graphnet.data import I3ToParquetConverter + + +class ParquetDataConverter(I3ToParquetConverter): + """Method for converting i3 files to parquet files.""" + + def __init__( + self, + gcd_rescue: str, + extractors: List[I3Extractor], + outdir: str, + index_column: str = "event_no", + workers: int = 1, + i3_filters: Union[I3Filter, List[I3Filter]] = None, # type: ignore + ): + """Convert I3 files to Parquet. + + Args: + gcd_rescue: gcd_rescue: Path to a GCD file that will be used if no GCD file is + found in subfolder. `I3Reader` will recursively search + the input directory for I3-GCD file pairs. By IceCube + convention, a folder containing i3 files will have an + accompanying GCD file. However, in some cases, this + convention is broken. In cases where a folder contains + i3 files but no GCD file, the `gcd_rescue` is used + instead. + extractors: The `Extractor`(s) that will be applied to the input + files. + outdir: The directory to save the files in. + icetray_verbose: Set the level of verbosity of icetray. + Defaults to 0. + index_column: Name of the event id column added to the events. + Defaults to "event_no". + workers: The number of CPUs used for parallel processing. + Defaults to 1 (no multiprocessing). + i3_filters: Instances of `I3Filter` to filter PFrames. Defaults to + `NullSplitI3Filter`. + """ + super().__init__( + extractors=extractors, + num_workers=workers, + index_column=index_column, + i3_filters=i3_filters, + outdir=outdir, + gcd_rescue=gcd_rescue, + ) + self.warning( + f"{self.__class__.__name__} will be deprecated in " + "GraphNeT 2.0. Please use I3ToParquetConverter instead." + ) diff --git a/src/graphnet/data/parquet/parquet_dataconverter.py b/src/graphnet/data/parquet/parquet_dataconverter.py deleted file mode 100644 index 68531c8e2..000000000 --- a/src/graphnet/data/parquet/parquet_dataconverter.py +++ /dev/null @@ -1,52 +0,0 @@ -"""DataConverter for the Parquet backend.""" - -from collections import OrderedDict -import os -from typing import List, Optional - -import awkward - -from graphnet.data.dataconverter import DataConverter # type: ignore[attr-defined] - - -class ParquetDataConverter(DataConverter): - """Class for converting I3-files to Parquet format.""" - - # Class variables - file_suffix: str = "parquet" - - # Abstract method implementation(s) - def save_data(self, data: List[OrderedDict], output_file: str) -> None: - """Save data to parquet file.""" - # Check(s) - if os.path.exists(output_file): - self.warning( - f"Output file {output_file} already exists. Overwriting." - ) - - self.debug(f"Saving to {output_file}") - self.debug( - f"- Data has {len(data)} events and {len(data[0])} tables for each" - ) - - awkward.to_parquet(awkward.from_iter(data), output_file) - - self.debug("- Done saving") - self._output_files.append(output_file) - - def merge_files( - self, output_file: str, input_files: Optional[List[str]] = None - ) -> None: - """Parquet-specific method for merging output files. - - Args: - output_file: Name of the output file containing the merged results. - input_files: Intermediate files to be merged, according to the - specific implementation. Default to None, meaning that all - files output by the current instance are merged. - - Raises: - NotImplementedError: If the method has not been implemented for the - Parquet backend. - """ - raise NotImplementedError() diff --git a/src/graphnet/data/pipeline.py b/src/graphnet/data/pipeline.py index d97415bb0..9973c763f 100644 --- a/src/graphnet/data/pipeline.py +++ b/src/graphnet/data/pipeline.py @@ -13,7 +13,9 @@ import torch from torch.utils.data import DataLoader -from graphnet.data.sqlite.sqlite_utilities import create_table_and_save_to_sql +from graphnet.data.utilities.sqlite_utilities import ( + create_table_and_save_to_sql, +) from graphnet.training.utils import get_predictions, make_dataloader from graphnet.models.graphs import GraphDefinition diff --git a/src/graphnet/data/pre_configured/__init__.py b/src/graphnet/data/pre_configured/__init__.py new file mode 100644 index 000000000..f56f0de18 --- /dev/null +++ b/src/graphnet/data/pre_configured/__init__.py @@ -0,0 +1,2 @@ +"""Module for pre-configured converter modules.""" +from .dataconverters import I3ToParquetConverter, I3ToSQLiteConverter diff --git a/src/graphnet/data/pre_configured/dataconverters.py b/src/graphnet/data/pre_configured/dataconverters.py new file mode 100644 index 000000000..6db89c46e --- /dev/null +++ b/src/graphnet/data/pre_configured/dataconverters.py @@ -0,0 +1,99 @@ +"""Pre-configured combinations of writers and readers.""" + +from typing import List, Union, Type + +from graphnet.data import DataConverter +from graphnet.data.readers import I3Reader +from graphnet.data.writers import ParquetWriter, SQLiteWriter +from graphnet.data.extractors.icecube import I3Extractor +from graphnet.data.extractors.icecube.utilities.i3_filters import I3Filter + + +class I3ToParquetConverter(DataConverter): + """Preconfigured DataConverter for converting i3 files to parquet files.""" + + def __init__( + self, + gcd_rescue: str, + extractors: List[I3Extractor], + outdir: str, + index_column: str = "event_no", + num_workers: int = 1, + i3_filters: Union[I3Filter, List[I3Filter]] = None, # type: ignore + ): + """Convert I3 files to Parquet. + + Args: + gcd_rescue: gcd_rescue: Path to a GCD file that will be used if no GCD file is + found in subfolder. `I3Reader` will recursively search + the input directory for I3-GCD file pairs. By IceCube + convention, a folder containing i3 files will have an + accompanying GCD file. However, in some cases, this + convention is broken. In cases where a folder contains + i3 files but no GCD file, the `gcd_rescue` is used + instead. + extractors: The `Extractor`(s) that will be applied to the input + files. + outdir: The directory to save the files in. + icetray_verbose: Set the level of verbosity of icetray. + Defaults to 0. + index_column: Name of the event id column added to the events. + Defaults to "event_no". + num_workers: The number of CPUs used for parallel processing. + Defaults to 1 (no multiprocessing). + i3_filters: Instances of `I3Filter` to filter PFrames. Defaults to + `NullSplitI3Filter`. + """ + super().__init__( + file_reader=I3Reader(gcd_rescue=gcd_rescue, i3_filters=i3_filters), + save_method=ParquetWriter(), + extractors=extractors, + num_workers=num_workers, + index_column=index_column, + outdir=outdir, + ) + + +class I3ToSQLiteConverter(DataConverter): + """Preconfigured DataConverter for converting i3 files to SQLite files.""" + + def __init__( + self, + gcd_rescue: str, + extractors: List[I3Extractor], + outdir: str, + index_column: str = "event_no", + num_workers: int = 1, + i3_filters: Union[I3Filter, List[I3Filter]] = None, # type: ignore + ): + """Convert I3 files to SQLite. + + Args: + gcd_rescue: gcd_rescue: Path to a GCD file that will be used if no GCD file is + found in subfolder. `I3Reader` will recursively search + the input directory for I3-GCD file pairs. By IceCube + convention, a folder containing i3 files will have an + accompanying GCD file. However, in some cases, this + convention is broken. In cases where a folder contains + i3 files but no GCD file, the `gcd_rescue` is used + instead. + extractors: The `Extractor`(s) that will be applied to the input + files. + outdir: The directory to save the files in. + icetray_verbose: Set the level of verbosity of icetray. + Defaults to 0. + index_column: Name of the event id column added to the events. + Defaults to "event_no". + num_workers: The number of CPUs used for parallel processing. + Defaults to 1 (no multiprocessing). + i3_filters: Instances of `I3Filter` to filter PFrames. Defaults to + `NullSplitI3Filter`. + """ + super().__init__( + file_reader=I3Reader(gcd_rescue=gcd_rescue, i3_filters=i3_filters), + save_method=SQLiteWriter(), + extractors=extractors, + num_workers=num_workers, + index_column=index_column, + outdir=outdir, + ) diff --git a/src/graphnet/data/readers/__init__.py b/src/graphnet/data/readers/__init__.py new file mode 100644 index 000000000..0755bd35a --- /dev/null +++ b/src/graphnet/data/readers/__init__.py @@ -0,0 +1,3 @@ +"""Modules for reading experiment-specific data and applying Extractors.""" +from .graphnet_file_reader import GraphNeTFileReader +from .i3reader import I3Reader diff --git a/src/graphnet/data/readers/graphnet_file_reader.py b/src/graphnet/data/readers/graphnet_file_reader.py new file mode 100644 index 000000000..c590c6424 --- /dev/null +++ b/src/graphnet/data/readers/graphnet_file_reader.py @@ -0,0 +1,142 @@ +"""Module containing different FileReader classes in GraphNeT. + +These methods are used to open and apply `Extractors` to experiment-specific +file formats. +""" + +from typing import List, Union, OrderedDict, Any +from abc import abstractmethod, ABC +import glob +import os + +from graphnet.utilities.decorators import final +from graphnet.utilities.logging import Logger +from graphnet.data.dataclasses import I3FileSet +from graphnet.data.extractors.extractor import Extractor +from graphnet.data.extractors.icecube import I3Extractor + + +class GraphNeTFileReader(Logger, ABC): + """A generic base class for FileReaders in GraphNeT. + + Classes inheriting from `GraphNeTFileReader` must implement a + `__call__` method that opens a file, applies `Extractor`(s) and returns + a list of ordered dictionaries. + + In addition, Classes inheriting from `GraphNeTFileReader` must set + class properties `accepted_file_extensions` and `accepted_extractors`. + """ + + _accepted_file_extensions: List[str] = [] + _accepted_extractors: List[Any] = [] + + @abstractmethod + def __call__(self, file_path: Union[str, I3FileSet]) -> List[OrderedDict]: + """Open and apply extractors to a single file. + + The `output` must be a list of dictionaries, where the number of events + in the file `n_events` satisfies `len(output) = n_events`. I.e each + element in the list is a dictionary, and each field in the dictionary + is the output of a single extractor. + """ + + @property + def accepted_file_extensions(self) -> List[str]: + """Return list of accepted file extensions.""" + return self._accepted_file_extensions + + @property + def accepted_extractors(self) -> List[Extractor]: + """Return list of compatible `Extractor`(s).""" + return self._accepted_extractors + + @property + def extracor_names(self) -> List[str]: + """Return list of table names produced by extractors.""" + return [extractor.name for extractor in self._extractors] + + def find_files( + self, path: Union[str, List[str]] + ) -> Union[List[str], List[I3FileSet]]: + """Search directory for input files recursively. + + This method may be overwritten by custom implementations. + + Args: + path: path to directory. + + Returns: + List of files matching accepted file extensions. + """ + if isinstance(path, str): + path = [path] + files = [] + for dir in path: + for accepted_file_extension in self.accepted_file_extensions: + files.extend(glob.glob(dir + f"/*{accepted_file_extension}")) + + # Check that files are OK. + self.validate_files(files) + return files + + @final + def set_extractors( + self, extractors: Union[List[Extractor], List[I3Extractor]] + ) -> None: + """Set `Extractor`(s) as member variable. + + Args: + extractors: A list of `Extractor`(s) to set as member variable. + """ + if not isinstance(extractors, list): + extractors = [extractors] + self._validate_extractors(extractors) + self._extractors = extractors + + @final + def _validate_extractors( + self, extractors: Union[List[Extractor], List[I3Extractor]] + ) -> None: + for extractor in extractors: + try: + assert isinstance(extractor, tuple(self.accepted_extractors)) # type: ignore + except AssertionError as e: + self.error( + f"{extractor.__class__.__name__}" + f" is not supported by {self.__class__.__name__}" + ) + raise e + + @final + def validate_files( + self, input_files: Union[List[str], List[I3FileSet]] + ) -> None: + """Check that the input files are accepted by the reader. + + Args: + input_files: Path(s) to input file(s). + """ + for input_file in input_files: + # Handle filepath vs. FileSet cases + if isinstance(input_file, I3FileSet): + self._validate_file(input_file.i3_file) + self._validate_file(input_file.gcd_file) + else: + self._validate_file(input_file) + + @final + def _validate_file(self, file: str) -> None: + """Validate a single file path. + + Args: + file: path to file. + + Returns: + None + """ + try: + assert file.lower().endswith(tuple(self.accepted_file_extensions)) + except AssertionError: + self.error( + f'{self.__class__.__name__} accepts {self.accepted_file_extensions} but {file.split("/")[-1]} has extension {os.path.splitext(file)[1]}.' + ) diff --git a/src/graphnet/data/readers/i3reader.py b/src/graphnet/data/readers/i3reader.py new file mode 100644 index 000000000..ed5fd7c1f --- /dev/null +++ b/src/graphnet/data/readers/i3reader.py @@ -0,0 +1,137 @@ +"""Module containing different I3Reader.""" + +from typing import List, Union, OrderedDict, Type + +from graphnet.utilities.imports import has_icecube_package +from graphnet.data.extractors.icecube.utilities.i3_filters import ( + I3Filter, + NullSplitI3Filter, +) +from graphnet.data.extractors.icecube import I3Extractor +from graphnet.data.dataclasses import I3FileSet +from graphnet.utilities.filesys import find_i3_files +from .graphnet_file_reader import GraphNeTFileReader + + +if has_icecube_package(): + from icecube import icetray, dataio # pyright: reportMissingImports=false + + +class I3Reader(GraphNeTFileReader): + """A class for reading .i3 files from the IceCube Neutrino Observatory. + + Note that this class relies on IceCube-specific software, and therefore + must be run in a software environment that contains IceTray. + """ + + def __init__( + self, + gcd_rescue: str, + i3_filters: Union[I3Filter, List[I3Filter]] = None, + icetray_verbose: int = 0, + ): + """Initialize `I3Reader`. + + Args: + gcd_rescue: Path to a GCD file that will be used if no GCD file is + found in subfolder. `I3Reader` will recursively search + the input directory for I3-GCD file pairs. By IceCube + convention, a folder containing i3 files will have an + accompanying GCD file. However, in some cases, this + convention is broken. In cases where a folder contains + i3 files but no GCD file, the `gcd_rescue` is used + instead. + i3_filters: Instances of `I3Filter` to filter PFrames. Defaults to + `NullSplitI3Filter`. + icetray_verbose: Set the level of verbosity of icetray. + Defaults to 0. + """ + # Set verbosity + if icetray_verbose == 0: + icetray.I3Logger.global_logger = icetray.I3NullLogger() + + if i3_filters is None: + i3_filters = [NullSplitI3Filter()] + # Set Member Variables + self._accepted_file_extensions = [".bz2", ".zst", ".gz"] + self._accepted_extractors = [I3Extractor] + self._gcd_rescue = gcd_rescue + self._i3filters = ( + i3_filters if isinstance(i3_filters, list) else [i3_filters] + ) + + # Base class constructor + super().__init__(name=__name__, class_name=self.__class__.__name__) + + def __call__(self, file_path: I3FileSet) -> List[OrderedDict]: # type: ignore + """Extract data from single I3 file. + + Args: + fileset: Path to I3 file and corresponding GCD file. + + Returns: + Extracted data. + """ + # Set I3-GCD file pair in extractor + for extractor in self._extractors: + assert isinstance(extractor, I3Extractor) + extractor.set_gcd( + i3_file=file_path.i3_file, gcd_file=file_path.gcd_file + ) + + # Open I3 file + i3_file_io = dataio.I3File(file_path.i3_file, "r") + data = list() + while i3_file_io.more(): + try: + frame = i3_file_io.pop_physics() + except Exception as e: + if "I3" in str(e): + continue + # check if frame should be skipped + if self._skip_frame(frame): + continue + + # Try to extract data from I3Frame + results = [extractor(frame) for extractor in self._extractors] + + data_dict = OrderedDict(zip(self.extracor_names, results)) + + data.append(data_dict) + return data + + def find_files(self, path: Union[str, List[str]]) -> List[I3FileSet]: + """Recursively search directory for I3 and GCD file pairs. + + Args: + path: directory to search recursively. + + Returns: + List I3 and GCD file pairs as I3FileSets + """ + # Find all I3 and GCD files in the specified directories. + i3_files, gcd_files = find_i3_files( + path, + self._gcd_rescue, + ) + + # Pack as I3FileSets + filesets = [ + I3FileSet(i3_file, gcd_file) + for i3_file, gcd_file in zip(i3_files, gcd_files) + ] + return filesets + + def _skip_frame(self, frame: "icetray.I3Frame") -> bool: + """Check the user defined filters. + + Returns: + bool: True if frame should be skipped, False otherwise. + """ + if self._i3filters is None: + return False # No filters defined, so we keep the frame + + for filter in self._i3filters: + if not filter(frame): + return True # keep_frame call false, skip the frame. + return False # All filter keep_frame calls true, keep the frame. diff --git a/src/graphnet/data/sqlite/__init__.py b/src/graphnet/data/sqlite/__init__.py index e4ac554a7..436a86f2d 100644 --- a/src/graphnet/data/sqlite/__init__.py +++ b/src/graphnet/data/sqlite/__init__.py @@ -1,4 +1,2 @@ -"""SQLite-specific implementation of data classes.""" -from .sqlite_dataconverter import SQLiteDataConverter -from .sqlite_utilities import create_table_and_save_to_sql -from .sqlite_utilities import run_sql_code, save_to_sql +"""Module for deprecated methods using sqlite.""" +from .deprecated_methods import SQLiteDataConverter diff --git a/src/graphnet/data/sqlite/deprecated_methods.py b/src/graphnet/data/sqlite/deprecated_methods.py new file mode 100644 index 000000000..30b563c59 --- /dev/null +++ b/src/graphnet/data/sqlite/deprecated_methods.py @@ -0,0 +1,62 @@ +"""Module containing deprecated data conversion code. + +This code will be removed in GraphNeT 2.0. +""" + +from typing import List, Union, Type + +from graphnet.data.extractors.icecube import I3Extractor +from graphnet.data.extractors.icecube.utilities.i3_filters import ( + I3Filter, + NullSplitI3Filter, +) +from graphnet.data import I3ToSQLiteConverter + + +class SQLiteDataConverter(I3ToSQLiteConverter): + """Method for converting i3 files to SQLite files.""" + + def __init__( + self, + gcd_rescue: str, + extractors: List[I3Extractor], + outdir: str, + index_column: str = "event_no", + workers: int = 1, + i3_filters: Union[I3Filter, List[I3Filter]] = None, # type: ignore + ): + """Convert I3 files to Parquet. + + Args: + gcd_rescue: gcd_rescue: Path to a GCD file that will be used if no GCD file is + found in subfolder. `I3Reader` will recursively search + the input directory for I3-GCD file pairs. By IceCube + convention, a folder containing i3 files will have an + accompanying GCD file. However, in some cases, this + convention is broken. In cases where a folder contains + i3 files but no GCD file, the `gcd_rescue` is used + instead. + extractors: The `Extractor`(s) that will be applied to the input + files. + outdir: The directory to save the files in. + icetray_verbose: Set the level of verbosity of icetray. + Defaults to 0. + index_column: Name of the event id column added to the events. + Defaults to "event_no". + workers: The number of CPUs used for parallel processing. + Defaults to 1 (no multiprocessing). + i3_filters: Instances of `I3Filter` to filter PFrames. Defaults to + `NullSplitI3Filter`. + """ + super().__init__( + extractors=extractors, + num_workers=workers, + index_column=index_column, + i3_filters=i3_filters, + outdir=outdir, + gcd_rescue=gcd_rescue, + ) + self.warning( + f"{self.__class__.__name__} will be deprecated in " + "GraphNeT 2.0. Please use I3ToSQLiteConverter instead." + ) diff --git a/src/graphnet/data/sqlite/sqlite_dataconverter.py b/src/graphnet/data/sqlite/sqlite_dataconverter.py deleted file mode 100644 index 1750b7a33..000000000 --- a/src/graphnet/data/sqlite/sqlite_dataconverter.py +++ /dev/null @@ -1,349 +0,0 @@ -"""DataConverter for the SQLite backend.""" - -from collections import OrderedDict -import os -from typing import Any, Dict, List, Optional, Tuple, Union - -import pandas as pd -import sqlalchemy -import sqlite3 -from tqdm import tqdm - -from graphnet.data.dataconverter import DataConverter # type: ignore[attr-defined] -from graphnet.data.sqlite.sqlite_utilities import ( - create_table, - create_table_and_save_to_sql, -) - - -class SQLiteDataConverter(DataConverter): - """Class for converting I3-file(s) to SQLite format.""" - - # Class variables - file_suffix = "db" - - # Abstract method implementation(s) - def save_data(self, data: List[OrderedDict], output_file: str) -> None: - """Save data to SQLite database.""" - # Check(s) - if os.path.exists(output_file): - self.warning( - f"Output file {output_file} already exists. Appending." - ) - - # Concatenate data - if len(data) == 0: - self.warning( - "No data was extracted from the processed I3 file(s). " - f"No data saved to {output_file}" - ) - return - - saved_any = False - dataframe_list: OrderedDict = OrderedDict( - [(key, []) for key in data[0]] - ) - for data_dict in data: - for key, data_values in data_dict.items(): - df = construct_dataframe(data_values) - - if self.any_pulsemap_is_non_empty(data_dict) and len(df) > 0: - # only include data_dict in temp. databases if at least one pulsemap is non-empty, - # and the current extractor (df) is also non-empty (also since truth is always non-empty) - dataframe_list[key].append(df) - - dataframe = OrderedDict( - [ - ( - key, - pd.concat(dfs, ignore_index=True, sort=True) - if dfs - else pd.DataFrame(), - ) - for key, dfs in dataframe_list.items() - ] - ) - # Can delete dataframe_list here to free up memory. - - # Save each dataframe to SQLite database - self.debug(f"Saving to {output_file}") - for table, df in dataframe.items(): - if len(df) > 0: - create_table_and_save_to_sql( - df, - table, - output_file, - default_type="FLOAT", - integer_primary_key=not ( - is_pulse_map(table) or is_mc_tree(table) - ), - ) - saved_any = True - - if saved_any: - self.debug("- Done saving") - else: - self.warning(f"No data saved to {output_file}") - - def merge_files( - self, - output_file: str, - input_files: Optional[List[str]] = None, - max_table_size: Optional[int] = None, - ) -> None: - """SQLite-specific method for merging output files/databases. - - Args: - output_file: Name of the output file containing the merged results. - input_files: Intermediate files/databases to be merged, according - to the specific implementation. Default to None, meaning that - all files/databases output by the current instance are merged. - max_table_size: The maximum number of rows in any given table. - If any one table exceed this limit, a new database will be - created. - """ - if max_table_size: - self.warning( - f"Merging got max_table_size of {max_table_size}. Will attempt to create databases with a maximum row count of this size." - ) - self.max_table_size = max_table_size - self._partition_count = 1 - - if input_files is None: - self.info("Merging files output by current instance.") - self._input_files = self._output_files - else: - self._input_files = input_files - - if not output_file.endswith("." + self.file_suffix): - output_file = ".".join([output_file, self.file_suffix]) - - if os.path.exists(output_file): - self.warning( - f"Target path for merged database, {output_file}, already exists." - ) - - if len(self._input_files) > 0: - self.info(f"Merging {len(self._input_files)} database files") - # Create one empty database table for each extraction - self._merged_table_names = self._extract_table_names( - self._input_files - ) - if self.max_table_size: - output_file = self._adjust_output_file_name(output_file) - self._create_empty_tables(output_file) - self._row_counts = self._initialize_row_counts() - # Merge temporary databases into newly created one - self._merge_temporary_databases(output_file, self._input_files) - else: - self.warning("No temporary database files found!") - - # Internal methods - def _adjust_output_file_name(self, output_file: str) -> str: - if "_part_" in output_file: - root = ( - output_file.split("_part_")[0] - + output_file.split("_part_")[1][1:] - ) - else: - root = output_file - str_list = root.split(".db") - return str_list[0] + f"_part_{self._partition_count}" + ".db" - - def _update_row_counts( - self, results: "OrderedDict[str, pd.DataFrame]" - ) -> None: - for table_name, data in results.items(): - self._row_counts[table_name] += len(data) - return - - def _initialize_row_counts(self) -> Dict[str, int]: - """Build dictionary with row counts. Initialized with 0. - - Returns: - Dictionary where every field is a table name that contains - corresponding row counts. - """ - row_counts = {} - for table_name in self._merged_table_names: - row_counts[table_name] = 0 - return row_counts - - def _create_empty_tables(self, output_file: str) -> None: - """Create tables for output database. - - Args: - output_file: Path to database. - """ - for table_name in self._merged_table_names: - column_names = self._extract_column_names( - self._input_files, table_name - ) - if len(column_names) > 1: - create_table( - column_names, - table_name, - output_file, - default_type="FLOAT", - integer_primary_key=not ( - is_pulse_map(table_name) or is_mc_tree(table_name) - ), - ) - - def _get_tables_in_database(self, db: str) -> Tuple[str, ...]: - with sqlite3.connect(db) as conn: - table_names = tuple( - [ - p[0] - for p in ( - conn.execute( - "SELECT name FROM sqlite_master WHERE type='table';" - ).fetchall() - ) - ] - ) - return table_names - - def _extract_table_names( - self, db: Union[str, List[str]] - ) -> Tuple[str, ...]: - """Get the names of all tables in database `db`.""" - if isinstance(db, str): - db = [db] - results = [self._get_tables_in_database(path) for path in db] - # @TODO: Check... - if all([results[0] == r for r in results]): - return results[0] - else: - unique_tables = [] - for tables in results: - for table in tables: - if table not in unique_tables: - unique_tables.append(table) - return tuple(unique_tables) - - def _extract_column_names( - self, db_paths: List[str], table_name: str - ) -> List[str]: - for db_path in db_paths: - tables_in_database = self._get_tables_in_database(db_path) - if table_name in tables_in_database: - with sqlite3.connect(db_path) as con: - query = f"select * from {table_name} limit 1" - columns = pd.read_sql(query, con).columns - if len(columns): - return columns - return [] - - def any_pulsemap_is_non_empty(self, data_dict: Dict[str, Dict]) -> bool: - """Check whether there are non-empty pulsemaps extracted from P frame. - - Takes in the data extracted from the P frame, then retrieves the - values, if there are any, from the pulsemap key(s) (e.g - SplitInIcePulses). If at least one of the pulsemaps is non-empty then - return true. If no pulsemaps exist, i.e., if no `I3FeatureExtractor` is - called e.g. because `I3GenericExtractor` is used instead, always return - True. - """ - if len(self._pulsemaps) == 0: - return True - - pulsemap_dicts = [data_dict[pulsemap] for pulsemap in self._pulsemaps] - return any(d["dom_x"] for d in pulsemap_dicts) - - def _submit_to_database( - self, database: str, key: str, data: pd.DataFrame - ) -> None: - """Submit data to the database with specified key.""" - if len(data) == 0: - self.info(f"No data provided for {key}.") - return - engine = sqlalchemy.create_engine("sqlite:///" + database) - data.to_sql(key, engine, index=False, if_exists="append") - engine.dispose() - - def _extract_everything(self, db: str) -> "OrderedDict[str, pd.DataFrame]": - """Extract everything from the temporary database `db`. - - Args: - db: Path to temporary database. - - Returns: - Dictionary containing the data for each extracted table. - """ - results = OrderedDict() - table_names = self._extract_table_names(db) - with sqlite3.connect(db) as conn: - for table_name in table_names: - query = f"select * from {table_name}" - try: - data = pd.read_sql(query, conn) - except: # noqa: E722 - data = [] - results[table_name] = data - return results - - def _merge_temporary_databases( - self, - output_file: str, - input_files: List[str], - ) -> None: - """Merge the temporary databases. - - Args: - output_file: path to the final database - input_files: list of names of temporary databases - """ - file_count = 0 - for input_file in tqdm(input_files, colour="green"): - results = self._extract_everything(input_file) - for table_name, data in results.items(): - self._submit_to_database(output_file, table_name, data) - file_count += 1 - if (self.max_table_size is not None) & ( - file_count < len(input_files) - ): - self._update_row_counts(results) - maximum_row_count_reached = False - for table in self._row_counts.keys(): - assert self.max_table_size is not None - if self._row_counts[table] >= self.max_table_size: - maximum_row_count_reached = True - if maximum_row_count_reached: - self._partition_count += 1 - output_file = self._adjust_output_file_name(output_file) - self.info( - f"Maximum row count reached. Creating new partition at {output_file}" - ) - self._create_empty_tables(output_file) - self._row_counts = self._initialize_row_counts() - - -# Implementation-specific utility function(s) -def construct_dataframe(extraction: Dict[str, Any]) -> pd.DataFrame: - """Convert extraction to pandas.DataFrame. - - Args: - extraction: Dictionary with the extracted data. - - Returns: - Extraction as pandas.DataFrame. - """ - all_scalars = True - for value in extraction.values(): - if isinstance(value, (list, tuple, dict)): - all_scalars = False - break - - out = pd.DataFrame(extraction, index=[0] if all_scalars else None) - return out - - -def is_pulse_map(table_name: str) -> bool: - """Check whether `table_name` corresponds to a pulse map.""" - return "pulse" in table_name.lower() or "series" in table_name.lower() - - -def is_mc_tree(table_name: str) -> bool: - """Check whether `table_name` corresponds to an MC tree.""" - return "I3MCTree" in table_name diff --git a/src/graphnet/data/utilities/__init__.py b/src/graphnet/data/utilities/__init__.py index 0dd9e0600..ad4f0c7db 100644 --- a/src/graphnet/data/utilities/__init__.py +++ b/src/graphnet/data/utilities/__init__.py @@ -1 +1,4 @@ """Utilities for use across `graphnet.data`.""" +from .sqlite_utilities import create_table_and_save_to_sql +from .sqlite_utilities import get_primary_keys +from .sqlite_utilities import query_database diff --git a/src/graphnet/data/utilities/parquet_to_sqlite.py b/src/graphnet/data/utilities/parquet_to_sqlite.py index 146e69ce8..11114698e 100644 --- a/src/graphnet/data/utilities/parquet_to_sqlite.py +++ b/src/graphnet/data/utilities/parquet_to_sqlite.py @@ -9,7 +9,9 @@ import pandas as pd from tqdm.auto import trange -from graphnet.data.sqlite.sqlite_utilities import create_table_and_save_to_sql +from graphnet.data.utilities.sqlite_utilities import ( + create_table_and_save_to_sql, +) from graphnet.utilities.logging import Logger diff --git a/src/graphnet/data/sqlite/sqlite_utilities.py b/src/graphnet/data/utilities/sqlite_utilities.py similarity index 72% rename from src/graphnet/data/sqlite/sqlite_utilities.py rename to src/graphnet/data/utilities/sqlite_utilities.py index 23bae802d..cfa308ba2 100644 --- a/src/graphnet/data/sqlite/sqlite_utilities.py +++ b/src/graphnet/data/utilities/sqlite_utilities.py @@ -1,7 +1,7 @@ """SQLite-specific utility functions for use in `graphnet.data`.""" import os.path -from typing import List +from typing import List, Dict, Tuple import pandas as pd import sqlalchemy @@ -16,6 +16,58 @@ def database_exists(database_path: str) -> bool: return os.path.exists(database_path) +def query_database(database: str, query: str) -> pd.DataFrame: + """Execute query on database, and return result. + + Args: + database: path to database. + query: query to be executed. + + Returns: + DataFrame containing the result of the query. + """ + with sqlite3.connect(database) as conn: + return pd.read_sql(query, conn) + + +def get_primary_keys(database: str) -> Tuple[Dict[str, str], str]: + """Get name of primary key column for each table in database. + + Args: + database: path to database. + + Returns: + A dictionary containing the names of primary keys in each table of + `database`. E.g. {'truth': "event_no", + 'SplitInIcePulses': None} + Name of the primary key. + """ + with sqlite3.connect(database) as conn: + query = 'SELECT name FROM sqlite_master WHERE type == "table"' + table_names = [table[0] for table in conn.execute(query).fetchall()] + + integer_primary_key = {} + for table in table_names: + query = f"SELECT l.name FROM pragma_table_info('{table}') as l WHERE l.pk = 1;" + first_primary_key = [ + key[0] for key in conn.execute(query).fetchall() + ] + integer_primary_key[table] = ( + first_primary_key[0] if len(first_primary_key) else None + ) + + # Get the primary key column name + primary_key_candidates = [] + for val in set(integer_primary_key.values()): + if val is not None: + primary_key_candidates.append(val) + + # There should only be one primary key: + assert len(primary_key_candidates) == 1 + + return integer_primary_key, primary_key_candidates[0] + + def database_table_exists(database_path: str, table_name: str) -> bool: """Check whether `table_name` exists in database at `database_path`.""" if not database_exists(database_path): diff --git a/src/graphnet/data/writers/__init__.py b/src/graphnet/data/writers/__init__.py new file mode 100644 index 000000000..ad3e2748e --- /dev/null +++ b/src/graphnet/data/writers/__init__.py @@ -0,0 +1,4 @@ +"""Modules for saving interim dataformat to various data backends.""" +from .graphnet_writer import GraphNeTWriter +from .parquet_writer import ParquetWriter +from .sqlite_writer import SQLiteWriter diff --git a/src/graphnet/data/writers/graphnet_writer.py b/src/graphnet/data/writers/graphnet_writer.py new file mode 100644 index 000000000..f6ec03029 --- /dev/null +++ b/src/graphnet/data/writers/graphnet_writer.py @@ -0,0 +1,94 @@ +"""Module containing `GraphNeTFileSaveMethod`(s). + +These modules are used to save the interim data format from `DataConverter` to +a deep-learning friendly file format. +""" + +import os +from typing import Dict, List, Union +from abc import abstractmethod, ABC + +from graphnet.utilities.decorators import final +from graphnet.utilities.logging import Logger + +import pandas as pd + + +class GraphNeTWriter(Logger, ABC): + """Generic base class for saving interim data format in `DataConverter`. + + Classes inheriting from `GraphNeTFileSaveMethod` must implement the + `save_file` method, which recieves the interim data format from + from a single file. + + In addition, classes inheriting from `GraphNeTFileSaveMethod` must + set the `file_extension` property. + """ + + @abstractmethod + def _save_file( + self, + data: Union[Dict[str, pd.DataFrame], Dict[str, List[pd.DataFrame]]], + output_file_path: str, + n_events: int, + ) -> None: + """Save the interim data format from a single input file. + + Args: + data: the interim data from a single input file. + output_file_path: output file path. + n_events: Number of events container in `data`. + """ + raise NotImplementedError + + @abstractmethod + def merge_files( + self, + files: List[str], + output_dir: str, + ) -> None: + """Merge smaller files. + + Args: + files: Files to be merged. + output_dir: The directory to store the merged files in. + """ + raise NotImplementedError + + @final + def __call__( + self, + data: Union[Dict[str, pd.DataFrame], Dict[str, List[pd.DataFrame]]], + file_name: str, + output_dir: str, + n_events: int, + ) -> None: + """Save data. + + Args: + data: data to be saved. + file_name: name of input file. Will be used to generate output + file name. + output_dir: directory to save data to. + n_events: Number of events in `data`. + """ + # make dir + os.makedirs(output_dir, exist_ok=True) + output_file_path = ( + os.path.join(output_dir, file_name) + self.file_extension + ) + + self._save_file( + data=data, output_file_path=output_file_path, n_events=n_events + ) + return + + @property + def file_extension(self) -> str: + """Return file extension used to store the data.""" + return self._file_extension # type: ignore + + @property + def expects_merged_dataframes(self) -> bool: + """Return if writer expects input to be merged dataframes or not.""" + return self._merge_dataframes # type: ignore diff --git a/src/graphnet/data/writers/parquet_writer.py b/src/graphnet/data/writers/parquet_writer.py new file mode 100644 index 000000000..18e524ca9 --- /dev/null +++ b/src/graphnet/data/writers/parquet_writer.py @@ -0,0 +1,51 @@ +"""DataConverter for the Parquet backend.""" + +import os +from typing import List, Optional, Dict + +import awkward +import pandas as pd + +from .graphnet_writer import GraphNeTWriter + + +class ParquetWriter(GraphNeTWriter): + """Class for writing interim data format to Parquet.""" + + # Class variables + _file_extension = ".parquet" + _merge_dataframes = False + + # Abstract method implementation(s) + def _save_file( + self, + data: Dict[str, List[pd.DataFrame]], + output_file_path: str, + n_events: int, + ) -> None: + """Save data to parquet.""" + # Check(s) + + if n_events > 0: + events = [] + for k in range(n_events): + event = {} + for table in data.keys(): + event[table] = data[table][k].to_dict(orient="list") + + events.append(event) + + awkward.to_parquet(awkward.from_iter(events), output_file_path) + + def merge_files(self, files: List[str], output_dir: str) -> None: + """Merge parquet files. + + Args: + files: input files for merging. + output_dir: directory to store merged file(s) in. + + Raises: + NotImplementedError + """ + self.error(f"{self.__class__.__name__} does not have a merge method.") + raise NotImplementedError diff --git a/src/graphnet/data/writers/sqlite_writer.py b/src/graphnet/data/writers/sqlite_writer.py new file mode 100644 index 000000000..d7cc48297 --- /dev/null +++ b/src/graphnet/data/writers/sqlite_writer.py @@ -0,0 +1,225 @@ +"""Module containing `GraphNeTFileSaveMethod`(s). + +These modules are used to save the interim data format from `DataConverter` to +a deep-learning friendly file format. +""" + +import os +from tqdm import tqdm +from typing import List, Dict, Optional + +from graphnet.data.utilities import ( + create_table_and_save_to_sql, + get_primary_keys, + query_database, +) +import pandas as pd +from .graphnet_writer import GraphNeTWriter + + +class SQLiteWriter(GraphNeTWriter): + """A method for saving GraphNeT's interim dataformat to SQLite.""" + + def __init__( + self, + merged_database_name: str = "merged.db", + max_table_size: Optional[int] = None, + ) -> None: + """Initialize `SQLiteWriter`. + + Args: + merged_database_name: name of the database, not path, that files + will be merged into. Defaults to "merged.db". + max_table_size: The maximum number of rows in any given table. + If given, the merging proceedure splits the databases into + partitions each with a maximum table size of max_table_size. + Note that the size is approximate. This feature is useful if + you have many events, as tables exceeding + 400 million rows tend to be noticably slower to query. + Defaults to None (All events are put into a single database). + """ + # Member Variables + self._file_extension = ".db" + self._merge_dataframes = True + self._max_table_size = max_table_size + self._database_name = merged_database_name + + # Add file extension to database name if forgotten + if not self._database_name.endswith(self._file_extension): + self._database_name = self._database_name + self._file_extension + + # Base class constructor + super().__init__(name=__name__, class_name=self.__class__.__name__) + + def _save_file( + self, + data: Dict[str, pd.DataFrame], + output_file_path: str, + n_events: int, + ) -> None: + """Save data to SQLite database.""" + # Check(s) + if os.path.exists(output_file_path): + self.warning( + f"Output file {output_file_path} already exists. Appending." + ) + + # Concatenate data + if len(data) == 0: + self.warning( + "No data was extracted from the processed I3 file(s). " + f"No data saved to {output_file_path}" + ) + return + + saved_any = False + # Save each dataframe to SQLite database + self.debug(f"Saving to {output_file_path}") + for table, df in data.items(): + if len(df) > 0: + create_table_and_save_to_sql( + df, + table, + output_file_path, + default_type="FLOAT", + integer_primary_key=len(df) <= n_events, + ) + saved_any = True + + if saved_any: + self.debug("- Done saving") + else: + self.warning(f"No data saved to {output_file_path}") + + def merge_files( + self, + files: List[str], + output_dir: str, + ) -> None: + """SQLite-specific method for merging output files/databases. + + Args: + files: paths to SQLite databases that needs to be merged. + output_dir: path to store the merged database(s) in. + database_name: name, not path, of database. E.g. "my_database". + max_table_size: The maximum number of rows in any given table. + If given, the merging proceedure splits the databases into + partitions each with a maximum table size of max_table_size. + Note that the size is approximate. This feature is useful if + you have many events, as tables exceeding + 400 million rows tend to be noticably slower to query. + Defaults to None (All events are put into a single database.) + """ + # Warnings + if self._max_table_size: + self.warning( + f"Merging got max_table_size of {self._max_table_size}." + " Will attempt to create databases with a maximum row count of" + " this size." + ) + + # Set variables + self._partition_count = 1 + + # Construct full database path + database_path = os.path.join(output_dir, self._database_name) + print(database_path) + # Start merging if files are given + if len(files) > 0: + os.makedirs(output_dir, exist_ok=True) + self.info(f"Merging {len(files)} database files") + self._merge_databases(files=files, database_path=database_path) + else: + self.warning("No database files given! Exiting.") + + def _merge_databases( + self, + files: List[str], + database_path: str, + ) -> None: + """Merge the temporary databases. + + Args: + files: List of files to be merged. + database_path: Path to a database, can be an empty path, where the + databases listed in `files` will be merged into. If no database + exists at the given path, one will be created. + """ + if os.path.exists(database_path): + self.warning( + "Target path for merged database", + f"{database_path}, already exists.", + ) + + if self._max_table_size is not None: + database_path = self._adjust_output_path(database_path) + self._row_counts: Dict[str, int] = {} + self._largest_table = 0 + + # Merge temporary databases into newly created one + for file_count, input_file in tqdm(enumerate(files), colour="green"): + + # Extract table names and index column name in database + tables, primary_key = get_primary_keys(database=input_file) + + for table_name in tables.keys(): + # Extract all data in the table from the given database + df = query_database( + database=input_file, query=f"SELECT * FROM {table_name}" + ) + + # Infer whether the table was previously indexed with + # A primary key or not. len(tables[table]) = 0 if not. + integer_primary_key = ( + True if tables[table_name] is not None else False + ) + + # Submit to new database + create_table_and_save_to_sql( + df=df, + table_name=table_name, + database_path=database_path, + index_column=primary_key, + integer_primary_key=integer_primary_key, + ) + + # Update row counts if needed + if self._max_table_size is not None: + self._update_row_counts(df=df, table_name=table_name) + + if (self._max_table_size is not None) & (file_count < len(files)): + assert self._max_table_size is not None # mypy... + if self._largest_table >= self._max_table_size: + # Increment partition, reset counts, adjust output path + self._partition_count += 1 + self._row_counts = {} + self._largest_table = 0 + database_path = self._adjust_output_path(database_path) + self.info( + "Maximum row count reached." + f" Creating new partition at {database_path}" + ) + + # Internal methods + + def _adjust_output_path(self, output_file: str) -> str: + """Adjust the file path to reflect that it is a partition.""" + path_without_extension, extension = os.path.splitext(output_file) + if "_part_" in path_without_extension: + # if true, this is already a partition. + database_name = path_without_extension.split("_part_")[:-1][0] + else: + database_name = path_without_extension + # split into multiple lines to avoid one long + database_name = database_name + f"_part_{self._partition_count}" + database_name = database_name + extension + return database_name + + def _update_row_counts(self, df: pd.DataFrame, table_name: str) -> None: + if table_name in self._row_counts.keys(): + self._row_counts[table_name] += len(df) + else: + self._row_counts[table_name] = len(df) + + self._largest_table = max(self._row_counts.values()) + return diff --git a/src/graphnet/deployment/i3modules/graphnet_module.py b/src/graphnet/deployment/i3modules/graphnet_module.py index d3aa878e0..a385413b3 100644 --- a/src/graphnet/deployment/i3modules/graphnet_module.py +++ b/src/graphnet/deployment/i3modules/graphnet_module.py @@ -7,7 +7,7 @@ import torch from torch_geometric.data import Data, Batch -from graphnet.data.extractors import ( +from graphnet.data.extractors.icecube import ( I3FeatureExtractor, I3FeatureExtractorIceCubeUpgrade, ) @@ -70,7 +70,7 @@ def __init__( self._i3_extractors = [pulsemap_extractor] for i3_extractor in self._i3_extractors: - i3_extractor.set_files(i3_file="", gcd_file=self._gcd_file) + i3_extractor.set_gcd(i3_file="", gcd_file=self._gcd_file) @abstractmethod def __call__(self, frame: I3Frame) -> bool: diff --git a/src/graphnet/models/graphs/edges/minkowski.py b/src/graphnet/models/graphs/edges/minkowski.py index 5d1134ec5..2526de1cb 100644 --- a/src/graphnet/models/graphs/edges/minkowski.py +++ b/src/graphnet/models/graphs/edges/minkowski.py @@ -69,12 +69,13 @@ def _construct_edges(self, graph: Data) -> Data: row = [] col = [] for batch in range(x.shape[0]): + x_masked = x[batch][mask[batch]] distance_mat = compute_minkowski_distance_mat( - x_masked := x[batch][mask[batch]], - x_masked, - self.c, - self.space_coords, - self.time_coord, + x=x_masked, + y=x_masked, + c=self.c, + space_coords=self.space_coords, + time_coord=self.time_coord, ) num_points = x_masked.shape[0] num_edges = min(self.nb_nearest_neighbours, num_points) diff --git a/src/graphnet/pisa/fitting.py b/src/graphnet/pisa/fitting.py index dfcc20a37..5408f9bfc 100644 --- a/src/graphnet/pisa/fitting.py +++ b/src/graphnet/pisa/fitting.py @@ -23,7 +23,7 @@ from pisa.analysis.analysis import Analysis from pisa import ureg -from graphnet.data.sqlite import create_table_and_save_to_sql +from graphnet.data.utilities import create_table_and_save_to_sql mpl.use("pdf") plt.rc("font", family="serif") diff --git a/src/graphnet/training/weight_fitting.py b/src/graphnet/training/weight_fitting.py index a52c91b29..97411bbe5 100644 --- a/src/graphnet/training/weight_fitting.py +++ b/src/graphnet/training/weight_fitting.py @@ -7,7 +7,9 @@ import pandas as pd import sqlite3 -from graphnet.data.sqlite.sqlite_utilities import create_table_and_save_to_sql +from graphnet.data.utilities.sqlite_utilities import ( + create_table_and_save_to_sql, +) from graphnet.utilities.logging import Logger diff --git a/tests/data/test_dataconverters_and_datasets.py b/tests/data/test_dataconverters_and_datasets.py index 480f11d4d..e1d9e773b 100644 --- a/tests/data/test_dataconverters_and_datasets.py +++ b/tests/data/test_dataconverters_and_datasets.py @@ -11,7 +11,7 @@ from graphnet.constants import TEST_OUTPUT_DIR from graphnet.data.constants import FEATURES, TRUTH from graphnet.data.dataconverter import DataConverter -from graphnet.data.extractors import ( +from graphnet.data.extractors.icecube import ( I3FeatureExtractorIceCube86, I3TruthExtractor, I3RetroExtractor, @@ -19,7 +19,6 @@ from graphnet.data.parquet import ParquetDataConverter from graphnet.data.dataset import ParquetDataset, SQLiteDataset from graphnet.data.sqlite import SQLiteDataConverter -from graphnet.data.sqlite.sqlite_dataconverter import is_pulse_map from graphnet.data.utilities.parquet_to_sqlite import ParquetToSQLiteConverter from graphnet.utilities.imports import has_icecube_package from graphnet.models.graphs import KNNGraph @@ -52,17 +51,6 @@ def get_file_path(backend: str) -> str: return path -# Unit test(s) -def test_is_pulsemap_check() -> None: - """Test behaviour of `is_pulsemap_check`.""" - assert is_pulse_map("SplitInIcePulses") is True - assert is_pulse_map("SRTInIcePulses") is True - assert is_pulse_map("InIceDSTPulses") is True - assert is_pulse_map("RTTWOfflinePulses") is True - assert is_pulse_map("truth") is False - assert is_pulse_map("retro") is False - - @pytest.mark.order(1) @pytest.mark.parametrize("backend", ["sqlite", "parquet"]) def test_dataconverter( diff --git a/tests/data/test_i3extractor.py b/tests/data/test_i3extractor.py index 3fa19f078..ce40626c0 100644 --- a/tests/data/test_i3extractor.py +++ b/tests/data/test_i3extractor.py @@ -1,6 +1,6 @@ -"""Unit tests for I3Extractor class.""" +"""Unit tests for I3Extractor.""" -from graphnet.data.extractors import ( +from graphnet.data.extractors.icecube import ( I3FeatureExtractorIceCube86, I3TruthExtractor, I3RetroExtractor, diff --git a/tests/data/test_i3genericextractor.py b/tests/data/test_i3genericextractor.py deleted file mode 100644 index 314fa5f44..000000000 --- a/tests/data/test_i3genericextractor.py +++ /dev/null @@ -1,97 +0,0 @@ -"""Unit tests for I3GenericExtractor class.""" - -import os - -import numpy as np - -import graphnet.constants -from graphnet.data.extractors import ( - I3FeatureExtractorIceCube86, - I3TruthExtractor, - I3GenericExtractor, -) -from graphnet.utilities.imports import has_icecube_package - -if has_icecube_package(): - from icecube import dataio # pyright: reportMissingImports=false - -# Global variable(s) -TEST_DATA_DIR = os.path.join( - graphnet.constants.TEST_DATA_DIR, "i3", "oscNext_genie_level7_v02" -) -FILE_NAME = "oscNext_genie_level7_v02_first_5_frames" -GCD_FILE = ( - "GeoCalibDetectorStatus_AVG_55697-57531_PASS2_SPE_withScaledNoise.i3.gz" -) - - -# Unit test(s) -def test_i3genericextractor(test_data_dir: str = TEST_DATA_DIR) -> None: - """Test the implementation of `I3GenericExtractor`.""" - # Constants(s) - mc_tree = "I3MCTree" - pulse_series = "SRTInIcePulses" - - # Constructor I3Extractor instance(s) - generic_extractor = I3GenericExtractor(keys=[mc_tree, pulse_series]) - truth_extractor = I3TruthExtractor() - feature_extractor = I3FeatureExtractorIceCube86(pulse_series) - - i3_file = os.path.join(test_data_dir, FILE_NAME) + ".i3.gz" - gcd_file = os.path.join(test_data_dir, GCD_FILE) - - generic_extractor.set_files(i3_file, gcd_file) - truth_extractor.set_files(i3_file, gcd_file) - feature_extractor.set_files(i3_file, gcd_file) - - i3_file_io = dataio.I3File(i3_file, "r") - ix_test = 5 - while i3_file_io.more(): - try: - frame = i3_file_io.pop_physics() - except: # noqa: E722 - continue - - generic_data = generic_extractor(frame) - truth_data = truth_extractor(frame) - feature_data = feature_extractor(frame) - - if ix_test == 5: - print(list(generic_data[pulse_series].keys())) - print(list(truth_data.keys())) - print(list(feature_data.keys())) - - # Truth vs. generic - key_pairs = [ - ("energy", "energy"), - ("zenith", "dir__zenith"), - ("azimuth", "dir__azimuth"), - ("pid", "pdg_encoding"), - ] - - for truth_key, generic_key in key_pairs: - assert ( - truth_data[truth_key] - == generic_data[f"{mc_tree}__primaries"][generic_key][0] - ) - - # Reco vs. generic - key_pairs = [ - ("charge", "charge"), - ("dom_time", "time"), - ("dom_x", "position__x"), - ("dom_y", "position__y"), - ("dom_z", "position__z"), - ("width", "width"), - ("pmt_area", "area"), - ("rde", "relative_dom_eff"), - ] - - for reco_key, generic_key in key_pairs: - assert np.allclose( - feature_data[reco_key], generic_data[pulse_series][generic_key] - ) - - ix_test -= 1 - if ix_test == 0: - break diff --git a/tests/deployment/queso_test.py b/tests/deployment/queso_test.py index d1258ed89..5c0088f5d 100644 --- a/tests/deployment/queso_test.py +++ b/tests/deployment/queso_test.py @@ -8,7 +8,7 @@ import pytest from graphnet.data.constants import FEATURES -from graphnet.data.extractors.i3featureextractor import ( +from graphnet.data.extractors.icecube import ( I3FeatureExtractorIceCubeUpgrade, ) from graphnet.constants import (