From 3ece599188f42619848e7b51e1032c41852e1d8e Mon Sep 17 00:00:00 2001 From: aaarrti Date: Sat, 27 Apr 2024 16:17:56 +0200 Subject: [PATCH] * --- tests/functions/test_pytorch_model.py | 166 +++++++++++++------------- 1 file changed, 80 insertions(+), 86 deletions(-) diff --git a/tests/functions/test_pytorch_model.py b/tests/functions/test_pytorch_model.py index bef56ad8..1ee19f95 100644 --- a/tests/functions/test_pytorch_model.py +++ b/tests/functions/test_pytorch_model.py @@ -5,11 +5,10 @@ import numpy as np import pytest +import pytest_mock import torch from pytest_lazyfixture import lazy_fixture from scipy.special import softmax -from transformers import PreTrainedModel - from quantus.helpers.model.pytorch_model import PyTorchModel @@ -206,7 +205,6 @@ def test_get_random_layer_generator(load_mnist_model): model = PyTorchModel(load_mnist_model, channel_first=True) for layer_name, random_layer_model in model.get_random_layer_generator(): - layer = getattr(model.get_model(), layer_name).parameters() new_layer = getattr(random_layer_model, layer_name).parameters() @@ -247,86 +245,82 @@ def test_add_mean_shift_to_first_layer(load_mnist_model): assert torch.all(torch.isclose(a1, a2, atol=1e-04)) -#@pytest.mark.pytorch_model -#@pytest.mark.parametrize( -# "hf_model,data,softmax,model_kwargs,expected", -# [ -# ( -# lazy_fixture("load_hf_distilbert_sequence_classifier"), -# lazy_fixture("dummy_hf_tokenizer"), -# False, -# {}, -# nullcontext(np.array([[0.00424026, -0.03878461]])), -# ), -# ( -# lazy_fixture("load_hf_distilbert_sequence_classifier"), -# lazy_fixture("dummy_hf_tokenizer"), -# False, -# {"labels": torch.tensor([1]), "output_hidden_states": True}, -# nullcontext(np.array([[0.00424026, -0.03878461]])), -# ), -# ( -# lazy_fixture("load_hf_distilbert_sequence_classifier"), -# { -# "input_ids": torch.tensor( -# [[101, 1996, 4248, 2829, 4419, 14523, 2058, 1996, 13971, 3899, 102]] -# ), -# "attention_mask": torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), -# }, -# False, -# {"labels": torch.tensor([1]), "output_hidden_states": True}, -# nullcontext(np.array([[0.00424026, -0.03878461]])), -# ), -# ( -# lazy_fixture("load_hf_distilbert_sequence_classifier"), -# lazy_fixture("dummy_hf_tokenizer"), -# True, -# {}, -# nullcontext(np.array([[0.51075452, 0.4892454]])), -# ), -# ( -# lazy_fixture("load_hf_distilbert_sequence_classifier"), -# np.array([1, 2, 3]), -# False, -# {}, -# pytest.raises(ValueError), -# ), -# ], -#) -#def test_huggingface_classifier_predict( -# hf_model, data, softmax, model_kwargs, expected -#): -# model = PyTorchModel( -# model=hf_model, softmax=softmax, model_predict_kwargs=model_kwargs -# ) -# with expected: -# out = model.predict(x=data) -# assert np.allclose(out, expected.enter_result), "Test failed." -# - -#@pytest.mark.pytorch_model -#@pytest.mark.parametrize( -# "transformers_installed,base_class,expected", -# [ -# (True, PreTrainedModel, nullcontext(np.array([[0.1, 0.9]], dtype=np.float32))), -# (False, None, pytest.raises(ValueError)), -# ], -#) -#def test_predict_transformers_installed( -# mocker, transformers_installed, base_class, expected -#): -# mocker.patch("importlib.util.find_spec", return_value=transformers_installed) -# from quantus.helpers.model import pytorch_model -# -# reload(pytorch_model) -# # Mock the model's behavior -# model_instance = PyTorchModel(model=mocker.MagicMock(spec=base_class)) -# model_instance.model.training = False -# model_instance.model.return_value.logits = torch.tensor([[0.1, 0.9]]) -# model_instance.softmax = False -# -# # Prepare input and call the predict method -# x = {"input_ids": np.array([1, 2, 3]), "attention_mask": np.array([1, 1, 1])} -# with expected: -# predictions = model_instance.predict(x) -# assert np.array_equal(predictions, expected.enter_result), "Test failed." +@pytest.mark.pytorch_model +@pytest.mark.parametrize( + "hf_model,data,softmax,model_kwargs,expected", + [ + ( + lazy_fixture("load_hf_distilbert_sequence_classifier"), + lazy_fixture("dummy_hf_tokenizer"), + False, + {}, + nullcontext(np.array([[0.00424026, -0.03878461]])), + ), + ( + lazy_fixture("load_hf_distilbert_sequence_classifier"), + lazy_fixture("dummy_hf_tokenizer"), + False, + {"labels": torch.tensor([1]), "output_hidden_states": True}, + nullcontext(np.array([[0.00424026, -0.03878461]])), + ), + ( + lazy_fixture("load_hf_distilbert_sequence_classifier"), + { + "input_ids": torch.tensor( + [[101, 1996, 4248, 2829, 4419, 14523, 2058, 1996, 13971, 3899, 102]] + ), + "attention_mask": torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), + }, + False, + {"labels": torch.tensor([1]), "output_hidden_states": True}, + nullcontext(np.array([[0.00424026, -0.03878461]])), + ), + ( + lazy_fixture("load_hf_distilbert_sequence_classifier"), + lazy_fixture("dummy_hf_tokenizer"), + True, + {}, + nullcontext(np.array([[0.51075452, 0.4892454]])), + ), + ( + lazy_fixture("load_hf_distilbert_sequence_classifier"), + np.array([1, 2, 3]), + False, + {}, + pytest.raises(ValueError), + ), + ], +) +def test_huggingface_classifier_predict( + hf_model, data, softmax, model_kwargs, expected +): + model = PyTorchModel( + model=hf_model, softmax=softmax, model_predict_kwargs=model_kwargs + ) + with expected: + out = model.predict(x=data) + assert np.allclose(out, expected.enter_result), "Test failed." + + +@pytest.fixture +def mock_transformers_not_installed(mocker: pytest_mock.MockFixture): + mocker.patch("importlib.util.find_spec", return_value=None) + from quantus.helpers.model import pytorch_model + + reload(pytorch_model) + # Mock the model's behavior + model_instance = PyTorchModel(model=mocker.MagicMock(spec=None)) + # model_instance.model.training = False + # model_instance.model.return_value.logits = torch.tensor([[0.1, 0.9]]) + # model_instance.softmax = False + yield model_instance + mocker.resetall() + + +@pytest.mark.pytorch_model +def test_predict_transformers_installed(mock_transformers_not_installed): + model_instance = mock_transformers_not_installed + # Prepare input and call the predict method + x = {"input_ids": np.array([1, 2, 3]), "attention_mask": np.array([1, 1, 1])} + with pytest.raises(ValueError): + _ = model_instance.predict(x)