-
Notifications
You must be signed in to change notification settings - Fork 25
/
layers.py
57 lines (43 loc) · 1.61 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import dgl.function as fn
import torch
import torch.nn as nn
EOS = 1e-10
class GCNConv_dense(nn.Module):
def __init__(self, input_size, output_size):
super(GCNConv_dense, self).__init__()
self.linear = nn.Linear(input_size, output_size)
def init_para(self):
self.linear.reset_parameters()
def forward(self, input, A, sparse=False):
hidden = self.linear(input)
if sparse:
output = torch.sparse.mm(A, hidden)
else:
output = torch.matmul(A, hidden)
return output
class GCNConv_dgl(nn.Module):
def __init__(self, input_size, output_size):
super(GCNConv_dgl, self).__init__()
self.linear = nn.Linear(input_size, output_size)
def forward(self, x, g):
with g.local_scope():
g.ndata['h'] = self.linear(x)
g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum(msg='m', out='h'))
return g.ndata['h']
class Attentive(nn.Module):
def __init__(self, isize):
super(Attentive, self).__init__()
self.w = nn.Parameter(torch.ones(isize))
def forward(self, x):
return x @ torch.diag(self.w)
class SparseDropout(nn.Module):
def __init__(self, dprob=0.5):
super(SparseDropout, self).__init__()
# dprob is ratio of dropout
# convert to keep probability
self.kprob = 1 - dprob
def forward(self, x):
mask = ((torch.rand(x._values().size()) + (self.kprob)).floor()).type(torch.bool)
rc = x._indices()[:,mask]
val = x._values()[mask]*(1.0 / self.kprob)
return torch.sparse.FloatTensor(rc, val, x.shape)