From bdbbf3e812dda04d6dd23b665548f762b4a0d7ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fernando=20P=C3=A9rez-Garc=C3=ADa?= Date: Thu, 4 Aug 2022 11:22:54 +0200 Subject: [PATCH] ENH: Add hubconf to load models without installing (#543) * Add hubconf file * Refactor to minimise hubconf dependencies * Pin hubconf dependencies * Revert "Pin hubconf dependencies" This reverts commit bc904a963ea1650ba904a95e1cdbfb80764ece2a as it didn't seem to work. * Add support for newer versions of torch * Add only the model folder to path * Remove unnecessary try-except block * Avoid duplicate definition of Hugging Face strings * Import from a more appropriate module * Add test to compare package and PyTorch Hub models * Add version number to package __init__ * Remove branch name from PyTorch Hub repo string * Check only fields from package model * Remove unnecessary zip wrap Co-authored-by: Shruthi42 <13177030+Shruthi42@users.noreply.github.com> Co-authored-by: Shruthi42 <13177030+Shruthi42@users.noreply.github.com> --- .../src/health_multimodal/__init__.py | 18 +------- .../src/health_multimodal/image/__init__.py | 2 + .../health_multimodal/image/model/__init__.py | 7 +++- .../health_multimodal/image/model/model.py | 41 +++++++++++++++++++ .../health_multimodal/image/model/resnet.py | 3 +- .../src/health_multimodal/image/utils.py | 38 +---------------- .../src/health_multimodal/text/utils.py | 5 +-- .../test_multimodal/image/model/test_model.py | 26 +++++++++++- hubconf.py | 21 ++++++++++ 9 files changed, 100 insertions(+), 61 deletions(-) create mode 100644 hubconf.py diff --git a/hi-ml-multimodal/src/health_multimodal/__init__.py b/hi-ml-multimodal/src/health_multimodal/__init__.py index 4443e68a0..0c2ddc501 100644 --- a/hi-ml-multimodal/src/health_multimodal/__init__.py +++ b/hi-ml-multimodal/src/health_multimodal/__init__.py @@ -3,20 +3,4 @@ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------- -BIOMED_VLP_CXR_BERT_SPECIALIZED = "microsoft/BiomedVLP-CXR-BERT-specialized" -REPO_URL = f"https://huggingface.co/{BIOMED_VLP_CXR_BERT_SPECIALIZED}" -CXR_BERT_COMMIT_TAG = "v1.1" - -BIOVIL_IMAGE_WEIGHTS_NAME = "biovil_image_resnet50_proj_size_128.pt" -BIOVIL_IMAGE_WEIGHTS_URL = f"{REPO_URL}/resolve/{CXR_BERT_COMMIT_TAG}/{BIOVIL_IMAGE_WEIGHTS_NAME}" -BIOVIL_IMAGE_WEIGHTS_MD5 = "02ce6ee460f72efd599295f440dbb453" - - -__all__ = [ - "BIOMED_VLP_CXR_BERT_SPECIALIZED", - "REPO_URL", - "CXR_BERT_COMMIT_TAG", - "BIOVIL_IMAGE_WEIGHTS_NAME", - "BIOVIL_IMAGE_WEIGHTS_URL", - "BIOVIL_IMAGE_WEIGHTS_MD5", -] +__version__ = "0.1.0" diff --git a/hi-ml-multimodal/src/health_multimodal/image/__init__.py b/hi-ml-multimodal/src/health_multimodal/image/__init__.py index bf14c7e00..cca3b3a32 100644 --- a/hi-ml-multimodal/src/health_multimodal/image/__init__.py +++ b/hi-ml-multimodal/src/health_multimodal/image/__init__.py @@ -36,6 +36,7 @@ from .model import ImageModel from .model import ResnetType +from .model import get_biovil_resnet from .inference_engine import ImageInferenceEngine from .utils import get_biovil_resnet_inference @@ -44,5 +45,6 @@ "ImageModel", "ResnetType", "ImageInferenceEngine", + "get_biovil_resnet", "get_biovil_resnet_inference", ] diff --git a/hi-ml-multimodal/src/health_multimodal/image/model/__init__.py b/hi-ml-multimodal/src/health_multimodal/image/model/__init__.py index c45eaff26..fbd4d7be9 100644 --- a/hi-ml-multimodal/src/health_multimodal/image/model/__init__.py +++ b/hi-ml-multimodal/src/health_multimodal/image/model/__init__.py @@ -5,9 +5,14 @@ from .model import ImageModel from .model import ResnetType - +from .model import get_biovil_resnet +from .model import CXR_BERT_COMMIT_TAG +from .model import BIOMED_VLP_CXR_BERT_SPECIALIZED __all__ = [ "ImageModel", "ResnetType", + "get_biovil_resnet", + "CXR_BERT_COMMIT_TAG", + "BIOMED_VLP_CXR_BERT_SPECIALIZED", ] diff --git a/hi-ml-multimodal/src/health_multimodal/image/model/model.py b/hi-ml-multimodal/src/health_multimodal/image/model/model.py index 151d295c6..77c55d5ca 100644 --- a/hi-ml-multimodal/src/health_multimodal/image/model/model.py +++ b/hi-ml-multimodal/src/health_multimodal/image/model/model.py @@ -3,7 +3,10 @@ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------- +from __future__ import annotations + import enum +import tempfile from pathlib import Path from dataclasses import dataclass from typing import Any, Optional, Tuple, Union, Sequence @@ -11,11 +14,49 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torchvision.datasets.utils import download_url from .resnet import resnet18, resnet50 from .modules import MLP, MultiTaskModel TypeImageEncoder = Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] +MODEL_TYPE = "resnet50" +JOINT_FEATURE_SIZE = 128 + +BIOMED_VLP_CXR_BERT_SPECIALIZED = "microsoft/BiomedVLP-CXR-BERT-specialized" +REPO_URL = f"https://huggingface.co/{BIOMED_VLP_CXR_BERT_SPECIALIZED}" +CXR_BERT_COMMIT_TAG = "v1.1" + +BIOVIL_IMAGE_WEIGHTS_NAME = "biovil_image_resnet50_proj_size_128.pt" +BIOVIL_IMAGE_WEIGHTS_URL = f"{REPO_URL}/resolve/{CXR_BERT_COMMIT_TAG}/{BIOVIL_IMAGE_WEIGHTS_NAME}" +BIOVIL_IMAGE_WEIGHTS_MD5 = "02ce6ee460f72efd599295f440dbb453" + + +def _download_biovil_image_model_weights() -> Path: + """Download image model weights from Hugging Face. + + More information available at https://huggingface.co/microsoft/BiomedVLP-CXR-BERT-specialized. + """ + root_dir = tempfile.gettempdir() + download_url( + BIOVIL_IMAGE_WEIGHTS_URL, + root=root_dir, + filename=BIOVIL_IMAGE_WEIGHTS_NAME, + md5=BIOVIL_IMAGE_WEIGHTS_MD5, + ) + return Path(root_dir, BIOVIL_IMAGE_WEIGHTS_NAME) + + +def get_biovil_resnet(pretrained: bool = True) -> ImageModel: + """Download weights from Hugging Face and instantiate the image model.""" + resnet_checkpoint_path = _download_biovil_image_model_weights() if pretrained else None + + image_model = ImageModel( + img_model_type=MODEL_TYPE, + joint_feature_size=JOINT_FEATURE_SIZE, + pretrained_model_path=resnet_checkpoint_path, + ) + return image_model @enum.unique diff --git a/hi-ml-multimodal/src/health_multimodal/image/model/resnet.py b/hi-ml-multimodal/src/health_multimodal/image/model/resnet.py index 0d4d60fe7..6bc6d0ce2 100644 --- a/hi-ml-multimodal/src/health_multimodal/image/model/resnet.py +++ b/hi-ml-multimodal/src/health_multimodal/image/model/resnet.py @@ -6,9 +6,8 @@ from typing import Any, List, Type, Union import torch - +from torch.hub import load_state_dict_from_url from torchvision.models.resnet import model_urls, ResNet, BasicBlock, Bottleneck -from torchvision.models.utils import load_state_dict_from_url class ResNetHIML(ResNet): diff --git a/hi-ml-multimodal/src/health_multimodal/image/utils.py b/hi-ml-multimodal/src/health_multimodal/image/utils.py index 916246bdf..9504b8e1e 100644 --- a/hi-ml-multimodal/src/health_multimodal/image/utils.py +++ b/hi-ml-multimodal/src/health_multimodal/image/utils.py @@ -3,51 +3,15 @@ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------- -import tempfile -from pathlib import Path - -from torchvision.datasets.utils import download_url - -from .. import BIOVIL_IMAGE_WEIGHTS_NAME -from .. import BIOVIL_IMAGE_WEIGHTS_URL -from .. import BIOVIL_IMAGE_WEIGHTS_MD5 -from .model import ImageModel from .inference_engine import ImageInferenceEngine from .data.transforms import create_chest_xray_transform_for_inference +from .model import get_biovil_resnet -MODEL_TYPE = "resnet50" -JOINT_FEATURE_SIZE = 128 TRANSFORM_RESIZE = 512 TRANSFORM_CENTER_CROP_SIZE = 480 -def _download_biovil_image_model_weights() -> Path: - """Download image model weights from Hugging Face. - - More information available at https://huggingface.co/microsoft/BiomedVLP-CXR-BERT-specialized. - """ - root_dir = tempfile.gettempdir() - download_url( - BIOVIL_IMAGE_WEIGHTS_URL, - root=root_dir, - filename=BIOVIL_IMAGE_WEIGHTS_NAME, - md5=BIOVIL_IMAGE_WEIGHTS_MD5, - ) - return Path(root_dir, BIOVIL_IMAGE_WEIGHTS_NAME) - - -def get_biovil_resnet() -> ImageModel: - """Download weights from Hugging Face and instantiate the image model.""" - resnet_checkpoint_path = _download_biovil_image_model_weights() - image_model = ImageModel( - img_model_type=MODEL_TYPE, - joint_feature_size=JOINT_FEATURE_SIZE, - pretrained_model_path=resnet_checkpoint_path, - ) - return image_model - - def get_biovil_resnet_inference() -> ImageInferenceEngine: """Create a :class:`ImageInferenceEngine` for the image model. diff --git a/hi-ml-multimodal/src/health_multimodal/text/utils.py b/hi-ml-multimodal/src/health_multimodal/text/utils.py index b5a5f278e..ff8207320 100644 --- a/hi-ml-multimodal/src/health_multimodal/text/utils.py +++ b/hi-ml-multimodal/src/health_multimodal/text/utils.py @@ -6,9 +6,8 @@ from typing import Tuple -from .. import BIOMED_VLP_CXR_BERT_SPECIALIZED -from .. import CXR_BERT_COMMIT_TAG - +from ..image.model import CXR_BERT_COMMIT_TAG +from ..image.model import BIOMED_VLP_CXR_BERT_SPECIALIZED from .inference_engine import TextInferenceEngine from .model import CXRBertModel from .model import CXRBertTokenizer diff --git a/hi-ml-multimodal/test_multimodal/image/model/test_model.py b/hi-ml-multimodal/test_multimodal/image/model/test_model.py index 3281841be..36f17eb33 100644 --- a/hi-ml-multimodal/test_multimodal/image/model/test_model.py +++ b/hi-ml-multimodal/test_multimodal/image/model/test_model.py @@ -3,11 +3,15 @@ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------- +from dataclasses import fields import pytest import torch -from health_multimodal.image.model.model import ImageEncoder, ImageModel +from health_multimodal.image.model.model import ImageModel +from health_multimodal.image.model.model import ImageEncoder +from health_multimodal.image.model.model import ImageModelOutput +from health_multimodal.image.model.model import get_biovil_resnet from health_multimodal.image.model.modules import MultiTaskModel from health_multimodal.image.model.resnet import resnet50 @@ -132,3 +136,23 @@ def test_reload_resnet_with_dilation() -> None: with torch.no_grad(): expected_output = expected_model(image) assert torch.allclose(outputs_dilation, expected_output) + + +@torch.no_grad() +def test_hubconf() -> None: + """Test that instantiating the image model using the PyTorch Hub is consistent with older methods.""" + image = torch.rand(1, 3, 480, 480) + + github = 'microsoft/hi-ml' + model_hub = torch.hub.load(github, 'biovil_resnet', pretrained=True) + model_himl = get_biovil_resnet() + + output_hub: ImageModelOutput = model_hub(image) + output_himl: ImageModelOutput = model_himl(image) + + for field_himl in fields(output_himl): + value_hub = getattr(output_hub, field_himl.name) + value_himl = getattr(output_himl, field_himl.name) + if value_hub is None and value_himl is None: # for example, class_logits + continue + assert torch.allclose(value_hub, value_himl) diff --git a/hubconf.py b/hubconf.py new file mode 100644 index 000000000..b027bf042 --- /dev/null +++ b/hubconf.py @@ -0,0 +1,21 @@ +# autopep8: off +dependencies = ["torch", "torchvision"] + +import sys +from pathlib import Path +repo_dir = Path(__file__).parent +multimodal_src_dir = repo_dir / "hi-ml-multimodal" / "src" / "health_multimodal" / "image" +sys.path.append(str(multimodal_src_dir)) + +from model import ImageModel +from model import get_biovil_resnet as _biovil_resnet +# autopep8: on + + +def biovil_resnet(pretrained: bool = False) -> ImageModel: + """Get BioViL image encoder. + + :param pretrained: Load pretrained weights from a checkpoint. + """ + model = _biovil_resnet(pretrained=pretrained) + return model