Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: Enabling training on MCPE pulses #29

Open
wants to merge 29 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
e7a5ed9
Simplify code
fschlueter Mar 29, 2024
cb05b28
Remove hard-coded label_key LabelsDeepLearning. Allow to train on MCPEs.
fschlueter Mar 29, 2024
009d351
Fixed setting charge index. Slightly refactor code to avoid future wa…
fschlueter Apr 1, 2024
bd3a139
Add new config file for training on MCPEs
fschlueter Apr 4, 2024
0dc9634
Keep training and testing data independent
fschlueter Apr 4, 2024
871c5c8
Allow to specify parameter names in config
fschlueter Apr 4, 2024
cc2a8da
Propagate new argument correctly
fschlueter Apr 4, 2024
37abda5
Update config for MCPE training
fschlueter Apr 9, 2024
10fadf1
Add missing config
fschlueter Apr 9, 2024
4f4bc0c
Set number of iteration to 1M
fschlueter Apr 9, 2024
1d2f440
Update MCPE config
fschlueter Apr 11, 2024
382fd55
Do not run validation in the first iteration
fschlueter Apr 11, 2024
8d51035
Add function which evaluates pdf only for one DOM
fschlueter Apr 15, 2024
e0863ac
Add module to evaluate mcpe eg models
fschlueter Apr 15, 2024
3229861
Clean up. Add correct time call to photonics function
fschlueter Apr 16, 2024
252a91f
Add cdf_per_dom function and get_probability_quantiles_per_dom
fschlueter Apr 16, 2024
b85d0ee
Little clean up
fschlueter Apr 16, 2024
8d62758
Store total charge per dom as well. Improve filename for figures
fschlueter Apr 17, 2024
16f6cbc
Add function which only calculates total charge per dom
fschlueter Apr 18, 2024
ea00a1d
fix ws
fschlueter May 29, 2024
f1fb9e8
Remove hard-coded label_key LabelsDeepLearning. Allow to train on MCPEs.
fschlueter Mar 29, 2024
0c71958
Fixed setting charge index. Slightly refactor code to avoid future wa…
fschlueter Apr 1, 2024
bc591ae
Add new config file for training on MCPEs
fschlueter Apr 4, 2024
84a74bd
Export to hdf5 rather than i3
fschlueter Jun 3, 2024
22392b8
Merge branch 'master' into enable_mcpe_training
fschlueter Jun 3, 2024
d837934
Allow evaluation in eager mode
mhuen Jun 3, 2024
de144f8
Code refactoring
fschlueter Jun 3, 2024
c7f9187
Merge branch 'FixMultiLearningRateScheduler' into enable_mcpe_training
fschlueter Jun 3, 2024
dec3eb0
Improve code. access charge from hdf files by name and not index
Jun 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
426 changes: 426 additions & 0 deletions configs/cascade_7param_MCPE.yaml

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions egenerator/data/handler/modular.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,10 @@ def _get_data(self, file_or_frame, method, *args, **kwargs):
assert isinstance(file_or_frame, str), "Expected file path string"

num_data, data = self.data_module.get_data_from_hdf(
file_or_frame, *args, **kwargs
)
file_or_frame, *args,
label_key=self.label_module.configuration.config["label_key"],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs to remain in the kwargs. Modules are not forced to specify/use a label_key. But I also have a slightly different implementation in the CollectBreakingChanges branch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not quite sure that I fully understand why it hace to remain in kwargs.

**kwargs)

num_labels, labels = self.label_module.get_data_from_hdf(
file_or_frame, *args, **kwargs
)
Expand Down
32 changes: 20 additions & 12 deletions egenerator/data/modules/data/pulse_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def _configure(
float_precision,
add_charge_quantiles,
discard_pulses_from_excluded_doms,
pulse_is_mcpe=False,
):
"""Configure Module Class
This is an abstract method and must be implemented by derived class.
Expand Down Expand Up @@ -72,6 +73,9 @@ def _configure(
discard_pulses_from_excluded_doms : bool, optional
If True, pulses on excluded DOMs are discarded. The pulses are
discarded after the charge at the DOM is collected.
pulse_is_mcpe : bool, optional
If True, train on MCPE pulses instead of RecoPulses. This setting
changes how the charge is read out of the hdf5 files. (Default: False)

Returns
-------
Expand Down Expand Up @@ -214,6 +218,7 @@ def _configure(
),
mutable_settings=dict(
pulse_key=pulse_key,
pulse_is_mcpe=pulse_is_mcpe,
dom_exclusions_key=dom_exclusions_key,
time_exclusions_key=time_exclusions_key,
discard_pulses_from_excluded_doms=(
Expand Down Expand Up @@ -247,12 +252,13 @@ def get_data_from_hdf(self, file, *args, **kwargs):
if not self.is_configured:
raise ValueError("Module not configured yet!")

# open file
f = pd.HDFStore(file, "r")
charge_str = "npe" if self.configuration.config["pulse_is_mcpe"] else "charge"

# open file
f = pd.HDFStore(file, 'r')
try:
pulses = f[self.configuration.config["pulse_key"]]
_labels = f["LabelsDeepLearning"]
_labels = f[kwargs["label_key"]]
if self.data["dom_exclusions_exist"]:
try:
dom_exclusions = f[
Expand Down Expand Up @@ -293,11 +299,13 @@ def get_data_from_hdf(self, file, *args, **kwargs):

# create Dictionary with event IDs
size = len(_labels["Event"])

if not size:
raise ValueError("Label length is 0.")

event_dict = {}
for idx, row in _labels.iterrows():
event_dict[
(row.iloc[0], row.iloc[1], row.iloc[2], row.iloc[3])
] = idx
for row in _labels.itertuples():
event_dict[(row[1:5])] = row[0]

# create empty array for DOM charges
x_dom_charge = np.zeros(
Expand Down Expand Up @@ -351,21 +359,22 @@ def get_data_from_hdf(self, file, *args, **kwargs):
"skipping pulse: {} {}".format(string, dom)
)
continue

index = event_dict[(row[1:5])]

# accumulate charge in DOMs
x_dom_charge[index, string - 1, dom - 1, 0] += row.charge
x_dom_charge[index, string - 1, dom - 1, 0] += getattr(row, charge_str)

# gather pulses
if add_charge_quantiles:

# (charge, time, quantile)
cum_charge = float(x_dom_charge[index, string - 1, dom - 1, 0])
x_pulses[pulse_index] = [row.charge, row.time, cum_charge]
x_pulses[pulse_index] = [getattr(row, charge_str), row.time, cum_charge]

else:
# (charge, time)
x_pulses[pulse_index] = [row.charge, row.time]
x_pulses[pulse_index] = [getattr(row, charge_str), row.time]

# gather pulse ids (batch index, string, dom)
x_pulses_ids[pulse_index] = [index, string - 1, dom - 1]
Expand Down Expand Up @@ -408,7 +417,7 @@ def get_data_from_hdf(self, file, *args, **kwargs):
continue
index = event_dict[(row[1:5])]

# t_start (pulse time): row[10], t_end (pulse width): row[11]
# t_start (pulse time): row.time, t_end (pulse width): row[11]

# (t_start, t_end)
x_time_exclusions[tw_index] = [row.time, row.width]
Expand Down Expand Up @@ -577,7 +586,6 @@ def get_data_from_frame(self, frame, *args, **kwargs):
for pulse in pulse_list:
index = 0

# pulse charge: row[12], time: row[10]
# accumulate charge in DOMs
x_dom_charge[index, string - 1, dom - 1, 0] += pulse.charge

Expand Down
34 changes: 12 additions & 22 deletions egenerator/data/modules/labels/cascades.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def _configure(
trafo_log,
float_precision,
label_key="LabelsDeepLearning",
parameter_names=["cascade_x", "cascade_y", "cascade_z",
"cascade_zenith", "cascade_azimuth",
"cascade_energy", "cascade_t"],
):
"""Configure Module Class
This is an abstract method and must be implemented by derived class.
Expand All @@ -51,10 +54,12 @@ def _configure(
If a single bool is given, this applies to all labels. Otherwise
a list of bools corresponds to the labels in the order:
x, y, z, zenith, azimuth, energy, time
label_key : str, optional
The name of the key under which the labels are saved.
float_precision : str
The float precision as a str.
label_key : str, optional
The name of the key under which the labels are saved.
parameter_names : list of str, optional
Name of the parameters (e.g, the columns in the hdf5 Dataset `label_key`)

Returns
-------
Expand Down Expand Up @@ -129,6 +134,7 @@ def _configure(
trafo_log=trafo_log,
float_precision=float_precision,
label_key=label_key,
parameter_names=parameter_names,
),
)
return configuration, data, {}
Expand Down Expand Up @@ -163,16 +169,8 @@ def get_data_from_hdf(self, file, *args, **kwargs):
cascade_parameters = []
try:
_labels = f[self.configuration.config["label_key"]]
for label in [
"cascade_x",
"cascade_y",
"cascade_z",
"cascade_zenith",
"cascade_azimuth",
"cascade_energy",
"cascade_t",
]:
cascade_parameters.append(_labels[label])
for par in self.configuration.config["parameter_names"]:
cascade_parameters.append(_labels[par])

except Exception as e:
self._logger.warning(e)
Expand Down Expand Up @@ -223,16 +221,8 @@ def get_data_from_frame(self, frame, *args, **kwargs):
cascade_parameters = []
try:
_labels = frame[self.configuration.config["label_key"]]
for label in [
"cascade_x",
"cascade_y",
"cascade_z",
"cascade_zenith",
"cascade_azimuth",
"cascade_energy",
"cascade_t",
]:
cascade_parameters.append(np.atleast_1d(_labels[label]))
for par in self.configuration.config["parameter_names"]:
cascade_parameters.append(np.atleast_1d(_labels[par]))

except Exception as e:
self._logger.warning(e)
Expand Down
2 changes: 2 additions & 0 deletions egenerator/ic3/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from egenerator.ic3.reconstruction import EventGeneratorReconstruction
from egenerator.ic3.simulation import EventGeneratorSimulation
from egenerator.ic3.evaluate_mcpe import CalculateLikelihood

__all__ = [
"EventGeneratorReconstruction",
"EventGeneratorSimulation",
"CalculateLikelihood",
]
Loading
Loading