diff --git a/ensemble_boxes/ensemble_boxes_nmw.py b/ensemble_boxes/ensemble_boxes_nmw.py index 6ed4a69..e3c4092 100644 --- a/ensemble_boxes/ensemble_boxes_nmw.py +++ b/ensemble_boxes/ensemble_boxes_nmw.py @@ -164,9 +164,9 @@ def non_maximum_weighted(boxes_list, scores_list, labels_list, weights=None, iou if len(weights) != len(boxes_list): print('Warning: incorrect number of weights {}. Must be: {}. Set weights equal to 1.'.format(len(weights), len(boxes_list))) weights = np.ones(len(boxes_list)) - weights = np.array(weights) - for i in range(len(weights)): - scores_list[i] = (np.array(scores_list[i]) * weights[i]) / weights.sum() + weights = np.array(weights) / max(weights) + # for i in range(len(weights)): + # scores_list[i] = (np.array(scores_list[i]) * weights[i]) filtered_boxes = prefilter_boxes(boxes_list, scores_list, labels_list, weights, skip_box_thr) if len(filtered_boxes) == 0: