Skip to content

Commit

Permalink
remove device completely
Browse files Browse the repository at this point in the history
  • Loading branch information
dilyabareeva committed Aug 20, 2024
1 parent c6b712b commit 82745b5
Show file tree
Hide file tree
Showing 19 changed files with 18 additions and 60 deletions.
2 changes: 1 addition & 1 deletion quanda/explainers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ def __init__(
model: torch.nn.Module,
cache_dir: Optional[str],
train_dataset: torch.utils.data.Dataset,

model_id: Optional[str] = None,
**kwargs,
):
self.device: Union[str, torch.device]
self.model = model

# if model has device attribute, use it, otherwise use the default device
Expand Down
3 changes: 0 additions & 3 deletions quanda/explainers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ def _init_explainer(explainer_cls, model, model_id, cache_dir, train_dataset, **
model_id=model_id,
cache_dir=cache_dir,
train_dataset=train_dataset,

**kwargs,
)
return explainer
Expand All @@ -31,7 +30,6 @@ def explain_fn_from_explainer(
model_id=model_id,
cache_dir=cache_dir,
train_dataset=train_dataset,

**kwargs,
)

Expand All @@ -53,7 +51,6 @@ def self_influence_fn_from_explainer(
model_id=model_id,
cache_dir=cache_dir,
train_dataset=train_dataset,

**kwargs,
)

Expand Down
4 changes: 0 additions & 4 deletions quanda/explainers/wrappers/captum_influence.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,12 @@ def __init__(
explain_kwargs: Any,
model_id: Optional[str] = None,
cache_dir: Optional[str] = None,

):
super().__init__(
model=model,
model_id=model_id,
cache_dir=cache_dir,
train_dataset=train_dataset,

)
self.explainer_cls = explainer_cls
self.explain_kwargs = explain_kwargs
Expand Down Expand Up @@ -298,7 +296,6 @@ def captum_arnoldi_explain(
test_tensor=test_tensor,
targets=explanation_targets,
train_dataset=train_dataset,

**kwargs,
)

Expand All @@ -321,7 +318,6 @@ def captum_arnoldi_self_influence(
model_id=model_id,
cache_dir=cache_dir,
train_dataset=train_dataset,

self_influence_kwargs=self_influence_kwargs,
**kwargs,
)
Expand Down
3 changes: 1 addition & 2 deletions quanda/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
**kwargs: Any
Additional keyword arguments.
"""

self.device: Union[str, torch.device]
self.model: torch.nn.Module = model
self.train_dataset: torch.utils.data.Dataset = train_dataset

Expand Down Expand Up @@ -197,7 +197,6 @@ def __init__(
global_method: Union[str, type] = "self-influence",
explainer_cls: Optional[type] = None,
expl_kwargs: Optional[dict] = None,

*args,
**kwargs,
):
Expand Down
3 changes: 1 addition & 2 deletions quanda/metrics/localization/class_detection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Union
from typing import List

import torch

Expand All @@ -10,7 +10,6 @@ def __init__(
self,
model: torch.nn.Module,
train_dataset: torch.utils.data.Dataset,

*args,
**kwargs,
):
Expand Down
6 changes: 0 additions & 6 deletions quanda/metrics/localization/mislabeling_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ def __init__(
global_method: Union[str, type] = "self-influence",
explainer_cls: Optional[type] = None,
expl_kwargs: Optional[dict] = None,

*args: Any,
**kwargs: Any,
):
Expand All @@ -25,7 +24,6 @@ def __init__(
explainer_cls=explainer_cls,
expl_kwargs=expl_kwargs,
model_id="test",

)
self.poisoned_indices = poisoned_indices

Expand All @@ -37,7 +35,6 @@ def self_influence_based(
explainer_cls: type,
poisoned_indices: List[int],
expl_kwargs: Optional[dict] = None,

*args: Any,
**kwargs: Any,
):
Expand All @@ -48,7 +45,6 @@ def self_influence_based(
global_method="self-influence",
explainer_cls=explainer_cls,
expl_kwargs=expl_kwargs,

)

@classmethod
Expand All @@ -58,7 +54,6 @@ def aggr_based(
train_dataset: torch.utils.data.Dataset,
poisoned_indices: List[int],
aggregator_cls: Union[str, type],

*args,
**kwargs,
):
Expand All @@ -67,7 +62,6 @@ def aggr_based(
global_method=aggregator_cls,
poisoned_indices=poisoned_indices,
train_dataset=train_dataset,

)

def update(
Expand Down
3 changes: 0 additions & 3 deletions quanda/metrics/localization/subclass_detection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Optional, Union

import torch

from quanda.metrics.localization import ClassDetectionMetric
Expand All @@ -11,7 +9,6 @@ def __init__(
model: torch.nn.Module,
train_dataset: torch.utils.data.Dataset,
subclass_labels: torch.Tensor,

*args,
**kwargs,
):
Expand Down
2 changes: 0 additions & 2 deletions quanda/metrics/randomization/model_randomization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,12 @@ def __init__(
seed: int = 42,
model_id: str = "0",
cache_dir: str = "./cache",

*args,
**kwargs,
):
super().__init__(
model=model,
train_dataset=train_dataset,

)
self.model = model
self.train_dataset = train_dataset
Expand Down
6 changes: 0 additions & 6 deletions quanda/metrics/unnamed/dataset_cleaning.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def __init__(
expl_kwargs: Optional[dict] = None,
model_id: str = "0",
cache_dir: str = "./cache",

*args,
**kwargs,
):
Expand All @@ -45,7 +44,6 @@ def __init__(
global_method=global_method,
explainer_cls=explainer_cls,
expl_kwargs={**expl_kwargs, "model_id": model_id, "cache_dir": cache_dir},

)
self.top_k = min(top_k, self.dataset_length - 1)
self.trainer = trainer
Expand All @@ -64,7 +62,6 @@ def self_influence_based(
expl_kwargs: Optional[dict] = None,
top_k: int = 50,
trainer_fit_kwargs: Optional[dict] = None,

*args,
**kwargs,
):
Expand All @@ -78,7 +75,6 @@ def self_influence_based(
top_k=top_k,
explainer_cls=explainer_cls,
expl_kwargs=expl_kwargs,

)

@classmethod
Expand All @@ -91,7 +87,6 @@ def aggr_based(
init_model: Optional[torch.nn.Module] = None,
top_k: int = 50,
trainer_fit_kwargs: Optional[dict] = None,

*args,
**kwargs,
):
Expand All @@ -103,7 +98,6 @@ def aggr_based(
trainer_fit_kwargs=trainer_fit_kwargs,
global_method=aggregator_cls,
top_k=top_k,

)

def update(
Expand Down
3 changes: 0 additions & 3 deletions quanda/metrics/unnamed/top_k_overlap.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Optional, Union

import torch

from quanda.metrics.base import Metric
Expand All @@ -11,7 +9,6 @@ def __init__(
model: torch.nn.Module,
train_dataset: torch.utils.data.Dataset,
top_k: int = 1,

*args,
**kwargs,
):
Expand Down
1 change: 1 addition & 0 deletions quanda/toy_benchmarks/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from typing import Optional, Union

import torch


Expand Down
6 changes: 2 additions & 4 deletions quanda/toy_benchmarks/localization/class_detection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Optional

import torch
from tqdm import tqdm
Expand All @@ -23,7 +23,6 @@ def generate(
cls,
model: torch.nn.Module,
train_dataset: torch.utils.data.Dataset,

*args,
**kwargs,
):
Expand All @@ -46,7 +45,7 @@ def bench_state(self):
}

@classmethod
def load(cls, path: str, batch_size: int = 8, *args, **kwargs):
def load(cls, path: str, batch_size: int = 8, *args, **kwargs):
"""
This method should load the benchmark components from a file and persist them in the instance.
"""
Expand All @@ -59,7 +58,6 @@ def assemble(
cls,
model: torch.nn.Module,
train_dataset: torch.utils.data.Dataset,

*args,
**kwargs,
):
Expand Down
5 changes: 1 addition & 4 deletions quanda/toy_benchmarks/localization/mislabeling_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def generate(
trainer_fit_kwargs: Optional[dict] = None,
seed: int = 27,
batch_size: int = 8,

*args,
**kwargs,
):
Expand Down Expand Up @@ -160,7 +159,7 @@ def bench_state(self):
}

@classmethod
def load(cls, path: str, batch_size: int = 8, *args, **kwargs):
def load(cls, path: str, batch_size: int = 8, *args, **kwargs):
"""
This method should load the benchmark components from a file and persist them in the instance.
"""
Expand All @@ -176,7 +175,6 @@ def load(cls, path: str, batch_size: int = 8, *args, **kwargs):
p=bench_state["p"],
global_method=bench_state["global_method"],
batch_size=batch_size,

)

@classmethod
Expand All @@ -191,7 +189,6 @@ def assemble(
p: float = 0.3, # TODO: type specification
global_method: Union[str, type] = "self-influence",
batch_size: int = 8,

*args,
**kwargs,
):
Expand Down
5 changes: 1 addition & 4 deletions quanda/toy_benchmarks/localization/subclass_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def generate(
trainer_fit_kwargs: Optional[dict] = None,
seed: int = 27,
batch_size: int = 8,

*args,
**kwargs,
):
Expand Down Expand Up @@ -146,7 +145,7 @@ def _generate(
raise ValueError("Trainer should be a Lightning Trainer or a BaseTrainer")

@classmethod
def load(cls, path: str, batch_size: int = 8, *args, **kwargs):
def load(cls, path: str, batch_size: int = 8, *args, **kwargs):
"""
This method should load the benchmark components from a file and persist them in the instance.
"""
Expand All @@ -160,7 +159,6 @@ def load(cls, path: str, batch_size: int = 8, *args, **kwargs):
class_to_group=bench_state["class_to_group"],
dataset_transform=bench_state["dataset_transform"],
batch_size=batch_size,

)

@classmethod
Expand All @@ -173,7 +171,6 @@ def assemble(
class_to_group: Dict[int, int], # TODO: type specification
dataset_transform: Optional[Callable] = None,
batch_size: int = 8,

*args,
**kwargs,
):
Expand Down
4 changes: 1 addition & 3 deletions quanda/toy_benchmarks/randomization/model_randomization.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def generate(
cls,
model: torch.nn.Module,
train_dataset: torch.utils.data.Dataset,

*args,
**kwargs,
):
Expand All @@ -49,7 +48,7 @@ def bench_state(self):
}

@classmethod
def load(cls, path: str, batch_size: int = 8, *args, **kwargs):
def load(cls, path: str, batch_size: int = 8, *args, **kwargs):
"""
This method should load the benchmark components from a file and persist them in the instance.
"""
Expand All @@ -62,7 +61,6 @@ def assemble(
cls,
model: torch.nn.Module,
train_dataset: torch.utils.data.Dataset,

*args,
**kwargs,
):
Expand Down
4 changes: 1 addition & 3 deletions quanda/toy_benchmarks/unnamed/dataset_cleaning.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def generate(
cls,
model: torch.nn.Module,
train_dataset: torch.utils.data.Dataset,

*args,
**kwargs,
):
Expand All @@ -49,7 +48,7 @@ def bench_state(self):
}

@classmethod
def load(cls, path: str, batch_size: int = 8, *args, **kwargs):
def load(cls, path: str, batch_size: int = 8, *args, **kwargs):
"""
This method should load the benchmark components from a file and persist them in the instance.
"""
Expand All @@ -61,7 +60,6 @@ def assemble(
cls,
model: torch.nn.Module,
train_dataset: torch.utils.data.Dataset,

*args,
**kwargs,
):
Expand Down
Loading

0 comments on commit 82745b5

Please sign in to comment.