Skip to content

Commit

Permalink
Merge pull request graphnet-team#646 from RasmusOrsoe/rename-_gnn-to-…
Browse files Browse the repository at this point in the history
…_architecture

Change argument `gnn` to `architecture`
  • Loading branch information
RasmusOrsoe authored Dec 15, 2023
2 parents d06882b + 9871fb8 commit 62c5340
Show file tree
Hide file tree
Showing 25 changed files with 91 additions and 33 deletions.
8 changes: 4 additions & 4 deletions GETTING_STARTED.md
Original file line number Diff line number Diff line change
Expand Up @@ -427,20 +427,20 @@ graph_definition = KNNGraph(
node_definition=NodesAsPulses(),
nb_nearest_neighbours=8,
)
gnn = DynEdge(
architecture = DynEdge(
nb_inputs=detector.nb_outputs,
global_pooling_schemes=["min", "max", "mean"],
)
task = ZenithReconstructionWithKappa(
hidden_size=gnn.nb_outputs,
hidden_size=architecture.nb_outputs,
target_labels="injection_zenith",
loss_function=VonMisesFisher2DLoss(),
)
# Construct the Model
model = StandardModel(
graph_definition=graph_definition,
gnn=gnn,
architecture=architecture,
tasks=[task],
)
```
Expand Down Expand Up @@ -479,7 +479,7 @@ You can find several pre-defined `ModelConfig`'s under `graphnet/configs/models`

```yml
arguments:
gnn:
architecture:
ModelConfig:
arguments:
add_global_variables_after_pooling: false
Expand Down
2 changes: 1 addition & 1 deletion configs/models/dynedge_PID_classification_example.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
arguments:
gnn:
backbone:
ModelConfig:
arguments:
add_global_variables_after_pooling: false
Expand Down
2 changes: 1 addition & 1 deletion configs/models/dynedge_position_custom_scaling_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ arguments:
class_name: NodesAsPulses
input_feature_names: null
class_name: KNNGraph
gnn:
backbone:
ModelConfig:
arguments:
add_global_variables_after_pooling: false
Expand Down
2 changes: 1 addition & 1 deletion configs/models/example_direction_reconstruction_model.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ arguments:
class_name: NodesAsPulses
input_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t]
class_name: KNNGraph
gnn:
backbone:
ModelConfig:
arguments:
add_global_variables_after_pooling: false
Expand Down
2 changes: 1 addition & 1 deletion configs/models/example_energy_reconstruction_model.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
arguments:
gnn:
backbone:
ModelConfig:
arguments:
add_global_variables_after_pooling: false
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
arguments:
gnn:
backbone:
ModelConfig:
arguments:
add_global_variables_after_pooling: false
Expand Down
6 changes: 3 additions & 3 deletions examples/04_training/01_train_dynedge.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,20 +96,20 @@ def main(

# Building model

gnn = DynEdge(
backbone = DynEdge(
nb_inputs=graph_definition.nb_outputs,
global_pooling_schemes=["min", "max", "mean", "sum"],
)
task = EnergyReconstruction(
hidden_size=gnn.nb_outputs,
hidden_size=backbone.nb_outputs,
target_labels=config["target"],
loss_function=LogCoshLoss(),
transform_prediction_and_target=lambda x: torch.log10(x),
transform_inference=lambda x: torch.pow(10, x),
)
model = StandardModel(
graph_definition=graph_definition,
gnn=gnn,
backbone=backbone,
tasks=[task],
optimizer_class=Adam,
optimizer_kwargs={"lr": 1e-03, "eps": 1e-03},
Expand Down
6 changes: 3 additions & 3 deletions examples/04_training/02_train_tito_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def main(
)

# Building model
gnn = DynEdgeTITO(
backbone = DynEdgeTITO(
nb_inputs=graph_definition.nb_outputs,
features_subset=[0, 1, 2, 3],
dyntrans_layer_sizes=[(256, 256), (256, 256), (256, 256), (256, 256)],
Expand All @@ -110,13 +110,13 @@ def main(
use_post_processing_layers=True,
)
task = DirectionReconstructionWithKappa(
hidden_size=gnn.nb_outputs,
hidden_size=backbone.nb_outputs,
target_labels=config["target"],
loss_function=VonMisesFisher3DLoss(),
)
model = StandardModel(
graph_definition=graph_definition,
gnn=gnn,
backbone=backbone,
tasks=[task],
optimizer_class=Adam,
optimizer_kwargs={"lr": 1e-03, "eps": 1e-03},
Expand Down
2 changes: 1 addition & 1 deletion src/graphnet/deployment/i3modules/graphnet_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def _construct_prediction_map(
Arguments:
frame: I3Frame (physics)
predictions: predictions from GNN
predictions: predictions from Model.
Returns:
predictions_map: a pulsemap from predictions
Expand Down
11 changes: 11 additions & 0 deletions src/graphnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
from torch_geometric.data import Data

from graphnet.utilities.deprecation_tools import rename_state_dict_entries
from graphnet.utilities.logging import Logger
from graphnet.utilities.config import (
Configurable,
Expand Down Expand Up @@ -63,6 +64,16 @@ def load_state_dict(
state_dict = torch.load(path)
else:
state_dict = path

# DEPRECATION UTILITY: REMOVE AT 2.0 LAUNCH
# See https://github.com/graphnet-team/graphnet/issues/647
state_dict, state_dict_altered = rename_state_dict_entries(
state_dict=state_dict, old_phrase="_gnn", new_phrase="backbone"
)
if state_dict_altered:
self.warning(
"DeprecationWarning: State dicts with `_gnn` entries will be deprecated in GraphNeT 2.0"
)
return super().load_state_dict(state_dict, **kargs)

@classmethod
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
arguments:
gnn:
backbone:
ModelConfig:
arguments: {add_global_variables_after_pooling: false, dynedge_layer_sizes: null,
features_subset: null, global_pooling_schemes: null, nb_inputs: 14, nb_neighbours: 8,
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
arguments:
gnn:
backbone:
ModelConfig:
arguments:
add_global_variables_after_pooling: false
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
arguments:
gnn:
backbone:
ModelConfig:
arguments:
add_global_variables_after_pooling: false
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
arguments:
gnn:
backbone:
ModelConfig:
arguments:
add_global_variables_after_pooling: false
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
arguments:
gnn:
backbone:
ModelConfig:
arguments:
add_global_variables_after_pooling: false
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
arguments:
gnn:
backbone:
ModelConfig:
arguments:
add_global_variables_after_pooling: false
Expand Down
Binary file not shown.
28 changes: 21 additions & 7 deletions src/graphnet/models/standard_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@ class StandardModel(Model):
"""Main class for standard models in graphnet.
This class chains together the different elements of a complete GNN- based
model (detector read-in, GNN architecture, and task-specific read-outs).
model (detector read-in, GNN backbone, and task-specific read-outs).
"""

def __init__(
self,
*,
graph_definition: GraphDefinition,
gnn: GNN,
backbone: GNN = None,
gnn: Optional[GNN] = None,
tasks: Union[StandardLearnedTask, List[StandardLearnedTask]],
optimizer_class: Type[torch.optim.Optimizer] = Adam,
optimizer_kwargs: Optional[Dict] = None,
Expand All @@ -50,11 +51,24 @@ def __init__(
assert isinstance(tasks, (list, tuple))
assert all(isinstance(task, StandardLearnedTask) for task in tasks)
assert isinstance(graph_definition, GraphDefinition)
assert isinstance(gnn, GNN)

# deprecation warnings
if (backbone is None) & (gnn is not None):
backbone = gnn
# Code continues after warning
self.warning(
"""DeprecationWarning: Argument `gnn` will be deprecated in GraphNeT 2.0. Please use `backbone` instead."""
)
elif (backbone is None) & (gnn is None):
# Code stops
raise TypeError(
"__init__() missing 1 required keyword-only argument: 'backbone'"
)
assert isinstance(backbone, GNN)

# Member variable(s)
self._graph_definition = graph_definition
self._gnn = gnn
self.backbone = backbone
self._tasks = ModuleList(tasks)
self._optimizer_class = optimizer_class
self._optimizer_kwargs = optimizer_kwargs or dict()
Expand All @@ -63,7 +77,7 @@ def __init__(
self._scheduler_config = scheduler_config or dict()

# set dtype of GNN from graph_definition
self._gnn.type(self._graph_definition._dtype)
self.backbone.type(self._graph_definition._dtype)

@staticmethod
def _construct_trainer(
Expand Down Expand Up @@ -226,7 +240,7 @@ def forward(
data = [data]
x_list = []
for d in data:
x = self._gnn(d)
x = self.backbone(d)
x_list.append(x)
x = torch.cat(x_list, dim=0)

Expand Down Expand Up @@ -467,7 +481,7 @@ def _create_default_callbacks(
save_top_k=1,
monitor="val_loss",
mode="min",
filename=f"{self._gnn.__class__.__name__}"
filename=f"{self.backbone.__class__.__name__}"
+ "-{epoch}-{val_loss:.2f}-{train_loss:.2f}",
)
)
Expand Down
33 changes: 33 additions & 0 deletions src/graphnet/utilities/deprecation_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Utility functions for handling deprecation transitions."""
from typing import Dict, Tuple
from copy import deepcopy
from torch import Tensor


def rename_state_dict_entries(
state_dict: Dict[str, Tensor], old_phrase: str, new_phrase: str
) -> Tuple[Dict[str, Tensor], bool]:
"""Replace `old_phrase` in state dict fields with `new_phrase`.
Returned state dict is a deepcopy of the input.
Args:
state_dict: The state dict whos fields need renaming.
old_phrase: Phrase in state dict field that needs to be replaced.
new_phrase: Phrase to add in place of `old_phrase` in state dict.
"""
assert isinstance(old_phrase, str)
assert isinstance(new_phrase, str)

# Make a carbon-copy
new_state_dict = deepcopy(state_dict)

# Replace old entries in copy
state_dict_altered = False
for key in state_dict.keys():
if old_phrase in key:
new_key = key.replace(old_phrase, new_phrase)
new_state_dict[new_key] = new_state_dict.pop(key)
state_dict_altered = True

return new_state_dict, state_dict_altered
8 changes: 4 additions & 4 deletions tests/utilities/test_model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,19 @@ def test_complete_model_config(path: str = "/tmp/complete_model.yml") -> None:
nb_nearest_neighbours=8,
input_feature_names=FEATURES.DEEPCORE,
)
gnn = DynEdge(
backbone = DynEdge(
nb_inputs=graph_definition.nb_outputs,
global_pooling_schemes=["min", "max", "mean", "sum"],
)
task = EnergyReconstruction(
hidden_size=gnn.nb_outputs,
hidden_size=backbone.nb_outputs,
target_labels="energy",
loss_function=LogCoshLoss(),
transform_prediction_and_target=lambda x: torch.log10(x),
)
model = StandardModel(
graph_definition=graph_definition,
gnn=gnn,
backbone=backbone,
tasks=[task],
optimizer_class=Adam,
optimizer_kwargs={"lr": 1e-03, "eps": 1e-03},
Expand Down Expand Up @@ -102,7 +102,7 @@ def test_complete_model_config(path: str = "/tmp/complete_model.yml") -> None:
"scheduler_class",
"scheduler_kwargs",
"scheduler_config",
"gnn",
"backbone",
]:
assert (
constructed_model.config.arguments[key]
Expand Down

0 comments on commit 62c5340

Please sign in to comment.