-
Notifications
You must be signed in to change notification settings - Fork 238
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added support for 3D boxes with function weighted_boxes_fusion_3d
- Loading branch information
Showing
5 changed files
with
370 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,222 @@ | ||
# coding: utf-8 | ||
__author__ = 'ZFTurbo: https://kaggle.com/zfturbo' | ||
|
||
|
||
import warnings | ||
import numpy as np | ||
from numba import jit | ||
|
||
|
||
@jit(nopython=True) | ||
def bb_intersection_over_union_3d(A, B) -> float: | ||
xA = max(A[0], B[0]) | ||
yA = max(A[1], B[1]) | ||
zA = max(A[2], B[2]) | ||
xB = min(A[3], B[3]) | ||
yB = min(A[4], B[4]) | ||
zB = min(A[5], B[5]) | ||
|
||
interVol = max(0, xB - xA) * max(0, yB - yA) * max(0, zB - zA) | ||
if interVol == 0: | ||
return 0.0 | ||
|
||
# compute the volume of both the prediction and ground-truth rectangular boxes | ||
boxAVol = (A[3] - A[0]) * (A[4] - A[1]) * (A[5] - A[2]) | ||
boxBVol = (B[3] - B[0]) * (B[4] - B[1]) * (B[5] - B[2]) | ||
|
||
iou = interVol / float(boxAVol + boxBVol - interVol) | ||
return iou | ||
|
||
|
||
def prefilter_boxes(boxes, scores, labels, weights, thr): | ||
# Create dict with boxes stored by its label | ||
new_boxes = dict() | ||
|
||
for t in range(len(boxes)): | ||
|
||
if len(boxes[t]) != len(scores[t]): | ||
print('Error. Length of boxes arrays not equal to length of scores array: {} != {}'.format(len(boxes[t]), len(scores[t]))) | ||
exit() | ||
|
||
if len(boxes[t]) != len(labels[t]): | ||
print('Error. Length of boxes arrays not equal to length of labels array: {} != {}'.format(len(boxes[t]), len(labels[t]))) | ||
exit() | ||
|
||
for j in range(len(boxes[t])): | ||
score = scores[t][j] | ||
if score < thr: | ||
continue | ||
label = int(labels[t][j]) | ||
box_part = boxes[t][j] | ||
x1 = float(box_part[0]) | ||
y1 = float(box_part[1]) | ||
z1 = float(box_part[2]) | ||
x2 = float(box_part[3]) | ||
y2 = float(box_part[4]) | ||
z2 = float(box_part[5]) | ||
|
||
# Box data checks | ||
if x2 < x1: | ||
warnings.warn('X2 < X1 value in box. Swap them.') | ||
x1, x2 = x2, x1 | ||
if y2 < y1: | ||
warnings.warn('Y2 < Y1 value in box. Swap them.') | ||
y1, y2 = y2, y1 | ||
if z2 < z1: | ||
warnings.warn('Z2 < Z1 value in box. Swap them.') | ||
z1, z2 = z2, z1 | ||
if x1 < 0: | ||
warnings.warn('X1 < 0 in box. Set it to 0.') | ||
x1 = 0 | ||
if x1 > 1: | ||
warnings.warn('X1 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.') | ||
x1 = 1 | ||
if x2 < 0: | ||
warnings.warn('X2 < 0 in box. Set it to 0.') | ||
x2 = 0 | ||
if x2 > 1: | ||
warnings.warn('X2 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.') | ||
x2 = 1 | ||
if y1 < 0: | ||
warnings.warn('Y1 < 0 in box. Set it to 0.') | ||
y1 = 0 | ||
if y1 > 1: | ||
warnings.warn('Y1 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.') | ||
y1 = 1 | ||
if y2 < 0: | ||
warnings.warn('Y2 < 0 in box. Set it to 0.') | ||
y2 = 0 | ||
if y2 > 1: | ||
warnings.warn('Y2 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.') | ||
y2 = 1 | ||
if z1 < 0: | ||
warnings.warn('Z1 < 0 in box. Set it to 0.') | ||
z1 = 0 | ||
if z1 > 1: | ||
warnings.warn('Z1 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.') | ||
z1 = 1 | ||
if z2 < 0: | ||
warnings.warn('Z2 < 0 in box. Set it to 0.') | ||
z2 = 0 | ||
if z2 > 1: | ||
warnings.warn('Z2 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.') | ||
z2 = 1 | ||
if (x2 - x1) * (y2 - y1) * (z2 - z1) == 0.0: | ||
warnings.warn("Zero volume box skipped: {}.".format(box_part)) | ||
continue | ||
|
||
b = [int(label), float(score) * weights[t], x1, y1, z1, x2, y2, z2] | ||
if label not in new_boxes: | ||
new_boxes[label] = [] | ||
new_boxes[label].append(b) | ||
|
||
# Sort each list in dict by score and transform it to numpy array | ||
for k in new_boxes: | ||
current_boxes = np.array(new_boxes[k]) | ||
new_boxes[k] = current_boxes[current_boxes[:, 1].argsort()[::-1]] | ||
|
||
return new_boxes | ||
|
||
|
||
def get_weighted_box(boxes, conf_type='avg'): | ||
""" | ||
Create weighted box for set of boxes | ||
:param boxes: set of boxes to fuse | ||
:param conf_type: type of confidence one of 'avg' or 'max' | ||
:return: weighted box | ||
""" | ||
|
||
box = np.zeros(8, dtype=np.float32) | ||
conf = 0 | ||
conf_list = [] | ||
for b in boxes: | ||
box[2:] += (b[1] * b[2:]) | ||
conf += b[1] | ||
conf_list.append(b[1]) | ||
box[0] = boxes[0][0] | ||
if conf_type == 'avg': | ||
box[1] = conf / len(boxes) | ||
elif conf_type == 'max': | ||
box[1] = np.array(conf_list).max() | ||
box[2:] /= conf | ||
return box | ||
|
||
|
||
def find_matching_box(boxes_list, new_box, match_iou): | ||
best_iou = match_iou | ||
best_index = -1 | ||
for i in range(len(boxes_list)): | ||
box = boxes_list[i] | ||
if box[0] != new_box[0]: | ||
continue | ||
iou = bb_intersection_over_union_3d(box[2:], new_box[2:]) | ||
if iou > best_iou: | ||
best_index = i | ||
best_iou = iou | ||
|
||
return best_index, best_iou | ||
|
||
|
||
def weighted_boxes_fusion_3d(boxes_list, scores_list, labels_list, weights=None, iou_thr=0.55, skip_box_thr=0.0, conf_type='avg', allows_overflow=False): | ||
''' | ||
:param boxes_list: list of boxes predictions from each model, each box is 6 numbers. | ||
It has 3 dimensions (models_number, model_preds, 6) | ||
Order of boxes: x1, y1, z1, x2, y2 z2. We expect float normalized coordinates [0; 1] | ||
:param scores_list: list of scores for each model | ||
:param labels_list: list of labels for each model | ||
:param weights: list of weights for each model. Default: None, which means weight == 1 for each model | ||
:param iou_thr: IoU value for boxes to be a match | ||
:param skip_box_thr: exclude boxes with score lower than this variable | ||
:param conf_type: how to calculate confidence in weighted boxes. 'avg': average value, 'max': maximum value | ||
:param allows_overflow: false if we want confidence score not exceed 1.0 | ||
:return: boxes: boxes coordinates (Order of boxes: x1, y1, z1, x2, y2, z2). | ||
:return: scores: confidence scores | ||
:return: labels: boxes labels | ||
''' | ||
|
||
if weights is None: | ||
weights = np.ones(len(boxes_list)) | ||
if len(weights) != len(boxes_list): | ||
print('Warning: incorrect number of weights {}. Must be: {}. Set weights equal to 1.'.format(len(weights), len(boxes_list))) | ||
weights = np.ones(len(boxes_list)) | ||
weights = np.array(weights) | ||
|
||
if conf_type not in ['avg', 'max']: | ||
print('Error. Unknown conf_type: {}. Must be "avg" or "max". Use "avg"'.format(conf_type)) | ||
conf_type = 'avg' | ||
|
||
filtered_boxes = prefilter_boxes(boxes_list, scores_list, labels_list, weights, skip_box_thr) | ||
if len(filtered_boxes) == 0: | ||
return np.zeros((0, 6)), np.zeros((0,)), np.zeros((0,)) | ||
|
||
overall_boxes = [] | ||
for label in filtered_boxes: | ||
boxes = filtered_boxes[label] | ||
new_boxes = [] | ||
weighted_boxes = [] | ||
|
||
# Clusterize boxes | ||
for j in range(0, len(boxes)): | ||
index, best_iou = find_matching_box(weighted_boxes, boxes[j], iou_thr) | ||
if index != -1: | ||
new_boxes[index].append(boxes[j]) | ||
weighted_boxes[index] = get_weighted_box(new_boxes[index], conf_type) | ||
else: | ||
new_boxes.append([boxes[j].copy()]) | ||
weighted_boxes.append(boxes[j].copy()) | ||
|
||
# Rescale confidence based on number of models and boxes | ||
for i in range(len(new_boxes)): | ||
if not allows_overflow: | ||
weighted_boxes[i][1] = weighted_boxes[i][1] * min(weights.sum(), len(new_boxes[i])) / weights.sum() | ||
else: | ||
weighted_boxes[i][1] = weighted_boxes[i][1] * len(new_boxes[i]) / weights.sum() | ||
overall_boxes.append(np.array(weighted_boxes)) | ||
|
||
overall_boxes = np.concatenate(overall_boxes, axis=0) | ||
overall_boxes = overall_boxes[overall_boxes[:, 1].argsort()[::-1]] | ||
boxes = overall_boxes[:, 2:] | ||
scores = overall_boxes[:, 1] | ||
labels = overall_boxes[:, 0] | ||
return boxes, scores, labels |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
# coding: utf-8 | ||
__author__ = 'ZFTurbo: https://kaggle.com/zfturbo' | ||
|
||
|
||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
from mpl_toolkits.mplot3d.art3d import Poly3DCollection | ||
from ensemble_boxes import * | ||
|
||
|
||
def plot_cube(ax, cube_definition, lbl, thickness): | ||
cube_definition_array = [ | ||
np.array(list(item)) | ||
for item in cube_definition | ||
] | ||
|
||
points = [] | ||
points += cube_definition_array | ||
vectors = [ | ||
cube_definition_array[1] - cube_definition_array[0], | ||
cube_definition_array[2] - cube_definition_array[0], | ||
cube_definition_array[3] - cube_definition_array[0] | ||
] | ||
|
||
points += [cube_definition_array[0] + vectors[0] + vectors[1]] | ||
points += [cube_definition_array[0] + vectors[0] + vectors[2]] | ||
points += [cube_definition_array[0] + vectors[1] + vectors[2]] | ||
points += [cube_definition_array[0] + vectors[0] + vectors[1] + vectors[2]] | ||
|
||
points = np.array(points) | ||
|
||
edges = [ | ||
[points[0], points[3], points[5], points[1]], | ||
[points[1], points[5], points[7], points[4]], | ||
[points[4], points[2], points[6], points[7]], | ||
[points[2], points[6], points[3], points[0]], | ||
[points[0], points[2], points[4], points[1]], | ||
[points[3], points[6], points[7], points[5]] | ||
] | ||
|
||
faces = Poly3DCollection(edges, linewidths=thickness + 1) | ||
if lbl == 0: | ||
faces.set_edgecolor((1, 0, 0)) | ||
else: | ||
faces.set_edgecolor((0, 0, 1)) | ||
faces.set_facecolor((0, 0, 1, 0.1)) | ||
|
||
ax.add_collection3d(faces) | ||
ax.scatter(points[:, 0], points[:, 1], points[:, 2], s=0) | ||
|
||
|
||
def show_boxes(boxes_list, scores_list, labels_list, image_size=800): | ||
image = np.zeros((image_size, image_size, 3), dtype=np.uint8) | ||
image[...] = 255 | ||
fig = plt.figure() | ||
ax = fig.add_subplot(111, projection='3d') | ||
|
||
for i in range(len(boxes_list)): | ||
for j in range(len(boxes_list[i])): | ||
x1 = int(image_size * boxes_list[i][j][0]) | ||
y1 = int(image_size * boxes_list[i][j][1]) | ||
z1 = int(image_size * boxes_list[i][j][2]) | ||
x2 = int(image_size * boxes_list[i][j][3]) | ||
y2 = int(image_size * boxes_list[i][j][4]) | ||
z2 = int(image_size * boxes_list[i][j][5]) | ||
lbl = labels_list[i][j] | ||
cube_definition = [ | ||
(x1, y1, z1), (x1, y2, z1), (x2, y1, z1), (x1, y1, z2) | ||
] | ||
plot_cube(ax, cube_definition, lbl, int(4 * scores_list[i][j])) | ||
|
||
plt.show() | ||
|
||
|
||
def example_wbf_3d_2_models(iou_thr=0.55, draw_image=True): | ||
""" | ||
This example shows how to ensemble boxes from 2 models using WBF_3D method | ||
:return: | ||
""" | ||
|
||
boxes_list = [ | ||
[ | ||
[0.00, 0.51, 0.41, 0.81, 0.91, 0.78], | ||
[0.10, 0.31, 0.45, 0.71, 0.61, 0.85], | ||
[0.01, 0.32, 0.55, 0.83, 0.93, 0.95], | ||
[0.02, 0.53, 0.11, 0.11, 0.94, 0.55], | ||
[0.03, 0.24, 0.34, 0.12, 0.35, 0.67], | ||
], | ||
[ | ||
[0.04, 0.56, 0.36, 0.84, 0.92, 0.82], | ||
[0.12, 0.33, 0.46, 0.72, 0.64, 0.75], | ||
[0.38, 0.66, 0.55, 0.79, 0.95, 0.90], | ||
[0.08, 0.49, 0.15, 0.21, 0.89, 0.67], | ||
], | ||
] | ||
scores_list = [ | ||
[ | ||
0.9, | ||
0.8, | ||
0.2, | ||
0.4, | ||
0.7, | ||
], | ||
[ | ||
0.5, | ||
0.8, | ||
0.7, | ||
0.3, | ||
] | ||
] | ||
labels_list = [ | ||
[ | ||
0, | ||
1, | ||
0, | ||
1, | ||
1, | ||
], | ||
[ | ||
1, | ||
1, | ||
1, | ||
0, | ||
] | ||
] | ||
weights = [2, 1] | ||
if draw_image: | ||
show_boxes(boxes_list, scores_list, labels_list) | ||
|
||
boxes, scores, labels = weighted_boxes_fusion_3d(boxes_list, scores_list, labels_list, weights=weights, iou_thr=iou_thr, skip_box_thr=0.0) | ||
|
||
if draw_image: | ||
show_boxes([boxes], [scores], [labels.astype(np.int32)]) | ||
|
||
print(len(boxes)) | ||
print(boxes) | ||
|
||
|
||
if __name__ == '__main__': | ||
draw_image = True | ||
example_wbf_3d_2_models(iou_thr=0.2, draw_image=draw_image) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters