Skip to content

Commit

Permalink
added gin network
Browse files Browse the repository at this point in the history
  • Loading branch information
Jerry-Master committed Jan 4, 2024
1 parent c17595a commit 25d66a7
Showing 1 changed file with 63 additions and 0 deletions.
63 changes: 63 additions & 0 deletions tumourkit/classification/models/gin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@

"""
How Powerful are Graph Neural Networks?
References
----------
Paper: https://arxiv.org/abs/1810.00826
Copyright (C) 2023 Jose Pérez Cano
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
Contact information: [email protected]
"""
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import GINConv
from .norm import Norm

class GIN(nn.Module):
def __init__(self, in_feats, h_feats, num_classes, num_layers, drop_rate, norm_type, enable_background=False):
super(GIN, self).__init__()
self.conv_layers = nn.ModuleList()
self.conv_layers.append(GINConv(nn.Linear(in_feats, h_feats), 'max', activation=F.elu))
self.conv_layers.append(nn.Dropout(drop_rate)) # Feature map dropout
self.conv_layers.append(Norm(norm_type=norm_type, hidden_dim=h_feats))
for l in range(1,num_layers):
self.conv_layers.append(GINConv(nn.Linear(h_feats, h_feats), 'max', activation=F.elu))
self.conv_layers.append(nn.Dropout(drop_rate))
self.conv_layers.append(Norm(norm_type=norm_type, hidden_dim=h_feats))
self.conv_layers.append(GINConv(nn.Linear(h_feats, num_classes), 'max',))

self.enable_background = enable_background
if enable_background:
self.bkgr_head = GINConv(nn.Linear(h_feats, 2), 'max',)


def forward(self, g, in_feat):
h = in_feat
for i, layer in enumerate(self.conv_layers):
if i == len(self.conv_layers) - 1:
if self.enable_background: # Last layer
h_bkgr = self.bkgr_head(g, h)
h = layer(g, h)
return h, h_bkgr
h = layer(g, h)
else:
if i % 3 == 1:
h = layer(h) # Dropout
else:
h = layer(g, h) # Other layers
return h

0 comments on commit 25d66a7

Please sign in to comment.