Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PercentileCluster #616

Merged
merged 22 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ graph_definition:
node_definition:
arguments: {}
class_name: NodesAsPulses
node_feature_names: [dom_x, dom_y, dom_z, dom_time, charge, rde, pmt_area]
input_feature_names: [dom_x, dom_y, dom_z, dom_time, charge, rde, pmt_area]
class_name: KNNGraph
pulsemaps:
- SRTTWOfflinePulsesDC
Expand Down
2 changes: 1 addition & 1 deletion configs/datasets/test_data_sqlite.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ graph_definition:
node_definition:
arguments: {}
class_name: NodesAsPulses
node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t]
input_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t]
class_name: KNNGraph
index_column: event_no
loss_weight_column: null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ graph_definition:
node_definition:
arguments: {}
class_name: NodesAsPulses
node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t]
input_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t]
class_name: KNNGraph
pulsemaps:
- total
Expand Down
2 changes: 1 addition & 1 deletion configs/datasets/training_example_data_parquet.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ graph_definition:
node_definition:
arguments: {}
class_name: NodesAsPulses
node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t]
input_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t]
class_name: KNNGraph
pulsemaps:
- total
Expand Down
2 changes: 1 addition & 1 deletion configs/datasets/training_example_data_sqlite.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ graph_definition:
node_definition:
arguments: {}
class_name: NodesAsPulses
node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t]
input_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t]
class_name: KNNGraph
pulsemaps:
- total
Expand Down
2 changes: 1 addition & 1 deletion configs/models/dynedge_PID_classification_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ arguments:
ModelConfig:
arguments: {}
class_name: NodesAsPulses
node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t]
input_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t]
class_name: KNNGraph
optimizer_class: '!class torch.optim.adam Adam'
optimizer_kwargs: {eps: 0.001, lr: 0.001}
Expand Down
2 changes: 1 addition & 1 deletion configs/models/dynedge_position_custom_scaling_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ arguments:
ModelConfig:
arguments: {}
class_name: NodesAsPulses
node_feature_names: null
input_feature_names: null
class_name: KNNGraph
gnn:
ModelConfig:
Expand Down
44 changes: 0 additions & 44 deletions configs/models/dynedge_position_example.yml

This file was deleted.

2 changes: 1 addition & 1 deletion configs/models/example_direction_reconstruction_model.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ arguments:
ModelConfig:
arguments: {}
class_name: NodesAsPulses
node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t]
input_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t]
class_name: KNNGraph
gnn:
ModelConfig:
Expand Down
2 changes: 1 addition & 1 deletion configs/models/example_energy_reconstruction_model.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ arguments:
ModelConfig:
arguments: {}
class_name: NodesAsPulses
node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t]
input_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t]
class_name: KNNGraph
optimizer_class: '!class torch.optim.adam Adam'
optimizer_kwargs: {eps: 0.001, lr: 0.001}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ arguments:
ModelConfig:
arguments: {}
class_name: NodesAsPulses
node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t]
input_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t]
class_name: KNNGraph
optimizer_class: '!class torch.optim.adam Adam'
optimizer_kwargs: {eps: 0.001, lr: 0.001}
Expand Down
2 changes: 1 addition & 1 deletion examples/02_data/04_ensemble_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
detector=IceCubeDeepCore(),
node_definition=NodesAsPulses(),
nb_nearest_neighbours=8,
node_feature_names=features,
input_feature_names=features,
)


Expand Down
2 changes: 1 addition & 1 deletion examples/05_pisa/02_make_pipeline_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def main() -> None:
detector=IceCubeDeepCore(),
node_definition=NodesAsPulses(),
nb_nearest_neighbours=8,
node_feature_names=FEATURES.DEEPCORE,
input_feature_names=FEATURES.DEEPCORE,
)

# Remove `interaction_time` if it exists
Expand Down
4 changes: 2 additions & 2 deletions src/graphnet/data/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,8 +629,8 @@ def _create_graph(
# Construct graph data object
assert self._graph_definition is not None
graph = self._graph_definition(
node_features=node_features,
node_feature_names=self._features[
input_features=node_features,
input_feature_names=self._features[
1:
], # first entry is index column
truth_dicts=truth_dicts,
Expand Down
8 changes: 4 additions & 4 deletions src/graphnet/deployment/i3modules/graphnet_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,12 @@ def _make_graph(
) -> Data: # py-l-i-n-t-:- -d-i-s-able=invalid-name
"""Process Physics I3Frame into graph."""
# Extract features
node_features = self._extract_feature_array_from_frame(frame)
input_features = self._extract_feature_array_from_frame(frame)
# Prepare graph data
if len(node_features) > 0:
if len(input_features) > 0:
data = self._graph_definition(
node_features=node_features,
node_feature_names=self._features,
input_features=input_features,
input_feature_names=self._features,
)
return Batch.from_data_list([data])
else:
Expand Down
78 changes: 45 additions & 33 deletions src/graphnet/models/graphs/graph_definition.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes me a little uneasy that in the forward call the function variable node_feature_names might be different from the class instantiated self._node_feature_names after the _node_definition call on line 147. While I do believe this is as intended it might be quite confusing upon revisiting the code later, maybe consider a renaming.

Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(
detector: Detector,
node_definition: NodeDefinition = NodesAsPulses(),
edge_definition: Optional[EdgeDefinition] = None,
node_feature_names: Optional[List[str]] = None,
input_feature_names: Optional[List[str]] = None,
dtype: Optional[torch.dtype] = torch.float,
perturbation_dict: Optional[Dict[str, float]] = None,
seed: Optional[Union[int, Generator]] = None,
Expand All @@ -44,7 +44,10 @@ def __init__(
detector: The corresponding ´Detector´ representing the data.
node_definition: Definition of nodes. Defaults to NodesAsPulses.
edge_definition: Definition of edges. Defaults to None.
node_feature_names: Names of node feature columns. Defaults to None
input_feature_names: Names of each column in expected input data
that will be built into a graph. If not provided,
it is automatically assumed that all features in `Detector` is
used.
dtype: data type used for node features. e.g. ´torch.float´
perturbation_dict: Dictionary mapping a feature name to a standard
deviation according to which the values for this
Expand All @@ -62,25 +65,30 @@ def __init__(
self._node_definition = node_definition
self._perturbation_dict = perturbation_dict

if node_feature_names is None:
if input_feature_names is None:
# Assume all features in Detector is used.
node_feature_names = list(self._detector.feature_map().keys()) # type: ignore
self._node_feature_names = node_feature_names
input_feature_names = list(self._detector.feature_map().keys()) # type: ignore
self._input_feature_names = input_feature_names

# Set input data column names for node definition
self._node_definition.set_output_feature_names(
self._input_feature_names
)

# Set data type
self.to(dtype)

# Set Input / Output dimensions
self._node_definition.set_number_of_inputs(
node_feature_names=node_feature_names
input_feature_names=input_feature_names
)
self.nb_inputs = len(self._node_feature_names)
self.nb_inputs = len(self._input_feature_names)
self.nb_outputs = self._node_definition.nb_outputs

# Set perturbation_cols if needed
if isinstance(self._perturbation_dict, dict):
self._perturbation_cols = [
self._node_feature_names.index(key)
self._input_feature_names.index(key)
for key in self._perturbation_dict.keys()
]
if seed is not None:
Expand All @@ -97,8 +105,8 @@ def __init__(

def forward( # type: ignore
self,
node_features: np.ndarray,
node_feature_names: List[str],
input_features: np.ndarray,
input_feature_names: List[str],
truth_dicts: Optional[List[Dict[str, Any]]] = None,
custom_label_functions: Optional[Dict[str, Callable[..., Any]]] = None,
loss_weight_column: Optional[str] = None,
Expand All @@ -109,8 +117,8 @@ def forward( # type: ignore
"""Construct graph as ´Data´ object.

Args:
node_features: node features for graph. Shape ´[num_nodes, d]´
node_feature_names: name of each column. Shape ´[,d]´.
input_features: Input features for graph construction. Shape ´[num_rows, d]´
input_feature_names: name of each column. Shape ´[,d]´.
truth_dicts: Dictionary containing truth labels.
custom_label_functions: Custom label functions. See https://github.com/graphnet-team/graphnet/blob/main/GETTING_STARTED.md#adding-custom-truth-labels.
loss_weight_column: Name of column that holds loss weight.
Expand All @@ -126,23 +134,27 @@ def forward( # type: ignore
"""
# Checks
self._validate_input(
node_features=node_features, node_feature_names=node_feature_names
input_features=input_features,
input_feature_names=input_feature_names,
)

# Gaussian perturbation of each column if perturbation dict is given
node_features = self._perturb_input(node_features)
input_features = self._perturb_input(input_features)

# Transform to pytorch tensor
node_features = torch.tensor(node_features, dtype=self.dtype)
input_features = torch.tensor(input_features, dtype=self.dtype)

# Standardize / Scale node features
node_features = self._detector(node_features, node_feature_names)
input_features = self._detector(input_features, input_feature_names)

# Create graph & get new node feature names
graph, node_feature_names = self._node_definition(input_features)

# Create graph
graph = self._node_definition(node_features)
# Enforce dtype
graph.x = graph.x.type(self.dtype)

# Attach number of pulses as static attribute.
graph.n_pulses = torch.tensor(len(node_features), dtype=torch.int32)
graph.n_pulses = torch.tensor(len(input_features), dtype=torch.int32)

# Assign edges
if self._edge_definition is not None:
Expand Down Expand Up @@ -186,40 +198,40 @@ def forward( # type: ignore
return graph

def _validate_input(
self, node_features: np.array, node_feature_names: List[str]
self, input_features: np.array, input_feature_names: List[str]
) -> None:
# node feature matrix dimension check
assert node_features.shape[1] == len(node_feature_names)
assert input_features.shape[1] == len(input_feature_names)

# check that provided features for input is the same that the ´Graph´
# was instantiated with.
assert len(node_feature_names) == len(
self._node_feature_names
), f"""Input features ({node_feature_names}) is not what
assert len(input_feature_names) == len(
self._input_feature_names
), f"""Input features ({input_feature_names}) is not what
{self.__class__.__name__} was instatiated
with ({self._node_feature_names})""" # noqa
for idx in range(len(node_feature_names)):
with ({self._input_feature_names})""" # noqa
for idx in range(len(input_feature_names)):
assert (
node_feature_names[idx] == self._node_feature_names[idx]
input_feature_names[idx] == self._input_feature_names[idx]
), f""" Order of node features in data
are not the same as expected. Got {node_feature_names}
vs. {self._node_feature_names}""" # noqa
are not the same as expected. Got {input_feature_names}
vs. {self._input_feature_names}""" # noqa

def _perturb_input(self, node_features: np.ndarray) -> np.ndarray:
def _perturb_input(self, input_features: np.ndarray) -> np.ndarray:
if isinstance(self._perturbation_dict, dict):
self.warning_once(
f"""Will randomly perturb
{list(self._perturbation_dict.keys())}
using stds {self._perturbation_dict.values()}""" # noqa
)
perturbed_features = self.rng.normal(
loc=node_features[:, self._perturbation_cols],
loc=input_features[:, self._perturbation_cols],
scale=np.array(
list(self._perturbation_dict.values()), dtype=float
),
)
node_features[:, self._perturbation_cols] = perturbed_features
return node_features
input_features[:, self._perturbation_cols] = perturbed_features
return input_features

def _add_loss_weights(
self,
Expand Down
6 changes: 3 additions & 3 deletions src/graphnet/models/graphs/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(
self,
detector: Detector,
node_definition: NodeDefinition = NodesAsPulses(),
node_feature_names: Optional[List[str]] = None,
input_feature_names: Optional[List[str]] = None,
dtype: Optional[torch.dtype] = torch.float,
perturbation_dict: Optional[Dict[str, float]] = None,
seed: Optional[Union[int, Generator]] = None,
Expand All @@ -29,7 +29,7 @@ def __init__(
Args:
detector: Detector that represents your data.
node_definition: Definition of nodes in the graph.
node_feature_names: Name of node features.
input_feature_names: Name of input feature columns.
dtype: data type for node features.
perturbation_dict: Dictionary mapping a feature name to a standard
deviation according to which the values for this
Expand All @@ -50,7 +50,7 @@ def __init__(
columns=columns,
),
dtype=dtype,
node_feature_names=node_feature_names,
input_feature_names=input_feature_names,
perturbation_dict=perturbation_dict,
seed=seed,
)
2 changes: 1 addition & 1 deletion src/graphnet/models/graphs/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
and their features.
"""

from .nodes import NodeDefinition, NodesAsPulses
from .nodes import NodeDefinition, NodesAsPulses, PercentileClusters
Loading