forked from FabianRei/neuro_detect
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtestNetwork_newData.py
39 lines (35 loc) · 1.59 KB
/
testNetwork_newData.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import numpy as np
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from src.data.preprocess import resizeTensors, normalizeByIndividualMean, getTensorList_general, cropBlockResize
from src.models.simple_net import SimpleNet
tensorDir = '/black/localhome/reith/Desktop/projects/Tensors/newData/nok/'
networkWeights = 'models/trained_simplenet.torch'
wantedShape = (41, 53, 38, 6)
crop = (slice(4, 28), slice(20, 44), slice(7, 31))
resizeFactor = 2
dimIn = 12*12*12*6
dimOut = 4
net = SimpleNet(dimIn=dimIn, dimOut=dimOut)
net.load_state_dict(torch.load(networkWeights))
tensors, names = getTensorList_general(tensorDir, giveNames=True)
tensors = resizeTensors(tensors, wantedShape)
tensors = cropBlockResize(tensors, resizeFactor, crop)
tensors = normalizeByIndividualMean(tensors)
tensors = np.stack(tensors)
tensors = torch.from_numpy(tensors).type(torch.float32)
tensors = Variable(tensors).view(-1, dimIn)
net_out = net(tensors)
prediction = net_out.max(1)[1]
predictionStringArr = ["no axis is flipped", "the x axis is flipped", "the y axis is flipped", "the z axis is flipped", "it has no idea what's happening"]
for i, name in enumerate(names):
predCertainty = F.softmax(net_out[i], dim=0)[prediction[i]].detach().numpy()*100
pred = prediction[i]
if predCertainty < 99:
predIndex = 4
else:
predIndex = pred
print(f"The neuro detector thinks that {predictionStringArr[predIndex]} for {names[i]}.")
print(f"[Pseudo certainty is at {F.softmax(net_out[i], dim=0)[prediction[i]].detach().numpy()*100}% for {predictionStringArr[pred]}]")
print("done")