Skip to content

Commit

Permalink
ENH: Add hubconf to load models without installing (#543)
Browse files Browse the repository at this point in the history
* Add hubconf file

* Refactor to minimise hubconf dependencies

* Pin hubconf dependencies

* Revert "Pin hubconf dependencies"

This reverts commit bc904a9 as it
didn't seem to work.

* Add support for newer versions of torch

* Add only the model folder to path

* Remove unnecessary try-except block

* Avoid duplicate definition of Hugging Face strings

* Import from a more appropriate module

* Add test to compare package and PyTorch Hub models

* Add version number to package __init__

* Remove branch name from PyTorch Hub repo string

* Check only fields from package model

* Remove unnecessary zip wrap

Co-authored-by: Shruthi42 <[email protected]>

Co-authored-by: Shruthi42 <[email protected]>
  • Loading branch information
fepegar and Shruthi42 authored Aug 4, 2022
1 parent 556c395 commit bdbbf3e
Show file tree
Hide file tree
Showing 9 changed files with 100 additions and 61 deletions.
18 changes: 1 addition & 17 deletions hi-ml-multimodal/src/health_multimodal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,4 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------

BIOMED_VLP_CXR_BERT_SPECIALIZED = "microsoft/BiomedVLP-CXR-BERT-specialized"
REPO_URL = f"https://huggingface.co/{BIOMED_VLP_CXR_BERT_SPECIALIZED}"
CXR_BERT_COMMIT_TAG = "v1.1"

BIOVIL_IMAGE_WEIGHTS_NAME = "biovil_image_resnet50_proj_size_128.pt"
BIOVIL_IMAGE_WEIGHTS_URL = f"{REPO_URL}/resolve/{CXR_BERT_COMMIT_TAG}/{BIOVIL_IMAGE_WEIGHTS_NAME}"
BIOVIL_IMAGE_WEIGHTS_MD5 = "02ce6ee460f72efd599295f440dbb453"


__all__ = [
"BIOMED_VLP_CXR_BERT_SPECIALIZED",
"REPO_URL",
"CXR_BERT_COMMIT_TAG",
"BIOVIL_IMAGE_WEIGHTS_NAME",
"BIOVIL_IMAGE_WEIGHTS_URL",
"BIOVIL_IMAGE_WEIGHTS_MD5",
]
__version__ = "0.1.0"
2 changes: 2 additions & 0 deletions hi-ml-multimodal/src/health_multimodal/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

from .model import ImageModel
from .model import ResnetType
from .model import get_biovil_resnet
from .inference_engine import ImageInferenceEngine
from .utils import get_biovil_resnet_inference

Expand All @@ -44,5 +45,6 @@
"ImageModel",
"ResnetType",
"ImageInferenceEngine",
"get_biovil_resnet",
"get_biovil_resnet_inference",
]
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@

from .model import ImageModel
from .model import ResnetType

from .model import get_biovil_resnet
from .model import CXR_BERT_COMMIT_TAG
from .model import BIOMED_VLP_CXR_BERT_SPECIALIZED

__all__ = [
"ImageModel",
"ResnetType",
"get_biovil_resnet",
"CXR_BERT_COMMIT_TAG",
"BIOMED_VLP_CXR_BERT_SPECIALIZED",
]
41 changes: 41 additions & 0 deletions hi-ml-multimodal/src/health_multimodal/image/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,60 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------

from __future__ import annotations

import enum
import tempfile
from pathlib import Path
from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union, Sequence

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets.utils import download_url

from .resnet import resnet18, resnet50
from .modules import MLP, MultiTaskModel

TypeImageEncoder = Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
MODEL_TYPE = "resnet50"
JOINT_FEATURE_SIZE = 128

BIOMED_VLP_CXR_BERT_SPECIALIZED = "microsoft/BiomedVLP-CXR-BERT-specialized"
REPO_URL = f"https://huggingface.co/{BIOMED_VLP_CXR_BERT_SPECIALIZED}"
CXR_BERT_COMMIT_TAG = "v1.1"

BIOVIL_IMAGE_WEIGHTS_NAME = "biovil_image_resnet50_proj_size_128.pt"
BIOVIL_IMAGE_WEIGHTS_URL = f"{REPO_URL}/resolve/{CXR_BERT_COMMIT_TAG}/{BIOVIL_IMAGE_WEIGHTS_NAME}"
BIOVIL_IMAGE_WEIGHTS_MD5 = "02ce6ee460f72efd599295f440dbb453"


def _download_biovil_image_model_weights() -> Path:
"""Download image model weights from Hugging Face.
More information available at https://huggingface.co/microsoft/BiomedVLP-CXR-BERT-specialized.
"""
root_dir = tempfile.gettempdir()
download_url(
BIOVIL_IMAGE_WEIGHTS_URL,
root=root_dir,
filename=BIOVIL_IMAGE_WEIGHTS_NAME,
md5=BIOVIL_IMAGE_WEIGHTS_MD5,
)
return Path(root_dir, BIOVIL_IMAGE_WEIGHTS_NAME)


def get_biovil_resnet(pretrained: bool = True) -> ImageModel:
"""Download weights from Hugging Face and instantiate the image model."""
resnet_checkpoint_path = _download_biovil_image_model_weights() if pretrained else None

image_model = ImageModel(
img_model_type=MODEL_TYPE,
joint_feature_size=JOINT_FEATURE_SIZE,
pretrained_model_path=resnet_checkpoint_path,
)
return image_model


@enum.unique
Expand Down
3 changes: 1 addition & 2 deletions hi-ml-multimodal/src/health_multimodal/image/model/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from typing import Any, List, Type, Union

import torch

from torch.hub import load_state_dict_from_url
from torchvision.models.resnet import model_urls, ResNet, BasicBlock, Bottleneck
from torchvision.models.utils import load_state_dict_from_url


class ResNetHIML(ResNet):
Expand Down
38 changes: 1 addition & 37 deletions hi-ml-multimodal/src/health_multimodal/image/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,51 +3,15 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------

import tempfile
from pathlib import Path

from torchvision.datasets.utils import download_url

from .. import BIOVIL_IMAGE_WEIGHTS_NAME
from .. import BIOVIL_IMAGE_WEIGHTS_URL
from .. import BIOVIL_IMAGE_WEIGHTS_MD5
from .model import ImageModel
from .inference_engine import ImageInferenceEngine
from .data.transforms import create_chest_xray_transform_for_inference
from .model import get_biovil_resnet


MODEL_TYPE = "resnet50"
JOINT_FEATURE_SIZE = 128
TRANSFORM_RESIZE = 512
TRANSFORM_CENTER_CROP_SIZE = 480


def _download_biovil_image_model_weights() -> Path:
"""Download image model weights from Hugging Face.
More information available at https://huggingface.co/microsoft/BiomedVLP-CXR-BERT-specialized.
"""
root_dir = tempfile.gettempdir()
download_url(
BIOVIL_IMAGE_WEIGHTS_URL,
root=root_dir,
filename=BIOVIL_IMAGE_WEIGHTS_NAME,
md5=BIOVIL_IMAGE_WEIGHTS_MD5,
)
return Path(root_dir, BIOVIL_IMAGE_WEIGHTS_NAME)


def get_biovil_resnet() -> ImageModel:
"""Download weights from Hugging Face and instantiate the image model."""
resnet_checkpoint_path = _download_biovil_image_model_weights()
image_model = ImageModel(
img_model_type=MODEL_TYPE,
joint_feature_size=JOINT_FEATURE_SIZE,
pretrained_model_path=resnet_checkpoint_path,
)
return image_model


def get_biovil_resnet_inference() -> ImageInferenceEngine:
"""Create a :class:`ImageInferenceEngine` for the image model.
Expand Down
5 changes: 2 additions & 3 deletions hi-ml-multimodal/src/health_multimodal/text/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@

from typing import Tuple

from .. import BIOMED_VLP_CXR_BERT_SPECIALIZED
from .. import CXR_BERT_COMMIT_TAG

from ..image.model import CXR_BERT_COMMIT_TAG
from ..image.model import BIOMED_VLP_CXR_BERT_SPECIALIZED
from .inference_engine import TextInferenceEngine
from .model import CXRBertModel
from .model import CXRBertTokenizer
Expand Down
26 changes: 25 additions & 1 deletion hi-ml-multimodal/test_multimodal/image/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------

from dataclasses import fields

import pytest
import torch

from health_multimodal.image.model.model import ImageEncoder, ImageModel
from health_multimodal.image.model.model import ImageModel
from health_multimodal.image.model.model import ImageEncoder
from health_multimodal.image.model.model import ImageModelOutput
from health_multimodal.image.model.model import get_biovil_resnet
from health_multimodal.image.model.modules import MultiTaskModel
from health_multimodal.image.model.resnet import resnet50

Expand Down Expand Up @@ -132,3 +136,23 @@ def test_reload_resnet_with_dilation() -> None:
with torch.no_grad():
expected_output = expected_model(image)
assert torch.allclose(outputs_dilation, expected_output)


@torch.no_grad()
def test_hubconf() -> None:
"""Test that instantiating the image model using the PyTorch Hub is consistent with older methods."""
image = torch.rand(1, 3, 480, 480)

github = 'microsoft/hi-ml'
model_hub = torch.hub.load(github, 'biovil_resnet', pretrained=True)
model_himl = get_biovil_resnet()

output_hub: ImageModelOutput = model_hub(image)
output_himl: ImageModelOutput = model_himl(image)

for field_himl in fields(output_himl):
value_hub = getattr(output_hub, field_himl.name)
value_himl = getattr(output_himl, field_himl.name)
if value_hub is None and value_himl is None: # for example, class_logits
continue
assert torch.allclose(value_hub, value_himl)
21 changes: 21 additions & 0 deletions hubconf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# autopep8: off
dependencies = ["torch", "torchvision"]

import sys
from pathlib import Path
repo_dir = Path(__file__).parent
multimodal_src_dir = repo_dir / "hi-ml-multimodal" / "src" / "health_multimodal" / "image"
sys.path.append(str(multimodal_src_dir))

from model import ImageModel
from model import get_biovil_resnet as _biovil_resnet
# autopep8: on


def biovil_resnet(pretrained: bool = False) -> ImageModel:
"""Get BioViL image encoder.
:param pretrained: Load pretrained weights from a checkpoint.
"""
model = _biovil_resnet(pretrained=pretrained)
return model

0 comments on commit bdbbf3e

Please sign in to comment.