diff --git a/src/graphnet/models/graphs/nodes/__init__.py b/src/graphnet/models/graphs/nodes/__init__.py index 2eebef4be..540752e64 100644 --- a/src/graphnet/models/graphs/nodes/__init__.py +++ b/src/graphnet/models/graphs/nodes/__init__.py @@ -10,4 +10,5 @@ NodesAsPulses, NodesAsDOMTimeSeries, PercentileClusters, + NodesAsRefoldedDOM, ) diff --git a/src/graphnet/models/graphs/nodes/nodes.py b/src/graphnet/models/graphs/nodes/nodes.py index e5a226168..30ec2fa04 100644 --- a/src/graphnet/models/graphs/nodes/nodes.py +++ b/src/graphnet/models/graphs/nodes/nodes.py @@ -13,8 +13,13 @@ from graphnet.models.graphs.utils import ( cluster_summarize_with_percentiles, identify_indices, + lex_sort, + pulse_template, + spe_atwd_old, + pulseseries_to_wf, ) from copy import deepcopy +import numpy as np class NodeDefinition(Model): # pylint: disable=too-few-public-methods @@ -145,6 +150,7 @@ def __init__( id_columns: List[str] = ["dom_x", "dom_y", "dom_z"], time_index: str = "dom_time", charge_index: str = "charge", + max_activations: Optional[int] = None, ) -> None: """Construct nodes as DOMs with time series of pulses. @@ -153,6 +159,7 @@ def __init__( id_columns: List of columns that uniquely identify a DOM. time_index: Name of the column that contains the time index. charge_index: Name of the column that contains the charge. + max_activations: Maximum number of activations to keep per DOM. """ assert isinstance(keys, type(id_columns)) @@ -160,9 +167,14 @@ def __init__( self._id_columns = [self._keys.index(key) for key in id_columns] self._time_index = self._keys.index(time_index) self._charge_index = self._keys.index(charge_index) - + self._max_activations = max_activations super().__init__() + def _define_output_feature_names( + self, input_feature_names: List[str] + ) -> List[str]: + return input_feature_names + def _sort_by_n_pulses( self, time_series: List[torch.Tensor] ) -> torch.Tensor: @@ -177,42 +189,64 @@ def _sort_by_n_pulses( def _construct_nodes(self, x: torch.Tensor) -> Data: """Construct nodes from raw node features ´x´.""" + # cast to numpy + x = x.numpy() # sort by time - x = x[x[:, self._time_index].sort().indices] + x = x[np.argsort(x[:, self._time_index])] # undo log10 scaling since we want to sum up charge - x[:, self._charge_index] = torch.pow(10, x[:, self._charge_index]) + x[:, self._charge_index] = np.power(10, x[:, self._charge_index]) # shift time to positive values with a small offset - x[:, self._time_index] += -min(x[:, self._time_index]) + x[:, self._time_index] += -np.min(x[:, self._time_index]) # Group pulses on the same DOM - dom_index = _group_identical(x[:, self._id_columns]) + x = lex_sort(x, self._id_columns) - val, ind = dom_index.sort(stable=True) - counts = torch.concat( - [torch.tensor([0]), val.bincount().cumsum(-1)[:-1]] + unique_sensors, counts = np.unique( + x[:, self._id_columns], return_counts=True, axis=0 + ) + # sort DOMs and pulse-counts + sort_this = np.concatenate( + [unique_sensors, counts.reshape(-1, 1)], axis=1 + ) + sort_this = lex_sort(x=sort_this, cluster_columns=self._id_columns) + unique_sensors = sort_this[:, 0 : unique_sensors.shape[1]] + counts = sort_this[:, unique_sensors.shape[1] :].flatten().astype(int) + + time_series = np.split( + x[:, [self._charge_index, self._time_index]], counts.cumsum()[:-1] ) - unique_doms = x[:, self._id_columns + [self._time_index]][ind][counts] - time_series = [ - x[dom_index == index_key][ - :, [self._charge_index, self._time_index] - ] - for index_key in dom_index.unique() - ] # add total charge to unique dom features and apply log10 scaling - charge = torch.stack( - [torch.log10(image[:, 0].sum()) for image in time_series] + time_charge = np.stack( + [ + (image[0, 1], np.arcsinh(5 * image[:, 0].sum()) / 5) + for image in time_series + ] ) - x = torch.column_stack([unique_doms, charge]) + x = np.column_stack([unique_sensors, time_charge]) + + # time_series, sort_ind = self._sort_by_n_pulses(time_series) + # cutter = torch.tensor([len(ts) for ts in time_series]) + # x = x[sort_ind] + if self._max_activations is not None: + counts[counts > self._max_activations] = self._max_activations + time_series = [ + image[: self._max_activations] for image in time_series + ] + time_series = np.concatenate(time_series) + # apply inverse hyperbolic sine to charge values (handles zeros unlike log scaling) + time_series[:, 0] = np.arcsinh(5 * time_series[:, 0]) / 5 - time_series, sort_ind = self._sort_by_n_pulses(time_series) - cutter = torch.tensor([len(ts) for ts in time_series]) - x = x[sort_ind] - time_series = torch.concat(time_series) - return Data(x=x, time_series=time_series, cutter=cutter, n_doms=len(x)) + return Data( + x=torch.tensor(x), + time_series=torch.tensor(time_series), + cutter=torch.tensor(counts), + n_doms=len(x), + ) @torch.jit.script def log10powsum(tensor: torch.Tensor) -> torch.Tensor: + """Convert to power of 10 and sum and convert back to log10.""" return torch.log10(torch.sum(torch.pow(10, tensor))) @@ -236,6 +270,15 @@ def pad_charge(tensor: torch.Tensor, time_range: torch.Tensor) -> torch.Tensor: @torch.jit.script def return_closest(tt: torch.Tensor, tr: torch.Tensor) -> torch.Tensor: + """Return closest index of time range to time tensor. + + Args: + tt: pulse time of shape (num_pulses,1). + tr: time range of shape (len(time_range),1). + + Returns: + tensor: binned time index of shape (num_pulses,1). + """ return torch.argmin(abs(tt - tr.unsqueeze(1)), dim=0) @@ -243,7 +286,7 @@ def return_closest(tt: torch.Tensor, tr: torch.Tensor) -> torch.Tensor: def sum_charge( x: torch.Tensor, dom_index: torch.Tensor, - id_columns: list[int], + id_columns: List[int], time_index: int, charge_index: int, ) -> torch.Tensor: @@ -273,11 +316,11 @@ def sum_charge( def create_time_series( x: torch.Tensor, dom_index: torch.Tensor, - id_columns: list[int], + id_columns: List[int], time_index: int, charge_index: int, time_range: torch.Tensor, -) -> list[torch.Tensor]: +) -> List[torch.Tensor]: """Create time series. Args: @@ -313,33 +356,44 @@ def create_time_series( return [unique_doms, time_series] -@torch.jit.script -def get_unique_dom_features( - x: torch.Tensor, - dom_index: torch.Tensor, - id_columns: list[int], - time_index: int, - charge_index: int, - time_range: torch.Tensor, -) -> torch.Tensor: - unique_doms = [] - for index_key in torch.unique(dom_index): - unique_doms.append( - torch.hstack( - [ - x[dom_index == index_key][0][id_columns], - x[dom_index == index_key][0][time_index], - x[dom_index == index_key][:, charge_index].max(), - ] - ) - ) - unique_doms = torch.vstack(unique_doms) - return unique_doms +# @torch.jit.script +# def get_unique_dom_features( +# x: torch.Tensor, +# dom_index: torch.Tensor, +# id_columns: List[int], +# time_index: int, +# charge_index: int, +# time_range: torch.Tensor, +# ) -> torch.Tensor: +# """Get unique dom features. + +# Args: +# tensor: padded charge tensor of shape )""" +# unique_doms = [] +# for index_key in torch.unique(dom_index): +# unique_doms.append( +# torch.hstack( +# [ +# x[dom_index == index_key][0][id_columns], +# x[dom_index == index_key][0][time_index], +# x[dom_index == index_key][:, charge_index].max(), +# ] +# ) +# ) +# unique_doms = torch.vstack(unique_doms) +# return unique_doms def create_sparse_charge_series( dom_index: torch.Tensor, time_index: torch.Tensor, values: torch.Tensor ) -> torch.Tensor: + """Create sparse charge series. + + Args: + dom_index: indexing of doms for grouping. + time_index: indexing of time bins for grouping. + values: values to be summed up. + """ i = torch.vstack([dom_index, time_index]) v = values s = torch.sparse_coo_tensor( @@ -386,6 +440,7 @@ def __init__( id_columns: List of columns that uniquely identify a DOM. time_index: Name of the column that contains the time index. charge_index: Name of the column that contains the charge. + granularity: Number of time bins to use. """ assert isinstance(keys, type(id_columns)) @@ -546,3 +601,141 @@ def _construct_nodes(self, x: torch.Tensor) -> Data: raise AttributeError return Data(x=torch.tensor(array)) + + +class NodesAsRefoldedDOM(NodeDefinition): + """Represent each node as DOM a with a time series of pulses.""" + + def __init__( + self, + keys: List[str] = [ + "dom_x", + "dom_y", + "dom_z", + "dom_time", + "charge", + "atwd", + "width", + ], + id_columns: List[str] = ["dom_x", "dom_y", "dom_z"], + time_index: str = "dom_time", + charge_index: str = "charge", + atwd_index: str = "atwd", + width_index: str = "width", + max_activations: Optional[int] = None, + ) -> None: + """Construct nodes as DOMs with time series of pulses. + + Args: + keys: List of node feature names. + id_columns: List of columns that uniquely identify a DOM. + time_index: Name of the column that contains the time index. + charge_index: Name of the column that contains the charge. + atwd_index: Name of the column that contains the atwd. + width_index: Name of the column that contains the width. + max_activations: Maximum number of activations to keep per DOM. + """ + assert isinstance(keys, type(id_columns)) + + self._keys = keys + self._id_columns = [self._keys.index(key) for key in id_columns] + self._time_index = self._keys.index(time_index) + self._charge_index = self._keys.index(charge_index) + self._atwd_index = self._keys.index(atwd_index) + self._width_index = self._keys.index(width_index) + self._max_activations = max_activations + super().__init__() + + def _define_output_feature_names( + self, input_feature_names: List[str] + ) -> List[str]: + return input_feature_names + ["n_pulses"] + + def _sort_by_n_pulses( + self, time_series: List[torch.Tensor] + ) -> torch.Tensor: + """Sort time series by number of pulses.""" + sort_index = ( + torch.tensor([len(ts) for ts in time_series]) + .sort(descending=True) + .indices + ) + sorted_time_series = [time_series[i] for i in sort_index] + return sorted_time_series, sort_index + + def _construct_nodes(self, x: torch.Tensor) -> Data: + """Construct nodes from raw node features ´x´.""" + # cast to numpy + x = x.numpy() + # sort by time + x = x[np.argsort(x[:, self._time_index])] + # undo log10 scaling since we want to sum up charge + # x[:, self._charge_index] = np.power(10, x[:, self._charge_index]) + # shift time to positive values with a small offset + x[:, self._time_index] += -np.min(x[:, self._time_index]) + # find largest charge bin in time + _counts, edges_ = np.histogram( + x[:, self._time_index], bins=1000, weights=x[:, self._charge_index] + ) + largest_charge_time = ( + 0.5 * (edges_[:-1] + edges_[1:])[np.argmax(_counts)] + ) + + # weigthed_mean_time = np.average(x[:, self._time_index],weights=x[:, self._charge_index]) + # Group pulses on the same DOM + x = lex_sort(x, self._id_columns) + + unique_sensors, counts = np.unique( + x[:, self._id_columns], return_counts=True, axis=0 + ) + # sort DOMs and pulse-counts + sort_this = np.concatenate( + [unique_sensors, counts.reshape(-1, 1)], axis=1 + ) + sort_this = lex_sort(x=sort_this, cluster_columns=self._id_columns) + unique_sensors = sort_this[:, 0 : unique_sensors.shape[1]] + counts = sort_this[:, unique_sensors.shape[1] :].flatten().astype(int) + + time_series = np.split( + x[ + :, + [ + self._charge_index, + self._time_index, + self._atwd_index, + self._width_index, + ], + ], + counts.cumsum()[:-1], + ) + + # add total charge to unique dom features and apply log10 scaling + time_charge_n = np.stack( + [ + (image[0, 1], np.log10(image[:, 0].sum()), len(image)) + for image in time_series + ] + ) + x = np.column_stack([unique_sensors, time_charge_n]) + + # perform refolding + ts = [] + time_range = np.linspace( + largest_charge_time - 1600, largest_charge_time + 3200, 1000 + ) + for ind, time_series in enumerate(time_series): + charge_series = pulseseries_to_wf( + time_series, spe_atwd_old, time_range, 1 + ).astype(np.float16) + ts.append(charge_series) + ts = np.array(ts) + # time_series = np.concatenate(time_series) + x = np.column_stack([x, ts]) + + return Data( + x=torch.tensor(x), + time_series=torch.tensor(ts), + time_range=time_range, + cutter=torch.tensor(counts), + n_doms=len(x), + ) diff --git a/src/graphnet/models/graphs/utils.py b/src/graphnet/models/graphs/utils.py index ccd861783..114a35a3d 100644 --- a/src/graphnet/models/graphs/utils.py +++ b/src/graphnet/models/graphs/utils.py @@ -1,6 +1,6 @@ """Utility functions for construction of graphs.""" -from typing import List, Tuple +from typing import List, Tuple, Callable import numpy as np @@ -158,3 +158,50 @@ def cluster_summarize_with_percentiles( ) return array + + +class pulse_template: + """SPE shape for DOMs.""" + + def __init__(self, args: List[float]): + """Initialize SPE shape for DOMs. + + Args: + args: List of parameters for the SPE shape. (c, x0, b1, b2) + """ + self.args = args + + def __call__(self, time: np.ndarray) -> np.ndarray: + """SPE shape implementation for DOMs.""" + c, x0, b1, b2 = self.args + t = time - 11.5 + func = c * (np.exp(-(t - x0) / b1) + np.exp((t - x0) / b2)) ** -8 + return func + + +def spe_atwd_old(time: np.ndarray) -> np.ndarray: + """SPE shape for ATWD.""" + t = time - 11.5 # causality + c = 15.47 / 13.292860653948139 + x0 = -3.929 - 5 + b1 = 4.7 + b2 = 39.0 + return c * (np.exp(-(abs(t - x0)) / b1) + np.exp((abs(t - x0)) / b2)) ** -8 + + +def pulseseries_to_wf( + pulses: np.ndarray, + template: Callable, + times: np.ndarray, + norm: float = 1.0, +) -> np.ndarray: + """Convolve pulses with SPE shape.""" + wf = ( + np.sum( + pulses[:, 0] + * template(times[..., None] - pulses[:, 1][None, ...]), + axis=1, + ) + * norm + ) + return wf diff --git a/src/graphnet/models/rnn/__init__.py b/src/graphnet/models/rnn/__init__.py index 63e29d3b2..830fbd56a 100644 --- a/src/graphnet/models/rnn/__init__.py +++ b/src/graphnet/models/rnn/__init__.py @@ -2,3 +2,4 @@ from .node_rnn import Node_RNN from .dom_window_rnn import Dom_Window_RNN +from .node_transformer import Node_Transformer