Skip to content

Commit

Permalink
SAGEStack added (#109)
Browse files Browse the repository at this point in the history
* SAGEStack added

* Update SAGEStack and test accuracy

* Simplify sage inputs

Co-authored-by: Massimiliano Lupo Pasini <[email protected]>
Co-authored-by: Sam Reeve <[email protected]>
  • Loading branch information
3 people authored Mar 24, 2022
1 parent 2977ad2 commit 5931f04
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 1 deletion.
32 changes: 32 additions & 0 deletions hydragnn/models/SAGEStack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
##############################################################################
# Copyright (c) 2021, Oak Ridge National Laboratory #
# All rights reserved. #
# #
# This file is part of HydraGNN and is distributed under a BSD 3-clause #
# license. For the licensing terms see the LICENSE file in the top-level #
# directory. #
# #
# SPDX-License-Identifier: BSD-3-Clause #
##############################################################################

import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import ModuleList
from torch_geometric.nn import SAGEConv, BatchNorm

from .Base import Base


class SAGEStack(Base):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def get_conv(self, input_dim, output_dim):
return SAGEConv(
in_channels=input_dim,
out_channels=output_dim,
)

def __str__(self):
return "SAGEStack"
14 changes: 14 additions & 0 deletions hydragnn/models/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from hydragnn.models.GATStack import GATStack
from hydragnn.models.MFCStack import MFCStack
from hydragnn.models.CGCNNStack import CGCNNStack
from hydragnn.models.SAGEStack import SAGEStack

from hydragnn.utils.distributed import get_device
from hydragnn.utils.print_utils import print_distributed
Expand Down Expand Up @@ -149,6 +150,19 @@ def create_model(
num_nodes=num_nodes,
)

elif model_type == "SAGE":
model = SAGEStack(
input_dim,
hidden_dim,
output_dim,
output_type,
output_heads,
loss_weights=task_weights,
freeze_conv=freeze_conv,
num_conv_layers=num_conv_layers,
num_nodes=num_nodes,
)

else:
raise ValueError("Unknown model_type: {0}".format(model_type))

Expand Down
3 changes: 2 additions & 1 deletion tests/test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def unittest_train_model(model_type, ci_input, use_lengths, overwrite_data=False

# Set RMSE and sample MAE/max error thresholds
thresholds = {
"SAGE": [0.20, 0.20, 0.75],
"PNA": [0.20, 0.20, 0.75],
"MFC": [0.20, 0.20, 1.5],
"GIN": [0.25, 0.20, 0.75],
Expand Down Expand Up @@ -179,7 +180,7 @@ def unittest_train_model(model_type, ci_input, use_lengths, overwrite_data=False


# Test across all models with both single/multihead
@pytest.mark.parametrize("model_type", ["GIN", "GAT", "MFC", "PNA", "CGCNN"])
@pytest.mark.parametrize("model_type", ["SAGE", "GIN", "GAT", "MFC", "PNA", "CGCNN"])
@pytest.mark.parametrize("ci_input", ["ci.json", "ci_multihead.json"])
def pytest_train_model(model_type, ci_input, overwrite_data=False):
unittest_train_model(model_type, ci_input, False, overwrite_data)
Expand Down

0 comments on commit 5931f04

Please sign in to comment.