Skip to content

Commit

Permalink
Merge branch 'main' into resnet_gnn
Browse files Browse the repository at this point in the history
  • Loading branch information
allaffa authored Apr 14, 2023
2 parents f707c2a + c78e7ac commit 1a42e26
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 22 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ There are many options for HydraGNN; the dataset and model type are particularly
important:
- `["Verbosity"]["level"]`: `0`, `1`, `2`, `3`, `4`
- `["Dataset"]["name"]`: `CuAu_32atoms`, `FePt_32atoms`, `FeSi_1024atoms`
- `["NeuralNetwork"]["Architecture"]["model_type"]`: `PNA`, `MFC`, `GIN`, `GAT`, `CGCNN`
- `["NeuralNetwork"]["Architecture"]["model_type"]`: `PNA`, `MFC`, `GIN`, `GAT`, `CGCNN`, `SchNet`
### Citations
"HydraGNN: Distributed PyTorch implementation of multi-headed graph convolutional neural networks", Copyright ID#: 81929619
Expand Down
32 changes: 17 additions & 15 deletions hydragnn/models/Base.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,12 @@ def _init_conv(self):
self.convs.append(conv)
self.batch_norms.append(BatchNorm(self.hidden_dim))

def _conv_args(self, data):
conv_args = {"edge_index": data.edge_index}
if (data.edge_attr is not None) and (self.use_edge_attr):
conv_args.update({"edge_attr": data.edge_attr})
return conv_args

def _freeze_conv(self):
for module in [self.convs, self.batch_norms]:
for layer in module:
Expand Down Expand Up @@ -278,26 +284,20 @@ def _multihead(self):
self.heads_NN.append(head_NN)

def forward(self, data):
x, edge_index, batch = (
data.x,
data.edge_index,
data.batch,
)
use_edge_attr = False
edge_attr = None
if (data.edge_attr is not None) and (self.use_edge_attr):
use_edge_attr = True
edge_attr = data.edge_attr
x = data.x

### encoder part ####
x = self.conv_shared(x=x, edge_index=edge_index, edge_attr=edge_attr)
conv_args = self._conv_args(data)
for conv, batch_norm in zip(self.convs, self.batch_norms):
c = conv(x=x, **conv_args)
x = F.relu(batch_norm(c))

#### multi-head decoder part####
# shared dense layers for graph level output
if batch is None:
if data.batch is None:
x_graph = x.mean(dim=0, keepdim=True)
else:
x_graph = global_mean_pool(x, batch.to(x.device))
x_graph = global_mean_pool(x, data.batch.to(x.device))
outputs = []
for head_dim, headloc, type_head in zip(
self.head_dims, self.heads_NN, self.head_type
Expand All @@ -308,9 +308,11 @@ def forward(self, data):
else:
if self.node_NN_type == "conv":
for conv, batch_norm in zip(headloc[0::2], headloc[1::2]):
x_node = F.relu(batch_norm(conv(x=x, edge_index=edge_index)))
x_node = F.relu(
batch_norm(conv(x=x, edge_index=data.edge_index))
)
else:
x_node = headloc(x=x, batch=batch)
x_node = headloc(x=x, batch=data.batch)
outputs.append(x_node)
return outputs

Expand Down
79 changes: 79 additions & 0 deletions hydragnn/models/SCFStack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
##############################################################################
# Copyright (c) 2021, Oak Ridge National Laboratory #
# All rights reserved. #
# #
# This file is part of HydraGNN and is distributed under a BSD 3-clause #
# license. For the licensing terms see the LICENSE file in the top-level #
# directory. #
# #
# SPDX-License-Identifier: BSD-3-Clause #
##############################################################################

from typing import Optional

import torch
from torch.nn import Linear, Sequential
from torch_geometric.nn.models.schnet import (
CFConv,
GaussianSmearing,
RadiusInteractionGraph,
ShiftedSoftplus,
)

from .Base import Base


class SCFStack(Base):
def __init__(
self,
num_filters: int,
num_gaussians: list,
radius: float,
*args,
max_neighbours: Optional[int] = None,
**kwargs,
):
self.radius = radius
self.max_neighbours = max_neighbours
self.num_filters = num_filters
self.num_gaussians = num_gaussians

super().__init__(*args, **kwargs)

self.distance_expansion = GaussianSmearing(0.0, radius, num_gaussians)
self.interaction_graph = RadiusInteractionGraph(radius, max_neighbours)

pass

def get_conv(self, input_dim, output_dim):
mlp = Sequential(
Linear(self.num_gaussians, self.num_filters),
ShiftedSoftplus(),
Linear(self.num_filters, self.num_filters),
)

return CFConv(
in_channels=input_dim,
out_channels=output_dim,
nn=mlp,
num_filters=self.num_filters,
cutoff=self.radius,
)

def _conv_args(self, data):
if (data.edge_attr is not None) and (self.use_edge_attr):
edge_index = data.edge_index
edge_weight = data.edge_attr.norm(dim=-1)
else:
edge_index, edge_weight = self.interaction_graph(data.pos, data.batch)

conv_args = {
"edge_index": edge_index,
"edge_weight": edge_weight,
"edge_attr": self.distance_expansion(edge_weight),
}

return conv_args

def __str__(self):
return "SCFStack"
30 changes: 29 additions & 1 deletion hydragnn/models/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from hydragnn.models.MFCStack import MFCStack
from hydragnn.models.CGCNNStack import CGCNNStack
from hydragnn.models.SAGEStack import SAGEStack
from hydragnn.models.SCFStack import SCFStack

from hydragnn.utils.distributed import get_device
from hydragnn.utils.print_utils import print_distributed
Expand All @@ -30,7 +31,6 @@ def create_model_config(
verbosity: int = 0,
use_gpu: bool = True,
):

return create_model(
config["Architecture"]["model_type"],
config["Architecture"]["input_dim"],
Expand All @@ -48,6 +48,9 @@ def create_model_config(
config["Architecture"]["edge_dim"],
config["Architecture"]["pna_deg"],
config["Architecture"]["skip_connection"],
config["Architecture"]["num_gaussians"],
config["Architecture"]["num_filters"],
config["Architecture"]["radius"],
verbosity,
use_gpu,
)
Expand All @@ -71,6 +74,9 @@ def create_model(
edge_dim: int = None,
pna_deg: torch.tensor = None,
skip_connection: bool = False,
num_gaussians: int = None,
num_filters: int = None,
radius: float = None,
verbosity: int = 0,
use_gpu: bool = True,
):
Expand Down Expand Up @@ -185,6 +191,28 @@ def create_model(
skip_connection=skip_connection,
)

elif model_type == "SchNet":
assert num_gaussians is not None, "SchNet requires num_guassians input."
assert num_filters is not None, "SchNet requires num_filters input."
assert radius is not None, "SchNet requires radius input."
model = SCFStack(
num_gaussians,
num_filters,
radius,
input_dim,
hidden_dim,
output_dim,
output_type,
output_heads,
loss_function_type,
max_neighbours=max_neighbours,
loss_weights=task_weights,
freeze_conv=freeze_conv,
initial_bias=initial_bias,
num_conv_layers=num_conv_layers,
num_nodes=num_nodes,
)

else:
raise ValueError("Unknown model_type: {0}".format(model_type))

Expand Down
11 changes: 8 additions & 3 deletions hydragnn/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ def update_config(config, train_loader, val_loader, test_loader):
else:
config["NeuralNetwork"]["Architecture"]["pna_deg"] = None

if "radius" not in config["NeuralNetwork"]["Architecture"]:
config["NeuralNetwork"]["Architecture"]["radius"] = None
if "num_gaussians" not in config["NeuralNetwork"]["Architecture"]:
config["NeuralNetwork"]["Architecture"]["num_gaussians"] = None
if "num_filters" not in config["NeuralNetwork"]["Architecture"]:
config["NeuralNetwork"]["Architecture"]["num_filters"] = None

config["NeuralNetwork"]["Architecture"] = update_config_edge_dim(
config["NeuralNetwork"]["Architecture"]
)
Expand All @@ -68,9 +75,8 @@ def update_config(config, train_loader, val_loader, test_loader):


def update_config_edge_dim(config):

config["edge_dim"] = None
edge_models = ["PNA", "CGCNN"]
edge_models = ["PNA", "CGCNN", "SchNet"]
if "edge_features" in config and config["edge_features"]:
assert (
config["model_type"] in edge_models
Expand All @@ -84,7 +90,6 @@ def update_config_edge_dim(config):


def check_output_dim_consistent(data, config):

output_type = config["NeuralNetwork"]["Variables_of_interest"]["type"]
output_index = config["NeuralNetwork"]["Variables_of_interest"]["output_index"]
if hasattr(data, "y_loc"):
Expand Down
2 changes: 2 additions & 0 deletions tests/inputs/ci.json
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
"skip_connection": false,
"radius": 2.0,
"max_neighbours": 100,
"num_gaussians": 50,
"num_filters": 126,
"periodic_boundary_conditions": false,
"hidden_dim": 8,
"num_conv_layers": 2,
Expand Down
2 changes: 2 additions & 0 deletions tests/inputs/ci_multihead.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
"skip_connection": false,
"radius": 2.0,
"max_neighbours": 100,
"num_gaussians": 50,
"num_filters": 126,
"periodic_boundary_conditions": false,
"hidden_dim": 8,
"num_conv_layers": 2,
Expand Down
8 changes: 6 additions & 2 deletions tests/test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import hydragnn, tests


# Main unit test function called by pytest wrappers.
def unittest_train_model(
model_type, skip_connection, ci_input, use_lengths, overwrite_data=False
Expand Down Expand Up @@ -133,6 +134,7 @@ def unittest_train_model(
"GIN": [0.25, 0.20],
"GAT": [0.60, 0.70],
"CGCNN": [0.50, 0.40],
"SchNet": [0.20, 0.20],
}
if use_lengths and ("vector" not in ci_input):
thresholds["CGCNN"] = [0.175, 0.175]
Expand Down Expand Up @@ -174,15 +176,17 @@ def unittest_train_model(


# Test across all models with both single/multihead
@pytest.mark.parametrize("model_type", ["SAGE", "GIN", "GAT", "MFC", "PNA", "CGCNN"])
@pytest.mark.parametrize(
"model_type", ["SAGE", "GIN", "GAT", "MFC", "PNA", "CGCNN", "SchNet"]
)
@pytest.mark.parametrize("skip_connection", [False, True])
@pytest.mark.parametrize("ci_input", ["ci.json", "ci_multihead.json"])
def pytest_train_model(model_type, skip_connection, ci_input, overwrite_data=False):
unittest_train_model(model_type, skip_connection, ci_input, False, overwrite_data)


# Test only models
@pytest.mark.parametrize("model_type", ["PNA", "CGCNN"])
@pytest.mark.parametrize("model_type", ["PNA", "CGCNN", "SchNet"])
@pytest.mark.parametrize("skip_connection", [False, True])
def pytest_train_model_lengths(model_type, skip_connection, overwrite_data=False):
unittest_train_model(model_type, skip_connection, "ci.json", True, overwrite_data)
Expand Down

0 comments on commit 1a42e26

Please sign in to comment.