Skip to content

Commit

Permalink
fix identical class
Browse files Browse the repository at this point in the history
  • Loading branch information
dilyabareeva committed Jun 4, 2024
1 parent 1e91255 commit a63995f
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 10 deletions.
2 changes: 1 addition & 1 deletion src/metrics/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@

import torch

from src.utils.cache import ExplanationsCache as EC
from src.utils.explanations import (
BatchedCachedExplanations,
TensorExplanations,
)
from src.utils.cache import ExplanationsCache as EC


def function_example(
Expand Down
8 changes: 3 additions & 5 deletions src/metrics/localization/identical_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def update(
self,
test_labels: torch.Tensor,
explanations: torch.Tensor,
train_dataset: torch.utils.data.Dataset,
):
"""
Used to implement metric-specific logic.
Expand All @@ -29,12 +30,9 @@ def update(
), f"Number of explanations ({explanations.shape[0]}) exceeds the number of test labels ({test_labels.shape[0]})."

top_one_xpl_indices = explanations.argmax(dim=1)
top_one_xpl_samples = torch.stack([self.train_dataset[i][0] for i in top_one_xpl_indices])

top_one_xpl_output = self.model(top_one_xpl_samples.to(self.device))
top_one_xpl_pred = top_one_xpl_output.argmax(dim=1)
top_one_xpl_targets = torch.stack([train_dataset[i][1] for i in top_one_xpl_indices])

score = (test_labels == top_one_xpl_pred) * 1.0
score = (test_labels == top_one_xpl_targets) * 1.0
self.scores.append(score)

def compute(self):
Expand Down
4 changes: 1 addition & 3 deletions src/utils/training/training.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from typing import Callable, Optional

import lightning as L
import torch
from lightning import Trainer


import lightning as L


class BasicLightningModule(L.LightningModule):
def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion tests/metrics/test_localization_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_identical_class_metrics(
dataset = request.getfixturevalue(dataset)
tda = request.getfixturevalue(explanations)
metric = IdenticalClass(model=model, train_dataset=dataset, device="cpu")
metric.update(test_labels=test_labels, explanations=tda)
metric.update(test_labels=test_labels, explanations=tda, train_dataset=dataset)
score = metric.compute()
# TODO: introduce a more meaningfull test, where the score is not zero
assert score == expected_score

0 comments on commit a63995f

Please sign in to comment.