Skip to content

Commit

Permalink
datasets small refactorings
Browse files Browse the repository at this point in the history
  • Loading branch information
dilyabareeva committed Apr 8, 2024
1 parent c68616c commit 945f752
Show file tree
Hide file tree
Showing 13 changed files with 216 additions and 202 deletions.
46 changes: 0 additions & 46 deletions src/datasets/corrupt_label_dataset.py

This file was deleted.

35 changes: 0 additions & 35 deletions src/datasets/group_label_dataset.py

This file was deleted.

101 changes: 0 additions & 101 deletions src/datasets/mark_dataset.py

This file was deleted.

3 changes: 1 addition & 2 deletions src/explainers/feature_kernel_explainer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import os
from typing import List, Union
from typing import Union

import torch

from datasets.activation_dataset import ActivationDataset
from src.explainers.base import Explainer
from utils.cache import ActivationsCache as AC

Expand Down
31 changes: 27 additions & 4 deletions src/utils/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from torch import Tensor
from torch.utils.data import DataLoader

from src.datasets.activation_dataset import ActivationDataset
from utils.common import _get_module_from_name
from utils.datasets.activation_dataset import ActivationDataset


class Cache:
Expand All @@ -20,16 +20,39 @@ class Cache:
def __init__(self):
pass

def save(self, **kwargs) -> None:
@staticmethod
def save(**kwargs) -> None:
raise NotImplementedError

def load(self, **kwargs) -> Any:
@staticmethod
def load(**kwargs) -> Any:
raise NotImplementedError

def exists(self, **kwargs) -> bool:
@staticmethod
def exists(**kwargs) -> bool:
raise NotImplementedError


class IndicesCache(Cache):
def __init__(self):
super().__init__()

@staticmethod
def save(path, file_id, indices) -> None:
file_path = os.path.join(path, file_id)
return torch.save(indices, file_path)

@staticmethod
def load(path, file_id) -> Tensor:
file_path = os.path.join(path, file_id)
return torch.load(file_path)

@staticmethod
def exists(path, file_id) -> bool:
file_path = os.path.join(path, file_id)
return os.path.isfile(file_path)


class ActivationsCache(Cache):
"""
Inspired by https://github.com/pytorch/captum/blob/master/captum/_utils/av.py.
Expand Down
File renamed without changes.
File renamed without changes.
56 changes: 56 additions & 0 deletions src/utils/datasets/corrupt_label_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import random

import torch
from torch.utils.data.dataset import Dataset

from utils.cache import IndicesCache as IC


class CorruptLabelDataset(Dataset):
def __init__(
self,
dataset: Dataset,
dataset_id: str,
class_labels,
classes,
inverse_transform,
cache_path="./datasets",
p=0.3,
):
super().__init__()
self.dataset = dataset
self.class_labels = class_labels
self.classes = classes
self.inverse_transform = inverse_transform
self.p = p

if IC.exists(path=cache_path, file_id=f"{dataset_id}_corrupt_ids"):
self.corrupt_indices = IC.load(path=cache_path, file_id=f"{dataset_id}_corrupt_ids")
else:
self.corrupt_indices = self.get_corrupt_sample_ids()
IC.save(path=cache_path, file_id=f"{dataset_id}_corrupt_ids", indices=self.corrupt_indices)

self.corrupt_labels = [self.corrupt_label(self.dataset[i][1]) for i in self.corrupt_indices]
IC.save(path=cache_path, file_id=f"{dataset_id}_corrupt_labels", indices=self.corrupt_labels)

def get_corrupt_sample_ids(self):
corrupt = torch.rand(len(self.dataset))
return torch.where(corrupt < self.p)[0]

def __getitem__(self, item):
x, y_true = self.dataset[item]
y = y_true
if item in self.corrupt_indices:
y = int(
self.corrupt_labels[torch.squeeze((self.corrupt_indices == item).nonzero())]
) # TODO: not the most elegant solution
return x, (y, y_true)

def __len__(self):
return len(self.dataset)

def corrupt_label(self, y):
classes = [cls for cls in self.classes if cls != y]
random.seed(27)
corrupted_class = random.choice(classes)
return corrupted_class
32 changes: 32 additions & 0 deletions src/utils/datasets/group_label_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from torch.utils.data.dataset import Dataset

CLASS_GROUP_BY = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]


class GroupLabelDataset(Dataset):
def __init__(self, dataset, class_groups=None):
self.dataset = dataset
self.class_labels = [i for i in range(len(class_groups))]
self.inverse_transform = dataset.inverse_transform
if class_groups is None:
class_groups = CLASS_GROUP_BY
self.class_groups = class_groups
self.inverted_class_groups = self.invert_class_groups(class_groups)

def __getitem__(self, index):
x, y = self.dataset[index]
g = self.inverted_class_groups[y]
return x, (g, y)

def __len__(self):
return len(self.dataset)

@staticmethod
def invert_class_groups(groups):
inverted_class_groups = {}
for g, group in enumerate(groups):
intersection = inverted_class_groups.keys() & group
if len(intersection) > 0:
raise ValueError("Class indices %s are present in multiple groups." % (str(intersection)))
inverted_class_groups.update({cls: g for cls in group})
return inverted_class_groups
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import torch
from torch.utils.data.dataset import Dataset


class RestrictedDataset(Dataset):
def __init__(self, dataset, indices, return_indices=False):
class IndexedSubset(Dataset):
def __init__(self, dataset: torch.utils.data.Dataset, indices, return_indices=False):
self.dataset = dataset
self.indices = indices
self.return_indices = return_indices
if hasattr(dataset, "name"):
self.name = dataset.name
else:
self.name = dataset.dataset.name

def __len__(self):
return len(self.indices)
Expand Down
Loading

0 comments on commit 945f752

Please sign in to comment.