Skip to content

Commit

Permalink
fix wrong metrics count missing in vision explorer of RAI Vision dash…
Browse files Browse the repository at this point in the history
…board for object detection (#2495)
  • Loading branch information
imatiach-msft authored Jan 19, 2024
1 parent bee87b5 commit 946115b
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import pickle
import shutil
import warnings
from collections import defaultdict
from enum import Enum
from pathlib import Path
from typing import Any, Optional
Expand All @@ -22,8 +21,6 @@
from ml_wrappers import wrap_model
from ml_wrappers.common.constants import Device
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from vision_explanation_methods.error_labeling.error_labeling import (
ErrorLabeling, ErrorLabelType)

from erroranalysis._internal.cohort_filter import FilterDataWithCohortFilters
from raiutils.data_processing import convert_to_list
Expand All @@ -47,7 +44,8 @@
from responsibleai_vision.utils.image_reader import (
get_base64_string_from_path, get_image_from_path, is_automl_image_model)
from responsibleai_vision.utils.image_utils import (
convert_images, get_images, transform_object_detection_labels)
convert_images, generate_od_error_labels, get_images,
transform_object_detection_labels)

IMAGE = ImageColumns.IMAGE.value
IMAGE_URL = ImageColumns.IMAGE_URL.value
Expand Down Expand Up @@ -85,10 +83,6 @@
_TIME_SERIES_ID_FEATURES = 'time_series_id_features'
_CATEGORICAL_FEATURES = 'categorical_features'
_DROPPED_FEATURES = 'dropped_features'
_INCORRECT = 'incorrect'
_CORRECT = 'correct'
_AGGREGATE_LABEL = 'aggregate'
_NOLABEL = '(none)'


def reshape_image(image):
Expand Down Expand Up @@ -701,76 +695,14 @@ def _get_dataset(self):
)

dashboard_dataset.object_detection_labels = \
self._generate_od_error_labels(
generate_od_error_labels(
dashboard_dataset.object_detection_true_y,
dashboard_dataset.object_detection_predicted_y,
class_names=dashboard_dataset.class_names
)

return dashboard_dataset

def _generate_od_error_labels(self, true_y, pred_y, class_names):
"""Utilized Error Labeling to generate labels
with correct and incorrect objects.
:param true_y: The true labels.
:type true_y: list
:param pred_y: The predicted labels.
:type pred_y: list
:param class_names: The class labels in the dataset.
:type class_names: list
:return: The aggregated labels.
:rtype: List[str]
"""
object_detection_labels = []
for image_idx in range(len(true_y)):
image_labels = defaultdict(lambda: defaultdict(int))
rendered_labels = {}
error_matrix = ErrorLabeling(
ModelTask.OBJECT_DETECTION,
pred_y[image_idx],
true_y[image_idx]
).compute_error_labels()

for label_idx in range(len(error_matrix)):
object_label = class_names[
int(true_y[image_idx][label_idx][0] - 1)]
if ErrorLabelType.MATCH in error_matrix[label_idx]:
image_labels[_CORRECT][object_label] += 1
else:
image_labels[_INCORRECT][object_label] += 1

duplicate_detections = np.count_nonzero(
error_matrix[label_idx] ==
ErrorLabelType.DUPLICATE_DETECTION)
if duplicate_detections > 0:
image_labels[_INCORRECT][object_label] += \
duplicate_detections

correct_labels = sorted(image_labels[_CORRECT].items(),
key=lambda x: class_names.index(x[0]))
incorrect_labels = sorted(image_labels[_INCORRECT].items(),
key=lambda x: class_names.index(x[0]))

rendered_labels[_CORRECT] = ', '.join(
f'{value} {key}' for key, value in
correct_labels)
if len(rendered_labels[_CORRECT]) == 0:
rendered_labels[_CORRECT] = _NOLABEL
rendered_labels[_INCORRECT] = ', '.join(
f'{value} {key}' for key, value in
incorrect_labels)
if len(rendered_labels[_INCORRECT]) == 0:
rendered_labels[_INCORRECT] = _NOLABEL
rendered_labels[_AGGREGATE_LABEL] = \
f"{sum(image_labels[_CORRECT].values())} {_CORRECT}, \
{sum(image_labels[_INCORRECT].values())} \
{_INCORRECT}"

object_detection_labels.append(rendered_labels)

return object_detection_labels

def _format_od_labels(self, y, class_names):
"""Formats the Object Detection label representation to
multi-label image classification to follow the UI format
Expand Down
69 changes: 68 additions & 1 deletion responsibleai_vision/responsibleai_vision/utils/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@

"""Contains image handling utilities."""

from collections import defaultdict

import numpy as np
from vision_explanation_methods.error_labeling.error_labeling import (
ErrorLabeling, ErrorLabelType)

from responsibleai_vision.common.constants import ImageColumns
from responsibleai_vision.common.constants import ImageColumns, ModelTask
from responsibleai_vision.utils.image_reader import get_image_from_path

IMAGE = ImageColumns.IMAGE.value
Expand All @@ -19,6 +23,10 @@
BOTTOM_X = 'bottomX'
BOTTOM_Y = 'bottomY'
IS_CROWD = 'isCrowd'
_INCORRECT = 'incorrect'
_CORRECT = 'correct'
_AGGREGATE_LABEL = 'aggregate'
_NOLABEL = '(none)'


def convert_images(dataset, image_mode):
Expand Down Expand Up @@ -141,3 +149,62 @@ def transform_object_detection_labels(test, target_column, classes):
err = invalid_msg + 'Image details and label must be present'
raise ValueError(err)
return test


def generate_od_error_labels(true_y, pred_y, class_names):
"""Utilized Error Labeling to generate labels
with correct and incorrect objects.
:param true_y: The true labels.
:type true_y: list
:param pred_y: The predicted labels.
:type pred_y: list
:param class_names: The class labels in the dataset.
:type class_names: list
:return: The aggregated labels.
:rtype: List[str]
"""
object_detection_labels = []
for image_idx in range(len(true_y)):
image_labels = defaultdict(lambda: defaultdict(int))
rendered_labels = {}
error_matrix = ErrorLabeling(
ModelTask.OBJECT_DETECTION,
pred_y[image_idx],
true_y[image_idx]
).compute_error_labels()
for label_idx in range(len(error_matrix)):
object_label = class_names[
int(true_y[image_idx][label_idx][0] - 1)]
if ErrorLabelType.MATCH in error_matrix[label_idx]:
image_labels[_CORRECT][object_label] += 1
else:
image_labels[_INCORRECT][object_label] += 1

duplicate_detections = np.count_nonzero(
error_matrix[label_idx] ==
ErrorLabelType.DUPLICATE_DETECTION)
if duplicate_detections > 0:
image_labels[_INCORRECT][object_label] += \
duplicate_detections
correct_labels = sorted(image_labels[_CORRECT].items(),
key=lambda x: class_names.index(x[0]))
incorrect_labels = sorted(image_labels[_INCORRECT].items(),
key=lambda x: class_names.index(x[0]))
rendered_labels[_CORRECT] = ', '.join(
f'{value} {key}' for key, value in
correct_labels)
if len(rendered_labels[_CORRECT]) == 0:
rendered_labels[_CORRECT] = _NOLABEL
rendered_labels[_INCORRECT] = ', '.join(
f'{value} {key}' for key, value in
incorrect_labels)
if len(rendered_labels[_INCORRECT]) == 0:
rendered_labels[_INCORRECT] = _NOLABEL
num_correct = sum(image_labels[_CORRECT].values())
num_incorrect = sum(image_labels[_INCORRECT].values())
agg_label = f"{num_correct} {_CORRECT}, {num_incorrect} {_INCORRECT}"
rendered_labels[_AGGREGATE_LABEL] = agg_label
object_detection_labels.append(rendered_labels)

return object_detection_labels
35 changes: 34 additions & 1 deletion responsibleai_vision/tests/test_image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
_requests_sessions as image_reader_requests_sessions
from responsibleai_vision.utils.image_reader import get_all_exif_feature_names
from responsibleai_vision.utils.image_utils import (
BOTTOM_X, BOTTOM_Y, HEIGHT, IS_CROWD, TOP_X, TOP_Y, WIDTH, classes_to_dict,
_NOLABEL, BOTTOM_X, BOTTOM_Y, HEIGHT, IS_CROWD, TOP_X, TOP_Y, WIDTH,
classes_to_dict, generate_od_error_labels,
transform_object_detection_labels)

LABEL = ImageColumns.LABEL.value
Expand Down Expand Up @@ -99,3 +100,35 @@ def test_get_all_exif_feature_names(self):
set(['Orientation', 'ExifOffset', 'ImageWidth', 'GPSInfo',
'Model', 'DateTime', 'YCbCrPositioning', 'ImageLength',
'ResolutionUnit', 'Software', 'Make'])

def test_generate_od_error_labels(self):
true_y = np.array([[[3, 142, 257, 395, 463, 0]],
[[3, 107, 272, 240, 501, 0],
[1, 261, 274, 393, 449, 0]],
[[4, 139, 253, 339, 506, 0]],
[[2, 100, 173, 233, 521, 0]],
[[1, 175, 253, 355, 416, 0]],
[[2, 86, 102, 216, 439, 0],
[3, 150, 377, 445, 490, 0]],
[[3, 103, 272, 358, 475, 0]],
[[4, 65, 289, 436, 414, 0]],
[[1, 130, 271, 367, 467, 0]],
[[1, 144, 260, 318, 429, 0]]])
pred_y = np.array([[[3, 140, 260, 396, 469, 0]],
[[3, 108, 270, 237, 505, 0],
[1, 259, 271, 401, 450, 0]],
[[4, 131, 250, 330, 485, 0]],
[[2, 97, 170, 241, 516, 0]],
[[1, 175, 250, 354, 414, 0]],
[[2, 83, 98, 222, 445, 0],
[3, 130, 366, 438, 496, 0]],
[[3, 104, 265, 360, 468, 0]],
[[4, 58, 284, 483, 420, 0]],
[[1, 128, 265, 367, 471, 0]],
[[1, 137, 260, 325, 430, 0]]])
class_names = ["can", "carton", "milk_bottle", "water_bottle"]
error_labels = generate_od_error_labels(true_y, pred_y, class_names)
assert len(error_labels) == 10
assert error_labels[0]['aggregate'] == "1 correct, 0 incorrect"
assert error_labels[0]['correct'] == "1 milk_bottle"
assert error_labels[0]['incorrect'] == _NOLABEL

0 comments on commit 946115b

Please sign in to comment.