Skip to content

Commit

Permalink
downstream base + identical subclass
Browse files Browse the repository at this point in the history
  • Loading branch information
dilyabareeva committed Jun 12, 2024
1 parent 684e26c commit d9891e6
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 23 deletions.
34 changes: 34 additions & 0 deletions src/downstream_tasks/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from abc import ABC, abstractmethod

import torch


class DownstreamTaskEval(ABC):
def __init__(self, device: str = "cpu", *args, **kwargs):
"""
I think here it would be nice to pass a general receipt for the downstream task construction.
For example, we could pass
- a dataset constructor that generates the dataset for training from the original
dataset (either by modifying the labels, the data, or removing some samples);
- a metric that generates the final score: it could be either a Metric object from our library, or maybe
accuracy comparison.
:param device:
:param args:
:param kwargs:
"""
self.device = device

@abstractmethod
def evaluate(
self,
model: torch.nn.Module,
dataset: torch.utils.data.dataset,
*args,
**kwargs,
):
"""
Used to update the metric with new data.
"""

raise NotImplementedError
Empty file.
33 changes: 33 additions & 0 deletions src/metrics/localization/identical_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,36 @@ def state_dict(self, *args, **kwargs):
Used to return the metric state.
"""
return {"scores": self.scores}


class IdenticalSubclass(IdenticalClass):
def __init__(
self,
model: torch.nn.Module,
train_dataset: torch.utils.data.Dataset,
subclass_labels: torch.Tensor,
device,
*args,
**kwargs,
):
assert len(subclass_labels) == len(train_dataset), (
f"Number of subclass labels ({len(subclass_labels)}) "
f"does not match the number of train dataset samples ({len(train_dataset)})."
)
super().__init__(model, train_dataset, device, *args, **kwargs)
self.subclass_labels = subclass_labels

def update(self, test_labels: torch.Tensor, explanations: torch.Tensor):
"""
Used to implement metric-specific logic.
"""

assert (
test_labels.shape[0] == explanations.shape[0]
), f"Number of explanations ({explanations.shape[0]}) exceeds the number of test labels ({test_labels.shape[0]})."

top_one_xpl_indices = explanations.argmax(dim=1)
top_one_xpl_targets = torch.stack([self.subclass_labels[i] for i in top_one_xpl_indices])

score = (test_labels == top_one_xpl_targets) * 1.0
self.scores.append(score)
62 changes: 40 additions & 22 deletions src/utils/datasets/group_label_dataset.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,50 @@
from torch.utils.data.dataset import Dataset
import random
from typing import Dict, Literal, Optional, Union

CLASS_GROUP_BY = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
import torch
from torch.utils.data import Dataset

ClassToGroupLiterals = Literal["random"]

class GroupLabelDataset(Dataset):
def __init__(self, dataset, class_groups=None):

class GroupLabelDataset:
def __init__(
self,
dataset: Dataset,
n_classes: int = 10,
n_groups: int = 2,
class_to_group: Union[ClassToGroupLiterals, Dict[int, int]] = "random",
seed: Optional[int] = 27,
device: int = "cpu",
):
self.dataset = dataset
self.class_labels = [i for i in range(len(class_groups))]
self.inverse_transform = dataset.inverse_transform
if class_groups is None:
class_groups = CLASS_GROUP_BY
self.class_groups = class_groups
self.inverted_class_groups = self.invert_class_groups(class_groups)
self.n_classes = n_classes
self.classes = list(range(n_classes))
self.n_groups = n_groups
self.groups = list(range(n_groups))
self.generator = torch.Generator(device=device)
if class_to_group == "random":
# create a dictionary of class groups that assigns each class to a group
random.seed(seed)
self.class_to_group = {i: random.choice(self.groups) for i in range(n_classes)}
elif isinstance(class_to_group, dict):
self.validate_class_to_group(class_to_group)
self.class_to_group = class_to_group
else:
raise ValueError(f"Invalid class_to_group value: {class_to_group}")

def validate_class_to_group(self, class_to_group):
assert len(class_to_group) == self.n_classes
assert all([g in self.groups for g in self.class_to_group.values()])

def __getitem__(self, index):
x, y = self.dataset[index]
g = self.inverted_class_groups[y]
return x, (g, y)
g = self.class_to_group[y]
return x, g

def get_subclass_label(self, index):
_, y = self.dataset[index]
return y

def __len__(self):
return len(self.dataset)

@staticmethod
def invert_class_groups(groups):
inverted_class_groups = {}
for g, group in enumerate(groups):
intersection = inverted_class_groups.keys() & group
if len(intersection) > 0:
raise ValueError("Class indices %s are present in multiple groups." % (str(intersection)))
inverted_class_groups.update({cls: g for cls in group})
return inverted_class_groups
19 changes: 19 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch.utils.data import TensorDataset

from tests.models import LeNet
from utils.datasets.group_label_dataset import GroupLabelDataset

MNIST_IMAGE_SIZE = 28
BATCH_SIZE = 124
Expand Down Expand Up @@ -57,6 +58,24 @@ def load_mnist_dataset():
return dataset


@pytest.fixture()
def load_mnist_labels():
y_batch = np.loadtxt("tests/assets/mnist_test_suite_1/mnist_y").astype(int)[:MINI_BATCH_SIZE]
return torch.tensor(y_batch).long()


@pytest.fixture()
def load_grouped_mnist_dataset():
x_batch = (
np.loadtxt("tests/assets/mnist_test_suite_1/mnist_x")
.astype(float)
.reshape((BATCH_SIZE, 1, MNIST_IMAGE_SIZE, MNIST_IMAGE_SIZE))
)[:MINI_BATCH_SIZE]
y_batch = np.loadtxt("tests/assets/mnist_test_suite_1/mnist_y").astype(int)[:MINI_BATCH_SIZE]
dataset = TensorDataset(torch.tensor(x_batch).float(), torch.tensor(y_batch).long())
return GroupLabelDataset(dataset, n_classes=10, n_groups=2, class_to_group="random", seed=27, device="cpu")


@pytest.fixture()
def load_mnist_dataloader():
"""Load a batch of MNIST digits: inputs and outputs to use for testing."""
Expand Down
35 changes: 34 additions & 1 deletion tests/metrics/test_localization_metrics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import pytest

from src.metrics.localization.identical_class import IdenticalClass
from src.metrics.localization.identical_class import (
IdenticalClass,
IdenticalSubclass,
)


@pytest.mark.localization_metrics
Expand Down Expand Up @@ -34,3 +37,33 @@ def test_identical_class_metrics(
# With a big test dataset, the probability of failing a truly random test
# should diminish.
assert score == expected_score


@pytest.mark.localization_metrics
@pytest.mark.parametrize(
"test_id, model, dataset, subclass_labels, test_labels, batch_size, explanations, expected_score",
[
(
"mnist",
"load_mnist_model",
"load_grouped_mnist_dataset",
"load_mnist_labels",
"load_mnist_test_labels_1",
8,
"load_mnist_explanations_1",
0.1,
),
],
)
def test_identical_subclass_metrics(
test_id, model, dataset, subclass_labels, test_labels, batch_size, explanations, expected_score, request
):
model = request.getfixturevalue(model)
test_labels = request.getfixturevalue(test_labels)
subclass_labels = request.getfixturevalue(subclass_labels)
dataset = request.getfixturevalue(dataset)
tda = request.getfixturevalue(explanations)
metric = IdenticalSubclass(model=model, train_dataset=dataset, subclass_labels=subclass_labels, device="cpu")
metric.update(test_labels=test_labels, explanations=tda)
score = metric.compute()
assert score == expected_score

0 comments on commit d9891e6

Please sign in to comment.