From a63995fb243dfbdc95ee971c03d055062c2173ca Mon Sep 17 00:00:00 2001 From: Dilyara Bareeva Date: Tue, 4 Jun 2024 15:38:15 +0200 Subject: [PATCH] fix identical class --- src/metrics/functional.py | 2 +- src/metrics/localization/identical_class.py | 8 +++----- src/utils/training/training.py | 4 +--- tests/metrics/test_localization_metrics.py | 2 +- 4 files changed, 6 insertions(+), 10 deletions(-) diff --git a/src/metrics/functional.py b/src/metrics/functional.py index f80b0ff9..c6a30995 100644 --- a/src/metrics/functional.py +++ b/src/metrics/functional.py @@ -8,11 +8,11 @@ import torch +from src.utils.cache import ExplanationsCache as EC from src.utils.explanations import ( BatchedCachedExplanations, TensorExplanations, ) -from src.utils.cache import ExplanationsCache as EC def function_example( diff --git a/src/metrics/localization/identical_class.py b/src/metrics/localization/identical_class.py index 336f66a4..ae599c48 100644 --- a/src/metrics/localization/identical_class.py +++ b/src/metrics/localization/identical_class.py @@ -19,6 +19,7 @@ def update( self, test_labels: torch.Tensor, explanations: torch.Tensor, + train_dataset: torch.utils.data.Dataset, ): """ Used to implement metric-specific logic. @@ -29,12 +30,9 @@ def update( ), f"Number of explanations ({explanations.shape[0]}) exceeds the number of test labels ({test_labels.shape[0]})." top_one_xpl_indices = explanations.argmax(dim=1) - top_one_xpl_samples = torch.stack([self.train_dataset[i][0] for i in top_one_xpl_indices]) - - top_one_xpl_output = self.model(top_one_xpl_samples.to(self.device)) - top_one_xpl_pred = top_one_xpl_output.argmax(dim=1) + top_one_xpl_targets = torch.stack([train_dataset[i][1] for i in top_one_xpl_indices]) - score = (test_labels == top_one_xpl_pred) * 1.0 + score = (test_labels == top_one_xpl_targets) * 1.0 self.scores.append(score) def compute(self): diff --git a/src/utils/training/training.py b/src/utils/training/training.py index d2c39bea..6054f692 100644 --- a/src/utils/training/training.py +++ b/src/utils/training/training.py @@ -1,12 +1,10 @@ from typing import Callable, Optional +import lightning as L import torch from lightning import Trainer -import lightning as L - - class BasicLightningModule(L.LightningModule): def __init__( self, diff --git a/tests/metrics/test_localization_metrics.py b/tests/metrics/test_localization_metrics.py index 445baaf6..fcd7e485 100644 --- a/tests/metrics/test_localization_metrics.py +++ b/tests/metrics/test_localization_metrics.py @@ -26,7 +26,7 @@ def test_identical_class_metrics( dataset = request.getfixturevalue(dataset) tda = request.getfixturevalue(explanations) metric = IdenticalClass(model=model, train_dataset=dataset, device="cpu") - metric.update(test_labels=test_labels, explanations=tda) + metric.update(test_labels=test_labels, explanations=tda, train_dataset=dataset) score = metric.compute() # TODO: introduce a more meaningfull test, where the score is not zero assert score == expected_score