Skip to content

Commit

Permalink
catch duplicate file names
Browse files Browse the repository at this point in the history
  • Loading branch information
RasmusOrsoe committed Sep 16, 2024
1 parent c0e0f5a commit 6cac966
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 16 deletions.
67 changes: 60 additions & 7 deletions src/graphnet/data/dataconverter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
"""Contains `DataConverter`."""
from typing import List, Union, OrderedDict, Dict, Tuple, Any, Optional, Type
from typing import (
List,
Union,
OrderedDict,
Dict,
DefaultDict,
Tuple,
Any,
Optional,
Type,
)
from abc import abstractmethod, ABC
from collections import defaultdict

from tqdm import tqdm
import numpy as np
Expand Down Expand Up @@ -98,20 +109,54 @@ def __call__(self, input_dir: Union[str, List[str]]) -> None:
# Get the file reader to produce a list of input files
# in the directory
input_files = self._file_reader.find_files(path=input_dir)
self._launch_jobs(input_files=input_files)
self._output_files = [
candidate_file_names = [
os.path.join(
self._output_dir,
self._create_file_name(file)
+ self._save_method.file_extension,
)
for file in input_files
]
output_files = self._rename_duplicates(candidate_file_names)
# file_map = {input_files[k]: output_files[k] for k in range(len(input_files))}
self._launch_jobs(
input_files=input_files, output_file_paths=output_files
)

def _rename_duplicates(self, files: List[str]) -> List[str]:
# Dictionary to track occurrences of each file
file_count: DefaultDict[str, int] = defaultdict(int)

# List to store updated file names
renamed_files = []

for file in files:
# Split the file into name and extension
name, extension = file.rsplit(".", 1)
file_name = os.path.basename(name) + f".{extension}"

# If the file has been encountered before, increment its count and rename it
if file_count[file_name] > 0:
new_name = os.path.join(
os.path.dirname(file),
f"{file_name}_{file_count[file_name]}.{extension}",
)
else:
new_name = file

# Increment the count for the file in file_count (after adding the file)
file_count[file_name] += 1

# Add the new name to the renamed_files list
renamed_files.append(new_name)

return renamed_files

@final
def _launch_jobs(
self,
input_files: Union[List[str], List[I3FileSet]],
output_file_paths: List[str],
) -> None:
"""Multi Processing Logic.
Expand All @@ -128,20 +173,26 @@ def _launch_jobs(
# Iterate over files
for _ in map_fn(
self._process_file,
tqdm(input_files, unit=" file(s)", colour="green"),
tqdm(
zip(input_files, output_file_paths),
unit=" file(s)",
colour="green",
total=len(input_files),
),
):
self.debug("processing file.")
self._update_shared_variables(pool)

@final
def _process_file(self, file_path: Union[str, I3FileSet]) -> None:
def _process_file(self, args: Tuple[Union[str, I3FileSet], str]) -> None:
"""Process a single file.
Calls file reader to recieve extracted output, event ids
is assigned to the extracted data and is handed to save method.
This function is called in parallel.
"""
file_path, output_file_path = args
# Read and apply extractors
data = self._file_reader(file_path=file_path)

Expand Down Expand Up @@ -169,12 +220,14 @@ def _process_file(self, file_path: Union[str, I3FileSet]) -> None:
del data

# Create output file name
output_file_name = self._create_file_name(input_file_path=file_path)

# output_file_name = self._output_files_map[file_path]
# output_file_name = self._create_file_name(input_file_path=file_path)

# Apply save method
self._save_method(
data=dataframes,
file_name=output_file_name,
file_name=output_file_path,
n_events=n_events,
output_dir=self._output_dir,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np

from graphnet.data.extractors import Extractor
from .utilities import compute_visible_inelasticity
from .utilities import compute_visible_inelasticity, get_muon_direction


class PrometheusExtractor(Extractor):
Expand Down Expand Up @@ -85,14 +85,18 @@ def __call__(self, event: pd.DataFrame) -> pd.DataFrame:
"""Extract event-level truth information."""
# Extract data
visible_inelasticity = compute_visible_inelasticity(event)
muon_zenith, muon_azimuth = get_muon_direction(event)
res = super().__call__(event=event)
# transform azimuth from [-pi, pi] to [0, 2pi] if wanted
if self._transform_az:
if len(res["initial_state_azimuth"]) > 0:
azimuth = np.asarray(res["initial_state_azimuth"]) + np.pi
azimuth = azimuth.tolist() # back to list
res["initial_state_azimuth"] = azimuth
muon_azimuth += np.pi
res["visible_inelasticity"] = [visible_inelasticity]
res["muon_azimuth"] = [muon_azimuth]
res["muon_zenith"] = [muon_zenith]
return res


Expand Down
6 changes: 3 additions & 3 deletions src/graphnet/data/extractors/prometheus/pulsemap_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def __call__(self, event: pd.DataFrame) -> pd.DataFrame:
photons = super().__call__(event=event)

# Create empty variables - these will be returned if needed
features = self._columns + ["charge", "is_signal"]
features = self._columns + ["is_signal", "charge"]
pulses: Dict[str, List] = {feature: [] for feature in features}

# Return empty if not enough signal
Expand Down Expand Up @@ -195,7 +195,7 @@ def __call__(self, event: pd.DataFrame) -> pd.DataFrame:
x=pulses["charge"], std=self._charge_std
)
)
return pulses
return {key: pulses[key] for key in features}
else:
return self._make_empty_return()
else:
Expand All @@ -204,7 +204,7 @@ def __call__(self, event: pd.DataFrame) -> pd.DataFrame:
return self._make_empty_return()

def _make_empty_return(self) -> Dict[str, List]:
features = self._columns + ["charge", "is_signal"]
features = self._columns + ["is_signal", "charge"]
pulses: Dict[str, List] = {feature: [] for feature in features}
return pulses

Expand Down
22 changes: 21 additions & 1 deletion src/graphnet/data/extractors/prometheus/utilities.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""A series of utility functions for extraction of data from Prometheus."""

from typing import Dict, Any
from typing import Dict, Any, Tuple
import pandas as pd
from abc import abstractmethod
import numpy as np
Expand Down Expand Up @@ -60,6 +60,26 @@ def compute_visible_inelasticity(mc_truth: pd.DataFrame) -> float:
return visible_inelasticity


def get_muon_direction(mc_truth: pd.DataFrame) -> Tuple[float, float]:
"""Get angles of muon in nu_mu CC events."""
final_type_1, final_type_2 = abs(mc_truth["final_state_type"])
if mc_truth["interaction"] != 1:
muon_zenith = -1
muon_azimuth = -1
elif not (final_type_1 == 13 or final_type_2 == 13):
muon_zenith = -1
muon_azimuth = -1
else:
# CC only
muon_zenith = mc_truth["final_state_zenith"][
abs(mc_truth["final_state_type"]) == 13
][0]
muon_azimuth = mc_truth["final_state_azimuth"][
abs(mc_truth["final_state_type"]) == 13
][0]
return muon_zenith, muon_azimuth


class PrometheusFilter(Logger):
"""Generic Filter Class for PrometheusReader."""

Expand Down
2 changes: 2 additions & 0 deletions src/graphnet/data/readers/internal_parquet_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from glob import glob
import os
import pandas as pd
import random

from graphnet.data.extractors.internal import ParquetExtractor
from .graphnet_file_reader import GraphNeTFileReader
Expand Down Expand Up @@ -52,4 +53,5 @@ def find_files(self, path: Union[str, List[str]]) -> List[str]:
os.path.join(p, extractor._extractor_name, "*.parquet")
)
)
random.shuffle(files)
return files
14 changes: 10 additions & 4 deletions src/graphnet/data/writers/parquet_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,19 @@ def _save_file(
file_name = os.path.splitext(
os.path.basename(output_file_path)
)[0]

table_dir = os.path.join(save_path, f"{table}")
output_path_new = os.path.join(
table_dir, file_name + f"_{table}.parquet"
)
os.makedirs(table_dir, exist_ok=True)
df = data[table].set_index(self._index_column)
df.to_parquet(
os.path.join(table_dir, file_name + f"_{table}.parquet")
)
if os.path.isfile(output_path_new):
self.warning(
f"{os.path.basename(output_path_new)}"
"already exists! Will be overwritten!"
)

df.to_parquet(output_path_new)

def merge_files(
self,
Expand Down

0 comments on commit 6cac966

Please sign in to comment.