Skip to content

Commit

Permalink
Added numba @jit(nopython=True) for critical functions (NMS, NMW and …
Browse files Browse the repository at this point in the history
…WBF). Speed up around x2 times (tested on example_oid.py).
  • Loading branch information
IDMIPPM committed Jun 16, 2020
1 parent 86e0717 commit dace6dc
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 5 deletions.
6 changes: 4 additions & 2 deletions ensemble_boxes/ensemble_boxes_nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
__author__ = 'ZFTurbo: https://kaggle.com/zfturbo'

import numpy as np
from numba import jit


def prepare_boxes(boxes, scores, labels):
Expand Down Expand Up @@ -48,7 +49,7 @@ def cpu_soft_nms_float(dets, sc, Nt, sigma, thresh, method):
:param sigma:
:param thresh:
:param method: 1 - linear soft-NMS, 2 - gaussian soft-NMS, 3 - standard NMS
:return: index of boxes to keep
:return: index of boxes to keep
"""

# indexes concatenate boxes with the last column
Expand Down Expand Up @@ -120,12 +121,13 @@ def cpu_soft_nms_float(dets, sc, Nt, sigma, thresh, method):
return keep


@jit(nopython=True)
def nms_float_fast(dets, scores, thresh):
"""
# It's different from original nms because we have float coordinates on range [0; 1]
:param dets: numpy array of boxes with shape: (N, 5). Order: x1, y1, x2, y2, score. All variables in range [0; 1]
:param thresh: IoU value for boxes
:return:
:return: index of boxes to keep
"""
x1 = dets[:, 0]
y1 = dets[:, 1]
Expand Down
2 changes: 2 additions & 0 deletions ensemble_boxes/ensemble_boxes_nmw.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@

import warnings
import numpy as np
from numba import jit


@jit(nopython=True)
def bb_intersection_over_union(A, B):
xA = max(A[0], B[0])
yA = max(A[1], B[1])
Expand Down
4 changes: 3 additions & 1 deletion ensemble_boxes/ensemble_boxes_wbf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

import warnings
import numpy as np
from numba import jit


def bb_intersection_over_union(A, B):
@jit(nopython=True)
def bb_intersection_over_union(A, B) -> float:
xA = max(A[0], B[0])
yA = max(A[1], B[1])
xB = min(A[2], B[2])
Expand Down
4 changes: 3 additions & 1 deletion example_oid.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,10 @@ def ensemble_predictions(pred_filenames, weights, params):
mean_ap, average_precisions = mean_average_precision_for_boxes(ann, det, verbose=False)
print("File: {} mAP: {:.6f}".format(os.path.basename(pred_list[i]), mean_ap))

start_time = time.time()
ensemble_preds = ensemble_predictions(pred_list, weights, params)
ensemble_preds.to_csv("test_data/debug.csv", index=False)
print("Overall ensemble time for method: {}: {:.2f} sec".format(params['run_type'], time.time() - start_time))
ensemble_preds.to_csv("test_data/debug_{}.csv".format(params['run_type']), index=False)
ensemble_preds = ensemble_preds[['ImageId', 'LabelName', 'Conf', 'XMin', 'XMax', 'YMin', 'YMax']].values
mean_ap, average_precisions = mean_average_precision_for_boxes(ann, ensemble_preds, verbose=True)
print("Ensemble [{}] Weights: {} Params: {} mAP: {:.6f}".format(len(weights), weights, params, mean_ap))
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='ensemble_boxes',
version='1.0.2',
version='1.0.3',
author='Roman Solovyev (ZFTurbo)',
packages=['ensemble_boxes', ],
url='https://github.com/ZFTurbo/Weighted-Boxes-Fusion',
Expand All @@ -16,5 +16,6 @@
install_requires=[
"numpy",
"pandas",
"numba",
],
)

0 comments on commit dace6dc

Please sign in to comment.