From 7a983f70665425bf4a62550edfe361ace2abc56f Mon Sep 17 00:00:00 2001 From: Dilyara Bareeva Date: Wed, 26 Jun 2024 16:20:02 +0200 Subject: [PATCH] toy benchmark introduction --- src/downstream_tasks/base.py | 34 --- .../subclass_identification.py | 133 ----------- .../__init__.py | 0 src/toy_benchmarks/base.py | 61 +++++ src/toy_benchmarks/subclass_detection.py | 215 ++++++++++++++++++ src/utils/datasets/group_label_dataset.py | 41 +++- .../__init__.py | 0 .../test_subclass_identification.py | 28 +-- 8 files changed, 320 insertions(+), 192 deletions(-) delete mode 100644 src/downstream_tasks/base.py delete mode 100644 src/downstream_tasks/subclass_identification.py rename src/{downstream_tasks => toy_benchmarks}/__init__.py (100%) create mode 100644 src/toy_benchmarks/base.py create mode 100644 src/toy_benchmarks/subclass_detection.py rename tests/{downstream_tasks => toy_benchmarks}/__init__.py (100%) rename tests/{downstream_tasks => toy_benchmarks}/test_subclass_identification.py (84%) diff --git a/src/downstream_tasks/base.py b/src/downstream_tasks/base.py deleted file mode 100644 index 5b07718a..00000000 --- a/src/downstream_tasks/base.py +++ /dev/null @@ -1,34 +0,0 @@ -from abc import ABC, abstractmethod - -import torch - - -class DownstreamTaskEval(ABC): - def __init__(self, device: str = "cpu", *args, **kwargs): - """ - I think here it would be nice to pass a general receipt for the downstream task construction. - For example, we could pass - - a dataset constructor that generates the dataset for training from the original - dataset (either by modifying the labels, the data, or removing some samples); - - a metric that generates the final score: it could be either a Metric object from our library, or maybe - accuracy comparison. - - :param device: - :param args: - :param kwargs: - """ - self.device = device - - @abstractmethod - def evaluate( - self, - model: torch.nn.Module, - dataset: torch.utils.data.Dataset, - *args, - **kwargs, - ): - """ - Used to update the metric with new data. - """ - - raise NotImplementedError diff --git a/src/downstream_tasks/subclass_identification.py b/src/downstream_tasks/subclass_identification.py deleted file mode 100644 index c2d542a2..00000000 --- a/src/downstream_tasks/subclass_identification.py +++ /dev/null @@ -1,133 +0,0 @@ -import os -from typing import Callable, Dict, Optional, Union - -import lightning as L -import torch - -from src.explainers.functional import ExplainFunc -from src.explainers.wrappers.captum_influence import captum_similarity_explain -from src.metrics.localization.identical_class import IdenticalClass -from src.utils.datasets.group_label_dataset import ( - ClassToGroupLiterals, - GroupLabelDataset, -) -from src.utils.training.trainer import BaseTrainer, Trainer - - -class SubclassIdentification: - def __init__( - self, - model: torch.nn.Module, - optimizer: Callable, - lr: float, - criterion: torch.nn.modules.loss._Loss, - scheduler: Optional[Callable] = None, - optimizer_kwargs: Optional[dict] = None, - scheduler_kwargs: Optional[dict] = None, - device: str = "cpu", - *args, - **kwargs, - ): - self.device = device - self.trainer: Optional[BaseTrainer] = Trainer.from_arguments( - model=model, - optimizer=optimizer, - lr=lr, - scheduler=scheduler, - criterion=criterion, - optimizer_kwargs=optimizer_kwargs, - scheduler_kwargs=scheduler_kwargs, - ) - - @classmethod - def from_pl_module(cls, model: torch.nn.Module, pl_module: L.LightningModule, device: str = "cpu", *args, **kwargs): - obj = cls.__new__(cls) - super(SubclassIdentification, obj).__init__() - obj.device = device - obj.trainer = Trainer.from_lightning_module(model, pl_module) - return obj - - @classmethod - def from_trainer(cls, trainer: BaseTrainer, device: str = "cpu", *args, **kwargs): - obj = cls.__new__(cls) - super(SubclassIdentification, obj).__init__() - if isinstance(trainer, BaseTrainer): - obj.trainer = trainer - obj.device = device - else: - raise ValueError("trainer must be an instance of BaseTrainer") - return obj - - def evaluate( - self, - train_dataset: torch.utils.data.Dataset, - val_dataset: Optional[torch.utils.data.Dataset] = None, - n_classes: int = 10, - n_groups: int = 2, - class_to_group: Union[ClassToGroupLiterals, Dict[int, int]] = "random", - explain_fn: ExplainFunc = captum_similarity_explain, - explain_kwargs: Optional[dict] = None, - trainer_kwargs: Optional[dict] = None, - cache_dir: str = "./cache", - model_id: str = "default_model_id", - run_id: str = "default_subclass_identification", - seed: Optional[int] = 27, - batch_size: int = 8, - device: str = "cpu", - *args, - **kwargs, - ): - if self.trainer is None: - raise ValueError( - "Trainer not initialized. Please initialize trainer using init_trainer_from_lightning_module or " - "init_trainer_from_train_arguments" - ) - if explain_kwargs is None: - explain_kwargs = {} - if trainer_kwargs is None: - trainer_kwargs = {} - - grouped_dataset = GroupLabelDataset( - dataset=train_dataset, - n_classes=n_classes, - n_groups=n_groups, - class_to_group=class_to_group, - seed=seed, - ) - grouped_train_loader = torch.utils.data.DataLoader(grouped_dataset, batch_size=batch_size) - original_train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size) - if val_dataset: - grouped_val_dataset = GroupLabelDataset( - dataset=train_dataset, - n_classes=n_classes, - n_groups=n_groups, - class_to_group=grouped_dataset.class_to_group, - seed=seed, - ) - val_loader: Optional[torch.utils.data.DataLoader] = torch.utils.data.DataLoader( - grouped_val_dataset, batch_size=batch_size - ) - else: - val_loader = None - - model = self.trainer.fit( - train_loader=grouped_train_loader, - val_loader=val_loader, - trainer_kwargs=trainer_kwargs, - ) - metric = IdenticalClass(model=model, train_dataset=train_dataset, device="cpu") - - for input, labels in original_train_loader: - input, labels = input.to(device), labels.to(device) - explanations = explain_fn( - model=model, - model_id=model_id, - cache_dir=os.path.join(cache_dir, run_id), - train_dataset=train_dataset, - test_tensor=input, - device=device, - **explain_kwargs, - ) - metric.update(labels, explanations) - - return metric.compute() diff --git a/src/downstream_tasks/__init__.py b/src/toy_benchmarks/__init__.py similarity index 100% rename from src/downstream_tasks/__init__.py rename to src/toy_benchmarks/__init__.py diff --git a/src/toy_benchmarks/base.py b/src/toy_benchmarks/base.py new file mode 100644 index 00000000..b3d7b530 --- /dev/null +++ b/src/toy_benchmarks/base.py @@ -0,0 +1,61 @@ +from abc import ABC, abstractmethod + + +class ToyBenchmark(ABC): + def __init__(self, device: str = "cpu", *args, **kwargs): + """ + I think here it would be nice to pass a general receipt for the downstream task construction. + For example, we could pass + - a dataset constructor that generates the dataset for training from the original + dataset (either by modifying the labels, the data, or removing some samples); + - a metric that generates the final score: it could be either a Metric object from our library, or maybe + accuracy comparison. + + :param device: + :param args: + :param kwargs: + """ + self.device = device + + @classmethod + @abstractmethod + def generate(cls, *args, **kwargs): + """ + This method should generate all the benchmark components and persist them in the instance. + """ + raise NotImplementedError + + @classmethod + @abstractmethod + def load(cls, path: str, *args, **kwargs): + """ + This method should load the benchmark components from a file and persist them in the instance. + """ + raise NotImplementedError + + @classmethod + @abstractmethod + def assemble(cls, *args, **kwargs): + """ + This method should assemble the benchmark components from arguments and persist them in the instance. + """ + raise NotImplementedError + + @abstractmethod + def save(self, *args, **kwargs): + """ + This method should save the benchmark components to a file/folder. + """ + raise NotImplementedError + + @abstractmethod + def evaluate( + self, + *args, + **kwargs, + ): + """ + Used to update the metric with new data. + """ + + raise NotImplementedError diff --git a/src/toy_benchmarks/subclass_detection.py b/src/toy_benchmarks/subclass_detection.py new file mode 100644 index 00000000..66ce9a08 --- /dev/null +++ b/src/toy_benchmarks/subclass_detection.py @@ -0,0 +1,215 @@ +import os +from typing import Callable, Dict, Optional, Union + +import torch + +from src.explainers.functional import ExplainFunc +from src.explainers.wrappers.captum_influence import captum_similarity_explain +from src.metrics.localization.identical_class import IdenticalClass +from src.toy_benchmarks.base import ToyBenchmark +from src.utils.datasets.group_label_dataset import ( + ClassToGroupLiterals, + GroupLabelDataset, +) +from src.utils.training.trainer import Trainer + + +class SubclassDetection(ToyBenchmark): + + def __init__( + self, + device: str = "cpu", + *args, + **kwargs, + ): + super().__init__(device=device) + self.trainer = None + self.model = None + self.train_dataset = None + self.dataset_transform = None + self.grouped_train_dl = None + self.original_train_dl = None + self.bench_state = None + + @classmethod + def generate( + cls, + model: torch.nn.Module, + train_dataset: torch.utils.data.Dataset, + optimizer: Callable, + lr: float, + criterion: torch.nn.modules.loss._Loss, + scheduler: Optional[Callable] = None, + optimizer_kwargs: Optional[dict] = None, + scheduler_kwargs: Optional[dict] = None, + val_dataset: Optional[torch.utils.data.Dataset] = None, + n_classes: int = 10, + n_groups: int = 2, + class_to_group: Union[ClassToGroupLiterals, Dict[int, int]] = "random", + trainer_kwargs: Optional[dict] = None, + seed: Optional[int] = 27, + batch_size: int = 8, + device: str = "cpu", + *args, + **kwargs, + ): + """ + This method should generate all the benchmark components and persist them in the instance. + """ + + obj = cls(device=device) + + obj.trainer = Trainer.from_arguments( + model=model, + optimizer=optimizer, + lr=lr, + scheduler=scheduler, + criterion=criterion, + optimizer_kwargs=optimizer_kwargs, + scheduler_kwargs=scheduler_kwargs, + ) + + if obj.trainer is None: + raise ValueError( + "Trainer not initialized. Please initialize trainer using init_trainer_from_lightning_module or " + "init_trainer_from_train_arguments" + ) + + grouped_dataset = GroupLabelDataset( + dataset=train_dataset, + n_classes=n_classes, + n_groups=n_groups, + class_to_group=class_to_group, + seed=seed, + ) + obj.class_to_group = grouped_dataset.class_to_group + obj.grouped_train_dl = torch.utils.data.DataLoader(grouped_dataset, batch_size=batch_size) + obj.original_train_dl = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size) + if val_dataset: + grouped_val_dataset = GroupLabelDataset( + dataset=train_dataset, + class_to_group=obj.class_to_group, + ) + obj.val_loader = torch.utils.data.DataLoader(grouped_val_dataset, batch_size=batch_size) + else: + obj.val_loader = None + + obj.model = obj.trainer.fit( + train_loader=obj.grouped_train_dl, + val_loader=obj.val_loader, + trainer_kwargs=trainer_kwargs, + ) + + obj.bench_state = { + "model": obj.model, + "train_dataset": obj.train_dataset, # ok this probably won't work, but that's the idea + "class_to_group": class_to_group, + } + + return obj + + @classmethod + def load(cls, path: str, device: str = "cpu", batch_size: int = 8, *args, **kwargs): + """ + This method should load the benchmark components from a file and persist them in the instance. + """ + obj = cls(device=device) + obj.bench_state = torch.load(path) + obj.model = obj.bench_state["model"] + obj.train_dataset = obj.bench_state["train_dataset"] + obj.class_to_group = obj.bench_state["class_to_group"] + + grouped_dataset = GroupLabelDataset( + dataset=obj.train_dataset, + class_to_group=obj.class_to_group, + ) + obj.grouped_train_dl = torch.utils.data.DataLoader(grouped_dataset, batch_size=batch_size) + obj.original_train_dl = torch.utils.data.DataLoader(obj.train_dataset, batch_size=batch_size) + + @classmethod + def assemble( + cls, + model: torch.nn.Module, + train_dataset: torch.utils.data.Dataset, + class_to_group: Dict[int, int], # TODO: type specification + batch_size: int = 8, + device: str = "cpu", + *args, + **kwargs, + ): + """ + This method should assemble the benchmark components from arguments and persist them in the instance. + """ + obj = cls(device=device) + obj.model = model + obj.train_dataset = train_dataset + obj.class_to_group = class_to_group + + grouped_dataset = GroupLabelDataset( + dataset=train_dataset, + class_to_group=class_to_group, + ) + obj.grouped_train_dl = torch.utils.data.DataLoader(grouped_dataset, batch_size=batch_size) + obj.original_train_dl = torch.utils.data.DataLoader(obj.train_dataset, batch_size=batch_size) + + def save(self, path: str, *args, **kwargs): + """ + This method should save the benchmark components to a file/folder. + """ + torch.save(self.bench_state, path) + + """ + @classmethod + def generate_from_pl(cls, model: torch.nn.Module, pl_module: L.LightningModule, device: str = "cpu", *args, **kwargs): + obj = cls.__new__(cls) + super(SubclassDetection, obj).__init__() + obj.device = device + obj.trainer = Trainer.from_lightning_module(model, pl_module) + return obj + + @classmethod + def generate_from_trainer(cls, trainer: BaseTrainer, device: str = "cpu", *args, **kwargs): + obj = cls.__new__(cls) + super(SubclassDetection, obj).__init__() + if isinstance(trainer, BaseTrainer): + obj.trainer = trainer + obj.device = device + else: + raise ValueError("trainer must be an instance of BaseTrainer") + return obj + """ + + def evaluate( + self, + expl_dataset: torch.utils.data.Dataset, + explain_fn: ExplainFunc = captum_similarity_explain, + explain_kwargs: Optional[dict] = None, + cache_dir: str = "./cache", + model_id: str = "default_model_id", + batch_size: int = 8, + device: str = "cpu", + *args, + **kwargs, + ): + grouped_expl_ds = GroupLabelDataset( + dataset=expl_dataset, + class_to_group=self.class_to_group, + ) # TODO: change to class_to_group + expl_dl = torch.utils.data.DataLoader(grouped_expl_ds, batch_size=batch_size) + + metric = IdenticalClass(model=self.model, train_dataset=self.train_dataset, device="cpu") + + for input, labels in expl_dl: + input, labels = input.to(device), labels.to(device) + explanations = explain_fn( + model=self.model, + model_id=model_id, + cache_dir=os.path.join(cache_dir), + train_dataset=self.train_dataset, + test_tensor=input, + device=device, + **explain_kwargs, + ) + metric.update(labels, explanations) + + return metric.compute() diff --git a/src/utils/datasets/group_label_dataset.py b/src/utils/datasets/group_label_dataset.py index d29fef9c..8597e2c8 100644 --- a/src/utils/datasets/group_label_dataset.py +++ b/src/utils/datasets/group_label_dataset.py @@ -1,4 +1,5 @@ import random +import warnings from typing import Dict, Literal, Optional, Sized, Union import torch @@ -8,34 +9,56 @@ class GroupLabelDataset(Dataset): + def __init__( self, dataset: Dataset, - n_classes: int = 10, - n_groups: int = 2, + n_classes: Optional[int] = None, + n_groups: Optional[int] = None, class_to_group: Union[ClassToGroupLiterals, Dict[int, int]] = "random", seed: Optional[int] = 27, device: str = "cpu", ): self.dataset = dataset self.n_classes = n_classes - self.classes = list(range(n_classes)) self.n_groups = n_groups - self.groups = list(range(n_groups)) self.generator = torch.Generator(device=device) + if class_to_group == "random": + + if (n_classes is None) or (n_groups is None): + raise ValueError("n_classes and n_groups must be specified when class_to_group is 'random'") + # create a dictionary of class groups that assigns each class to a group random.seed(seed) - self.class_to_group = {i: random.choice(self.groups) for i in range(n_classes)} + + self.class_to_group = {i: random.randrange(self.n_groups) for i in range(self.n_classes)} + elif isinstance(class_to_group, dict): - self.validate_class_to_group(class_to_group) + + if (n_classes is not None) or (n_groups is not None): + warnings.warn("Class-to-group assignment is used. n_classes or n_groups parameters are ignored.") + self.class_to_group = class_to_group + self.n_classes = len(self.class_to_group) + self.n_groups = len(set(self.class_to_group.values())) + else: + raise ValueError(f"Invalid class_to_group value: {class_to_group}") - def validate_class_to_group(self, class_to_group): - assert len(class_to_group) == self.n_classes - assert all([g in self.groups for g in self.class_to_group.values()]) + self.classes = list(range(self.n_classes)) + self.groups = list(range(self.n_groups)) + self.validate_class_to_group() + + def validate_class_to_group(self): + if not len(self.class_to_group) == self.n_classes: + raise ValueError( + f"Length of class_to_group dictionary ({len(self.class_to_group)}) " + f"does not match number of classes ({self.n_classes})" + ) + if not all([g in self.groups for g in self.class_to_group.values()]): + raise ValueError(f"Invalid group assignment in class_to_group: {self.class_to_group.values()}") def __getitem__(self, index): x, y = self.dataset[index] diff --git a/tests/downstream_tasks/__init__.py b/tests/toy_benchmarks/__init__.py similarity index 100% rename from tests/downstream_tasks/__init__.py rename to tests/toy_benchmarks/__init__.py diff --git a/tests/downstream_tasks/test_subclass_identification.py b/tests/toy_benchmarks/test_subclass_identification.py similarity index 84% rename from tests/downstream_tasks/test_subclass_identification.py rename to tests/toy_benchmarks/test_subclass_identification.py index cf045da9..7d57396d 100644 --- a/tests/downstream_tasks/test_subclass_identification.py +++ b/tests/toy_benchmarks/test_subclass_identification.py @@ -1,14 +1,11 @@ -import os -import shutil - import pytest -from src.downstream_tasks.subclass_identification import SubclassIdentification from src.explainers.wrappers.captum_influence import captum_similarity_explain +from src.toy_benchmarks.subclass_detection import SubclassDetection from src.utils.functions.similarities import cosine_similarity -@pytest.mark.downstream_tasks +@pytest.mark.toy_benchmarks @pytest.mark.parametrize( "test_id, model, optimizer, lr, criterion, max_epochs, dataset, n_classes, n_groups, seed, test_labels, " "batch_size, explain, explain_kwargs, expected_score", @@ -59,30 +56,29 @@ def test_identical_subclass_metrics( criterion = request.getfixturevalue(criterion) dataset = request.getfixturevalue(dataset) - dst_eval = SubclassIdentification( + dst_eval = SubclassDetection.generate( model=model, + train_dataset=dataset, optimizer=optimizer, - lr=lr, criterion=criterion, - ) - score = dst_eval.evaluate( - train_dataset=dataset, + lr=lr, val_dataset=None, n_classes=n_classes, n_groups=n_groups, class_to_group="random", + trainer_kwargs={"max_epochs": max_epochs}, + seed=seed, + batch_size=batch_size, + device="cpu", + ) + score = dst_eval.evaluate( + expl_dataset=dataset, explain_fn=explain, explain_kwargs=explain_kwargs, - trainer_kwargs={"max_epochs": max_epochs}, cache_dir=str(tmp_path), model_id="default_model_id", - run_id="default_subclass_identification", - seed=seed, batch_size=batch_size, device="cpu", ) - # remove cache directory if it exists - if os.path.exists("./test_cache"): - shutil.rmtree("./test_cache") assert score == expected_score