From e746691c2e6cc340d32a4993243ca5a6da96733f Mon Sep 17 00:00:00 2001 From: dilyabareeva Date: Mon, 23 Dec 2024 18:19:55 +0100 Subject: [PATCH] fix: mixed datasets benchmark indexing --- .../benchmarks/heuristics/mixed_datasets.py | 18 ++++++++-- quanda/benchmarks/resources/benchmark_urls.py | 2 +- quanda/utils/datasets/image_datasets.py | 33 +++++-------------- .../heuristics/test_mixed_datasets.py | 19 ++++++++--- 4 files changed, 39 insertions(+), 33 deletions(-) diff --git a/quanda/benchmarks/heuristics/mixed_datasets.py b/quanda/benchmarks/heuristics/mixed_datasets.py index a0d9d29..226393d 100644 --- a/quanda/benchmarks/heuristics/mixed_datasets.py +++ b/quanda/benchmarks/heuristics/mixed_datasets.py @@ -90,6 +90,7 @@ def generate( eval_dataset: torch.utils.data.Dataset, adversarial_dir: str, adversarial_label: int, + adv_train_indices: List[int], trainer: Union[L.Trainer, BaseTrainer], cache_dir: str, data_transform: Optional[Callable] = None, @@ -127,6 +128,8 @@ class as the samples in the adversarial dataset. class). adversarial_label: int The label to be used for the adversarial dataset. + adv_train_indices: List[int] + List of indices of the adversarial dataset used for training. trainer: Union[L.Trainer, BaseTrainer] Trainer to be used for training the model. Can be a Lightning Trainer or a `BaseTrainer`. @@ -199,7 +202,7 @@ class as the samples in the adversarial dataset. root=adversarial_dir, label=adversarial_label, transform=adversarial_transform, - train=True, + indices=adv_train_indices, ) obj.mixed_dataset = torch.utils.data.ConcatDataset( @@ -320,14 +323,19 @@ def download(cls, name: str, cache_dir: str, device: str, *args, **kwargs): adversarial_transform = sample_transforms[ bench_state["adversarial_transform"] ] + adv_test_indices = bench_state["adv_indices_test"] + eval_from_test_indices = bench_state["eval_test_indices"] + eval_indices = [adv_test_indices[i] for i in eval_from_test_indices] eval_dataset = SingleClassImageDataset( root=adversarial_dir, label=bench_state["adversarial_label"], transform=adversarial_transform, - train=False, + indices=eval_indices, ) + adv_train_indices = bench_state["adv_indices_train"] + return obj.assemble( model=module, checkpoints=bench_state["checkpoints_binary"], @@ -338,6 +346,7 @@ def download(cls, name: str, cache_dir: str, device: str, *args, **kwargs): adversarial_dir=adversarial_dir, adversarial_label=bench_state["adversarial_label"], adversarial_transform=adversarial_transform, + adv_train_indices = adv_train_indices, data_transform=dataset_transform, checkpoint_paths=checkpoint_paths, ) @@ -394,6 +403,7 @@ def assemble( base_dataset: torch.utils.data.Dataset, adversarial_dir: str, adversarial_label: int, + adv_train_indices: List[int], checkpoints: Optional[Union[str, List[str]]] = None, checkpoints_load_func: Optional[Callable[..., Any]] = None, data_transform: Optional[Callable] = None, @@ -422,6 +432,8 @@ class as the samples in the adversarial dataset. Path to the adversarial dataset of a single class. adversarial_label: int The label to be used for the adversarial dataset. + adv_train_indices: List[int] + List of indices of the adversarial dataset used for training. checkpoints : Optional[Union[str, List[str]]], optional Path to the model checkpoint file(s), defaults to None. checkpoints_load_func : Optional[Callable[..., Any]], optional @@ -472,7 +484,7 @@ class as the samples in the adversarial dataset. root=adversarial_dir, label=adversarial_label, transform=adversarial_transform, - train=True, + indices=adv_train_indices, ) obj.mixed_dataset = torch.utils.data.ConcatDataset( diff --git a/quanda/benchmarks/resources/benchmark_urls.py b/quanda/benchmarks/resources/benchmark_urls.py index 0350986..ae1db38 100644 --- a/quanda/benchmarks/resources/benchmark_urls.py +++ b/quanda/benchmarks/resources/benchmark_urls.py @@ -3,7 +3,7 @@ benchmark_urls: dict = { "mnist_top_k_cardinality": "https://datacloud.hhi.fraunhofer.de/s/32DKgtDZYbo75Xs/download/mnist_top_k_cardinality.pth", "mnist_subclass_detection": "https://datacloud.hhi.fraunhofer.de/s/ABY846RC2XyDCLa/download/mnist_subclass_detection.pth", - "mnist_mixed_datasets": "https://datacloud.hhi.fraunhofer.de/s/gXrmw4ten3wq75T/download/mnist_mixed_datasets.pth", + "mnist_mixed_datasets": "https://datacloud.hhi.fraunhofer.de/s/BNpJpL5YJ8L5yGg/download/mnist_mixed_datasets.pth", "mnist_mislabeling_detection": "https://datacloud.hhi.fraunhofer.de/s/nwfP4ojrSfHq69e/download/mnist_shortcut_detection.pth", "mnist_shortcut_detection": "https://datacloud.hhi.fraunhofer.de/s/CLYLTbb4Zb74tHA/download/mnist_mislabeling_detection.pth", "mnist_class_detection": "https://datacloud.hhi.fraunhofer.de/s/r68RfiGnQSrd3RT/download/mnist_class_detection.pth", diff --git a/quanda/utils/datasets/image_datasets.py b/quanda/utils/datasets/image_datasets.py index 110ca44..f0bfd03 100644 --- a/quanda/utils/datasets/image_datasets.py +++ b/quanda/utils/datasets/image_datasets.py @@ -2,6 +2,7 @@ import glob import os +from typing import Optional, List import torch from PIL import Image # type: ignore @@ -14,49 +15,33 @@ class SingleClassImageDataset(Dataset): def __init__( self, root: str, - train: bool, label: int, + indices: Optional[List[int]] = None, transform=None, - *args, - **kwargs, ): """Construct the SingleClassImageDataset.""" self.root = root self.label = label self.transform = transform - self.train = train + self.indices = indices # find all images in the root directory filenames = [] for extension in ["*.JPEG", "*.jpeg", "*.jpg", "*.png"]: filenames += glob.glob(os.path.join(root, extension)) - self.filenames = filenames - - filenames = sorted(filenames) - - if os.path.exists(os.path.join(root, "train_indices")): - train_indices = torch.load(os.path.join(root, "train_indices")) - test_indices = torch.load(os.path.join(root, "test_indices")) - else: - randrank = torch.randperm(len(filenames)) - size = int(len(filenames) / 2) - train_indices = randrank[:size] - test_indices = randrank[size:] - torch.save(train_indices, os.path.join(root, "train_indices")) - torch.save(test_indices, os.path.join(root, "test_indices")) - - if self.train: - self.filenames = [filenames[i] for i in train_indices] - else: - self.filenames = [filenames[i] for i in test_indices] + self.filenames = sorted(filenames) def __len__(self): """Get dataset length.""" - return len(self.filenames) + if self.indices is None: + return len(self.filenames) + return len(self.indices) def __getitem__(self, idx): """Get a sample by index.""" + if self.indices is not None: + idx = self.indices[idx] img_path = self.filenames[idx] image = Image.open(img_path).convert("RGB") if self.transform: diff --git a/tests/benchmarks/heuristics/test_mixed_datasets.py b/tests/benchmarks/heuristics/test_mixed_datasets.py index bd8caed..45d1cf5 100644 --- a/tests/benchmarks/heuristics/test_mixed_datasets.py +++ b/tests/benchmarks/heuristics/test_mixed_datasets.py @@ -15,7 +15,7 @@ @pytest.mark.benchmarks @pytest.mark.parametrize( "test_id, init_method, model, checkpoint, optimizer, lr, criterion, max_epochs, dataset, adversarial_path," - "adversarial_label, adversarial_transforms, batch_size, explainer_cls, expl_kwargs," + "adversarial_label, adversarial_transforms, adv_train_indices, adv_eval_indices,batch_size, explainer_cls, expl_kwargs," "expected_score", [ ( @@ -31,6 +31,8 @@ "load_fashion_mnist_path", 3, "load_fashion_mnist_to_mnist_transform", + None, + None, 8, CaptumSimilarity, { @@ -53,6 +55,8 @@ "load_fashion_mnist_path", 4, "load_fashion_mnist_to_mnist_transform", + None, + None, 8, CaptumSimilarity, { @@ -77,6 +81,8 @@ def test_mixed_datasets( adversarial_path, adversarial_label, adversarial_transforms, + adv_train_indices, + adv_eval_indices, batch_size, explainer_cls, expl_kwargs, @@ -91,11 +97,12 @@ def test_mixed_datasets( dataset = request.getfixturevalue(dataset) adversarial_transforms = request.getfixturevalue(adversarial_transforms) adversarial_path = request.getfixturevalue(adversarial_path) + eval_dataset = SingleClassImageDataset( root=adversarial_path, label=adversarial_label, transform=adversarial_transforms, - train=False, + indices=adv_eval_indices, ) if init_method == "generate": @@ -113,6 +120,7 @@ def test_mixed_datasets( eval_dataset=eval_dataset, adversarial_label=adversarial_label, adversarial_dir=adversarial_path, + adv_train_indices=adv_train_indices, adversarial_transform=adversarial_transforms, trainer_fit_kwargs={}, cache_dir=str(tmp_path), @@ -128,6 +136,7 @@ def test_mixed_datasets( adversarial_label=adversarial_label, adversarial_dir=adversarial_path, adversarial_transform=adversarial_transforms, + adv_train_indices=adv_train_indices, ) else: raise ValueError(f"Invalid init_method: {init_method}") @@ -231,7 +240,7 @@ def test_mixed_dataset_download( dst_eval.mixed_dataset, list(range(16)) ) dst_eval.eval_dataset = torch.utils.data.Subset( - dst_eval.eval_dataset, list(range(16)) + dst_eval.eval_dataset, list(range(8)) ) dst_eval.adversarial_indices = dst_eval.adversarial_indices[:16] @@ -256,7 +265,7 @@ def hook(model, input, output): dst_eval.mixed_dataset, batch_size=16, shuffle=False ) test_ld = torch.utils.data.DataLoader( - dst_eval.eval_dataset, batch_size=16, shuffle=False + dst_eval.eval_dataset, batch_size=8, shuffle=False ) for x, y in iter(train_ld): x = x.to(dst_eval.device) @@ -266,7 +275,7 @@ def hook(model, input, output): for x, y in iter(test_ld): x = x.to(dst_eval.device) y_preds = dst_eval.model(x).argmax(dim=-1) - select_idx = torch.tensor([True] * 16) + select_idx = torch.tensor([True] * 8) if filter_by_prediction: select_idx *= y_preds == dst_eval.adversarial_label dst_eval.model(x)