Skip to content

Commit

Permalink
Added support for 3D boxes with function weighted_boxes_fusion_3d
Browse files Browse the repository at this point in the history
  • Loading branch information
IDMIPPM committed Jun 16, 2020
1 parent 3d40923 commit c789f05
Show file tree
Hide file tree
Showing 5 changed files with 370 additions and 1 deletion.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ boxes, scores, labels = weighted_boxes_fusion([boxes_list], [scores_list], [labe

More examples can be found in [example.py](./example.py)

#### 3D version

There is support for 3D boxes in WBF method with `weighted_boxes_fusion_3d` function. Check example of usage in [example_3d.py](./example_3d.py)

## Accuracy and speed comparison

Comparison was made for ensemble of 5 different object detection models predictions trained on [Open Images Dataset](https://storage.googleapis.com/openimages/web/index.html) (500 classes).
Expand Down
1 change: 1 addition & 0 deletions ensemble_boxes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from .ensemble_boxes_nms import nms_method
from .ensemble_boxes_nms import nms
from .ensemble_boxes_nms import soft_nms
from .ensemble_boxes_wbf_3d import weighted_boxes_fusion_3d
222 changes: 222 additions & 0 deletions ensemble_boxes/ensemble_boxes_wbf_3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
# coding: utf-8
__author__ = 'ZFTurbo: https://kaggle.com/zfturbo'


import warnings
import numpy as np
from numba import jit


@jit(nopython=True)
def bb_intersection_over_union_3d(A, B) -> float:
xA = max(A[0], B[0])
yA = max(A[1], B[1])
zA = max(A[2], B[2])
xB = min(A[3], B[3])
yB = min(A[4], B[4])
zB = min(A[5], B[5])

interVol = max(0, xB - xA) * max(0, yB - yA) * max(0, zB - zA)
if interVol == 0:
return 0.0

# compute the volume of both the prediction and ground-truth rectangular boxes
boxAVol = (A[3] - A[0]) * (A[4] - A[1]) * (A[5] - A[2])
boxBVol = (B[3] - B[0]) * (B[4] - B[1]) * (B[5] - B[2])

iou = interVol / float(boxAVol + boxBVol - interVol)
return iou


def prefilter_boxes(boxes, scores, labels, weights, thr):
# Create dict with boxes stored by its label
new_boxes = dict()

for t in range(len(boxes)):

if len(boxes[t]) != len(scores[t]):
print('Error. Length of boxes arrays not equal to length of scores array: {} != {}'.format(len(boxes[t]), len(scores[t])))
exit()

if len(boxes[t]) != len(labels[t]):
print('Error. Length of boxes arrays not equal to length of labels array: {} != {}'.format(len(boxes[t]), len(labels[t])))
exit()

for j in range(len(boxes[t])):
score = scores[t][j]
if score < thr:
continue
label = int(labels[t][j])
box_part = boxes[t][j]
x1 = float(box_part[0])
y1 = float(box_part[1])
z1 = float(box_part[2])
x2 = float(box_part[3])
y2 = float(box_part[4])
z2 = float(box_part[5])

# Box data checks
if x2 < x1:
warnings.warn('X2 < X1 value in box. Swap them.')
x1, x2 = x2, x1
if y2 < y1:
warnings.warn('Y2 < Y1 value in box. Swap them.')
y1, y2 = y2, y1
if z2 < z1:
warnings.warn('Z2 < Z1 value in box. Swap them.')
z1, z2 = z2, z1
if x1 < 0:
warnings.warn('X1 < 0 in box. Set it to 0.')
x1 = 0
if x1 > 1:
warnings.warn('X1 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
x1 = 1
if x2 < 0:
warnings.warn('X2 < 0 in box. Set it to 0.')
x2 = 0
if x2 > 1:
warnings.warn('X2 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
x2 = 1
if y1 < 0:
warnings.warn('Y1 < 0 in box. Set it to 0.')
y1 = 0
if y1 > 1:
warnings.warn('Y1 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
y1 = 1
if y2 < 0:
warnings.warn('Y2 < 0 in box. Set it to 0.')
y2 = 0
if y2 > 1:
warnings.warn('Y2 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
y2 = 1
if z1 < 0:
warnings.warn('Z1 < 0 in box. Set it to 0.')
z1 = 0
if z1 > 1:
warnings.warn('Z1 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
z1 = 1
if z2 < 0:
warnings.warn('Z2 < 0 in box. Set it to 0.')
z2 = 0
if z2 > 1:
warnings.warn('Z2 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
z2 = 1
if (x2 - x1) * (y2 - y1) * (z2 - z1) == 0.0:
warnings.warn("Zero volume box skipped: {}.".format(box_part))
continue

b = [int(label), float(score) * weights[t], x1, y1, z1, x2, y2, z2]
if label not in new_boxes:
new_boxes[label] = []
new_boxes[label].append(b)

# Sort each list in dict by score and transform it to numpy array
for k in new_boxes:
current_boxes = np.array(new_boxes[k])
new_boxes[k] = current_boxes[current_boxes[:, 1].argsort()[::-1]]

return new_boxes


def get_weighted_box(boxes, conf_type='avg'):
"""
Create weighted box for set of boxes
:param boxes: set of boxes to fuse
:param conf_type: type of confidence one of 'avg' or 'max'
:return: weighted box
"""

box = np.zeros(8, dtype=np.float32)
conf = 0
conf_list = []
for b in boxes:
box[2:] += (b[1] * b[2:])
conf += b[1]
conf_list.append(b[1])
box[0] = boxes[0][0]
if conf_type == 'avg':
box[1] = conf / len(boxes)
elif conf_type == 'max':
box[1] = np.array(conf_list).max()
box[2:] /= conf
return box


def find_matching_box(boxes_list, new_box, match_iou):
best_iou = match_iou
best_index = -1
for i in range(len(boxes_list)):
box = boxes_list[i]
if box[0] != new_box[0]:
continue
iou = bb_intersection_over_union_3d(box[2:], new_box[2:])
if iou > best_iou:
best_index = i
best_iou = iou

return best_index, best_iou


def weighted_boxes_fusion_3d(boxes_list, scores_list, labels_list, weights=None, iou_thr=0.55, skip_box_thr=0.0, conf_type='avg', allows_overflow=False):
'''
:param boxes_list: list of boxes predictions from each model, each box is 6 numbers.
It has 3 dimensions (models_number, model_preds, 6)
Order of boxes: x1, y1, z1, x2, y2 z2. We expect float normalized coordinates [0; 1]
:param scores_list: list of scores for each model
:param labels_list: list of labels for each model
:param weights: list of weights for each model. Default: None, which means weight == 1 for each model
:param iou_thr: IoU value for boxes to be a match
:param skip_box_thr: exclude boxes with score lower than this variable
:param conf_type: how to calculate confidence in weighted boxes. 'avg': average value, 'max': maximum value
:param allows_overflow: false if we want confidence score not exceed 1.0
:return: boxes: boxes coordinates (Order of boxes: x1, y1, z1, x2, y2, z2).
:return: scores: confidence scores
:return: labels: boxes labels
'''

if weights is None:
weights = np.ones(len(boxes_list))
if len(weights) != len(boxes_list):
print('Warning: incorrect number of weights {}. Must be: {}. Set weights equal to 1.'.format(len(weights), len(boxes_list)))
weights = np.ones(len(boxes_list))
weights = np.array(weights)

if conf_type not in ['avg', 'max']:
print('Error. Unknown conf_type: {}. Must be "avg" or "max". Use "avg"'.format(conf_type))
conf_type = 'avg'

filtered_boxes = prefilter_boxes(boxes_list, scores_list, labels_list, weights, skip_box_thr)
if len(filtered_boxes) == 0:
return np.zeros((0, 6)), np.zeros((0,)), np.zeros((0,))

overall_boxes = []
for label in filtered_boxes:
boxes = filtered_boxes[label]
new_boxes = []
weighted_boxes = []

# Clusterize boxes
for j in range(0, len(boxes)):
index, best_iou = find_matching_box(weighted_boxes, boxes[j], iou_thr)
if index != -1:
new_boxes[index].append(boxes[j])
weighted_boxes[index] = get_weighted_box(new_boxes[index], conf_type)
else:
new_boxes.append([boxes[j].copy()])
weighted_boxes.append(boxes[j].copy())

# Rescale confidence based on number of models and boxes
for i in range(len(new_boxes)):
if not allows_overflow:
weighted_boxes[i][1] = weighted_boxes[i][1] * min(weights.sum(), len(new_boxes[i])) / weights.sum()
else:
weighted_boxes[i][1] = weighted_boxes[i][1] * len(new_boxes[i]) / weights.sum()
overall_boxes.append(np.array(weighted_boxes))

overall_boxes = np.concatenate(overall_boxes, axis=0)
overall_boxes = overall_boxes[overall_boxes[:, 1].argsort()[::-1]]
boxes = overall_boxes[:, 2:]
scores = overall_boxes[:, 1]
labels = overall_boxes[:, 0]
return boxes, scores, labels
142 changes: 142 additions & 0 deletions example_3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# coding: utf-8
__author__ = 'ZFTurbo: https://kaggle.com/zfturbo'


import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from ensemble_boxes import *


def plot_cube(ax, cube_definition, lbl, thickness):
cube_definition_array = [
np.array(list(item))
for item in cube_definition
]

points = []
points += cube_definition_array
vectors = [
cube_definition_array[1] - cube_definition_array[0],
cube_definition_array[2] - cube_definition_array[0],
cube_definition_array[3] - cube_definition_array[0]
]

points += [cube_definition_array[0] + vectors[0] + vectors[1]]
points += [cube_definition_array[0] + vectors[0] + vectors[2]]
points += [cube_definition_array[0] + vectors[1] + vectors[2]]
points += [cube_definition_array[0] + vectors[0] + vectors[1] + vectors[2]]

points = np.array(points)

edges = [
[points[0], points[3], points[5], points[1]],
[points[1], points[5], points[7], points[4]],
[points[4], points[2], points[6], points[7]],
[points[2], points[6], points[3], points[0]],
[points[0], points[2], points[4], points[1]],
[points[3], points[6], points[7], points[5]]
]

faces = Poly3DCollection(edges, linewidths=thickness + 1)
if lbl == 0:
faces.set_edgecolor((1, 0, 0))
else:
faces.set_edgecolor((0, 0, 1))
faces.set_facecolor((0, 0, 1, 0.1))

ax.add_collection3d(faces)
ax.scatter(points[:, 0], points[:, 1], points[:, 2], s=0)


def show_boxes(boxes_list, scores_list, labels_list, image_size=800):
image = np.zeros((image_size, image_size, 3), dtype=np.uint8)
image[...] = 255
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

for i in range(len(boxes_list)):
for j in range(len(boxes_list[i])):
x1 = int(image_size * boxes_list[i][j][0])
y1 = int(image_size * boxes_list[i][j][1])
z1 = int(image_size * boxes_list[i][j][2])
x2 = int(image_size * boxes_list[i][j][3])
y2 = int(image_size * boxes_list[i][j][4])
z2 = int(image_size * boxes_list[i][j][5])
lbl = labels_list[i][j]
cube_definition = [
(x1, y1, z1), (x1, y2, z1), (x2, y1, z1), (x1, y1, z2)
]
plot_cube(ax, cube_definition, lbl, int(4 * scores_list[i][j]))

plt.show()


def example_wbf_3d_2_models(iou_thr=0.55, draw_image=True):
"""
This example shows how to ensemble boxes from 2 models using WBF_3D method
:return:
"""

boxes_list = [
[
[0.00, 0.51, 0.41, 0.81, 0.91, 0.78],
[0.10, 0.31, 0.45, 0.71, 0.61, 0.85],
[0.01, 0.32, 0.55, 0.83, 0.93, 0.95],
[0.02, 0.53, 0.11, 0.11, 0.94, 0.55],
[0.03, 0.24, 0.34, 0.12, 0.35, 0.67],
],
[
[0.04, 0.56, 0.36, 0.84, 0.92, 0.82],
[0.12, 0.33, 0.46, 0.72, 0.64, 0.75],
[0.38, 0.66, 0.55, 0.79, 0.95, 0.90],
[0.08, 0.49, 0.15, 0.21, 0.89, 0.67],
],
]
scores_list = [
[
0.9,
0.8,
0.2,
0.4,
0.7,
],
[
0.5,
0.8,
0.7,
0.3,
]
]
labels_list = [
[
0,
1,
0,
1,
1,
],
[
1,
1,
1,
0,
]
]
weights = [2, 1]
if draw_image:
show_boxes(boxes_list, scores_list, labels_list)

boxes, scores, labels = weighted_boxes_fusion_3d(boxes_list, scores_list, labels_list, weights=weights, iou_thr=iou_thr, skip_box_thr=0.0)

if draw_image:
show_boxes([boxes], [scores], [labels.astype(np.int32)])

print(len(boxes))
print(boxes)


if __name__ == '__main__':
draw_image = True
example_wbf_3d_2_models(iou_thr=0.2, draw_image=draw_image)

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='ensemble_boxes',
version='1.0.3',
version='1.0.4',
author='Roman Solovyev (ZFTurbo)',
packages=['ensemble_boxes', ],
url='https://github.com/ZFTurbo/Weighted-Boxes-Fusion',
Expand Down

0 comments on commit c789f05

Please sign in to comment.