Skip to content

Commit

Permalink
fix: mixed datasets benchmark indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
dilyabareeva committed Dec 23, 2024
1 parent 818e5f0 commit e746691
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 33 deletions.
18 changes: 15 additions & 3 deletions quanda/benchmarks/heuristics/mixed_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def generate(
eval_dataset: torch.utils.data.Dataset,
adversarial_dir: str,
adversarial_label: int,
adv_train_indices: List[int],
trainer: Union[L.Trainer, BaseTrainer],
cache_dir: str,
data_transform: Optional[Callable] = None,
Expand Down Expand Up @@ -127,6 +128,8 @@ class as the samples in the adversarial dataset.
class).
adversarial_label: int
The label to be used for the adversarial dataset.
adv_train_indices: List[int]
List of indices of the adversarial dataset used for training.
trainer: Union[L.Trainer, BaseTrainer]
Trainer to be used for training the model. Can be a Lightning
Trainer or a `BaseTrainer`.
Expand Down Expand Up @@ -199,7 +202,7 @@ class as the samples in the adversarial dataset.
root=adversarial_dir,
label=adversarial_label,
transform=adversarial_transform,
train=True,
indices=adv_train_indices,
)

obj.mixed_dataset = torch.utils.data.ConcatDataset(
Expand Down Expand Up @@ -320,14 +323,19 @@ def download(cls, name: str, cache_dir: str, device: str, *args, **kwargs):
adversarial_transform = sample_transforms[
bench_state["adversarial_transform"]
]
adv_test_indices = bench_state["adv_indices_test"]
eval_from_test_indices = bench_state["eval_test_indices"]
eval_indices = [adv_test_indices[i] for i in eval_from_test_indices]

eval_dataset = SingleClassImageDataset(
root=adversarial_dir,
label=bench_state["adversarial_label"],
transform=adversarial_transform,
train=False,
indices=eval_indices,
)

adv_train_indices = bench_state["adv_indices_train"]

return obj.assemble(
model=module,
checkpoints=bench_state["checkpoints_binary"],
Expand All @@ -338,6 +346,7 @@ def download(cls, name: str, cache_dir: str, device: str, *args, **kwargs):
adversarial_dir=adversarial_dir,
adversarial_label=bench_state["adversarial_label"],
adversarial_transform=adversarial_transform,
adv_train_indices = adv_train_indices,
data_transform=dataset_transform,
checkpoint_paths=checkpoint_paths,
)
Expand Down Expand Up @@ -394,6 +403,7 @@ def assemble(
base_dataset: torch.utils.data.Dataset,
adversarial_dir: str,
adversarial_label: int,
adv_train_indices: List[int],
checkpoints: Optional[Union[str, List[str]]] = None,
checkpoints_load_func: Optional[Callable[..., Any]] = None,
data_transform: Optional[Callable] = None,
Expand Down Expand Up @@ -422,6 +432,8 @@ class as the samples in the adversarial dataset.
Path to the adversarial dataset of a single class.
adversarial_label: int
The label to be used for the adversarial dataset.
adv_train_indices: List[int]
List of indices of the adversarial dataset used for training.
checkpoints : Optional[Union[str, List[str]]], optional
Path to the model checkpoint file(s), defaults to None.
checkpoints_load_func : Optional[Callable[..., Any]], optional
Expand Down Expand Up @@ -472,7 +484,7 @@ class as the samples in the adversarial dataset.
root=adversarial_dir,
label=adversarial_label,
transform=adversarial_transform,
train=True,
indices=adv_train_indices,
)

obj.mixed_dataset = torch.utils.data.ConcatDataset(
Expand Down
2 changes: 1 addition & 1 deletion quanda/benchmarks/resources/benchmark_urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
benchmark_urls: dict = {
"mnist_top_k_cardinality": "https://datacloud.hhi.fraunhofer.de/s/32DKgtDZYbo75Xs/download/mnist_top_k_cardinality.pth",
"mnist_subclass_detection": "https://datacloud.hhi.fraunhofer.de/s/ABY846RC2XyDCLa/download/mnist_subclass_detection.pth",
"mnist_mixed_datasets": "https://datacloud.hhi.fraunhofer.de/s/gXrmw4ten3wq75T/download/mnist_mixed_datasets.pth",
"mnist_mixed_datasets": "https://datacloud.hhi.fraunhofer.de/s/BNpJpL5YJ8L5yGg/download/mnist_mixed_datasets.pth",
"mnist_mislabeling_detection": "https://datacloud.hhi.fraunhofer.de/s/nwfP4ojrSfHq69e/download/mnist_shortcut_detection.pth",
"mnist_shortcut_detection": "https://datacloud.hhi.fraunhofer.de/s/CLYLTbb4Zb74tHA/download/mnist_mislabeling_detection.pth",
"mnist_class_detection": "https://datacloud.hhi.fraunhofer.de/s/r68RfiGnQSrd3RT/download/mnist_class_detection.pth",
Expand Down
33 changes: 9 additions & 24 deletions quanda/utils/datasets/image_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import glob
import os
from typing import Optional, List

import torch
from PIL import Image # type: ignore
Expand All @@ -14,49 +15,33 @@ class SingleClassImageDataset(Dataset):
def __init__(
self,
root: str,
train: bool,
label: int,
indices: Optional[List[int]] = None,
transform=None,
*args,
**kwargs,
):
"""Construct the SingleClassImageDataset."""
self.root = root
self.label = label
self.transform = transform
self.train = train
self.indices = indices

# find all images in the root directory
filenames = []
for extension in ["*.JPEG", "*.jpeg", "*.jpg", "*.png"]:
filenames += glob.glob(os.path.join(root, extension))

self.filenames = filenames

filenames = sorted(filenames)

if os.path.exists(os.path.join(root, "train_indices")):
train_indices = torch.load(os.path.join(root, "train_indices"))
test_indices = torch.load(os.path.join(root, "test_indices"))
else:
randrank = torch.randperm(len(filenames))
size = int(len(filenames) / 2)
train_indices = randrank[:size]
test_indices = randrank[size:]
torch.save(train_indices, os.path.join(root, "train_indices"))
torch.save(test_indices, os.path.join(root, "test_indices"))

if self.train:
self.filenames = [filenames[i] for i in train_indices]
else:
self.filenames = [filenames[i] for i in test_indices]
self.filenames = sorted(filenames)

def __len__(self):
"""Get dataset length."""
return len(self.filenames)
if self.indices is None:
return len(self.filenames)
return len(self.indices)

def __getitem__(self, idx):
"""Get a sample by index."""
if self.indices is not None:
idx = self.indices[idx]
img_path = self.filenames[idx]
image = Image.open(img_path).convert("RGB")
if self.transform:
Expand Down
19 changes: 14 additions & 5 deletions tests/benchmarks/heuristics/test_mixed_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
@pytest.mark.benchmarks
@pytest.mark.parametrize(
"test_id, init_method, model, checkpoint, optimizer, lr, criterion, max_epochs, dataset, adversarial_path,"
"adversarial_label, adversarial_transforms, batch_size, explainer_cls, expl_kwargs,"
"adversarial_label, adversarial_transforms, adv_train_indices, adv_eval_indices,batch_size, explainer_cls, expl_kwargs,"
"expected_score",
[
(
Expand All @@ -31,6 +31,8 @@
"load_fashion_mnist_path",
3,
"load_fashion_mnist_to_mnist_transform",
None,
None,
8,
CaptumSimilarity,
{
Expand All @@ -53,6 +55,8 @@
"load_fashion_mnist_path",
4,
"load_fashion_mnist_to_mnist_transform",
None,
None,
8,
CaptumSimilarity,
{
Expand All @@ -77,6 +81,8 @@ def test_mixed_datasets(
adversarial_path,
adversarial_label,
adversarial_transforms,
adv_train_indices,
adv_eval_indices,
batch_size,
explainer_cls,
expl_kwargs,
Expand All @@ -91,11 +97,12 @@ def test_mixed_datasets(
dataset = request.getfixturevalue(dataset)
adversarial_transforms = request.getfixturevalue(adversarial_transforms)
adversarial_path = request.getfixturevalue(adversarial_path)

eval_dataset = SingleClassImageDataset(
root=adversarial_path,
label=adversarial_label,
transform=adversarial_transforms,
train=False,
indices=adv_eval_indices,
)

if init_method == "generate":
Expand All @@ -113,6 +120,7 @@ def test_mixed_datasets(
eval_dataset=eval_dataset,
adversarial_label=adversarial_label,
adversarial_dir=adversarial_path,
adv_train_indices=adv_train_indices,
adversarial_transform=adversarial_transforms,
trainer_fit_kwargs={},
cache_dir=str(tmp_path),
Expand All @@ -128,6 +136,7 @@ def test_mixed_datasets(
adversarial_label=adversarial_label,
adversarial_dir=adversarial_path,
adversarial_transform=adversarial_transforms,
adv_train_indices=adv_train_indices,
)
else:
raise ValueError(f"Invalid init_method: {init_method}")
Expand Down Expand Up @@ -231,7 +240,7 @@ def test_mixed_dataset_download(
dst_eval.mixed_dataset, list(range(16))
)
dst_eval.eval_dataset = torch.utils.data.Subset(
dst_eval.eval_dataset, list(range(16))
dst_eval.eval_dataset, list(range(8))
)

dst_eval.adversarial_indices = dst_eval.adversarial_indices[:16]
Expand All @@ -256,7 +265,7 @@ def hook(model, input, output):
dst_eval.mixed_dataset, batch_size=16, shuffle=False
)
test_ld = torch.utils.data.DataLoader(
dst_eval.eval_dataset, batch_size=16, shuffle=False
dst_eval.eval_dataset, batch_size=8, shuffle=False
)
for x, y in iter(train_ld):
x = x.to(dst_eval.device)
Expand All @@ -266,7 +275,7 @@ def hook(model, input, output):
for x, y in iter(test_ld):
x = x.to(dst_eval.device)
y_preds = dst_eval.model(x).argmax(dim=-1)
select_idx = torch.tensor([True] * 16)
select_idx = torch.tensor([True] * 8)
if filter_by_prediction:
select_idx *= y_preds == dst_eval.adversarial_label
dst_eval.model(x)
Expand Down

0 comments on commit e746691

Please sign in to comment.