-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
downstream base + identical subclass
- Loading branch information
1 parent
684e26c
commit d9891e6
Showing
6 changed files
with
160 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
import torch | ||
|
||
|
||
class DownstreamTaskEval(ABC): | ||
def __init__(self, device: str = "cpu", *args, **kwargs): | ||
""" | ||
I think here it would be nice to pass a general receipt for the downstream task construction. | ||
For example, we could pass | ||
- a dataset constructor that generates the dataset for training from the original | ||
dataset (either by modifying the labels, the data, or removing some samples); | ||
- a metric that generates the final score: it could be either a Metric object from our library, or maybe | ||
accuracy comparison. | ||
:param device: | ||
:param args: | ||
:param kwargs: | ||
""" | ||
self.device = device | ||
|
||
@abstractmethod | ||
def evaluate( | ||
self, | ||
model: torch.nn.Module, | ||
dataset: torch.utils.data.dataset, | ||
*args, | ||
**kwargs, | ||
): | ||
""" | ||
Used to update the metric with new data. | ||
""" | ||
|
||
raise NotImplementedError |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,32 +1,50 @@ | ||
from torch.utils.data.dataset import Dataset | ||
import random | ||
from typing import Dict, Literal, Optional, Union | ||
|
||
CLASS_GROUP_BY = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]] | ||
import torch | ||
from torch.utils.data import Dataset | ||
|
||
ClassToGroupLiterals = Literal["random"] | ||
|
||
class GroupLabelDataset(Dataset): | ||
def __init__(self, dataset, class_groups=None): | ||
|
||
class GroupLabelDataset: | ||
def __init__( | ||
self, | ||
dataset: Dataset, | ||
n_classes: int = 10, | ||
n_groups: int = 2, | ||
class_to_group: Union[ClassToGroupLiterals, Dict[int, int]] = "random", | ||
seed: Optional[int] = 27, | ||
device: int = "cpu", | ||
): | ||
self.dataset = dataset | ||
self.class_labels = [i for i in range(len(class_groups))] | ||
self.inverse_transform = dataset.inverse_transform | ||
if class_groups is None: | ||
class_groups = CLASS_GROUP_BY | ||
self.class_groups = class_groups | ||
self.inverted_class_groups = self.invert_class_groups(class_groups) | ||
self.n_classes = n_classes | ||
self.classes = list(range(n_classes)) | ||
self.n_groups = n_groups | ||
self.groups = list(range(n_groups)) | ||
self.generator = torch.Generator(device=device) | ||
if class_to_group == "random": | ||
# create a dictionary of class groups that assigns each class to a group | ||
random.seed(seed) | ||
self.class_to_group = {i: random.choice(self.groups) for i in range(n_classes)} | ||
elif isinstance(class_to_group, dict): | ||
self.validate_class_to_group(class_to_group) | ||
self.class_to_group = class_to_group | ||
else: | ||
raise ValueError(f"Invalid class_to_group value: {class_to_group}") | ||
|
||
def validate_class_to_group(self, class_to_group): | ||
assert len(class_to_group) == self.n_classes | ||
assert all([g in self.groups for g in self.class_to_group.values()]) | ||
|
||
def __getitem__(self, index): | ||
x, y = self.dataset[index] | ||
g = self.inverted_class_groups[y] | ||
return x, (g, y) | ||
g = self.class_to_group[y] | ||
return x, g | ||
|
||
def get_subclass_label(self, index): | ||
_, y = self.dataset[index] | ||
return y | ||
|
||
def __len__(self): | ||
return len(self.dataset) | ||
|
||
@staticmethod | ||
def invert_class_groups(groups): | ||
inverted_class_groups = {} | ||
for g, group in enumerate(groups): | ||
intersection = inverted_class_groups.keys() & group | ||
if len(intersection) > 0: | ||
raise ValueError("Class indices %s are present in multiple groups." % (str(intersection))) | ||
inverted_class_groups.update({cls: g for cls in group}) | ||
return inverted_class_groups |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters