Skip to content

Commit

Permalink
Added SchNet model as SCFStack (#171)
Browse files Browse the repository at this point in the history
  • Loading branch information
JustinBakerMath authored Mar 15, 2023
1 parent f960794 commit c78e7ac
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 8 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ There are many options for HydraGNN; the dataset and model type are particularly
important:
- `["Verbosity"]["level"]`: `0`, `1`, `2`, `3`, `4`
- `["Dataset"]["name"]`: `CuAu_32atoms`, `FePt_32atoms`, `FeSi_1024atoms`
- `["NeuralNetwork"]["Architecture"]["model_type"]`: `PNA`, `MFC`, `GIN`, `GAT`, `CGCNN`
- `["NeuralNetwork"]["Architecture"]["model_type"]`: `PNA`, `MFC`, `GIN`, `GAT`, `CGCNN`, `SchNet`
### Citations
"HydraGNN: Distributed PyTorch implementation of multi-headed graph convolutional neural networks", Copyright ID#: 81929619
Expand Down
79 changes: 79 additions & 0 deletions hydragnn/models/SCFStack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
##############################################################################
# Copyright (c) 2021, Oak Ridge National Laboratory #
# All rights reserved. #
# #
# This file is part of HydraGNN and is distributed under a BSD 3-clause #
# license. For the licensing terms see the LICENSE file in the top-level #
# directory. #
# #
# SPDX-License-Identifier: BSD-3-Clause #
##############################################################################

from typing import Optional

import torch
from torch.nn import Linear, Sequential
from torch_geometric.nn.models.schnet import (
CFConv,
GaussianSmearing,
RadiusInteractionGraph,
ShiftedSoftplus,
)

from .Base import Base


class SCFStack(Base):
def __init__(
self,
num_filters: int,
num_gaussians: list,
radius: float,
*args,
max_neighbours: Optional[int] = None,
**kwargs,
):
self.radius = radius
self.max_neighbours = max_neighbours
self.num_filters = num_filters
self.num_gaussians = num_gaussians

super().__init__(*args, **kwargs)

self.distance_expansion = GaussianSmearing(0.0, radius, num_gaussians)
self.interaction_graph = RadiusInteractionGraph(radius, max_neighbours)

pass

def get_conv(self, input_dim, output_dim):
mlp = Sequential(
Linear(self.num_gaussians, self.num_filters),
ShiftedSoftplus(),
Linear(self.num_filters, self.num_filters),
)

return CFConv(
in_channels=input_dim,
out_channels=output_dim,
nn=mlp,
num_filters=self.num_filters,
cutoff=self.radius,
)

def _conv_args(self, data):
if (data.edge_attr is not None) and (self.use_edge_attr):
edge_index = data.edge_index
edge_weight = data.edge_attr.norm(dim=-1)
else:
edge_index, edge_weight = self.interaction_graph(data.pos, data.batch)

conv_args = {
"edge_index": edge_index,
"edge_weight": edge_weight,
"edge_attr": self.distance_expansion(edge_weight),
}

return conv_args

def __str__(self):
return "SCFStack"
30 changes: 29 additions & 1 deletion hydragnn/models/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from hydragnn.models.MFCStack import MFCStack
from hydragnn.models.CGCNNStack import CGCNNStack
from hydragnn.models.SAGEStack import SAGEStack
from hydragnn.models.SCFStack import SCFStack

from hydragnn.utils.distributed import get_device
from hydragnn.utils.print_utils import print_distributed
Expand All @@ -30,7 +31,6 @@ def create_model_config(
verbosity: int = 0,
use_gpu: bool = True,
):

return create_model(
config["Architecture"]["model_type"],
config["Architecture"]["input_dim"],
Expand All @@ -47,6 +47,9 @@ def create_model_config(
config["Architecture"]["max_neighbours"],
config["Architecture"]["edge_dim"],
config["Architecture"]["pna_deg"],
config["Architecture"]["num_gaussians"],
config["Architecture"]["num_filters"],
config["Architecture"]["radius"],
verbosity,
use_gpu,
)
Expand All @@ -69,6 +72,9 @@ def create_model(
max_neighbours: int = None,
edge_dim: int = None,
pna_deg: torch.tensor = None,
num_gaussians: int = None,
num_filters: int = None,
radius: float = None,
verbosity: int = 0,
use_gpu: bool = True,
):
Expand Down Expand Up @@ -178,6 +184,28 @@ def create_model(
num_nodes=num_nodes,
)

elif model_type == "SchNet":
assert num_gaussians is not None, "SchNet requires num_guassians input."
assert num_filters is not None, "SchNet requires num_filters input."
assert radius is not None, "SchNet requires radius input."
model = SCFStack(
num_gaussians,
num_filters,
radius,
input_dim,
hidden_dim,
output_dim,
output_type,
output_heads,
loss_function_type,
max_neighbours=max_neighbours,
loss_weights=task_weights,
freeze_conv=freeze_conv,
initial_bias=initial_bias,
num_conv_layers=num_conv_layers,
num_nodes=num_nodes,
)

else:
raise ValueError("Unknown model_type: {0}".format(model_type))

Expand Down
11 changes: 8 additions & 3 deletions hydragnn/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ def update_config(config, train_loader, val_loader, test_loader):
else:
config["NeuralNetwork"]["Architecture"]["pna_deg"] = None

if "radius" not in config["NeuralNetwork"]["Architecture"]:
config["NeuralNetwork"]["Architecture"]["radius"] = None
if "num_gaussians" not in config["NeuralNetwork"]["Architecture"]:
config["NeuralNetwork"]["Architecture"]["num_gaussians"] = None
if "num_filters" not in config["NeuralNetwork"]["Architecture"]:
config["NeuralNetwork"]["Architecture"]["num_filters"] = None

config["NeuralNetwork"]["Architecture"] = update_config_edge_dim(
config["NeuralNetwork"]["Architecture"]
)
Expand All @@ -65,9 +72,8 @@ def update_config(config, train_loader, val_loader, test_loader):


def update_config_edge_dim(config):

config["edge_dim"] = None
edge_models = ["PNA", "CGCNN"]
edge_models = ["PNA", "CGCNN", "SchNet"]
if "edge_features" in config and config["edge_features"]:
assert (
config["model_type"] in edge_models
Expand All @@ -81,7 +87,6 @@ def update_config_edge_dim(config):


def check_output_dim_consistent(data, config):

output_type = config["NeuralNetwork"]["Variables_of_interest"]["type"]
output_index = config["NeuralNetwork"]["Variables_of_interest"]["output_index"]
if hasattr(data, "y_loc"):
Expand Down
2 changes: 2 additions & 0 deletions tests/inputs/ci.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
"model_type": "PNA",
"radius": 2.0,
"max_neighbours": 100,
"num_gaussians": 50,
"num_filters": 126,
"periodic_boundary_conditions": false,
"hidden_dim": 8,
"num_conv_layers": 2,
Expand Down
2 changes: 2 additions & 0 deletions tests/inputs/ci_multihead.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
"model_type": "PNA",
"radius": 2.0,
"max_neighbours": 100,
"num_gaussians": 50,
"num_filters": 126,
"periodic_boundary_conditions": false,
"hidden_dim": 8,
"num_conv_layers": 2,
Expand Down
9 changes: 6 additions & 3 deletions tests/test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@

import hydragnn, tests


# Main unit test function called by pytest wrappers.
def unittest_train_model(model_type, ci_input, use_lengths, overwrite_data=False):

world_size, rank = hydragnn.utils.get_comm_size_and_rank()

os.environ["SERIALIZED_DATA_PATH"] = os.getcwd()
Expand Down Expand Up @@ -130,6 +130,7 @@ def unittest_train_model(model_type, ci_input, use_lengths, overwrite_data=False
"GIN": [0.25, 0.20],
"GAT": [0.60, 0.70],
"CGCNN": [0.50, 0.40],
"SchNet": [0.20, 0.20],
}
if use_lengths and ("vector" not in ci_input):
thresholds["CGCNN"] = [0.175, 0.175]
Expand Down Expand Up @@ -171,14 +172,16 @@ def unittest_train_model(model_type, ci_input, use_lengths, overwrite_data=False


# Test across all models with both single/multihead
@pytest.mark.parametrize("model_type", ["SAGE", "GIN", "GAT", "MFC", "PNA", "CGCNN"])
@pytest.mark.parametrize(
"model_type", ["SAGE", "GIN", "GAT", "MFC", "PNA", "CGCNN", "SchNet"]
)
@pytest.mark.parametrize("ci_input", ["ci.json", "ci_multihead.json"])
def pytest_train_model(model_type, ci_input, overwrite_data=False):
unittest_train_model(model_type, ci_input, False, overwrite_data)


# Test only models
@pytest.mark.parametrize("model_type", ["PNA", "CGCNN"])
@pytest.mark.parametrize("model_type", ["PNA", "CGCNN", "SchNet"])
def pytest_train_model_lengths(model_type, overwrite_data=False):
unittest_train_model(model_type, "ci.json", True, overwrite_data)

Expand Down

0 comments on commit c78e7ac

Please sign in to comment.