diff --git a/src/graphnet/data/dataconverter.py b/src/graphnet/data/dataconverter.py index bf57fac89..a7ce98964 100644 --- a/src/graphnet/data/dataconverter.py +++ b/src/graphnet/data/dataconverter.py @@ -1,6 +1,17 @@ """Contains `DataConverter`.""" -from typing import List, Union, OrderedDict, Dict, Tuple, Any, Optional, Type +from typing import ( + List, + Union, + OrderedDict, + Dict, + DefaultDict, + Tuple, + Any, + Optional, + Type, +) from abc import abstractmethod, ABC +from collections import defaultdict from tqdm import tqdm import numpy as np @@ -98,8 +109,7 @@ def __call__(self, input_dir: Union[str, List[str]]) -> None: # Get the file reader to produce a list of input files # in the directory input_files = self._file_reader.find_files(path=input_dir) - self._launch_jobs(input_files=input_files) - self._output_files = [ + candidate_file_names = [ os.path.join( self._output_dir, self._create_file_name(file) @@ -107,11 +117,46 @@ def __call__(self, input_dir: Union[str, List[str]]) -> None: ) for file in input_files ] + output_files = self._rename_duplicates(candidate_file_names) + # file_map = {input_files[k]: output_files[k] for k in range(len(input_files))} + self._launch_jobs( + input_files=input_files, output_file_paths=output_files + ) + + def _rename_duplicates(self, files: List[str]) -> List[str]: + # Dictionary to track occurrences of each file + file_count: DefaultDict[str, int] = defaultdict(int) + + # List to store updated file names + renamed_files = [] + + for file in files: + # Split the file into name and extension + name, extension = file.rsplit(".", 1) + file_name = os.path.basename(name) + f".{extension}" + + # If the file has been encountered before, increment its count and rename it + if file_count[file_name] > 0: + new_name = os.path.join( + os.path.dirname(file), + f"{file_name}_{file_count[file_name]}.{extension}", + ) + else: + new_name = file + + # Increment the count for the file in file_count (after adding the file) + file_count[file_name] += 1 + + # Add the new name to the renamed_files list + renamed_files.append(new_name) + + return renamed_files @final def _launch_jobs( self, input_files: Union[List[str], List[I3FileSet]], + output_file_paths: List[str], ) -> None: """Multi Processing Logic. @@ -128,13 +173,18 @@ def _launch_jobs( # Iterate over files for _ in map_fn( self._process_file, - tqdm(input_files, unit=" file(s)", colour="green"), + tqdm( + zip(input_files, output_file_paths), + unit=" file(s)", + colour="green", + total=len(input_files), + ), ): self.debug("processing file.") self._update_shared_variables(pool) @final - def _process_file(self, file_path: Union[str, I3FileSet]) -> None: + def _process_file(self, args: Tuple[Union[str, I3FileSet], str]) -> None: """Process a single file. Calls file reader to recieve extracted output, event ids @@ -142,6 +192,7 @@ def _process_file(self, file_path: Union[str, I3FileSet]) -> None: This function is called in parallel. """ + file_path, output_file_path = args # Read and apply extractors data = self._file_reader(file_path=file_path) @@ -169,12 +220,14 @@ def _process_file(self, file_path: Union[str, I3FileSet]) -> None: del data # Create output file name - output_file_name = self._create_file_name(input_file_path=file_path) + + # output_file_name = self._output_files_map[file_path] + # output_file_name = self._create_file_name(input_file_path=file_path) # Apply save method self._save_method( data=dataframes, - file_name=output_file_name, + file_name=output_file_path, n_events=n_events, output_dir=self._output_dir, ) diff --git a/src/graphnet/data/extractors/prometheus/prometheus_extractor.py b/src/graphnet/data/extractors/prometheus/prometheus_extractor.py index 3d7fc3df1..d32d8eea0 100644 --- a/src/graphnet/data/extractors/prometheus/prometheus_extractor.py +++ b/src/graphnet/data/extractors/prometheus/prometheus_extractor.py @@ -4,7 +4,7 @@ import numpy as np from graphnet.data.extractors import Extractor -from .utilities import compute_visible_inelasticity +from .utilities import compute_visible_inelasticity, get_muon_direction class PrometheusExtractor(Extractor): @@ -85,6 +85,7 @@ def __call__(self, event: pd.DataFrame) -> pd.DataFrame: """Extract event-level truth information.""" # Extract data visible_inelasticity = compute_visible_inelasticity(event) + muon_zenith, muon_azimuth = get_muon_direction(event) res = super().__call__(event=event) # transform azimuth from [-pi, pi] to [0, 2pi] if wanted if self._transform_az: @@ -92,7 +93,10 @@ def __call__(self, event: pd.DataFrame) -> pd.DataFrame: azimuth = np.asarray(res["initial_state_azimuth"]) + np.pi azimuth = azimuth.tolist() # back to list res["initial_state_azimuth"] = azimuth + muon_azimuth += np.pi res["visible_inelasticity"] = [visible_inelasticity] + res["muon_azimuth"] = [muon_azimuth] + res["muon_zenith"] = [muon_zenith] return res diff --git a/src/graphnet/data/extractors/prometheus/pulsemap_simulator.py b/src/graphnet/data/extractors/prometheus/pulsemap_simulator.py index cda1ebd66..d3080a1b4 100644 --- a/src/graphnet/data/extractors/prometheus/pulsemap_simulator.py +++ b/src/graphnet/data/extractors/prometheus/pulsemap_simulator.py @@ -141,7 +141,7 @@ def __call__(self, event: pd.DataFrame) -> pd.DataFrame: photons = super().__call__(event=event) # Create empty variables - these will be returned if needed - features = self._columns + ["charge", "is_signal"] + features = self._columns + ["is_signal", "charge"] pulses: Dict[str, List] = {feature: [] for feature in features} # Return empty if not enough signal @@ -195,7 +195,7 @@ def __call__(self, event: pd.DataFrame) -> pd.DataFrame: x=pulses["charge"], std=self._charge_std ) ) - return pulses + return {key: pulses[key] for key in features} else: return self._make_empty_return() else: @@ -204,7 +204,7 @@ def __call__(self, event: pd.DataFrame) -> pd.DataFrame: return self._make_empty_return() def _make_empty_return(self) -> Dict[str, List]: - features = self._columns + ["charge", "is_signal"] + features = self._columns + ["is_signal", "charge"] pulses: Dict[str, List] = {feature: [] for feature in features} return pulses diff --git a/src/graphnet/data/extractors/prometheus/utilities.py b/src/graphnet/data/extractors/prometheus/utilities.py index 2a0324314..74cda9d42 100644 --- a/src/graphnet/data/extractors/prometheus/utilities.py +++ b/src/graphnet/data/extractors/prometheus/utilities.py @@ -1,6 +1,6 @@ """A series of utility functions for extraction of data from Prometheus.""" -from typing import Dict, Any +from typing import Dict, Any, Tuple import pandas as pd from abc import abstractmethod import numpy as np @@ -60,6 +60,26 @@ def compute_visible_inelasticity(mc_truth: pd.DataFrame) -> float: return visible_inelasticity +def get_muon_direction(mc_truth: pd.DataFrame) -> Tuple[float, float]: + """Get angles of muon in nu_mu CC events.""" + final_type_1, final_type_2 = abs(mc_truth["final_state_type"]) + if mc_truth["interaction"] != 1: + muon_zenith = -1 + muon_azimuth = -1 + elif not (final_type_1 == 13 or final_type_2 == 13): + muon_zenith = -1 + muon_azimuth = -1 + else: + # CC only + muon_zenith = mc_truth["final_state_zenith"][ + abs(mc_truth["final_state_type"]) == 13 + ][0] + muon_azimuth = mc_truth["final_state_azimuth"][ + abs(mc_truth["final_state_type"]) == 13 + ][0] + return muon_zenith, muon_azimuth + + class PrometheusFilter(Logger): """Generic Filter Class for PrometheusReader.""" diff --git a/src/graphnet/data/readers/internal_parquet_reader.py b/src/graphnet/data/readers/internal_parquet_reader.py index 27e08338d..dcd6335fc 100644 --- a/src/graphnet/data/readers/internal_parquet_reader.py +++ b/src/graphnet/data/readers/internal_parquet_reader.py @@ -4,6 +4,7 @@ from glob import glob import os import pandas as pd +import random from graphnet.data.extractors.internal import ParquetExtractor from .graphnet_file_reader import GraphNeTFileReader @@ -52,4 +53,5 @@ def find_files(self, path: Union[str, List[str]]) -> List[str]: os.path.join(p, extractor._extractor_name, "*.parquet") ) ) + random.shuffle(files) return files diff --git a/src/graphnet/data/writers/parquet_writer.py b/src/graphnet/data/writers/parquet_writer.py index 721ea4dc2..df9e7b4ad 100644 --- a/src/graphnet/data/writers/parquet_writer.py +++ b/src/graphnet/data/writers/parquet_writer.py @@ -52,13 +52,19 @@ def _save_file( file_name = os.path.splitext( os.path.basename(output_file_path) )[0] - table_dir = os.path.join(save_path, f"{table}") + output_path_new = os.path.join( + table_dir, file_name + f"_{table}.parquet" + ) os.makedirs(table_dir, exist_ok=True) df = data[table].set_index(self._index_column) - df.to_parquet( - os.path.join(table_dir, file_name + f"_{table}.parquet") - ) + if os.path.isfile(output_path_new): + self.warning( + f"{os.path.basename(output_path_new)}" + "already exists! Will be overwritten!" + ) + + df.to_parquet(output_path_new) def merge_files( self,