-
Notifications
You must be signed in to change notification settings - Fork 0
/
testmodel.py
109 lines (85 loc) · 2.82 KB
/
testmodel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""
Pass processed images through CNN to extract starshade position.
"""
import numpy as np
import h5py
import os
import torch
import torch.nn.functional as F
from torchvision import transforms
from cnn_andrew import CNN
import h5py
import atexit
import time
do_save = [False, True][1]
# data_run = 'run__6_01_21__data_1s_bin1__spiders__median'
# data_run = 'run__8_30_21__data_1s_bin1__spiders__median'
data_run = 'CNN__12_21_21__t15_wide__spiders'
model_name = 'Newest_andrew'
save_ext = ''
#Telescope sizes in Lab and Space coordinates [m] (sets scaling factor)
Dtel_lab = 2.201472e-3
Dtel_space = 2.4
lab2space = Dtel_space / Dtel_lab
#Directories
data_dir = './lab_experiments/processing_data/Processed_Images'
model_dir = './models'
save_dir = 'Test_Results'
#######################
#Open test data file
test_loader = h5py.File(os.path.join(data_dir, data_run + '.h5'), 'r')
atexit.register(test_loader.close)
#Get images, amplitudes, and positions
images = test_loader['images']
amplitudes = test_loader['amplitudes']
positions = test_loader['positions']
#Load model
model = CNN()
mod_file = os.path.join(model_dir, model_name + '.pt')
model.load_state_dict(torch.load(mod_file))
model.eval()
#Transform
transform = transforms.Compose([transforms.ToTensor()])
tik = time.perf_counter()
print(f'Testing {model_name} model...')
#Loop through images and get prediction position
predicted_position = np.zeros((0,2))
difference = np.zeros((0,2))
with torch.no_grad():
for img, amp, pos in zip(images, amplitudes, positions):
#Catch error
if amp == -1:
predictions = np.concatenate((predictions, [[-1,-1]]))
continue
#Normalize image by fit amplitude
img /= amp
#Change datatype
img = img.astype('float32')
#Transform image
img = transform(img)
img = torch.unsqueeze(img, 0)
#Get solved position
output = model(img)
output = output.cpu().detach().numpy().squeeze().astype(float)
#Compare to truth position (scale truth to space)
diff = output - pos * lab2space
#Store
predicted_position = np.concatenate((predicted_position, [output]))
difference = np.concatenate((difference, [diff]))
tok = time.perf_counter()
print(f'Done! in {tok-tik:.1f} s')
#Save results
if do_save:
#Make sure directory exists
if not os.path.exists(save_dir):
os.makedirs(save_dir)
if save_ext != '':
save_ext = '_' + save_ext
#Save data
with h5py.File(os.path.join(save_dir, f'{data_run}__{model_name}{save_ext}.h5'), 'w') as f:
f.create_dataset('predicted_position', data=predicted_position)
f.create_dataset('difference', data=difference)
f.create_dataset('lab2space', data=lab2space)
f.create_dataset('positions', data=positions)
else:
breakpoint()