From 3438bb36fa25b1ac96bfcb15ae5805d1dfcbeb06 Mon Sep 17 00:00:00 2001 From: dilyabareeva Date: Wed, 14 Aug 2024 18:13:14 +0200 Subject: [PATCH] device fixes --- .../randomization/model_randomization.py | 2 +- quanda/metrics/unnamed/dataset_cleaning.py | 4 ++-- quanda/toy_benchmarks/base.py | 19 +++++++++++++++++ .../localization/class_detection.py | 17 ++++++++------- .../localization/mislabeling_detection.py | 21 ++++++++++--------- .../localization/subclass_detection.py | 17 ++++++++------- .../randomization/model_randomization.py | 16 +++++++------- .../unnamed/dataset_cleaning.py | 18 ++++++++-------- .../toy_benchmarks/unnamed/top_k_overlap.py | 14 ++++++------- tutorials/usage_testing.py | 14 +++++++------ 10 files changed, 83 insertions(+), 59 deletions(-) diff --git a/quanda/metrics/randomization/model_randomization.py b/quanda/metrics/randomization/model_randomization.py index 94108fbc..f92ea903 100644 --- a/quanda/metrics/randomization/model_randomization.py +++ b/quanda/metrics/randomization/model_randomization.py @@ -42,7 +42,7 @@ def __init__( self.model_id = model_id self.cache_dir = cache_dir - self.generator = torch.Generator(device=self.device) + self.generator = torch.Generator(device=self.model_device) self.generator.manual_seed(self.seed) self.rand_model = self._randomize_model(model) self.rand_explainer = explainer_cls( diff --git a/quanda/metrics/unnamed/dataset_cleaning.py b/quanda/metrics/unnamed/dataset_cleaning.py index dcab7cb8..d09296e9 100644 --- a/quanda/metrics/unnamed/dataset_cleaning.py +++ b/quanda/metrics/unnamed/dataset_cleaning.py @@ -128,7 +128,7 @@ def compute(self, *args, **kwargs): clean_subset = torch.utils.data.Subset(self.train_dataset, clean_indices) train_dl = torch.utils.data.DataLoader(self.train_dataset, batch_size=32, shuffle=True) - original_accuracy = class_accuracy(self.model, train_dl, self.device) + original_accuracy = class_accuracy(self.model, train_dl, self.model_device) clean_dl = torch.utils.data.DataLoader(clean_subset, batch_size=32, shuffle=True) @@ -155,6 +155,6 @@ def compute(self, *args, **kwargs): else: raise ValueError("Trainer should be a Lightning Trainer or a BaseTrainer") - clean_accuracy = class_accuracy(self.model, clean_dl, self.device) + clean_accuracy = class_accuracy(self.model, clean_dl, self.model_device) return original_accuracy - clean_accuracy diff --git a/quanda/toy_benchmarks/base.py b/quanda/toy_benchmarks/base.py index 62e68115..e34a2b11 100644 --- a/quanda/toy_benchmarks/base.py +++ b/quanda/toy_benchmarks/base.py @@ -1,4 +1,6 @@ from abc import ABC, abstractmethod +from typing import Optional, Union +import torch class ToyBenchmark(ABC): @@ -15,6 +17,8 @@ def __init__(self, *args, **kwargs): :param args: :param kwargs: """ + self.model_device: Optional[Union[str, torch.device]] + self.device: Optional[Union[str, torch.device]] @classmethod @abstractmethod @@ -58,3 +62,18 @@ def evaluate( """ raise NotImplementedError + + def set_devices( + self, + model: torch.nn.Module, + device: Optional[Union[str, torch.device]] = None, + ): + """ + This method should set the device for the model. + """ + if next(model.parameters(), None) is not None: + self.model_device = next(model.parameters()).device + else: + self.model_device = torch.device("cpu") + + self.device = device or self.model_device diff --git a/quanda/toy_benchmarks/localization/class_detection.py b/quanda/toy_benchmarks/localization/class_detection.py index ef0da38e..3e6ee6b8 100644 --- a/quanda/toy_benchmarks/localization/class_detection.py +++ b/quanda/toy_benchmarks/localization/class_detection.py @@ -10,11 +10,10 @@ class ClassDetection(ToyBenchmark): def __init__( self, - device: Optional[Union[str, torch.device]] = None, *args, **kwargs, ): - super().__init__(device=device) + super().__init__() self.model: torch.nn.Module self.train_dataset: torch.utils.data.Dataset @@ -32,11 +31,11 @@ def generate( This method should generate all the benchmark components and persist them in the instance. """ - obj = cls(device=device) + obj = cls() obj.model = model obj.train_dataset = train_dataset - + obj.set_devices(model, device) return obj @property @@ -67,10 +66,13 @@ def assemble( """ This method should assemble the benchmark components from arguments and persist them in the instance. """ - obj = cls(device=device) + + obj = cls() obj.model = model obj.train_dataset = train_dataset + obj.set_devices(model, device) + return obj def save(self, path: str, *args, **kwargs): @@ -88,7 +90,6 @@ def evaluate( cache_dir: str = "./cache", model_id: str = "default_model_id", batch_size: int = 8, - device: Optional[Union[str, torch.device]] = None, *args, **kwargs, ): @@ -99,7 +100,7 @@ def evaluate( expl_dl = torch.utils.data.DataLoader(expl_dataset, batch_size=batch_size) - metric = ClassDetectionMetric(model=self.model, train_dataset=self.train_dataset, device="cpu") + metric = ClassDetectionMetric(model=self.model, train_dataset=self.train_dataset, device=self.device) pbar = tqdm(expl_dl) n_batches = len(expl_dl) @@ -107,7 +108,7 @@ def evaluate( for i, (input, labels) in enumerate(pbar): pbar.set_description("Metric evaluation, batch %d/%d" % (i + 1, n_batches)) - input, labels = input.to(device), labels.to(device) + input, labels = input.to(self.model_device), labels.to(self.model_device) if use_predictions: with torch.no_grad(): diff --git a/quanda/toy_benchmarks/localization/mislabeling_detection.py b/quanda/toy_benchmarks/localization/mislabeling_detection.py index d8296145..ae7fe8a0 100644 --- a/quanda/toy_benchmarks/localization/mislabeling_detection.py +++ b/quanda/toy_benchmarks/localization/mislabeling_detection.py @@ -18,11 +18,10 @@ class MislabelingDetection(ToyBenchmark): def __init__( self, - device: Optional[Union[str, torch.device]] = None, *args, **kwargs, ): - super().__init__(device=device) + super().__init__() self.model: Union[torch.nn.Module, L.LightningModule] self.train_dataset: torch.utils.data.Dataset @@ -59,8 +58,8 @@ def generate( This method should generate all the benchmark components and persist them in the instance. """ - obj = cls(device=device) - + obj = cls() + obj.set_devices(model, device) obj._generate( model=model, train_dataset=train_dataset, @@ -199,7 +198,7 @@ def assemble( """ This method should assemble the benchmark components from arguments and persist them in the instance. """ - obj = cls(device=device) + obj = cls() obj.model = model obj.train_dataset = train_dataset obj.p = p @@ -220,6 +219,9 @@ def assemble( obj.poisoned_train_dl = torch.utils.data.DataLoader(obj.poisoned_dataset, batch_size=batch_size) obj.original_train_dl = torch.utils.data.DataLoader(obj.train_dataset, batch_size=batch_size) + + obj.set_devices(model, device) + return obj def save(self, path: str, *args, **kwargs): @@ -235,12 +237,11 @@ def evaluate( expl_kwargs: Optional[dict] = None, use_predictions: bool = False, batch_size: int = 8, - device: Optional[Union[str, torch.device]] = None, *args, **kwargs, ): expl_kwargs = expl_kwargs or {} - explainer = explainer_cls(model=self.model, train_dataset=self.train_dataset, device=device, **expl_kwargs) + explainer = explainer_cls(model=self.model, train_dataset=self.train_dataset, device=self.device, **expl_kwargs) poisoned_expl_ds = LabelFlippingDataset( dataset=expl_dataset, dataset_transform=self.dataset_transform, n_classes=self.n_classes, p=0.0 @@ -251,7 +252,7 @@ def evaluate( model=self.model, train_dataset=self.poisoned_dataset, poisoned_indices=self.poisoned_indices, - device=device, + device=self.device, aggregator_cls=self.global_method, ) @@ -261,7 +262,7 @@ def evaluate( for i, (inputs, labels) in enumerate(pbar): pbar.set_description("Metric evaluation, batch %d/%d" % (i + 1, n_batches)) - inputs, labels = inputs.to(device), labels.to(device) + inputs, labels = inputs.to(self.model_device), labels.to(self.model_device) if use_predictions: with torch.no_grad(): targets = self.model(inputs).argmax(dim=-1) @@ -274,7 +275,7 @@ def evaluate( model=self.model, train_dataset=self.poisoned_dataset, poisoned_indices=self.poisoned_indices, - device=device, + device=self.device, explainer_cls=explainer_cls, expl_kwargs=expl_kwargs, ) diff --git a/quanda/toy_benchmarks/localization/subclass_detection.py b/quanda/toy_benchmarks/localization/subclass_detection.py index 5aceaa35..78fb5dc3 100644 --- a/quanda/toy_benchmarks/localization/subclass_detection.py +++ b/quanda/toy_benchmarks/localization/subclass_detection.py @@ -17,11 +17,10 @@ class SubclassDetection(ToyBenchmark): def __init__( self, - device: Optional[Union[str, torch.device]] = None, *args, **kwargs, ): - super().__init__(device=device) + super().__init__() self.model: Union[torch.nn.Module, L.LightningModule] self.group_model: Union[torch.nn.Module, L.LightningModule] @@ -56,8 +55,8 @@ def generate( This method should generate all the benchmark components and persist them in the instance. """ - obj = cls(device=device) - + obj = cls() + obj.set_devices(model, device) obj.model = model obj._generate( trainer=trainer, @@ -181,7 +180,7 @@ def assemble( """ This method should assemble the benchmark components from arguments and persist them in the instance. """ - obj = cls(device=device) + obj = cls() obj.group_model = group_model obj.train_dataset = train_dataset obj.class_to_group = class_to_group @@ -198,6 +197,9 @@ def assemble( ) obj.grouped_train_dl = torch.utils.data.DataLoader(obj.grouped_dataset, batch_size=batch_size) obj.original_train_dl = torch.utils.data.DataLoader(obj.train_dataset, batch_size=batch_size) + + obj.set_devices(group_model, device) + return obj def save(self, path: str, *args, **kwargs): @@ -215,7 +217,6 @@ def evaluate( cache_dir: str = "./cache", model_id: str = "default_model_id", batch_size: int = 8, - device: Optional[Union[str, torch.device]] = None, *args, **kwargs, ): @@ -226,7 +227,7 @@ def evaluate( expl_dl = torch.utils.data.DataLoader(expl_dataset, batch_size=batch_size) - metric = ClassDetectionMetric(model=self.group_model, train_dataset=self.train_dataset, device=device) + metric = ClassDetectionMetric(model=self.group_model, train_dataset=self.train_dataset, device=self.device) pbar = tqdm(expl_dl) n_batches = len(expl_dl) @@ -234,7 +235,7 @@ def evaluate( for i, (inputs, labels) in enumerate(pbar): pbar.set_description("Metric evaluation, batch %d/%d" % (i + 1, n_batches)) - inputs, labels = inputs.to(device), labels.to(device) + inputs, labels = inputs.to(self.model_device), labels.to(self.model_device) grouped_labels = torch.tensor([self.class_to_group[i.item()] for i in labels], device=labels.device) if use_predictions: with torch.no_grad(): diff --git a/quanda/toy_benchmarks/randomization/model_randomization.py b/quanda/toy_benchmarks/randomization/model_randomization.py index 168f2ea8..35f104fd 100644 --- a/quanda/toy_benchmarks/randomization/model_randomization.py +++ b/quanda/toy_benchmarks/randomization/model_randomization.py @@ -13,11 +13,10 @@ class ModelRandomization(ToyBenchmark): def __init__( self, - device: Optional[Union[str, torch.device]] = None, *args, **kwargs, ): - super().__init__(device=device) + super().__init__() self.model: torch.nn.Module self.train_dataset: torch.utils.data.Dataset @@ -35,8 +34,8 @@ def generate( This method should generate all the benchmark components and persist them in the instance. """ - obj = cls(device=device) - + obj = cls() + obj.set_devices(model, device) obj.model = model obj.train_dataset = train_dataset @@ -70,10 +69,12 @@ def assemble( """ This method should assemble the benchmark components from arguments and persist them in the instance. """ - obj = cls(device=device) + obj = cls() obj.model = model obj.train_dataset = train_dataset + obj.set_devices(model, device) + return obj def save(self, path: str, *args, **kwargs): @@ -93,7 +94,6 @@ def evaluate( cache_dir: str = "./cache", model_id: str = "default_model_id", batch_size: int = 8, - device: Optional[Union[str, torch.device]] = None, *args, **kwargs, ): @@ -109,7 +109,7 @@ def evaluate( seed=seed, model_id=model_id, cache_dir=cache_dir, - device=device, + device=self.device, ) pbar = tqdm(expl_dl) n_batches = len(expl_dl) @@ -117,7 +117,7 @@ def evaluate( for i, (input, labels) in enumerate(pbar): pbar.set_description("Metric evaluation, batch %d/%d" % (i + 1, n_batches)) - input, labels = input.to(device), labels.to(device) + input, labels = input.to(self.model_device), labels.to(self.model_device) if use_predictions: with torch.no_grad(): diff --git a/quanda/toy_benchmarks/unnamed/dataset_cleaning.py b/quanda/toy_benchmarks/unnamed/dataset_cleaning.py index 46164d72..c007cc14 100644 --- a/quanda/toy_benchmarks/unnamed/dataset_cleaning.py +++ b/quanda/toy_benchmarks/unnamed/dataset_cleaning.py @@ -13,11 +13,10 @@ class DatasetCleaning(ToyBenchmark): def __init__( self, - device: Optional[Union[str, torch.device]] = None, *args, **kwargs, ): - super().__init__(device=device) + super().__init__() self.model: torch.nn.Module self.train_dataset: torch.utils.data.Dataset @@ -35,8 +34,8 @@ def generate( This method should generate all the benchmark components and persist them in the instance. """ - obj = cls(device=device) - + obj = cls() + obj.set_devices(model, device) obj.model = model obj.train_dataset = train_dataset @@ -69,10 +68,12 @@ def assemble( """ This method should assemble the benchmark components from arguments and persist them in the instance. """ - obj = cls(device=device) + obj = cls() obj.model = model obj.train_dataset = train_dataset + obj.set_devices(model, device) + return obj def save(self, path: str, *args, **kwargs): @@ -93,7 +94,6 @@ def evaluate( cache_dir: str = "./cache", model_id: str = "default_model_id", batch_size: int = 8, - device: Optional[Union[str, torch.device]] = None, global_method: Union[str, type] = "self-influence", top_k: int = 50, *args, @@ -116,7 +116,7 @@ def evaluate( trainer=trainer, trainer_fit_kwargs=trainer_fit_kwargs, top_k=top_k, - device=device, + device=self.device, ) pbar = tqdm(expl_dl) n_batches = len(expl_dl) @@ -124,7 +124,7 @@ def evaluate( for i, (inputs, labels) in enumerate(pbar): pbar.set_description("Metric evaluation, batch %d/%d" % (i + 1, n_batches)) - inputs, labels = inputs.to(device), labels.to(device) + inputs, labels = inputs.to(self.model_device), labels.to(self.model_device) if use_predictions: with torch.no_grad(): @@ -149,7 +149,7 @@ def evaluate( explainer_cls=explainer_cls, expl_kwargs=expl_kwargs, top_k=top_k, - device=device, + device=self.device, ) return metric.compute() diff --git a/quanda/toy_benchmarks/unnamed/top_k_overlap.py b/quanda/toy_benchmarks/unnamed/top_k_overlap.py index 9af99dc6..c7369c37 100644 --- a/quanda/toy_benchmarks/unnamed/top_k_overlap.py +++ b/quanda/toy_benchmarks/unnamed/top_k_overlap.py @@ -10,11 +10,10 @@ class TopKOverlap(ToyBenchmark): def __init__( self, - device: Optional[Union[str, torch.device]] = None, *args, **kwargs, ): - super().__init__(device=device) + super().__init__() self.model: torch.nn.Module self.train_dataset: torch.utils.data.Dataset @@ -32,7 +31,7 @@ def generate( This method should generate all the benchmark components and persist them in the instance. """ - obj = cls(device=device) + obj = cls() obj.model = model obj.train_dataset = train_dataset @@ -66,9 +65,11 @@ def assemble( """ This method should assemble the benchmark components from arguments and persist them in the instance. """ - obj = cls(device=device) + obj = cls() + obj.set_devices(model, device) obj.model = model obj.train_dataset = train_dataset + obj.set_devices(model, device) return obj @@ -88,7 +89,6 @@ def evaluate( model_id: str = "default_model_id", batch_size: int = 8, top_k: int = 1, - device: Optional[Union[str, torch.device]] = None, *args, **kwargs, ): @@ -99,7 +99,7 @@ def evaluate( expl_dl = torch.utils.data.DataLoader(expl_dataset, batch_size=batch_size) - metric = TopKOverlapMetric(model=self.model, train_dataset=self.train_dataset, top_k=top_k, device=device) + metric = TopKOverlapMetric(model=self.model, train_dataset=self.train_dataset, top_k=top_k, device=self.device) pbar = tqdm(expl_dl) n_batches = len(expl_dl) @@ -107,7 +107,7 @@ def evaluate( for i, (input, labels) in enumerate(pbar): pbar.set_description("Metric evaluation, batch %d/%d" % (i + 1, n_batches)) - input, labels = input.to(device), labels.to(device) + input, labels = input.to(self.model_device), labels.to(self.model_device) if use_predictions: with torch.no_grad(): diff --git a/tutorials/usage_testing.py b/tutorials/usage_testing.py index bd185cea..52e75de6 100644 --- a/tutorials/usage_testing.py +++ b/tutorials/usage_testing.py @@ -72,6 +72,8 @@ def main(): init_model = resnet18(weights=None, num_classes=10) model.load_state_dict(weights_pretrained) model.to(DEVICE) + + device = "cpu" model.eval() # a temporary data loader without normalization, just to show the images @@ -121,12 +123,12 @@ def accuracy(net, loader): cache_dir=cache_dir, correlation_fn="spearman", seed=42, - device=DEVICE, + device=device, ) - id_class = ClassDetectionMetric(model=model, train_dataset=train_set, device=DEVICE) + id_class = ClassDetectionMetric(model=model, train_dataset=train_set, device=device) - top_k = TopKOverlapMetric(model=model, train_dataset=train_set, top_k=1, device=DEVICE) + top_k = TopKOverlapMetric(model=model, train_dataset=train_set, top_k=1, device=device) # dataset cleaning max_epochs = 1 @@ -166,7 +168,7 @@ def accuracy(net, loader): global_method="sum_abs", trainer=trainer, top_k=50, - device=DEVICE, + device=device, ) # iterate over test set and feed tensor batches first to explain, then to metric @@ -178,7 +180,7 @@ def accuracy(net, loader): cache_dir=cache_dir, test_tensor=data, train_dataset=train_set, - device=DEVICE, + device=device, **explain_fn_kwargs, ) model_rand.update(data, tda) @@ -211,7 +213,7 @@ def accuracy(net, loader): class_to_group="random", seed=42, batch_size=100, - device=DEVICE, + device=device, ) score = bench.evaluate(