diff --git a/tumourkit/classification/models/gin.py b/tumourkit/classification/models/gin.py new file mode 100644 index 0000000..094bd22 --- /dev/null +++ b/tumourkit/classification/models/gin.py @@ -0,0 +1,63 @@ + +""" +How Powerful are Graph Neural Networks? + +References +---------- +Paper: https://arxiv.org/abs/1810.00826 + +Copyright (C) 2023 Jose PĂ©rez Cano + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as published by +the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public License +along with this program. If not, see . + +Contact information: joseperez2000@hotmail.es +""" +import torch.nn as nn +import torch.nn.functional as F +from dgl.nn import GINConv +from .norm import Norm + +class GIN(nn.Module): + def __init__(self, in_feats, h_feats, num_classes, num_layers, drop_rate, norm_type, enable_background=False): + super(GIN, self).__init__() + self.conv_layers = nn.ModuleList() + self.conv_layers.append(GINConv(nn.Linear(in_feats, h_feats), 'max', activation=F.elu)) + self.conv_layers.append(nn.Dropout(drop_rate)) # Feature map dropout + self.conv_layers.append(Norm(norm_type=norm_type, hidden_dim=h_feats)) + for l in range(1,num_layers): + self.conv_layers.append(GINConv(nn.Linear(h_feats, h_feats), 'max', activation=F.elu)) + self.conv_layers.append(nn.Dropout(drop_rate)) + self.conv_layers.append(Norm(norm_type=norm_type, hidden_dim=h_feats)) + self.conv_layers.append(GINConv(nn.Linear(h_feats, num_classes), 'max',)) + + self.enable_background = enable_background + if enable_background: + self.bkgr_head = GINConv(nn.Linear(h_feats, 2), 'max',) + + + def forward(self, g, in_feat): + h = in_feat + for i, layer in enumerate(self.conv_layers): + if i == len(self.conv_layers) - 1: + if self.enable_background: # Last layer + h_bkgr = self.bkgr_head(g, h) + h = layer(g, h) + return h, h_bkgr + h = layer(g, h) + else: + if i % 3 == 1: + h = layer(h) # Dropout + else: + h = layer(g, h) # Other layers + return h \ No newline at end of file