From f8e2ac258726b908d4a248a459ddd9f70072764c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Wed, 22 May 2024 20:06:23 +0200 Subject: [PATCH] add torchmetrics to pyproject.toml to attempt to pass tests --- pyproject.toml | 1 + tests/metrics/test_randomization_metrics.py | 1 + 2 files changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index f616f400..12491e00 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,7 @@ dependencies = [ "numpy>=1.19.5", "torch>=1.13.1", "captum>=0.6.0", + "torchmetrics>=1.4.0" ] dynamic = ["version"] diff --git a/tests/metrics/test_randomization_metrics.py b/tests/metrics/test_randomization_metrics.py index 97fc138c..45c9e5e9 100644 --- a/tests/metrics/test_randomization_metrics.py +++ b/tests/metrics/test_randomization_metrics.py @@ -30,6 +30,7 @@ def reproducibility_test(): gen.manual_seed(42) assert torch.all(torch.rand(5, generator=gen) == torch.Tensor([0.8823, 0.9150, 0.3829, 0.9593, 0.3904])) + @pytest.mark.randomization @pytest.mark.parametrize( "model",