-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtest.py
31 lines (21 loc) · 784 Bytes
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from dataloader import *
from TRAINING_CONFIG import *
import torchvision
def eval(test_model):
test_dataloader = data_prep(testing=True)
test_model.eval()
if not os.path.exists(output_images_path):
os.mkdir(output_images_path)
for i, (img, _, name) in enumerate(test_dataloader):
with torch.no_grad():
if img.size()[1] == 4: #if alpha channel exists in test images, remove alpha channel
img = img[:, :3, :, :]
img = img.to(device)
generate_img = test_model(img)
torchvision.utils.save_image(generate_img, output_images_path + name[0])
print("Evaluation of Given Test Images Completed!")
def run_testing():
model_test = torch.load(test_model_path).to(device)
eval(model_test)
### START TESTING ###
#run_testing()