diff --git a/responsibleai_vision/responsibleai_vision/managers/error_analysis_manager.py b/responsibleai_vision/responsibleai_vision/managers/error_analysis_manager.py index 771d0a64eb..4a3365695d 100644 --- a/responsibleai_vision/responsibleai_vision/managers/error_analysis_manager.py +++ b/responsibleai_vision/responsibleai_vision/managers/error_analysis_manager.py @@ -316,13 +316,21 @@ def _load(path, rai_insights): feature_names = list(dataset.columns) inst.__dict__['_feature_names'] = feature_names task_type = rai_insights.task_type - wrapped_model = wrap_model(rai_insights.model, dataset, - task_type, - classes=rai_insights._classes, - device=rai_insights.device) + classes = rai_insights._classes + device = rai_insights.device + + test = rai_insights.test + image_mode = rai_insights.image_mode + transformations = rai_insights._transformations + sample = test.iloc[0:2] + sample = get_images(sample, image_mode, transformations) + wrapped_model = wrap_model( + rai_insights.model, sample, task_type, classes=classes, + device=device) + inst.__dict__['_task_type'] = task_type - index_classes = rai_insights._classes - index_dataset = rai_insights.test + index_classes = classes + index_dataset = test if isinstance(target_column, list): # create copy of dataset as we will make modifications to it index_dataset = index_dataset.copy() diff --git a/responsibleai_vision/responsibleai_vision/rai_vision_insights/rai_vision_insights.py b/responsibleai_vision/responsibleai_vision/rai_vision_insights/rai_vision_insights.py index 34c14687ae..0787f37f92 100644 --- a/responsibleai_vision/responsibleai_vision/rai_vision_insights/rai_vision_insights.py +++ b/responsibleai_vision/responsibleai_vision/rai_vision_insights/rai_vision_insights.py @@ -1095,7 +1095,9 @@ def load(path): # load current state RAIBaseInsights._load( path, inst, manager_map, RAIVisionInsights._load_metadata) - inst._wrapped_model = wrap_model(inst.model, inst.test, inst.task_type, + sample = inst.test.iloc[0:2] + sample = get_images(sample, inst.image_mode, inst._transformations) + inst._wrapped_model = wrap_model(inst.model, sample, inst.task_type, classes=inst._classes, device=inst.device) inst.automl_image_model = is_automl_image_model(inst._wrapped_model) diff --git a/responsibleai_vision/tests/common_vision_utils.py b/responsibleai_vision/tests/common_vision_utils.py index f132853b37..63239c5b7a 100644 --- a/responsibleai_vision/tests/common_vision_utils.py +++ b/responsibleai_vision/tests/common_vision_utils.py @@ -164,6 +164,15 @@ def create_dummy_model(df): return DummyFlowersClassifier() +def create_raw_torchvision_classification_model(): + """Creates a dummy torchvision model for testing purposes. + + :return: dummy torchvision model + :rtype: torchvision.models.resnet.ResNet + """ + return torchvision_models.vgg16(pretrained=False, num_classes=2) + + def retrieve_unzip_file(download_url, data_file): fetch_dataset(download_url, data_file) # extract files @@ -486,6 +495,14 @@ def _get_model_path(self, path): return os.path.join(path, 'image-classification-model') +class TorchvisionDummyPipelineSerializer(object): + def save(self, model, path): + pass + + def load(self, path): + return create_raw_torchvision_classification_model() + + class ObjectDetectionPipelineSerializer(object): def save(self, model, path): pass diff --git a/responsibleai_vision/tests/test_rai_vision_insights_save_and_load_scenarios.py b/responsibleai_vision/tests/test_rai_vision_insights_save_and_load_scenarios.py index 2cdf7b2e44..8ec08515e4 100644 --- a/responsibleai_vision/tests/test_rai_vision_insights_save_and_load_scenarios.py +++ b/responsibleai_vision/tests/test_rai_vision_insights_save_and_load_scenarios.py @@ -12,8 +12,10 @@ from common_vision_utils import (DummyFlowersPipelineSerializer, ImageClassificationPipelineSerializer, ObjectDetectionPipelineSerializer, + TorchvisionDummyPipelineSerializer, create_dummy_model, create_image_classification_pipeline, + create_raw_torchvision_classification_model, load_flowers_dataset, load_fridge_object_detection_dataset, load_imagenet_dataset, load_imagenet_labels, @@ -49,6 +51,22 @@ def test_rai_insights_empty_save_load_save(self): run_and_validate_serialization( pred, test, task_type, class_names, label, serializer) + def test_rai_insights_pytorch_empty_save_load_save(self): + data = load_flowers_dataset(upscale=False) + data = data[0:1] + # stack two of the same image since we need same + # image sizes for pytorch model + data = data.append(data).reset_index(drop=True) + pred = create_raw_torchvision_classification_model() + test = data + class_names = data[ImageColumns.LABEL.value].unique() + task_type = ModelTask.IMAGE_CLASSIFICATION + label = ImageColumns.LABEL + serializer = TorchvisionDummyPipelineSerializer() + + run_and_validate_serialization( + pred, test, task_type, class_names, label, serializer) + @pytest.mark.skip("Insufficient memory on test machines to load images") def test_rai_insights_large_images_save_load_save(self): PIL.Image.MAX_IMAGE_PIXELS = None