-
Notifications
You must be signed in to change notification settings - Fork 2
/
draw_image.py
68 lines (55 loc) · 2.67 KB
/
draw_image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# -*- coding: utf-8 -*-
"""
Created on Tue Mar 31 01:57:34 2020
@author: malrawi
"""
# https://www.learnopencv.com/mask-r-cnn-instance-segmentation-with-pytorch/
import numpy as np
from PIL import Image, ImageChops
import torchvision.transforms as T
import random
from datasets import get_clothCoParse_class_names
import matplotlib.pyplot as plt
import cv2
from misc_utils import get_transforms
import torch
INSTANCE_CATEGORY_NAMES = get_clothCoParse_class_names()
# saving segmented cloths
def save_masks_as_images(img_name, masks, path, file_name, labels):
img = Image.open(img_name)
for i in range(len(masks)):
image_A = ImageChops.multiply(img, Image.fromarray(255*masks[i]).convert('RGB') )
image_A.save(path+file_name+labels[i]+'.png')
def get_prediction(model, img, threshold, device):
# img = Image.open(img_path)
transforms_train, transforms_test, transforms_target = get_transforms()
img = [T.Compose(transforms_test)(img).to(device)]
with torch.no_grad(): pred = model(img)
pred_score = list(pred[0]['scores'].cpu().numpy())
pred_t = [pred_score.index(x) for x in pred_score if x>threshold][-1]
masks = (pred[0]['masks']>0.5).squeeze().cpu().numpy()
pred_class = [INSTANCE_CATEGORY_NAMES[i] for i in list(pred[0]['labels'].cpu().numpy())]
pred_boxes = [[(i[0], i[1]), (i[2], i[3])] for i in list(pred[0]['boxes'].cpu().numpy())]
masks = masks[:pred_t+1]
pred_boxes = pred_boxes[:pred_t+1]
pred_class = pred_class[:pred_t+1]
return masks, pred_boxes, pred_class
def random_colour_masks(image):
colours = [[0, 255, 0],[0, 0, 255],[255, 0, 0],[0, 255, 255],[255, 255, 0],[255, 0, 255],[80, 70, 180],[250, 80, 190],[245, 145, 50],[70, 150, 250],[50, 190, 190]]
r = np.zeros_like(image).astype(np.uint8)
g = np.zeros_like(image).astype(np.uint8)
b = np.zeros_like(image).astype(np.uint8)
r[image == 1], g[image == 1], b[image == 1] = colours[random.randrange(0,10)]
coloured_mask = np.stack([r, g, b], axis=2)
return coloured_mask
def instance_segmentation_api(model, img_name, device, threshold=0.5, rect_th=3, text_size=1, text_th=3):
img = Image.open(img_name) #; img.resize( (550, 850) )
img.show()
masks, boxes, pred_cls = get_prediction(model, img, threshold, device)
img= np.array(img)
for i in range(len(masks)):
rgb_mask = random_colour_masks(masks[i])
img = cv2.addWeighted(img, 1, rgb_mask, 0.5, 0)
cv2.rectangle(img, boxes[i][0], boxes[i][1],color=(0, 255, 0), thickness=rect_th)
cv2.putText(img,pred_cls[i], boxes[i][0], cv2.FONT_HERSHEY_SIMPLEX, text_size, (0,255,0),thickness=text_th)
Image.fromarray(img).show()