From d9891e69e8a4bb12018ce767e3b50a8216ab5926 Mon Sep 17 00:00:00 2001 From: Dilyara Bareeva Date: Wed, 12 Jun 2024 11:12:04 +0200 Subject: [PATCH] downstream base + identical subclass --- src/downstream_tasks/base.py | 34 ++++++++++ .../subclass_identification.py | 0 src/metrics/localization/identical_class.py | 33 ++++++++++ src/utils/datasets/group_label_dataset.py | 62 ++++++++++++------- tests/conftest.py | 19 ++++++ tests/metrics/test_localization_metrics.py | 35 ++++++++++- 6 files changed, 160 insertions(+), 23 deletions(-) create mode 100644 src/downstream_tasks/base.py create mode 100644 src/downstream_tasks/subclass_identification.py diff --git a/src/downstream_tasks/base.py b/src/downstream_tasks/base.py new file mode 100644 index 00000000..bba84cd8 --- /dev/null +++ b/src/downstream_tasks/base.py @@ -0,0 +1,34 @@ +from abc import ABC, abstractmethod + +import torch + + +class DownstreamTaskEval(ABC): + def __init__(self, device: str = "cpu", *args, **kwargs): + """ + I think here it would be nice to pass a general receipt for the downstream task construction. + For example, we could pass + - a dataset constructor that generates the dataset for training from the original + dataset (either by modifying the labels, the data, or removing some samples); + - a metric that generates the final score: it could be either a Metric object from our library, or maybe + accuracy comparison. + + :param device: + :param args: + :param kwargs: + """ + self.device = device + + @abstractmethod + def evaluate( + self, + model: torch.nn.Module, + dataset: torch.utils.data.dataset, + *args, + **kwargs, + ): + """ + Used to update the metric with new data. + """ + + raise NotImplementedError diff --git a/src/downstream_tasks/subclass_identification.py b/src/downstream_tasks/subclass_identification.py new file mode 100644 index 00000000..e69de29b diff --git a/src/metrics/localization/identical_class.py b/src/metrics/localization/identical_class.py index 63e06638..84ebac8c 100644 --- a/src/metrics/localization/identical_class.py +++ b/src/metrics/localization/identical_class.py @@ -53,3 +53,36 @@ def state_dict(self, *args, **kwargs): Used to return the metric state. """ return {"scores": self.scores} + + +class IdenticalSubclass(IdenticalClass): + def __init__( + self, + model: torch.nn.Module, + train_dataset: torch.utils.data.Dataset, + subclass_labels: torch.Tensor, + device, + *args, + **kwargs, + ): + assert len(subclass_labels) == len(train_dataset), ( + f"Number of subclass labels ({len(subclass_labels)}) " + f"does not match the number of train dataset samples ({len(train_dataset)})." + ) + super().__init__(model, train_dataset, device, *args, **kwargs) + self.subclass_labels = subclass_labels + + def update(self, test_labels: torch.Tensor, explanations: torch.Tensor): + """ + Used to implement metric-specific logic. + """ + + assert ( + test_labels.shape[0] == explanations.shape[0] + ), f"Number of explanations ({explanations.shape[0]}) exceeds the number of test labels ({test_labels.shape[0]})." + + top_one_xpl_indices = explanations.argmax(dim=1) + top_one_xpl_targets = torch.stack([self.subclass_labels[i] for i in top_one_xpl_indices]) + + score = (test_labels == top_one_xpl_targets) * 1.0 + self.scores.append(score) diff --git a/src/utils/datasets/group_label_dataset.py b/src/utils/datasets/group_label_dataset.py index ea71d9f3..2901fadd 100644 --- a/src/utils/datasets/group_label_dataset.py +++ b/src/utils/datasets/group_label_dataset.py @@ -1,32 +1,50 @@ -from torch.utils.data.dataset import Dataset +import random +from typing import Dict, Literal, Optional, Union -CLASS_GROUP_BY = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]] +import torch +from torch.utils.data import Dataset +ClassToGroupLiterals = Literal["random"] -class GroupLabelDataset(Dataset): - def __init__(self, dataset, class_groups=None): + +class GroupLabelDataset: + def __init__( + self, + dataset: Dataset, + n_classes: int = 10, + n_groups: int = 2, + class_to_group: Union[ClassToGroupLiterals, Dict[int, int]] = "random", + seed: Optional[int] = 27, + device: int = "cpu", + ): self.dataset = dataset - self.class_labels = [i for i in range(len(class_groups))] - self.inverse_transform = dataset.inverse_transform - if class_groups is None: - class_groups = CLASS_GROUP_BY - self.class_groups = class_groups - self.inverted_class_groups = self.invert_class_groups(class_groups) + self.n_classes = n_classes + self.classes = list(range(n_classes)) + self.n_groups = n_groups + self.groups = list(range(n_groups)) + self.generator = torch.Generator(device=device) + if class_to_group == "random": + # create a dictionary of class groups that assigns each class to a group + random.seed(seed) + self.class_to_group = {i: random.choice(self.groups) for i in range(n_classes)} + elif isinstance(class_to_group, dict): + self.validate_class_to_group(class_to_group) + self.class_to_group = class_to_group + else: + raise ValueError(f"Invalid class_to_group value: {class_to_group}") + + def validate_class_to_group(self, class_to_group): + assert len(class_to_group) == self.n_classes + assert all([g in self.groups for g in self.class_to_group.values()]) def __getitem__(self, index): x, y = self.dataset[index] - g = self.inverted_class_groups[y] - return x, (g, y) + g = self.class_to_group[y] + return x, g + + def get_subclass_label(self, index): + _, y = self.dataset[index] + return y def __len__(self): return len(self.dataset) - - @staticmethod - def invert_class_groups(groups): - inverted_class_groups = {} - for g, group in enumerate(groups): - intersection = inverted_class_groups.keys() & group - if len(intersection) > 0: - raise ValueError("Class indices %s are present in multiple groups." % (str(intersection))) - inverted_class_groups.update({cls: g for cls in group}) - return inverted_class_groups diff --git a/tests/conftest.py b/tests/conftest.py index 15e4033b..bf21469d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,7 @@ from torch.utils.data import TensorDataset from tests.models import LeNet +from utils.datasets.group_label_dataset import GroupLabelDataset MNIST_IMAGE_SIZE = 28 BATCH_SIZE = 124 @@ -57,6 +58,24 @@ def load_mnist_dataset(): return dataset +@pytest.fixture() +def load_mnist_labels(): + y_batch = np.loadtxt("tests/assets/mnist_test_suite_1/mnist_y").astype(int)[:MINI_BATCH_SIZE] + return torch.tensor(y_batch).long() + + +@pytest.fixture() +def load_grouped_mnist_dataset(): + x_batch = ( + np.loadtxt("tests/assets/mnist_test_suite_1/mnist_x") + .astype(float) + .reshape((BATCH_SIZE, 1, MNIST_IMAGE_SIZE, MNIST_IMAGE_SIZE)) + )[:MINI_BATCH_SIZE] + y_batch = np.loadtxt("tests/assets/mnist_test_suite_1/mnist_y").astype(int)[:MINI_BATCH_SIZE] + dataset = TensorDataset(torch.tensor(x_batch).float(), torch.tensor(y_batch).long()) + return GroupLabelDataset(dataset, n_classes=10, n_groups=2, class_to_group="random", seed=27, device="cpu") + + @pytest.fixture() def load_mnist_dataloader(): """Load a batch of MNIST digits: inputs and outputs to use for testing.""" diff --git a/tests/metrics/test_localization_metrics.py b/tests/metrics/test_localization_metrics.py index 92618cca..bef4839f 100644 --- a/tests/metrics/test_localization_metrics.py +++ b/tests/metrics/test_localization_metrics.py @@ -1,6 +1,9 @@ import pytest -from src.metrics.localization.identical_class import IdenticalClass +from src.metrics.localization.identical_class import ( + IdenticalClass, + IdenticalSubclass, +) @pytest.mark.localization_metrics @@ -34,3 +37,33 @@ def test_identical_class_metrics( # With a big test dataset, the probability of failing a truly random test # should diminish. assert score == expected_score + + +@pytest.mark.localization_metrics +@pytest.mark.parametrize( + "test_id, model, dataset, subclass_labels, test_labels, batch_size, explanations, expected_score", + [ + ( + "mnist", + "load_mnist_model", + "load_grouped_mnist_dataset", + "load_mnist_labels", + "load_mnist_test_labels_1", + 8, + "load_mnist_explanations_1", + 0.1, + ), + ], +) +def test_identical_subclass_metrics( + test_id, model, dataset, subclass_labels, test_labels, batch_size, explanations, expected_score, request +): + model = request.getfixturevalue(model) + test_labels = request.getfixturevalue(test_labels) + subclass_labels = request.getfixturevalue(subclass_labels) + dataset = request.getfixturevalue(dataset) + tda = request.getfixturevalue(explanations) + metric = IdenticalSubclass(model=model, train_dataset=dataset, subclass_labels=subclass_labels, device="cpu") + metric.update(test_labels=test_labels, explanations=tda) + score = metric.compute() + assert score == expected_score