-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
38 lines (34 loc) · 1.28 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# @Filename: predict.py
# @Author: Ashutosh Tiwari
# @Email: [email protected]
# @Time: 5/2/22 12:16 PM
import sys
import torch
from modules.blind_net import BlindNet
from PIL import Image, ImageDraw, ImageFont
from data.data_loader import CocoDataset
from torch.utils.data import DataLoader
import tkinter
import matplotlib
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")
matplotlib.use('TkAgg')
if __name__ == '__main__':
model_path = sys.argv[1]
dataset = CocoDataset(annotations='coco2017/annotations/instances_train2017.json', image_root_dir='coco2017', mask_root_dir='cat_id_masked_arrays', train=True, image_size=32, predict=True)
train_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
image = train_loader.dataset[torch.randint(1, 100, (1,))[0]]
ground_truth = image[1]
image_path = image[3]
rgb_image = Image.open(image_path)
# plt.imshow(rgb_image)
# plt.show()
net = BlindNet(image_size=image[0].shape[1])
net.load_state_dict(torch.load(model_path))
net.eval()
result = net(image[0].unsqueeze(0))
result = torch.argmax(result, dim=1)
# result = result.reshape(32, 32)
print(torch.unique(ground_truth.reshape(32 * 32)))
print(torch.unique(result))