Skip to content

Commit

Permalink
add DOM summarization
Browse files Browse the repository at this point in the history
  • Loading branch information
Aske-Rosted committed Aug 29, 2024
1 parent db34941 commit 019bb3f
Showing 1 changed file with 165 additions and 7 deletions.
172 changes: 165 additions & 7 deletions src/graphnet/models/graphs/nodes/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,9 @@ def __init__(
time_column: str = "dom_time",
charge_column: str = "charge",
max_activations: Optional[int] = None,
percentiles: List[int] = [],
time_sums: List[int] = [],
counts_sum_std: List[bool] = [True, True, True],
) -> None:
"""Construct `NodeAsDOMTimeSeries`.
Expand All @@ -242,9 +245,12 @@ def __init__(
time_column: Name of time column.
charge_column: Name of charge column.
max_activations: Maximum number of activations to include in the time series.
percentiles: List of percentiles to calculate when summarizing the time and charge series.
time_sums: List of time sums to calculate when summarizing the time and charge series.
counts_sum_std: List of booleans indicating whether to include counts, sum and std in the summarization features.
"""
self._keys = keys
super().__init__(input_feature_names=self._keys)

self._id_columns = [self._keys.index(key) for key in id_columns]
self._time_index = self._keys.index(time_column)
try:
Expand All @@ -259,11 +265,52 @@ def __init__(
self._charge_index = None

self._max_activations = max_activations
self._percentiles = percentiles
self._counts_sum_std = counts_sum_std
# ensure that max_activations is not set if percentiles are set and vice versa
assert not (
max_activations is not None and len(percentiles) > 0 # type: ignore
), "Cannot set both max_activations and percentiles"

self._time_sums = time_sums
super().__init__(input_feature_names=self._keys)
self._time_features = [
"time" in key for key in self._output_feature_names
]
self._charge_features = [
"charge" in key for key in self._output_feature_names
]

def _define_output_feature_names(
self, input_feature_names: List[str]
) -> List[str]:
return input_feature_names + ["new_node_col"]
output_feature_names = deepcopy(input_feature_names)
if len(self._percentiles) > 0:

output_feature_names[self._time_index] = "first_time"
if self._charge_index is not None:
output_feature_names[self._charge_index] = "sum_charge"
output_feature_names += list(
np.array(["counts", "mean_time", "std_time"])[
[self._counts_sum_std]
]
)
time = []
charge = []
for pct in self._percentiles:
time.append(f"time_pct{pct}")
charge.append(f"charge_pct{pct}")
output_feature_names += time
output_feature_names += charge
if len(self._time_sums) > 0:
time = []
charge = []
for time_sum in self._time_sums:
time.append(f"time_sum{time_sum}")
charge.append(f"charge_sum{time_sum}")
output_feature_names += time
output_feature_names += charge
return output_feature_names

def _construct_nodes(self, x: torch.Tensor) -> Data:
"""Construct nodes from raw node features ´x´."""
Expand All @@ -280,11 +327,6 @@ def _construct_nodes(self, x: torch.Tensor) -> Data:

# Sort by time
x = x[x[:, self._time_index].argsort()]
# Undo log10 scaling so we can sum charges
x[:, charge_index] = np.power(10, x[:, charge_index])
# Shift time to start at 0
x[:, self._time_index] -= np.min(x[:, self._time_index])
# Group pulses on the same DOM
x = lex_sort(x, self._id_columns)

unique_sensors, counts = np.unique(
Expand All @@ -296,7 +338,117 @@ def _construct_nodes(self, x: torch.Tensor) -> Data:
)
sort_this = lex_sort(x=sort_this, cluster_columns=self._id_columns)
unique_sensors = sort_this[:, 0 : unique_sensors.shape[1]]

counts = sort_this[:, unique_sensors.shape[1] :].flatten().astype(int)
counts_sum = counts.sum()

if len(self._percentiles) > 0:
new_x = np.zeros(
(
len(counts),
len(self._id_columns) + 2 + sum(self._counts_sum_std),
)
)
new_x[:, self._id_columns] = unique_sensors
split_x = torch.split(torch.tensor(x), counts.tolist())
split_x = torch.nn.utils.rnn.pad_sequence(
split_x, batch_first=True, padding_value=np.nan
).numpy()
weights = split_x[:, :, charge_index] / np.nansum(
split_x[:, :, charge_index], axis=1
).reshape(-1, 1)
split_x[:, :, charge_index] = split_x[:, :, charge_index].cumsum(
axis=1
)
new_x[:, charge_index] = np.nanmax(
split_x[:, :, charge_index], axis=1
)
split_x[:, :, charge_index] = split_x[:, :, charge_index] / new_x[
:, charge_index
].reshape(-1, 1)
new_x[:, charge_index] = np.log10(new_x[:, charge_index])
new_x[:, self._time_index] = np.nanmin(
split_x[:, :, self._time_index], axis=1
)
if self._counts_sum_std[0]:
new_x[:, -sum(self._counts_sum_std)] = np.log10(counts)
if self._counts_sum_std[1]:
new_x[:, -sum(self._counts_sum_std[1:])] = np.nansum(
split_x[:, :, self._time_index] * weights, axis=1
)
if self._counts_sum_std[2]:
new_x[:, -1] = np.nanstd(
split_x[:, :, self._time_index] * weights, axis=1
)
# Calculate the percentiles timings and charges
selections = np.argmax(
split_x[:, :, charge_index][:, :, np.newaxis]
>= (np.array(self._percentiles) / 100),
axis=1,
)
selections += (np.arange(len(counts)) * split_x.shape[1])[
:, np.newaxis
]
new_x = np.column_stack(
[new_x, split_x[:, :, self._time_index].flatten()[selections]]
)
new_x = np.column_stack(
[new_x, split_x[:, :, charge_index].flatten()[selections]]
)
if len(self._time_sums) > 0:
# calculate the sum of the time and charge for the first n pulses
selections = np.argmax(
(
split_x[:, :, self._time_index]
- split_x[:, 0, self._time_index][:, np.newaxis]
)[:, :, np.newaxis]
>= self._time_sums,
axis=1,
)
selections += (np.arange(len(counts)) * split_x.shape[1])[
:, np.newaxis
]
new_x = np.column_stack(
[
new_x,
split_x[:, :, self._time_index].flatten()[selections],
]
)
new_x = np.column_stack(
[new_x, split_x[:, :, charge_index].flatten()[selections]]
)

new_x[:, self._time_features] = (
new_x[:, self._time_features] / 3.0e4
)
return Data(x=torch.tensor(new_x))

if self._max_activations is not None:
counts_mask = np.argwhere(
counts >= self._max_activations
).flatten()
if counts_mask.size > 0:
counts_vals = counts[counts_mask]
indices = np.concatenate(
[
[ind] * repeats
for ind, repeats in zip(
counts_mask,
np.floor(
counts_vals / self._max_activations
).astype(int),
)
]
)
counts = np.insert(
counts % self._max_activations,
indices,
self._max_activations,
)
counts = counts[counts != 0]
assert (
counts.sum() == counts_sum
), f"Counts sum changed from {counts_sum} to {counts.sum()}"

new_node_col = np.zeros(x.shape[0])
new_node_col[counts.cumsum()[:-1]] = 1
Expand All @@ -305,6 +457,12 @@ def _construct_nodes(self, x: torch.Tensor) -> Data:

return Data(x=torch.tensor(x))

def _rename_features(self, feature_names: List[str]) -> List[str]:
"""Rename features to include time series data."""
new_feature_names = deepcopy(feature_names)

return new_feature_names


class IceMixNodes(NodeDefinition):
"""Calculate ice properties and perform random sampling.
Expand Down

0 comments on commit 019bb3f

Please sign in to comment.