forked from cmdupuis3/gnn-workspace
-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
76 lines (48 loc) · 2.41 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import numpy as np
import xarray as xr
import xbatcher as xb
import torch
import torch.nn as nn
from torch_geometric.utils.convert import to_networkx, from_networkx
import matplotlib.pyplot as plt
from batching import rolling_batcher, batch_generator
from surface_currents_prep import *
from scenario import Scenario, sc5
from models import MsgModelDiff
from xr_to_networkx import xr_to_graphs, graphs_to_xr
def ez_plot(var, min, max, filename):
plt.figure(figsize=(16, 10))
xr.plot.contourf(var, levels = 100, vmin=min, vmax=max)
plt.savefig(filename)
def predict(model, ds_predict, batch_size=32):
predict_batch = rolling_batcher(ds_predict, 7, 7)
U_pred = np.full(ds_predict['U'].shape, np.nan)
V_pred = np.full(ds_predict['V'].shape, np.nan)
loss_fn = nn.MSELoss()
for c, f, t, co in batch_generator(predict_batch, batch_size):
for convs, features, targets, coords in zip(c, f, t, co):
predictions_graph = model(convs.x.float(), features.x.float(), features.edge_index, features.weight)
batch_loss = loss_fn(predictions_graph, targets.x)
for ct, node in enumerate(predictions_graph):
nlat, nlon = coords[ct]
U_pred[nlat, nlon] = node[0]
V_pred[nlat, nlon] = node[1]
print(f'[Batch Loss: {batch_loss}')
U_pred = xr.DataArray(U_pred, dims=['nlat', 'nlon'])
V_pred = xr.DataArray(V_pred, dims=['nlat', 'nlon'])
U_diff = U_pred - ds_predict['U']
V_diff = V_pred - ds_predict['V']
ez_plot(ds_predict['U'], -20, 20, 'C:/Users/cdupu/Documents/model3_U.png')
ez_plot(ds_predict['V'], -20, 20, 'C:/Users/cdupu/Documents/model3_V.png')
ez_plot(U_pred, -20, 20, 'C:/Users/cdupu/Documents/model3_U_pred.png')
ez_plot(V_pred, -20, 20, 'C:/Users/cdupu/Documents/model3_V_pred.png')
ez_plot(U_diff, -100, 100, 'C:/Users/cdupu/Documents/model3_U_diff.png')
ez_plot(V_diff, -100, 100, 'C:/Users/cdupu/Documents/model3_V_diff.png')
if __name__ == '__main__':
load_path = "C:/Users/cdupu/Documents/gnn_model3.pt"
model = MsgModelDiff(5, [40,20,10], 2, num_conv=2, num_conv_channels=40, message_multiplier=2)
model.load_state_dict(torch.load(load_path))
ds_predict = load_predict_data(sc5)
ds_predict = just_the_data(ds_predict)
#ds_predict = select_from(ds_predict)
predict(model, ds_predict, batch_size=512)