Skip to content

Commit

Permalink
Merge pull request #349 from understandable-machine-intelligence-lab/…
Browse files Browse the repository at this point in the history
…i-dont-know-how-to-name-it

Test transformers installed V2
  • Loading branch information
annahedstroem authored May 5, 2024
2 parents 6857561 + 4157a5e commit 8ad1076
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 35 deletions.
142 changes: 114 additions & 28 deletions quantus/helpers/model/pytorch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,34 @@
from contextlib import suppress
from copy import deepcopy
from functools import lru_cache
from importlib import util
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union
from typing import (
Any,
Dict,
Generator,
Iterable,
List,
Mapping,
Optional,
Tuple,
Union,
TypedDict,
)


import numpy as np
import numpy.typing as npt
import torch
import sys
from torch import nn

from quantus.helpers import utils
from quantus.helpers.model.model_interface import ModelInterface

if util.find_spec("transformers"):
from transformers import PreTrainedModel
from transformers.tokenization_utils import BatchEncoding

if sys.version_info >= (3, 10):
from typing import TypeGuard
else:
PreTrainedModel = None
BatchEncoding = None
from typing_extensions import TypeGuard


class PyTorchModel(ModelInterface[nn.Module]):
Expand Down Expand Up @@ -79,11 +90,8 @@ def _get_last_softmax_layer_index(self) -> Optional[int]:
return i
return None

last_layer = list(self.model.children())[-1]
return isinstance(last_layer, torch.nn.Softmax)

@lru_cache(maxsize=None)
def _get_model_with_linear_top(self) -> torch.nn:
def _get_model_with_linear_top(self) -> torch.nn.Module:
"""
In a case model has a softmax module, the last torch.nn.Softmax module in the self.model.modules() list is
replaced with torch.nn.Identity().
Expand All @@ -106,30 +114,34 @@ def _get_model_with_linear_top(self) -> torch.nn:

return linear_model

def _obtain_predictions(self, x, model_predict_kwargs):
pred = None
if PreTrainedModel is not None and isinstance(self.model, PreTrainedModel):
# BatchEncoding is the default output from Tokenizers which contains
# necessary keys such as `input_ids` and `attention_mask`.
# It is also possible to pass a Dict with those keys.
if not (
isinstance(x, BatchEncoding)
or (
isinstance(x, dict) and ("input_ids" in x and "attention_mask" in x)
)
):
def _obtain_predictions(
self,
x: Union[
torch.Tensor,
npt.ArrayLike,
Mapping[str, Union[torch.Tensor, npt.ArrayLike]],
],
model_predict_kwargs: Dict[str, Any],
) -> torch.Tensor:
if safe_isinstance(self.model, "transformers.modeling_utils.PreTrainedModel"):

if not is_batch_encoding_like(x):
raise ValueError(
"When using HuggingFace pretrained models, please use Tokenizers output for `x` "
"or make sure you're passing a dict with input_ids and attention_mask as keys"
)

x = {k: torch.as_tensor(v, device=self.device) for k, v in x.items()}
pred = self.model(**x, **model_predict_kwargs).logits
if self.softmax:
return torch.softmax(pred, dim=-1)
return pred

elif isinstance(self.model, nn.Module):
pred_model = self.get_softmax_arg_model()
return pred_model(torch.Tensor(x).to(self.device), **model_predict_kwargs)
raise ValueError("Predictions cant be null")
else:
raise ValueError("Predictions cant be null")

def get_softmax_arg_model(self) -> torch.nn.Module:
"""
Expand Down Expand Up @@ -230,11 +242,11 @@ def predict(

def shape_input(
self,
x: np.array,
x: np.ndarray,
shape: Tuple[int, ...],
channel_first: Optional[bool] = None,
batched: bool = False,
) -> np.array:
) -> np.ndarray:
"""
Reshape input into model expected input.
Expand Down Expand Up @@ -267,7 +279,7 @@ def shape_input(
return utils.make_channel_first(x, channel_first)
raise ValueError("Channel first order expected for a torch model.")

def get_model(self) -> torch.nn:
def get_model(self) -> torch.nn.Module:
"""
Get the original torch model.
"""
Expand Down Expand Up @@ -323,7 +335,7 @@ def sample(
mean: float,
std: float,
noise_type: str = "multiplicative",
) -> torch.nn:
) -> torch.nn.Module:
"""
Sample a model by means of adding normally distributed noise.
Expand Down Expand Up @@ -504,3 +516,77 @@ def random_layer_generator_length(self) -> int:
if (hasattr(i[1], "reset_parameters"))
]
)


def safe_isinstance(obj: Any, class_path_str: Union[Iterable[str], str]) -> bool:
"""Acts as a safe version of isinstance without having to explicitly
import packages which may not exist in the users environment.
Checks if obj is an instance of type specified by class_path_str.
Parameters
----------
obj: Any
Some object you want to test against
class_path_str: str or list
A string or list of strings specifying full class paths
Example: `sklearn.ensemble.RandomForestRegressor`
Returns
-------
bool: True if isinstance is true and the package exists, False otherwise
"""
# Taken from https://github.com/shap/shap/blob/dffc346f323ff8cf55f39f71c613ebd00e1c88f8/shap/utils/_general.py#L197

if isinstance(class_path_str, str):
class_path_str = [class_path_str]

# try each module path in order
for class_path_str in class_path_str:
if "." not in class_path_str:
raise ValueError(
"class_path_str must be a string or list of strings specifying a full \
module path to a class. Eg, 'sklearn.ensemble.RandomForestRegressor'"
)

# Splits on last occurrence of "."
module_name, class_name = class_path_str.rsplit(".", 1)

# here we don't check further if the model is not imported, since we shouldn't have
# an object of that types passed to us if the model the type is from has never been
# imported. (and we don't want to import lots of new modules for no reason)
if module_name not in sys.modules:
continue

module = sys.modules[module_name]

# Get class
_class = getattr(module, class_name, None)

if _class is None:
continue

if isinstance(obj, _class):
return True

return False


class BatchEncodingLike(TypedDict):
input_ids: Union[torch.Tensor, npt.ArrayLike]
attention_mask: Union[torch.Tensor, npt.ArrayLike]


def is_batch_encoding_like(x: Any) -> TypeGuard[BatchEncodingLike]:
# BatchEncoding is the default output from Tokenizers which contains
# necessary keys such as `input_ids` and `attention_mask`.
# It is also possible to pass a Dict with those keys.
if safe_isinstance(x, "transformers.tokenization_utils_base.BatchEncoding"):
return True

elif isinstance(x, Mapping) and "input_ids" in x and "attention_mask" in x:
return True

else:
return False
48 changes: 42 additions & 6 deletions tests/functions/test_pytorch_model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from collections import OrderedDict
from contextlib import nullcontext
from typing import Union
import sys

import numpy as np
import pytest
import pytest_mock
import torch
from pytest_lazyfixture import lazy_fixture
from quantus.helpers.model.pytorch_model import PyTorchModel
from scipy.special import softmax
from quantus.helpers.model.pytorch_model import PyTorchModel


@pytest.fixture
Expand Down Expand Up @@ -203,7 +205,6 @@ def test_get_random_layer_generator(load_mnist_model):
model = PyTorchModel(load_mnist_model, channel_first=True)

for layer_name, random_layer_model in model.get_random_layer_generator():

layer = getattr(model.get_model(), layer_name).parameters()
new_layer = getattr(random_layer_model, layer_name).parameters()

Expand Down Expand Up @@ -264,8 +265,12 @@ def test_add_mean_shift_to_first_layer(load_mnist_model):
),
(
lazy_fixture("load_hf_distilbert_sequence_classifier"),
{'input_ids': torch.tensor([[ 101, 1996, 4248, 2829, 4419, 14523, 2058, 1996, 13971, 3899,
102]]), 'attention_mask': torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])},
{
"input_ids": torch.tensor(
[[101, 1996, 4248, 2829, 4419, 14523, 2058, 1996, 13971, 3899, 102]]
),
"attention_mask": torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]),
},
False,
{"labels": torch.tensor([1]), "output_hidden_states": True},
nullcontext(np.array([[0.00424026, -0.03878461]])),
Expand All @@ -286,8 +291,39 @@ def test_add_mean_shift_to_first_layer(load_mnist_model):
),
],
)
def test_huggingface_classifier_predict(hf_model, data, softmax, model_kwargs, expected):
model = PyTorchModel(model=hf_model, softmax=softmax, model_predict_kwargs=model_kwargs)
def test_huggingface_classifier_predict(
hf_model, data, softmax, model_kwargs, expected
):
model = PyTorchModel(
model=hf_model, softmax=softmax, model_predict_kwargs=model_kwargs
)
with expected:
out = model.predict(x=data)
assert np.allclose(out, expected.enter_result), "Test failed."


@pytest.fixture
def mock_transformers_not_installed(mocker: pytest_mock.MockerFixture):
mock_dict = {k: v for k, v in sys.modules.items() if "transformers" not in k}
mocker.patch.dict("sys.modules", mock_dict)
model = mocker.MagicMock(spec=None)
model.training = False
yield model
mocker.resetall(return_value=True, side_effect=True)


@pytest.mark.pytorch_model
def test_predict_transformers_not_installed(mock_transformers_not_installed):
model = PyTorchModel(model=mock_transformers_not_installed, softmax=True)
x = {"input_ids": np.array([1, 2, 3]), "attention_mask": np.array([1, 1, 1])}
with pytest.raises(ValueError):
model.predict(x)


@pytest.mark.pytorch_model
def test_predict_invalid_input(load_hf_distilbert_sequence_classifier):
model = PyTorchModel(load_hf_distilbert_sequence_classifier)
# Prepare input and call the predict method
x = torch.tensor([1, 2, 3, 4])
with pytest.raises(ValueError):
model.predict(x)
2 changes: 1 addition & 1 deletion tests/metrics/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ def mock_predict(self, x_batch, *args):
)
yield
# Restore original behaviour after test finished execution.
mocker.resetall()
mocker.resetall(side_effect=True, return_value=True)

0 comments on commit 8ad1076

Please sign in to comment.