-
Notifications
You must be signed in to change notification settings - Fork 96
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement class to crop pulsemaps to maximum length #648
Comments
Hey @AMHermansen! I think it is a great idea to allow for such functionality in from graphnet.models.graphs import GraphDefinition
graph_definition = GraphDefinition(detector = detector,
node_definition = node_definition,
edge_definition = edge_definition,
sampler = sampler) in the if self.sampler is not None:
subsample_idx = self.sampler(input_features = input_features,
input_feature_names = input_feature_names)
input_features = input_features[subsample_idx,:] That would mean that the sampling would be independent of what users would like to do with the pulses. Here's a quick take on what the sampling module could look like: from typing import List
from abc import abstractmethod
from graphnet.models import Model
from graphnet.utilities.decorators import final
import numpy as np
class Sampler(Model):
"""Base class for sub-sampling rows in single events."""
def __init__(self) -> None:
"""Construct `Sampler`."""
# Base class constructor
super().__init__(name=__name__, class_name=self.__class__.__name__)
@final
def forward(self,
input_features: np.ndarray,
input_feature_names: List[str]) -> List[bool]:
"""Produce subsampling indices."""
mask = self._create_subsample_indices(input_features = input_features,
input_feature_names = input_feature_names)
self._validate_mask(mask = mask,
input_features = input_features)
return mask
def _validate_mask(self,
mask: List[bool],
input_features: np.ndarray) -> None:
"""Check that the output of the custom mask method meets requirements."""
try:
assert isinstance(mask, list)
except AssertionError as e:
self.error(f"Subsampling indices must be a list of bools.
Got {type(mask)}.")
raise e
try:
assert len(mask) == len(input_features)
except AssertionError as e:
self.error(f"Subsampling method did not return a bool for reach row.")
raise e
return
@abstractmethod
def _create_subsample_indices(self,
input_features: np.ndarray,
input_feature_names: List[str]) -> List[int]:
"""Create a list of integers that defines which rows in `input_features are kept.`
Example:
input_features = [[1,2,3],
[5,5,5],
[0,0,1],]
input_feature_names = ['dom_x', 'dom_y', 'dom_z']
Suppose we wrote logic that produced the following
mask = [0,1]
This would mean that the corresponding subsampled rows would be:
input_features = [[1,2,3],
[5,5,5]]"""
raise NotImplementedError So a class RandomMaxSampler(Sampler):
"""Randomly sample events exceeding a maximum length."""
def __init__(self,
max_event_size: int,
seed: int = 42):
"""Randomly sample available pulses if event is larger than `max_event_size`.
Args:
max_event_size: The maximum number of pulses in the event.
Events with more pulses than this will be randomly sampled.
seed: seed used for random sampling. Defaults to 42.
"""
self._max_size = max_event_size
self._seed = seed
def _create_subsample_indices(self,
input_features: np.ndarray,
input_feature_names: List[str]) -> List[int]:
if input_features.shape[0] > self._max_size:
mask = np.random.choice(input_features, self._max_size, seed = self._seed)
else:
mask = np.arange(0, len(input_features))
return mask |
Is your feature request related to a problem? Please describe.
With #558 we now have better control over how a pulsemap is processed. From the Kaggle competition it became apparent that many of the top scoring models simply cropped the number of pulses to some fixed number, to reduce the impact of the
n^2
term from Self-Attention components.While their primary way to select pulses was to simply select the first n pulses, I believe it might be interesting to look into other methods of selecting pulses. (Randomly, sorted by charge, sorted by probability of real signal, farthest point sampling etc.)
Describe the solution you'd like
To avoid having to implement many Node Definitions I think it might make sense to make a common class for all cropped nodes
Such a structure would also allow to easier re-use the copping methods in other node definitions. (Maybe you want to crop after calculating summary nodes per dom, to make sure you do not get an event which triggered 5k doms.
Describe alternatives you've considered
We could of course just implement each cropping algorithm as a subclass of a common
CroppedNodes
class and have the logic restricted to each subclass. But I think the cropping logic is general enough that there is merit to have it as a separate component.The text was updated successfully, but these errors were encountered: