-
Notifications
You must be signed in to change notification settings - Fork 57
/
Copy pathtest_model.py
85 lines (58 loc) · 2.56 KB
/
test_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# Copyright 2020 by Andrey Ignatov. All Rights Reserved.
from scipy import misc
import numpy as np
import sys
import os
from torch.utils.data import DataLoader
from torchvision import transforms
import torch
from load_data import LoadVisualData
from model import PyNET
import utils
to_image = transforms.Compose([transforms.ToPILImage()])
level, restore_epoch, dataset_dir, use_gpu, orig_model = utils.process_test_model_args(sys.argv)
dslr_scale = float(1) / (2 ** (level - 1))
def test_model():
if use_gpu == "true":
torch.backends.cudnn.deterministic = True
device = torch.device("cuda")
else:
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
device = torch.device("cpu")
# Creating dataset loaders
visual_dataset = LoadVisualData(dataset_dir, 10, dslr_scale, level, full_resolution=True)
visual_loader = DataLoader(dataset=visual_dataset, batch_size=1, shuffle=False, num_workers=0,
pin_memory=True, drop_last=False)
# Creating and loading pre-trained PyNET model
model = PyNET(level=level, instance_norm=True, instance_norm_level_1=True).to(device)
model = torch.nn.DataParallel(model)
if orig_model == "true":
model.load_state_dict(torch.load("models/original/pynet_level_0.pth"), strict=True)
else:
model.load_state_dict(torch.load("models/pynet_level_" + str(level) +
"_epoch_" + str(restore_epoch) + ".pth"), strict=True)
if use_gpu == "true":
model.half()
model.eval()
# Processing full-resolution RAW images
with torch.no_grad():
visual_iter = iter(visual_loader)
for j in range(len(visual_loader)):
print("Processing image " + str(j))
torch.cuda.empty_cache()
raw_image = next(visual_iter)
if use_gpu == "true":
raw_image = raw_image.to(device, dtype=torch.half)
else:
raw_image = raw_image.to(device)
# Run inference
enhanced = model(raw_image.detach())
enhanced = np.asarray(to_image(torch.squeeze(enhanced.float().detach().cpu())))
# Save the results as .png images
if orig_model == "true":
misc.imsave("results/full-resolution/" + str(j) + "_level_" + str(level) + "_orig.png", enhanced)
else:
misc.imsave("results/full-resolution/" + str(j) + "_level_" + str(level) +
"_epoch_" + str(restore_epoch) + ".png", enhanced)
if __name__ == '__main__':
test_model()