diff --git a/rai_test_utils/rai_test_utils/datasets/vision/__init__.py b/rai_test_utils/rai_test_utils/datasets/vision/__init__.py new file mode 100644 index 0000000000..e9a338b106 --- /dev/null +++ b/rai_test_utils/rai_test_utils/datasets/vision/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Microsoft Corporation +# Licensed under the MIT License. + +"""Namespace for tabular datasets.""" + +from .object_detection_data_utils import load_fridge_object_detection_dataset + +__all__ = [ + "load_fridge_object_detection_dataset" +] diff --git a/rai_test_utils/rai_test_utils/datasets/vision/object_detection_data_utils.py b/rai_test_utils/rai_test_utils/datasets/vision/object_detection_data_utils.py new file mode 100644 index 0000000000..b1344e8b86 --- /dev/null +++ b/rai_test_utils/rai_test_utils/datasets/vision/object_detection_data_utils.py @@ -0,0 +1,84 @@ +# Copyright (c) Microsoft Corporation +# Licensed under the MIT License. + +import os +from zipfile import ZipFile +import pandas as pd +import xml.etree.ElementTree as ET + +import urllib.request as request_file + + +def load_fridge_object_detection_dataset_labels(): + + src_images = "./data/odFridgeObjects/" + + # Path to the annotations + annotations_folder = os.path.join(src_images, "annotations") + + labels = [] + label_dict = {'can': 1, 'carton': 2, 'milk_bottle': 3, 'water_bottle': 4} + + # Read each annotation + for _, filename in enumerate(os.listdir(annotations_folder)): + if filename.endswith(".xml"): + print("Parsing " + os.path.join(src_images, filename)) + + root = ET.parse( + os.path.join(annotations_folder, filename) + ).getroot() + + # use if needed + # width = int(root.find("size/width").text) + # height = int(root.find("size/height").text) + + image_labels = [] + for object in root.findall("object"): + name = object.find("name").text + xmin = object.find("bndbox/xmin").text + ymin = object.find("bndbox/ymin").text + xmax = object.find("bndbox/xmax").text + ymax = object.find("bndbox/ymax").text + isCrowd = int(object.find("difficult").text) + image_labels.append([ + label_dict[name], # label + float(xmin), # topX. To normalize, divide by width. + float(ymin), # topY. To normalize, divide by height. + float(xmax), # bottomX. To normalize, divide by width + float(ymax), # bottomY. To normalize, divide by height + int(isCrowd) + ]) + labels.append(image_labels) + + return labels + + +def load_fridge_object_detection_dataset(): + # create data folder if it doesnt exist. + os.makedirs("data", exist_ok=True) + + # download data + download_url = ("https://cvbp-secondary.z19.web.core.windows.net/" + + "datasets/object_detection/odFridgeObjects.zip") + data_file = "./odFridgeObjects.zip" + request_file.urlretrieve(download_url, filename=data_file) + + # extract files + with ZipFile(data_file, "r") as zip: + print("extracting files...") + zip.extractall(path="./data") + print("done") + # delete zip file + os.remove(data_file) + + labels = load_fridge_object_detection_dataset_labels() + + # get all file names into a pandas dataframe with the labels + data = pd.DataFrame(columns=["image", "label"]) + for i, file in enumerate(os.listdir("./data/odFridgeObjects/" + "images")): + image_path = "./data/odFridgeObjects/" + "images" + "/" + file + data = data.append({"image": image_path, + "label": labels[i]}, # folder + ignore_index=True) + + return data diff --git a/rai_test_utils/rai_test_utils/models/torch/__init__.py b/rai_test_utils/rai_test_utils/models/torch/__init__.py new file mode 100644 index 0000000000..a41a2bc420 --- /dev/null +++ b/rai_test_utils/rai_test_utils/models/torch/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Microsoft Corporation +# Licensed under the MIT License. + +"""Namespace for torch models.""" + +from .torch_model_utils import get_object_detection_fridge_model + +__all__ = ["get_object_detection_fridge_model"] diff --git a/rai_test_utils/rai_test_utils/models/torch/torch_model_utils.py b/rai_test_utils/rai_test_utils/models/torch/torch_model_utils.py new file mode 100644 index 0000000000..706214abc2 --- /dev/null +++ b/rai_test_utils/rai_test_utils/models/torch/torch_model_utils.py @@ -0,0 +1,53 @@ +# Copyright (c) Microsoft Corporation +# Licensed under the MIT License. + +import os + +import torch +import torchvision +from torchvision.models.detection.faster_rcnn import FastRCNNPredictor + +import urllib.request as request_file + + +# download fine-tuned recycling model from url +def download_assets(filepath, force=False): + if force or not os.path.exists(filepath): + request_file.urlretrieve( + "https://publictestdatasets.blob.core.windows.net\ + /models/fastrcnn.pt", + os.path.join(filepath)) + else: + print('Found' + filepath) + + return filepath + + +def get_instance_segmentation_model(num_classes): + # load an instance segmentation model pre-trained on COCO + model = torchvision.models.detection.fasterrcnn_resnet50_fpn( + pretrained=True + ) + in_features = model.roi_heads.box_predictor.cls_score.in_features + # replace the pre-trained head with a new one + model.roi_heads.box_predictor = FastRCNNPredictor( + in_features, + num_classes + ) + return model + + +def get_object_detection_fridge_model(): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + num_classes = 5 + model = get_instance_segmentation_model(num_classes) + _ = download_assets('Recycling_finetuned_FastRCNN.pt') + model.load_state_dict( + torch.load('Recycling_finetuned_FastRCNN.pt', + map_location=device + ) + ) + + model.to(device) + return model diff --git a/raiwidgets/raiwidgets/constants.py b/raiwidgets/raiwidgets/constants.py index 8c5ecb60a9..57d26309b2 100644 --- a/raiwidgets/raiwidgets/constants.py +++ b/raiwidgets/raiwidgets/constants.py @@ -162,6 +162,7 @@ class ModelTask(str, Enum): CLASSIFICATION = 'classification' REGRESSION = 'regression' + OBJECT_DETECTION = 'object_detection' UNKNOWN = 'unknown' diff --git a/raiwidgets/raiwidgets/error_analysis_dashboard_input.py b/raiwidgets/raiwidgets/error_analysis_dashboard_input.py index 745d4df90d..5aa2e6d4d9 100644 --- a/raiwidgets/raiwidgets/error_analysis_dashboard_input.py +++ b/raiwidgets/raiwidgets/error_analysis_dashboard_input.py @@ -346,6 +346,8 @@ def setup_local(self, explanation, model, dataset, true_y, classes, metric = Metrics.ERROR_RATE else: metric = self._error_analyzer.metric + elif self._error_analyzer.model_task == ModelTask.OBJECT_DETECTION: + metric = Metrics.ERROR_RATE else: if self._error_analyzer.metric is None: metric = Metrics.MEAN_SQUARED_ERROR diff --git a/raiwidgets/requirements.txt b/raiwidgets/requirements.txt index aa464b107f..f2f6bd8ea9 100644 --- a/raiwidgets/requirements.txt +++ b/raiwidgets/requirements.txt @@ -8,3 +8,4 @@ lightgbm>=2.0.11 erroranalysis>=0.4.4 fairlearn==0.7.0 raiutils>=0.4.0 +rai-test-utils diff --git a/raiwidgets/tests/test_error_analysis_dashboard.py b/raiwidgets/tests/test_error_analysis_dashboard.py index 33dda1c929..fb46b185b7 100644 --- a/raiwidgets/tests/test_error_analysis_dashboard.py +++ b/raiwidgets/tests/test_error_analysis_dashboard.py @@ -1,8 +1,12 @@ # Copyright (c) Microsoft Corporation # Licensed under the MIT License. +import os +from zipfile import ZipFile + import numpy as np import pandas as pd +import pytest import shap import sklearn from interpret.ext.blackbox import MimicExplainer @@ -11,6 +15,20 @@ from sklearn.datasets import load_iris, make_classification from sklearn.model_selection import train_test_split +from rai_test_utils.datasets.vision.object_detection_data_utils import load_fridge_object_detection_dataset +from rai_test_utils.models.torch.torch_model_utils import get_object_detection_fridge_model + +try: + import torch + import torchvision + from torchvision.models.detection.faster_rcnn import FastRCNNPredictor + torch_installed = True +except ImportError: + torch_installed = False + +import urllib.request as request_file +import xml.etree.ElementTree as ET + from erroranalysis._internal.constants import Metrics, metric_to_display_name from erroranalysis._internal.surrogate_error_tree import ( DEFAULT_MAX_DEPTH, DEFAULT_MIN_CHILD_SAMPLES, DEFAULT_NUM_LEAVES) @@ -158,6 +176,18 @@ def test_error_analysis_adult_census_numeric_feature_names(self): run_error_analysis_adult_census(X, y, cat_idxs) + @pytest.mark.skipif(not torch_installed, + reason="requires torch & torchvision") + def test_error_analysis_fridge_object_detection(self): + model = get_object_detection_fridge_model() + dataset = load_fridge_object_detection_dataset() + classes = np.array(['can', 'carton', 'milk_bottle', 'water_bottle']) + + X_test = dataset[["image"]] + y_test = dataset[["label"]] + ErrorAnalysisDashboard(model=model, dataset=X_test, + true_y=y_test, classes=classes) + def run_error_analysis_adult_census(X, y, categorical_features): X, y = sklearn.utils.resample(