From dace6dcd8d4874a4e356543e2206e36440899d17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=A0=D0=BE=D0=BC=D0=B0=D0=BD?= Date: Tue, 16 Jun 2020 13:20:56 +0300 Subject: [PATCH] Added numba @jit(nopython=True) for critical functions (NMS, NMW and WBF). Speed up around x2 times (tested on example_oid.py). --- ensemble_boxes/ensemble_boxes_nms.py | 6 ++++-- ensemble_boxes/ensemble_boxes_nmw.py | 2 ++ ensemble_boxes/ensemble_boxes_wbf.py | 4 +++- example_oid.py | 4 +++- setup.py | 3 ++- 5 files changed, 14 insertions(+), 5 deletions(-) diff --git a/ensemble_boxes/ensemble_boxes_nms.py b/ensemble_boxes/ensemble_boxes_nms.py index 3fe8b67..4d591d0 100644 --- a/ensemble_boxes/ensemble_boxes_nms.py +++ b/ensemble_boxes/ensemble_boxes_nms.py @@ -2,6 +2,7 @@ __author__ = 'ZFTurbo: https://kaggle.com/zfturbo' import numpy as np +from numba import jit def prepare_boxes(boxes, scores, labels): @@ -48,7 +49,7 @@ def cpu_soft_nms_float(dets, sc, Nt, sigma, thresh, method): :param sigma: :param thresh: :param method: 1 - linear soft-NMS, 2 - gaussian soft-NMS, 3 - standard NMS - :return: index of boxes to keep + :return: index of boxes to keep """ # indexes concatenate boxes with the last column @@ -120,12 +121,13 @@ def cpu_soft_nms_float(dets, sc, Nt, sigma, thresh, method): return keep +@jit(nopython=True) def nms_float_fast(dets, scores, thresh): """ # It's different from original nms because we have float coordinates on range [0; 1] :param dets: numpy array of boxes with shape: (N, 5). Order: x1, y1, x2, y2, score. All variables in range [0; 1] :param thresh: IoU value for boxes - :return: + :return: index of boxes to keep """ x1 = dets[:, 0] y1 = dets[:, 1] diff --git a/ensemble_boxes/ensemble_boxes_nmw.py b/ensemble_boxes/ensemble_boxes_nmw.py index a1805a9..6ed4a69 100644 --- a/ensemble_boxes/ensemble_boxes_nmw.py +++ b/ensemble_boxes/ensemble_boxes_nmw.py @@ -9,8 +9,10 @@ import warnings import numpy as np +from numba import jit +@jit(nopython=True) def bb_intersection_over_union(A, B): xA = max(A[0], B[0]) yA = max(A[1], B[1]) diff --git a/ensemble_boxes/ensemble_boxes_wbf.py b/ensemble_boxes/ensemble_boxes_wbf.py index f779832..63d7c3d 100644 --- a/ensemble_boxes/ensemble_boxes_wbf.py +++ b/ensemble_boxes/ensemble_boxes_wbf.py @@ -4,9 +4,11 @@ import warnings import numpy as np +from numba import jit -def bb_intersection_over_union(A, B): +@jit(nopython=True) +def bb_intersection_over_union(A, B) -> float: xA = max(A[0], B[0]) yA = max(A[1], B[1]) xB = min(A[2], B[2]) diff --git a/example_oid.py b/example_oid.py index 14ef49e..201c77e 100644 --- a/example_oid.py +++ b/example_oid.py @@ -314,8 +314,10 @@ def ensemble_predictions(pred_filenames, weights, params): mean_ap, average_precisions = mean_average_precision_for_boxes(ann, det, verbose=False) print("File: {} mAP: {:.6f}".format(os.path.basename(pred_list[i]), mean_ap)) + start_time = time.time() ensemble_preds = ensemble_predictions(pred_list, weights, params) - ensemble_preds.to_csv("test_data/debug.csv", index=False) + print("Overall ensemble time for method: {}: {:.2f} sec".format(params['run_type'], time.time() - start_time)) + ensemble_preds.to_csv("test_data/debug_{}.csv".format(params['run_type']), index=False) ensemble_preds = ensemble_preds[['ImageId', 'LabelName', 'Conf', 'XMin', 'XMax', 'YMin', 'YMax']].values mean_ap, average_precisions = mean_average_precision_for_boxes(ann, ensemble_preds, verbose=True) print("Ensemble [{}] Weights: {} Params: {} mAP: {:.6f}".format(len(weights), weights, params, mean_ap)) diff --git a/setup.py b/setup.py index 5aeb241..a497469 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name='ensemble_boxes', - version='1.0.2', + version='1.0.3', author='Roman Solovyev (ZFTurbo)', packages=['ensemble_boxes', ], url='https://github.com/ZFTurbo/Weighted-Boxes-Fusion', @@ -16,5 +16,6 @@ install_requires=[ "numpy", "pandas", + "numba", ], )