Skip to content

Commit

Permalink
dataset reorganization
Browse files Browse the repository at this point in the history
  • Loading branch information
dilyabareeva committed Apr 30, 2024
1 parent 4479fe0 commit 6991d7c
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 43 deletions.
2 changes: 1 addition & 1 deletion src/utils/datasets/corrupt_label_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from torch.utils.data.dataset import Dataset

from utils.cache import IndicesCache as IC
from utils.cache import TensorCache as IC


class CorruptLabelDataset(Dataset):
Expand Down
1 change: 0 additions & 1 deletion src/utils/datasets/dataset_wrapper.py

This file was deleted.

32 changes: 32 additions & 0 deletions src/utils/datasets/label_transform_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from torch.utils.data.dataset import Dataset

CLASS_GROUP_BY = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
TRANSFORM_DICT = {0: 0, 1: 0, 2: 1, 3: 1, 4: 2, 5: 2, 6: 3, 7: 3, 8: 4, 9: 4}


class LabelTransformDataset(Dataset):
"""
Meant to replace: GroupLabelDataset, CorruptLabelDataset
"""

def __init__(self, dataset, transform_dict=None):
self.dataset = dataset
self.inverse_transform = dataset.inverse_transform
if transform_dict is None:
transform_dict = TRANSFORM_DICT
self.transform_dict = transform_dict
self.inv_transform_dict = self.invert_labels_dict(transform_dict)
self.class_labels = list(self.inv_transform_dict.keys())

def __getitem__(self, index):
x, y = self.dataset[index]
g = self.transform_dict[y]
return x, (g, y)

def __len__(self):
return len(self.dataset)

@staticmethod
def invert_labels_dict(labels_dict):
return {v: [k for k in labels_dict if labels_dict[k] == v] for v in set(labels_dict.values())}
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import torch
from torch.utils.data.dataset import Dataset

from utils.cache import IndicesCache as IC
from utils.cache import TensorCache as IC
from utils.transforms import mark_image_contour_and_square


class MarkDataset(Dataset):
class SampleTransformDataset(Dataset):
def __init__(
self,
dataset: torch.utils.data.Dataset,
Expand All @@ -30,7 +31,7 @@ def __init__(
if mark_fn is not None:
self.mark_image = mark_fn
else:
self.mark_image = self.mark_image_contour_and_square
self.mark_image = mark_image_contour_and_square

if IC.exists(path=cache_path, file_id=f"{dataset_id}_mark_ids"):
self.mark_indices = IC.load(path="./datasets", file_id=f"{dataset_id}_mark_ids")
Expand All @@ -55,41 +56,3 @@ def get_mark_sample_ids(self):
corrupt = torch.rand(len(self.dataset))
indices = torch.where((corrupt < self.mark_prob) & (in_cls))[0]
return torch.tensor(indices, dtype=torch.int)

@staticmethod
def mark_image_contour(x):
# TODO: make controur, middle square and combined masks a constant somewhere else
mask = torch.zeros_like(x[0])
mask[:2, :] = 1.0
mask[-2:, :] = 1.0
mask[:, -2:] = 1.0
mask[:, :2] = 1.0
x[0] = torch.ones_like(x[0]) * mask + x[0] * (1 - mask)
if x.shape[0] > 1:
x[1:] = torch.zeros_like(x[1:]) * mask + x[1:] * (1 - mask)

return x.numpy().transpose(1, 2, 0)

@staticmethod
def mark_image_middle_square(x):
mask = torch.zeros_like(x[0])
mid = int(x.shape[-1] / 2)
mask[(mid - 4) : (mid + 4), (mid - 4) : (mid + 4)] = 1.0
x[0] = torch.ones_like(x[0]) * mask + x[0] * (1 - mask)
if x.shape[0] > 1:
x[1:] = torch.zeros_like(x[1:]) * mask + x[1:] * (1 - mask)
return x.numpy().transpose(1, 2, 0)

@staticmethod
def mark_image_contour_and_square(x):
mask = torch.zeros_like(x[0])
mid = int(x.shape[-1] / 2)
mask[mid - 3 : mid + 3, mid - 3 : mid + 3] = 1.0
mask[:2, :] = 1.0
mask[-2:, :] = 1.0
mask[:, -2:] = 1.0
mask[:, :2] = 1.0
x[0] = torch.ones_like(x[0]) * mask + x[0] * (1 - mask)
if x.shape[0] > 1:
x[1:] = torch.zeros_like(x[1:]) * mask + x[1:] * (1 - mask)
return x.numpy().transpose(1, 2, 0)
39 changes: 39 additions & 0 deletions src/utils/transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch


def mark_image_contour_and_square(x):
mask = torch.zeros_like(x[0])
mid = int(x.shape[-1] / 2)
mask[mid - 3 : mid + 3, mid - 3 : mid + 3] = 1.0
mask[:2, :] = 1.0
mask[-2:, :] = 1.0
mask[:, -2:] = 1.0
mask[:, :2] = 1.0
x[0] = torch.ones_like(x[0]) * mask + x[0] * (1 - mask)
if x.shape[0] > 1:
x[1:] = torch.zeros_like(x[1:]) * mask + x[1:] * (1 - mask)
return x.numpy().transpose(1, 2, 0)


def mark_image_middle_square(x):
mask = torch.zeros_like(x[0])
mid = int(x.shape[-1] / 2)
mask[(mid - 4) : (mid + 4), (mid - 4) : (mid + 4)] = 1.0
x[0] = torch.ones_like(x[0]) * mask + x[0] * (1 - mask)
if x.shape[0] > 1:
x[1:] = torch.zeros_like(x[1:]) * mask + x[1:] * (1 - mask)
return x.numpy().transpose(1, 2, 0)


def mark_image_contour(x):
# TODO: make controur, middle square and combined masks a constant somewhere else
mask = torch.zeros_like(x[0])
mask[:2, :] = 1.0
mask[-2:, :] = 1.0
mask[:, -2:] = 1.0
mask[:, :2] = 1.0
x[0] = torch.ones_like(x[0]) * mask + x[0] * (1 - mask)
if x.shape[0] > 1:
x[1:] = torch.zeros_like(x[1:]) * mask + x[1:] * (1 - mask)

return x.numpy().transpose(1, 2, 0)

0 comments on commit 6991d7c

Please sign in to comment.