-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathvisualize_json_results_infer.py
93 lines (70 loc) · 3.21 KB
/
visualize_json_results_infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates.
import argparse
import json
import numpy as np
import os
from collections import defaultdict
import cv2
import tqdm
from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.structures import Boxes, BoxMode, Instances, Keypoints
from detectron2.utils.file_io import PathManager
from detectron2.utils.logger import setup_logger
from visualizer import Visualizer
import register_cattle_datasets
def create_instances(predictions, image_size):
ret = Instances(image_size)
score = np.asarray([x["score"] for x in predictions])[[0]]
chosen = (score > args.conf_threshold).nonzero()[0]
score = score[chosen]
bbox = np.asarray([predictions[i]["bbox"] for i in chosen]).reshape(-1, 4)[[0]]
bbox = BoxMode.convert(bbox, BoxMode.XYWH_ABS, BoxMode.XYXY_ABS)
keypoints = np.asarray([predictions[i]["keypoints"] for i in chosen])
# keep only the highest scoring keypoint
keypoints = Keypoints(keypoints.reshape(-1, 13, 3)[[0]])
labels = np.asarray([dataset_id_map(predictions[i]["category_id"]) for i in chosen])[[0]]
ret.scores = score
ret.pred_boxes = bbox
ret.pred_classes = labels
ret.pred_keypoints = keypoints
try:
ret.pred_masks = [predictions[i]["segmentation"] for i in chosen]
except KeyError:
pass
return ret
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="A script that visualizes the json predictions from COCO or LVIS dataset."
)
parser.add_argument("--input", required=True, help="JSON file produced by the model")
parser.add_argument("--output", required=True, help="output directory")
parser.add_argument("--dataset", help="name of the dataset", default="coco_2017_val")
parser.add_argument("--conf-threshold", default=0.5, type=float, help="confidence threshold")
args = parser.parse_args()
logger = setup_logger()
with PathManager.open(args.input, "r") as f:
predictions = json.load(f)
pred_by_image = defaultdict(list)
for p in predictions:
pred_by_image[p["image_id"]].append(p)
dicts = list(DatasetCatalog.get(args.dataset))
metadata = MetadataCatalog.get(args.dataset)
if hasattr(metadata, "thing_dataset_id_to_contiguous_id"):
def dataset_id_map(ds_id):
return metadata.thing_dataset_id_to_contiguous_id[ds_id]
elif "lvis" in args.dataset:
# LVIS results are in the same format as COCO results, but have a different
# mapping from dataset category id to contiguous category id in [0, #categories - 1]
def dataset_id_map(ds_id):
return ds_id - 1
else:
raise ValueError("Unsupported dataset: {}".format(args.dataset))
os.makedirs(args.output, exist_ok=True)
for dic in tqdm.tqdm(dicts):
img = cv2.imread(dic["file_name"], cv2.IMREAD_COLOR)[:, :, ::-1]
basename = os.path.basename(dic["file_name"])
predictions = create_instances(pred_by_image[dic["image_id"]], img.shape[:2])
vis = Visualizer(img, metadata)
vis_pred = vis.draw_instance_predictions(predictions).get_image()
cv2.imwrite(os.path.join(args.output, basename), vis_pred[:, :, ::-1])