forked from cmdupuis3/gnn-workspace
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
78 lines (58 loc) · 2.68 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import numpy as np
import xarray as xr
import xbatcher as xb
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from surface_currents_prep import *
from scenario import Scenario, sc5
from models import MsgModelDiff, ModelLikeAnirbans
from batching import rolling_batcher, batch_generator
def train(model, ds_training, ds_testing,
num_epochs=1, batch_size=32, plot_loss=False):
training_batch = rolling_batcher(ds_training, 7, 7)
testing_batch = rolling_batcher(ds_testing, 7, 7)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
testing_loss = []
for epoch in range(num_epochs):
for c, f, t, _ in batch_generator(training_batch, batch_size):
for convs, features, targets in zip(c, f, t):
optimizer.zero_grad()
outs = model(convs.x.float(), features.x.float(), features.edge_index, features.weight)
loss = loss_fn(outs, targets.x)
loss.backward()
optimizer.step()
num_batches = 0
epoch_loss = 0.0
for c, f, t, _ in batch_generator(testing_batch, batch_size):
for convs, features, targets in zip(c, f, t):
outs = model(convs.x.float(),features.x.float(), features.edge_index, features.weight)
batch_loss = loss_fn(outs, targets.x)
num_batches = num_batches + 1
epoch_loss = epoch_loss + batch_loss
# print(f'[Batch Loss: {batch_loss}')
epoch_loss = epoch_loss / num_batches
print(f'[\tEpoch Loss: {epoch_loss}')
testing_loss.append(epoch_loss.item())
if(plot_loss):
plt.figure(figsize=(18, 5))
plt.plot(range(num_epochs), testing_loss, color='#ff6347', label="loss")
plt.plot(epoch, testing_loss[-1], marker = 'o', markersize=10, color='#ff6347')
plt.legend()
plt.xlabel(r'Epoch')
plt.ylabel('Loss')
plt.ylim([10., (1.3 * testing_loss[0])])
plt.yscale("log")
plt.savefig('C:/Users/cdupu/Documents/gnn_training_loss.png')
if __name__ == '__main__':
ds_training = load_training_data(sc5)
ds_training = just_the_data(ds_training)
ds_training = select_from(ds_training)
ds_testing = load_test_data(sc5)
ds_testing = just_the_data(ds_testing)
ds_testing = select_from(ds_testing)
model = ModelLikeAnirbans(5, [40,20,10], 2, num_conv=2, num_conv_channels=40, message_multiplier=2)
train(model, ds_training, ds_testing, num_epochs=30, batch_size=64, plot_loss=True)
save_path = "C:/Users/cdupu/Documents/gnn_model4.pt"
torch.save(model.state_dict(), save_path)