-
Notifications
You must be signed in to change notification settings - Fork 0
/
layers.py
102 lines (85 loc) · 3.58 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
from ogb.graphproppred.mol_encoder import BondEncoder
from dgl.nn.functional import edge_softmax
from modules import MLP, MessageNorm
class GENConv(nn.Module):
r"""
Description
-----------
Generalized Message Aggregator was introduced in "DeeperGCN: All You Need to Train Deeper GCNs <https://arxiv.org/abs/2006.07739>"
Parameters
----------
in_dim: int
Input size.
out_dim: int
Output size.
aggregator: str
Type of aggregation. Default is 'softmax'.
beta: float
A continuous variable called an inverse temperature. Default is 1.0.
learn_beta: bool
Whether beta is a learnable variable or not. Default is False.
p: float
Initial power for power mean aggregation. Default is 1.0.
learn_p: bool
Whether p is a learnable variable or not. Default is False.
msg_norm: bool
Whether message normalization is used. Default is False.
learn_msg_scale: bool
Whether s is a learnable scaling factor or not in message normalization. Default is False.
mlp_layers: int
The number of MLP layers. Default is 1.
eps: float
A small positive constant in message construction function. Default is 1e-7.
"""
def __init__(self,
in_dim,
out_dim,
aggregator='softmax',
beta=1.0,
learn_beta=False,
p=1.0,
learn_p=False,
msg_norm=False,
learn_msg_scale=False,
mlp_layers=1,
eps=1e-7):
super(GENConv, self).__init__()
self.aggr = aggregator
self.eps = eps
channels = [in_dim]
for _ in range(mlp_layers - 1):
channels.append(in_dim * 2)
channels.append(out_dim)
self.mlp = MLP(channels)
self.msg_norm = MessageNorm(learn_msg_scale) if msg_norm else None
self.beta = nn.Parameter(torch.Tensor([beta]), requires_grad=True) if learn_beta and self.aggr == 'softmax' else beta
self.p = nn.Parameter(torch.Tensor([p]), requires_grad=True) if learn_p else p
self.edge_encoder = BondEncoder(in_dim)
def forward(self, g, node_feats, edge_feats):
with g.local_scope():
# Node and edge feature size need to match.
g.ndata['h'] = node_feats
g.edata['h'] = self.edge_encoder(edge_feats)
g.apply_edges(fn.u_add_e('h', 'h', 'm'))
if self.aggr == 'softmax':
g.edata['m'] = F.relu(g.edata['m']) + self.eps
g.edata['a'] = edge_softmax(g, g.edata['m'] * self.beta)
g.update_all(lambda edge: {'x': edge.data['m'] * edge.data['a']},
fn.sum('x', 'm'))
elif self.aggr == 'power':
minv, maxv = 1e-7, 1e1
torch.clamp_(g.edata['m'], minv, maxv)
g.update_all(lambda edge: {'x': torch.pow(edge.data['m'], self.p)},
fn.mean('x', 'm'))
torch.clamp_(g.ndata['m'], minv, maxv)
g.ndata['m'] = torch.pow(g.ndata['m'], self.p)
else:
raise NotImplementedError(f'Aggregator {self.aggr} is not supported.')
if self.msg_norm is not None:
g.ndata['m'] = self.msg_norm(node_feats, g.ndata['m'])
feats = node_feats + g.ndata['m']
return self.mlp(feats)