-
Notifications
You must be signed in to change notification settings - Fork 0
/
trivial_gcn.py
97 lines (73 loc) · 2.99 KB
/
trivial_gcn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool
class AtomEncoder(torch.nn.Module):
def __init__(self, hidden_channels):
super(AtomEncoder, self).__init__()
self.embeddings = torch.nn.ModuleList()
for i in range(9):
self.embeddings.append(torch.nn.Embedding(100, hidden_channels))
def reset_parameters(self):
for embedding in self.embeddings:
embedding.reset_parameters()
def forward(self, x):
if x.dim() == 1:
x = x.unsqueeze(1)
out = 0
for i in range(x.size(1)):
out += self.embeddings[i](x[:, i])
return out
class GCN(torch.nn.Module):
def __init__(self, hidden_channels):
super(GCN, self).__init__()
torch.manual_seed(12345)
self.emb = AtomEncoder(train_dataset.num_node_features)
self.conv1 = GCNConv(train_dataset.num_node_features, hidden_channels)
self.conv2 = GCNConv(hidden_channels, hidden_channels)
self.conv3 = GCNConv(hidden_channels, hidden_channels)
self.lin = Linear(hidden_channels, train_dataset.num_classes)
def forward(self, x, edge_index, batch):
# 1. Obtain node embeddings
x = self.emb(x)
x = self.conv1(x, edge_index)
x = x.relu()
x = self.conv2(x, edge_index)
x = x.relu()
x = self.conv3(x, edge_index)
# 2. Readout layer
x = global_mean_pool(x, batch) # [batch_size, hidden_channels]
# 3. Apply a final classifier
x = F.dropout(x, p=0.5, training=self.training)
x = self.lin(x)
return x
model = GCN(hidden_channels=64)
print(model)
model = GCN(hidden_channels=64)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.L1Loss()
def train():
model.train()
loss_sum = 0
for data in train_loader: # Iterate in batches over the training dataset.
out = model(data.x, data.edge_index, data.batch) # Perform a single forward pass.
loss = criterion(out, data.y) # Compute the loss (Mean Absolute Error)
loss.backward() # Derive gradients.
optimizer.step() # Update parameters based on gradients.
optimizer.zero_grad() # Clear gradients.
loss_sum+=loss
return loss #loss_sum / len(train_loader.dataset) # Derive average loss over minibatch.
def test(loader):
model.eval()
loss_sum = 0
for data in loader: # Iterate in batches over the training/test dataset.
out = model(data.x, data.edge_index, data.batch)
# pred = out.argmax(dim=1)1
loss = criterion(out,data.y)
# loss_sum+=loss
return loss #loss_sum / len(loader.dataset) # Derive average loss over minibatch.
for epoch in range(1, 201):
train()
train_loss = test(train_loader)
test_loss = test(eval_loader)
print(f'Epoch: {epoch:03d}, Train MAE: {train_loss:.4f}, Test MAE: {test_loss:.4f}')