Skip to content

Commit

Permalink
device fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
dilyabareeva committed Aug 14, 2024
1 parent 952fc35 commit 3438bb3
Show file tree
Hide file tree
Showing 10 changed files with 83 additions and 59 deletions.
2 changes: 1 addition & 1 deletion quanda/metrics/randomization/model_randomization.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(
self.model_id = model_id
self.cache_dir = cache_dir

self.generator = torch.Generator(device=self.device)
self.generator = torch.Generator(device=self.model_device)
self.generator.manual_seed(self.seed)
self.rand_model = self._randomize_model(model)
self.rand_explainer = explainer_cls(
Expand Down
4 changes: 2 additions & 2 deletions quanda/metrics/unnamed/dataset_cleaning.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def compute(self, *args, **kwargs):
clean_subset = torch.utils.data.Subset(self.train_dataset, clean_indices)

train_dl = torch.utils.data.DataLoader(self.train_dataset, batch_size=32, shuffle=True)
original_accuracy = class_accuracy(self.model, train_dl, self.device)
original_accuracy = class_accuracy(self.model, train_dl, self.model_device)

clean_dl = torch.utils.data.DataLoader(clean_subset, batch_size=32, shuffle=True)

Expand All @@ -155,6 +155,6 @@ def compute(self, *args, **kwargs):
else:
raise ValueError("Trainer should be a Lightning Trainer or a BaseTrainer")

clean_accuracy = class_accuracy(self.model, clean_dl, self.device)
clean_accuracy = class_accuracy(self.model, clean_dl, self.model_device)

return original_accuracy - clean_accuracy
19 changes: 19 additions & 0 deletions quanda/toy_benchmarks/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from abc import ABC, abstractmethod
from typing import Optional, Union
import torch


class ToyBenchmark(ABC):
Expand All @@ -15,6 +17,8 @@ def __init__(self, *args, **kwargs):
:param args:
:param kwargs:
"""
self.model_device: Optional[Union[str, torch.device]]
self.device: Optional[Union[str, torch.device]]

@classmethod
@abstractmethod
Expand Down Expand Up @@ -58,3 +62,18 @@ def evaluate(
"""

raise NotImplementedError

def set_devices(
self,
model: torch.nn.Module,
device: Optional[Union[str, torch.device]] = None,
):
"""
This method should set the device for the model.
"""
if next(model.parameters(), None) is not None:
self.model_device = next(model.parameters()).device
else:
self.model_device = torch.device("cpu")

self.device = device or self.model_device
17 changes: 9 additions & 8 deletions quanda/toy_benchmarks/localization/class_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
class ClassDetection(ToyBenchmark):
def __init__(
self,
device: Optional[Union[str, torch.device]] = None,
*args,
**kwargs,
):
super().__init__(device=device)
super().__init__()

self.model: torch.nn.Module
self.train_dataset: torch.utils.data.Dataset
Expand All @@ -32,11 +31,11 @@ def generate(
This method should generate all the benchmark components and persist them in the instance.
"""

obj = cls(device=device)
obj = cls()

obj.model = model
obj.train_dataset = train_dataset

obj.set_devices(model, device)
return obj

@property
Expand Down Expand Up @@ -67,10 +66,13 @@ def assemble(
"""
This method should assemble the benchmark components from arguments and persist them in the instance.
"""
obj = cls(device=device)

obj = cls()
obj.model = model
obj.train_dataset = train_dataset

obj.set_devices(model, device)

return obj

def save(self, path: str, *args, **kwargs):
Expand All @@ -88,7 +90,6 @@ def evaluate(
cache_dir: str = "./cache",
model_id: str = "default_model_id",
batch_size: int = 8,
device: Optional[Union[str, torch.device]] = None,
*args,
**kwargs,
):
Expand All @@ -99,15 +100,15 @@ def evaluate(

expl_dl = torch.utils.data.DataLoader(expl_dataset, batch_size=batch_size)

metric = ClassDetectionMetric(model=self.model, train_dataset=self.train_dataset, device="cpu")
metric = ClassDetectionMetric(model=self.model, train_dataset=self.train_dataset, device=self.device)

pbar = tqdm(expl_dl)
n_batches = len(expl_dl)

for i, (input, labels) in enumerate(pbar):
pbar.set_description("Metric evaluation, batch %d/%d" % (i + 1, n_batches))

input, labels = input.to(device), labels.to(device)
input, labels = input.to(self.model_device), labels.to(self.model_device)

if use_predictions:
with torch.no_grad():
Expand Down
21 changes: 11 additions & 10 deletions quanda/toy_benchmarks/localization/mislabeling_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@
class MislabelingDetection(ToyBenchmark):
def __init__(
self,
device: Optional[Union[str, torch.device]] = None,
*args,
**kwargs,
):
super().__init__(device=device)
super().__init__()

self.model: Union[torch.nn.Module, L.LightningModule]
self.train_dataset: torch.utils.data.Dataset
Expand Down Expand Up @@ -59,8 +58,8 @@ def generate(
This method should generate all the benchmark components and persist them in the instance.
"""

obj = cls(device=device)

obj = cls()
obj.set_devices(model, device)
obj._generate(
model=model,
train_dataset=train_dataset,
Expand Down Expand Up @@ -199,7 +198,7 @@ def assemble(
"""
This method should assemble the benchmark components from arguments and persist them in the instance.
"""
obj = cls(device=device)
obj = cls()
obj.model = model
obj.train_dataset = train_dataset
obj.p = p
Expand All @@ -220,6 +219,9 @@ def assemble(

obj.poisoned_train_dl = torch.utils.data.DataLoader(obj.poisoned_dataset, batch_size=batch_size)
obj.original_train_dl = torch.utils.data.DataLoader(obj.train_dataset, batch_size=batch_size)

obj.set_devices(model, device)

return obj

def save(self, path: str, *args, **kwargs):
Expand All @@ -235,12 +237,11 @@ def evaluate(
expl_kwargs: Optional[dict] = None,
use_predictions: bool = False,
batch_size: int = 8,
device: Optional[Union[str, torch.device]] = None,
*args,
**kwargs,
):
expl_kwargs = expl_kwargs or {}
explainer = explainer_cls(model=self.model, train_dataset=self.train_dataset, device=device, **expl_kwargs)
explainer = explainer_cls(model=self.model, train_dataset=self.train_dataset, device=self.device, **expl_kwargs)

poisoned_expl_ds = LabelFlippingDataset(
dataset=expl_dataset, dataset_transform=self.dataset_transform, n_classes=self.n_classes, p=0.0
Expand All @@ -251,7 +252,7 @@ def evaluate(
model=self.model,
train_dataset=self.poisoned_dataset,
poisoned_indices=self.poisoned_indices,
device=device,
device=self.device,
aggregator_cls=self.global_method,
)

Expand All @@ -261,7 +262,7 @@ def evaluate(
for i, (inputs, labels) in enumerate(pbar):
pbar.set_description("Metric evaluation, batch %d/%d" % (i + 1, n_batches))

inputs, labels = inputs.to(device), labels.to(device)
inputs, labels = inputs.to(self.model_device), labels.to(self.model_device)
if use_predictions:
with torch.no_grad():
targets = self.model(inputs).argmax(dim=-1)
Expand All @@ -274,7 +275,7 @@ def evaluate(
model=self.model,
train_dataset=self.poisoned_dataset,
poisoned_indices=self.poisoned_indices,
device=device,
device=self.device,
explainer_cls=explainer_cls,
expl_kwargs=expl_kwargs,
)
Expand Down
17 changes: 9 additions & 8 deletions quanda/toy_benchmarks/localization/subclass_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@
class SubclassDetection(ToyBenchmark):
def __init__(
self,
device: Optional[Union[str, torch.device]] = None,
*args,
**kwargs,
):
super().__init__(device=device)
super().__init__()

self.model: Union[torch.nn.Module, L.LightningModule]
self.group_model: Union[torch.nn.Module, L.LightningModule]
Expand Down Expand Up @@ -56,8 +55,8 @@ def generate(
This method should generate all the benchmark components and persist them in the instance.
"""

obj = cls(device=device)

obj = cls()
obj.set_devices(model, device)
obj.model = model
obj._generate(
trainer=trainer,
Expand Down Expand Up @@ -181,7 +180,7 @@ def assemble(
"""
This method should assemble the benchmark components from arguments and persist them in the instance.
"""
obj = cls(device=device)
obj = cls()
obj.group_model = group_model
obj.train_dataset = train_dataset
obj.class_to_group = class_to_group
Expand All @@ -198,6 +197,9 @@ def assemble(
)
obj.grouped_train_dl = torch.utils.data.DataLoader(obj.grouped_dataset, batch_size=batch_size)
obj.original_train_dl = torch.utils.data.DataLoader(obj.train_dataset, batch_size=batch_size)

obj.set_devices(group_model, device)

return obj

def save(self, path: str, *args, **kwargs):
Expand All @@ -215,7 +217,6 @@ def evaluate(
cache_dir: str = "./cache",
model_id: str = "default_model_id",
batch_size: int = 8,
device: Optional[Union[str, torch.device]] = None,
*args,
**kwargs,
):
Expand All @@ -226,15 +227,15 @@ def evaluate(

expl_dl = torch.utils.data.DataLoader(expl_dataset, batch_size=batch_size)

metric = ClassDetectionMetric(model=self.group_model, train_dataset=self.train_dataset, device=device)
metric = ClassDetectionMetric(model=self.group_model, train_dataset=self.train_dataset, device=self.device)

pbar = tqdm(expl_dl)
n_batches = len(expl_dl)

for i, (inputs, labels) in enumerate(pbar):
pbar.set_description("Metric evaluation, batch %d/%d" % (i + 1, n_batches))

inputs, labels = inputs.to(device), labels.to(device)
inputs, labels = inputs.to(self.model_device), labels.to(self.model_device)
grouped_labels = torch.tensor([self.class_to_group[i.item()] for i in labels], device=labels.device)
if use_predictions:
with torch.no_grad():
Expand Down
16 changes: 8 additions & 8 deletions quanda/toy_benchmarks/randomization/model_randomization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
class ModelRandomization(ToyBenchmark):
def __init__(
self,
device: Optional[Union[str, torch.device]] = None,
*args,
**kwargs,
):
super().__init__(device=device)
super().__init__()

self.model: torch.nn.Module
self.train_dataset: torch.utils.data.Dataset
Expand All @@ -35,8 +34,8 @@ def generate(
This method should generate all the benchmark components and persist them in the instance.
"""

obj = cls(device=device)

obj = cls()
obj.set_devices(model, device)
obj.model = model
obj.train_dataset = train_dataset

Expand Down Expand Up @@ -70,10 +69,12 @@ def assemble(
"""
This method should assemble the benchmark components from arguments and persist them in the instance.
"""
obj = cls(device=device)
obj = cls()
obj.model = model
obj.train_dataset = train_dataset

obj.set_devices(model, device)

return obj

def save(self, path: str, *args, **kwargs):
Expand All @@ -93,7 +94,6 @@ def evaluate(
cache_dir: str = "./cache",
model_id: str = "default_model_id",
batch_size: int = 8,
device: Optional[Union[str, torch.device]] = None,
*args,
**kwargs,
):
Expand All @@ -109,15 +109,15 @@ def evaluate(
seed=seed,
model_id=model_id,
cache_dir=cache_dir,
device=device,
device=self.device,
)
pbar = tqdm(expl_dl)
n_batches = len(expl_dl)

for i, (input, labels) in enumerate(pbar):
pbar.set_description("Metric evaluation, batch %d/%d" % (i + 1, n_batches))

input, labels = input.to(device), labels.to(device)
input, labels = input.to(self.model_device), labels.to(self.model_device)

if use_predictions:
with torch.no_grad():
Expand Down
Loading

0 comments on commit 3438bb3

Please sign in to comment.