Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PercentileCluster #616

Merged
merged 22 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/graphnet/models/graphs/graph_definition.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes me a little uneasy that in the forward call the function variable node_feature_names might be different from the class instantiated self._node_feature_names after the _node_definition call on line 147. While I do believe this is as intended it might be quite confusing upon revisiting the code later, maybe consider a renaming.

Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ def __init__(
node_feature_names = list(self._detector.feature_map().keys()) # type: ignore
self._node_feature_names = node_feature_names

# Set input data column names for node definition
self._node_definition.set_output_feature_names(
self._node_feature_names
)

# Set data type
self.to(dtype)

Expand Down Expand Up @@ -138,8 +143,11 @@ def forward( # type: ignore
# Standardize / Scale node features
node_features = self._detector(node_features, node_feature_names)

# Create graph
graph = self._node_definition(node_features)
# Create graph & get new node feature names
graph, node_feature_names = self._node_definition(node_features)

# Enforce dtype
graph.x = graph.x.type(self.dtype)

# Attach number of pulses as static attribute.
graph.n_pulses = torch.tensor(len(node_features), dtype=torch.int32)
Expand Down
2 changes: 1 addition & 1 deletion src/graphnet/models/graphs/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
and their features.
"""

from .nodes import NodeDefinition, NodesAsPulses
from .nodes import NodeDefinition, NodesAsPulses, PercentileClusters
159 changes: 151 additions & 8 deletions src/graphnet/models/graphs/nodes/nodes.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,68 @@
"""Class(es) for building/connecting graphs."""

from typing import List
from typing import List, Tuple, Optional
from abc import abstractmethod

import torch
from torch_geometric.data import Data

from graphnet.utilities.decorators import final
from graphnet.models import Model
from graphnet.models.graphs.utils import (
cluster_summarize_with_percentiles,
identify_indices,
)
from copy import deepcopy


class NodeDefinition(Model): # pylint: disable=too-few-public-methods
"""Base class for graph building."""

def __init__(self) -> None:
def __init__(
self, input_feature_names: Optional[List[str]] = None
) -> None:
"""Construct `Detector`."""
# Base class constructor
super().__init__(name=__name__, class_name=self.__class__.__name__)
if input_feature_names is not None:
self.set_output_feature_names(
input_feature_names=input_feature_names
)

@final
def forward(self, x: torch.tensor) -> Data:
def forward(self, x: torch.tensor) -> Tuple[Data, List[str]]:
"""Construct nodes from raw node features.

Args:
x: standardized node features with shape ´[num_pulses, d]´,
where ´d´ is the number of node features.
node_feature_names: list of names for each column in ´x´.

Returns:
graph: a graph without edges
new_features_name: List of new feature names.
"""
graph = self._construct_nodes(x)
return graph
graph = self._construct_nodes(x=x)
try:
self._output_feature_names
except AttributeError as e:
self.error(
f"""{self.__class__.__name__} was instantiated without
`input_feature_names` and it was not set prior to this
forward call. If you are using this class outside a
`GraphDefinition`, please instatiate
with `input_feature_names`."""
) # noqa
raise e
return graph, self._output_feature_names

@property
def nb_outputs(self) -> int:
"""Return number of output features.

This the default, but may be overridden by specific inheriting classes.
"""
return self.nb_inputs
return len(self._output_feature_names)

@final
def set_number_of_inputs(self, node_feature_names: List[str]) -> None:
Expand All @@ -50,21 +74,140 @@ def set_number_of_inputs(self, node_feature_names: List[str]) -> None:
assert isinstance(node_feature_names, list)
self.nb_inputs = len(node_feature_names)

@final
def set_output_feature_names(self, input_feature_names: List[str]) -> None:
"""Set output features names as a member variable.

Args:
input_feature_names: List of column names of the input to the
node definition.
"""
self._output_feature_names = self._define_output_feature_names(
input_feature_names
)

@abstractmethod
def _define_output_feature_names(
self, input_feature_names: List[str]
) -> List[str]:
"""Construct names of output columns.

Args:
input_feature_names: List of column names for the input data.

Returns:
A list of column names for each column in
the node definition output.
"""

@abstractmethod
def _construct_nodes(self, x: torch.tensor) -> Data:
def _construct_nodes(self, x: torch.tensor) -> Tuple[Data, List[str]]:
"""Construct nodes from raw node features ´x´.

Args:
x: standardized node features with shape ´[num_pulses, d]´,
where ´d´ is the number of node features.
feature_names: List of names for reach column in `x`. Identical
order of appearance. Length `d`.

Returns:
graph: graph without edges.
new_node_features: A list of node features names.
"""


class NodesAsPulses(NodeDefinition):
"""Represent each measured pulse of Cherenkov Radiation as a node."""

def _construct_nodes(self, x: torch.Tensor) -> Data:
def _define_output_feature_names(
self, input_feature_names: List[str]
) -> List[str]:
return input_feature_names

def _construct_nodes(self, x: torch.Tensor) -> Tuple[Data, List[str]]:
return Data(x=x)


class PercentileClusters(NodeDefinition):
"""Represent nodes as clusters with percentile summary node features.

If `cluster_on` is set to the xyz coordinates of DOMs
e.g. `cluster_on = ['dom_x', 'dom_y', 'dom_z']`, each node will be a
unique DOM and the pulse information (charge, time) is summarized using
percentiles.
"""

def __init__(
self,
cluster_on: List[str],
percentiles: List[int],
add_counts: bool = True,
input_feature_names: Optional[List[str]] = None,
) -> None:
"""Construct `PercentileClusters`.

Args:
cluster_on: Names of features to create clusters from.
percentiles: List of percentiles. E.g. `[10, 50, 90]`.
add_counts: If True, number of duplicates is added to output array.
input_feature_names: (Optional) column names for input features.
"""
self._cluster_on = cluster_on
self._percentiles = percentiles
self._add_counts = add_counts
# Base class constructor
super().__init__(input_feature_names=input_feature_names)

def _define_output_feature_names(
self, input_feature_names: List[str]
) -> List[str]:
(
cluster_idx,
summ_idx,
new_feature_names,
) = self._get_indices_and_feature_names(
input_feature_names, self._add_counts
)
self._cluster_indices = cluster_idx
self._summarization_indices = summ_idx
return new_feature_names

def _get_indices_and_feature_names(
self,
feature_names: List[str],
add_counts: bool,
) -> Tuple[List[int], List[int], List[str]]:
cluster_idx, summ_idx, summ_names = identify_indices(
feature_names, self._cluster_on
)
new_feature_names = deepcopy(self._cluster_on)
for feature in summ_names:
for pct in self._percentiles:
new_feature_names.append(f"{feature}_pct{pct}")
if add_counts:
# add "counts" as the last feature
new_feature_names.append("counts")
return cluster_idx, summ_idx, new_feature_names

def _construct_nodes(self, x: torch.Tensor) -> Data:
# Cast to Numpy
x = x.numpy()
# Construct clusters with percentile-summarized features
if hasattr(self, "_summarization_indices"):
array = cluster_summarize_with_percentiles(
x=x,
summarization_indices=self._summarization_indices,
cluster_indices=self._cluster_indices,
percentiles=self._percentiles,
add_counts=self._add_counts,
)
else:
self.error(
f"""{self.__class__.__name__} was not instatiated with
`input_feature_names` and has not been set later.
Please instantiate this class with `input_feature_names`
if you're using it outside `GraphDefinition`."""
) # noqa
raise AttributeError

return Data(x=torch.tensor(array))
Loading