Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OD model type support in raiwidgets for EA #2229

Closed
wants to merge 10 commits into from
1 change: 1 addition & 0 deletions raiwidgets/raiwidgets/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ class ModelTask(str, Enum):

CLASSIFICATION = 'classification'
REGRESSION = 'regression'
OBJECT_DETECTION = 'object_detection'
UNKNOWN = 'unknown'


Expand Down
2 changes: 2 additions & 0 deletions raiwidgets/raiwidgets/error_analysis_dashboard_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,8 @@ def setup_local(self, explanation, model, dataset, true_y, classes,
metric = Metrics.ERROR_RATE
else:
metric = self._error_analyzer.metric
elif self._error_analyzer.model_task == ModelTask.OBJECT_DETECTION:
Advitya17 marked this conversation as resolved.
Show resolved Hide resolved
metric = Metrics.ERROR_RATE
else:
if self._error_analyzer.metric is None:
metric = Metrics.MEAN_SQUARED_ERROR
Expand Down
141 changes: 141 additions & 0 deletions raiwidgets/tests/test_error_analysis_dashboard.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
# Copyright (c) Microsoft Corporation
# Licensed under the MIT License.

import os
from zipfile import ZipFile

import numpy as np
import pandas as pd
import pytest
import shap
import sklearn
from interpret.ext.blackbox import MimicExplainer
Expand All @@ -11,6 +15,17 @@
from sklearn.datasets import load_iris, make_classification
from sklearn.model_selection import train_test_split

try:
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
torch_installed = True
except ImportError:
torch_installed = False

import urllib.request as request_file
import xml.etree.ElementTree as ET

from erroranalysis._internal.constants import Metrics, metric_to_display_name
from erroranalysis._internal.surrogate_error_tree import (
DEFAULT_MAX_DEPTH, DEFAULT_MIN_CHILD_SAMPLES, DEFAULT_NUM_LEAVES)
Expand Down Expand Up @@ -158,6 +173,132 @@ def test_error_analysis_adult_census_numeric_feature_names(self):

run_error_analysis_adult_census(X, y, cat_idxs)

@pytest.mark.skipif(not torch_installed,
reason="requires torch & torchvision")
def test_error_analysis_fridge_object_detection(self):
model = get_object_detection_model()
dataset = load_fridge_object_detection_dataset()
classes = np.array(['can', 'carton', 'milk_bottle', 'water_bottle'])

X_test = dataset[["image"]]
y_test = dataset[["label"]]
ErrorAnalysisDashboard(model=model, dataset=X_test,
true_y=y_test, classes=classes)


def get_object_detection_model():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems these methods below are duplicated from https://github.com/microsoft/responsible-ai-toolbox/blob/main/responsibleai_vision/tests/common_vision_utils.py#L562 - perhaps we can move them to a more common place like rai_test_utils package?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, made a PR for it - #2246

# download fine-tuned recycling model from url
def download_assets(filepath, force=False):
if force or not os.path.exists(filepath):
request_file.urlretrieve(
"https://publictestdatasets.blob.core.windows.net\
/models/fastrcnn.pt",
os.path.join(filepath))
else:
print('Found' + filepath)

return filepath

def get_instance_segmentation_model(num_classes):
# load an instance segmentation model pre-trained on COCO
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
pretrained=True
)
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(
in_features,
num_classes
)
return model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

num_classes = 5
model = get_instance_segmentation_model(num_classes)
_ = download_assets('Recycling_finetuned_FastRCNN.pt')
model.load_state_dict(
torch.load('Recycling_finetuned_FastRCNN.pt',
map_location=device
)
)

model.to(device)


def load_fridge_object_detection_dataset_labels():

src_images = "./data/odFridgeObjects/"

# Path to the annotations
annotations_folder = os.path.join(src_images, "annotations")

labels = []
label_dict = {'can': 1, 'carton': 2, 'milk_bottle': 3, 'water_bottle': 4}

# Read each annotation
for _, filename in enumerate(os.listdir(annotations_folder)):
if filename.endswith(".xml"):
print("Parsing " + os.path.join(src_images, filename))

root = ET.parse(
os.path.join(annotations_folder, filename)
).getroot()

# use if needed
# width = int(root.find("size/width").text)
# height = int(root.find("size/height").text)

image_labels = []
for object in root.findall("object"):
name = object.find("name").text
xmin = object.find("bndbox/xmin").text
ymin = object.find("bndbox/ymin").text
xmax = object.find("bndbox/xmax").text
ymax = object.find("bndbox/ymax").text
isCrowd = int(object.find("difficult").text)
image_labels.append([
label_dict[name], # label
float(xmin), # topX. To normalize, divide by width.
float(ymin), # topY. To normalize, divide by height.
float(xmax), # bottomX. To normalize, divide by width
float(ymax), # bottomY. To normalize, divide by height
int(isCrowd)
])
labels.append(image_labels)

return labels


def load_fridge_object_detection_dataset():
# create data folder if it doesnt exist.
os.makedirs("data", exist_ok=True)

# download data
download_url = ("https://cvbp-secondary.z19.web.core.windows.net/" +
"datasets/object_detection/odFridgeObjects.zip")
data_file = "./odFridgeObjects.zip"
request_file.urlretrieve(download_url, filename=data_file)

# extract files
with ZipFile(data_file, "r") as zip:
print("extracting files...")
zip.extractall(path="./data")
print("done")
# delete zip file
os.remove(data_file)

labels = load_fridge_object_detection_dataset_labels()

# get all file names into a pandas dataframe with the labels
data = pd.DataFrame(columns=["image", "label"])
for i, file in enumerate(os.listdir("./data/odFridgeObjects/" + "images")):
image_path = "./data/odFridgeObjects/" + "images" + "/" + file
data = data.append({"image": image_path,
"label": labels[i]}, # folder
ignore_index=True)

return data


def run_error_analysis_adult_census(X, y, categorical_features):
X, y = sklearn.utils.resample(
Expand Down
Loading