Skip to content

Commit

Permalink
Merge pull request #105 from dilyabareeva/minor-fixes
Browse files Browse the repository at this point in the history
Making model_id and cache_dir optional
  • Loading branch information
gumityolcu authored Aug 12, 2024
2 parents 5b92924 + 821f4e4 commit 3c7baad
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 45 deletions.
2 changes: 1 addition & 1 deletion quanda/explainers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ class BaseExplainer(ABC):
def __init__(
self,
model: torch.nn.Module,
model_id: str,
cache_dir: Optional[str],
train_dataset: torch.utils.data.Dataset,
device: Union[str, torch.device],
model_id: Optional[str] = None,
**kwargs,
):
self.model = model
Expand Down
8 changes: 4 additions & 4 deletions quanda/explainers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ def _init_explainer(explainer_cls, model, model_id, cache_dir, train_dataset, de
def explain_fn_from_explainer(
explainer_cls: type,
model: torch.nn.Module,
model_id: str,
cache_dir: Optional[str],
test_tensor: torch.Tensor,
train_dataset: torch.utils.data.Dataset,
device: Union[str, torch.device],
targets: Optional[Union[List[int], torch.Tensor]] = None,
cache_dir: Optional[str] = None,
model_id: Optional[str] = None,
**kwargs: Any,
) -> torch.Tensor:
explainer = _init_explainer(
Expand All @@ -42,11 +42,11 @@ def explain_fn_from_explainer(
def self_influence_fn_from_explainer(
explainer_cls: type,
model: torch.nn.Module,
model_id: str,
cache_dir: Optional[str],
train_dataset: torch.utils.data.Dataset,
device: Union[str, torch.device],
self_influence_kwargs: dict,
cache_dir: Optional[str] = None,
model_id: Optional[str] = None,
**kwargs: Any,
) -> torch.Tensor:
explainer = _init_explainer(
Expand Down
38 changes: 16 additions & 22 deletions quanda/explainers/wrappers/captum_influence.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ class CaptumInfluence(BaseExplainer, ABC):
def __init__(
self,
model: torch.nn.Module,
model_id: str,
cache_dir: Optional[str],
train_dataset: torch.utils.data.Dataset,
device: Union[str, torch.device],
explainer_cls: type,
explain_kwargs: Any,
model_id: Optional[str] = None,
cache_dir: Optional[str] = None,
):
super().__init__(
model=model,
Expand Down Expand Up @@ -200,10 +200,8 @@ class CaptumArnoldi(CaptumInfluence):
def __init__(
self,
model: torch.nn.Module,
model_id: str, # TODO Make optional
train_dataset: torch.utils.data.Dataset,
checkpoint: str,
cache_dir: str, # TODO Make optional
loss_fn: Union[torch.nn.Module, Callable] = torch.nn.CrossEntropyLoss(),
checkpoints_load_func: Optional[Callable[..., Any]] = None,
layers: Optional[List[str]] = None,
Expand All @@ -220,6 +218,8 @@ def __init__(
projection_on_cpu: bool = True,
show_progress: bool = False,
device: Union[str, torch.device] = "cpu", # TODO Check if gpu works
model_id: Optional[str] = None,
cache_dir: Optional[str] = None,
**explainer_kwargs: Any,
):
if checkpoints_load_func is None:
Expand Down Expand Up @@ -287,13 +287,12 @@ def self_influence(self, **kwargs: Any) -> torch.Tensor:

def captum_arnoldi_explain(
model: torch.nn.Module,
model_id: str,
cache_dir: str,
test_tensor: torch.Tensor,
train_dataset: torch.utils.data.Dataset,
loss_fn: Union[torch.nn.Module, Callable],
device: Union[str, torch.device],
explanation_targets: Optional[Union[List[int], torch.Tensor]] = None,
model_id: Optional[str] = None,
cache_dir: Optional[str] = None,
**kwargs: Any,
) -> torch.Tensor:
return explain_fn_from_explainer(
Expand All @@ -304,20 +303,18 @@ def captum_arnoldi_explain(
test_tensor=test_tensor,
targets=explanation_targets,
train_dataset=train_dataset,
loss_fn=loss_fn,
device=device,
**kwargs,
)


def captum_arnoldi_self_influence(
model: torch.nn.Module,
model_id: str,
cache_dir: str,
train_dataset: torch.utils.data.Dataset,
loss_fn: Union[torch.nn.Module, Callable],
device: Union[str, torch.device],
inputs_dataset: Optional[Union[Tuple[Any, ...], torch.utils.data.DataLoader]] = None,
model_id: Optional[str] = None,
cache_dir: Optional[str] = None,
**kwargs: Any,
) -> torch.Tensor:
self_influence_kwargs = {
Expand All @@ -329,7 +326,6 @@ def captum_arnoldi_self_influence(
model_id=model_id,
cache_dir=cache_dir,
train_dataset=train_dataset,
loss_fn=loss_fn,
device=device,
self_influence_kwargs=self_influence_kwargs,
**kwargs,
Expand All @@ -340,16 +336,17 @@ class CaptumTracInCP(CaptumInfluence):
def __init__(
self,
model: torch.nn.Module,
model_id: str,
train_dataset: torch.utils.data.Dataset,
checkpoints: Union[str, List[str], Iterator],
cache_dir: Optional[str],
checkpoints_load_func: Optional[Callable[..., Any]] = None,
layers: Optional[List[str]] = None,
loss_fn: Optional[Union[torch.nn.Module, Callable]] = None,
batch_size: int = 1,
test_loss_fn: Optional[Union[torch.nn.Module, Callable]] = None,
sample_wise_grads_per_batch: bool = False,
device: Union[str, torch.device] = "cpu",
model_id: Optional[str] = None,
cache_dir: Optional[str] = None,
**explainer_kwargs: Any,
):
if checkpoints_load_func is None:
Expand All @@ -369,6 +366,7 @@ def __init__(
"train_dataset": train_dataset,
"checkpoints": checkpoints,
"checkpoints_load_func": checkpoints_load_func,
"layers": layers,
"loss_fn": loss_fn,
"batch_size": batch_size,
"test_loss_fn": test_loss_fn,
Expand Down Expand Up @@ -410,13 +408,12 @@ def self_influence(self, **kwargs: Any) -> torch.Tensor:

def captum_tracincp_explain(
model: torch.nn.Module,
model_id: str,
cache_dir: Optional[str],
test_tensor: torch.Tensor,
train_dataset: torch.utils.data.Dataset,
checkpoints: Union[str, List[str], Iterator],
device: Union[str, torch.device],
explanation_targets: Optional[Union[List[int], torch.Tensor]] = None,
model_id: Optional[str] = None,
cache_dir: Optional[str] = None,
**kwargs: Any,
) -> torch.Tensor:
return explain_fn_from_explainer(
Expand All @@ -427,21 +424,19 @@ def captum_tracincp_explain(
test_tensor=test_tensor,
targets=explanation_targets,
train_dataset=train_dataset,
checkpoints=checkpoints,
device=device,
**kwargs,
)


def captum_tracincp_self_influence(
model: torch.nn.Module,
model_id: str,
cache_dir: Optional[str],
train_dataset: torch.utils.data.Dataset,
checkpoints: Union[str, List[str], Iterator],
device: Union[str, torch.device],
inputs: Optional[Union[Tuple[Any, ...], torch.utils.data.DataLoader]] = None,
outer_loop_by_checkpoints: bool = False,
model_id: Optional[str] = None,
cache_dir: Optional[str] = None,
**kwargs: Any,
) -> torch.Tensor:
self_influence_kwargs = {"inputs": inputs, "outer_loop_by_checkpoints": outer_loop_by_checkpoints}
Expand All @@ -451,7 +446,6 @@ def captum_tracincp_self_influence(
model_id=model_id,
cache_dir=cache_dir,
train_dataset=train_dataset,
checkpoints=checkpoints,
device=device,
self_influence_kwargs=self_influence_kwargs,
**kwargs,
Expand Down
22 changes: 4 additions & 18 deletions tests/explainers/wrappers/test_captum_influence.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def test_captum_influence_explain_functional(
],
)
def test_captum_arnoldi(
test_id, model, dataset, test_tensor, test_labels, method_kwargs_simple, method_kwargs_complex, request, tmp_path
test_id, model, dataset, test_tensor, test_labels, method_kwargs_simple, method_kwargs_complex, request
):
model = request.getfixturevalue(model)
dataset = request.getfixturevalue(dataset)
Expand All @@ -188,8 +188,6 @@ def test_captum_arnoldi(

explainer_simple = CaptumArnoldi(
model=model,
model_id="test_id",
cache_dir=str(tmp_path),
train_dataset=dataset,
checkpoint="tests/assets/mnist",
device="cpu",
Expand All @@ -211,8 +209,6 @@ def test_captum_arnoldi(

explainer_complex = CaptumArnoldi(
model=model,
model_id="test_id",
cache_dir=str(tmp_path),
train_dataset=dataset,
checkpoint="tests/assets/mnist",
device="cpu",
Expand Down Expand Up @@ -314,8 +310,6 @@ def test_captum_arnoldi_explain_functional(

explanations_complex = captum_arnoldi_explain(
model=model,
model_id="test_id",
cache_dir=str(tmp_path),
test_tensor=test_tensor,
train_dataset=dataset,
explanation_targets=test_labels,
Expand Down Expand Up @@ -351,7 +345,7 @@ def test_captum_arnoldi_explain_functional(
),
],
)
def test_captum_arnoldi_self_influence(test_id, model, dataset, method_kwargs, request, tmp_path):
def test_captum_arnoldi_self_influence(test_id, model, dataset, method_kwargs, request):
model = request.getfixturevalue(model)
dataset = request.getfixturevalue(dataset)

Expand All @@ -367,8 +361,6 @@ def test_captum_arnoldi_self_influence(test_id, model, dataset, method_kwargs, r

explanations = captum_arnoldi_self_influence(
model=model,
model_id="test_id",
cache_dir=str(tmp_path),
train_dataset=dataset,
device="cpu",
checkpoint="tests/assets/mnist",
Expand Down Expand Up @@ -396,7 +388,7 @@ def test_captum_arnoldi_self_influence(test_id, model, dataset, method_kwargs, r
),
],
)
def test_captum_tracincp(test_id, model, dataset, test_tensor, checkpoints, method_kwargs, request, tmp_path):
def test_captum_tracincp(test_id, model, dataset, test_tensor, checkpoints, method_kwargs, request):
model = request.getfixturevalue(model)
dataset = request.getfixturevalue(dataset)
test_tensor = request.getfixturevalue(test_tensor)
Expand All @@ -413,8 +405,6 @@ def test_captum_tracincp(test_id, model, dataset, test_tensor, checkpoints, meth

explainer = CaptumTracInCP(
model=model,
model_id="test_id",
cache_dir=str(tmp_path),
train_dataset=dataset,
checkpoints=checkpoints,
checkpoints_load_func=get_load_state_dict_func("cpu"),
Expand Down Expand Up @@ -466,8 +456,6 @@ def test_captum_tracincp_explain_functional(

explanations_simple = captum_tracincp_explain(
model=model,
model_id="test_id",
cache_dir=str(tmp_path),
train_dataset=dataset,
checkpoints=checkpoints,
test_tensor=test_tensor,
Expand Down Expand Up @@ -519,7 +507,7 @@ def test_captum_tracincp_explain_functional(
),
],
)
def test_captum_tracincp_self_influence(test_id, model, dataset, checkpoints, method_kwargs, request, tmp_path):
def test_captum_tracincp_self_influence(test_id, model, dataset, checkpoints, method_kwargs, request):
model = request.getfixturevalue(model)
dataset = request.getfixturevalue(dataset)
checkpoints = request.getfixturevalue(checkpoints)
Expand All @@ -535,8 +523,6 @@ def test_captum_tracincp_self_influence(test_id, model, dataset, checkpoints, me

explanations = captum_tracincp_self_influence(
model=model,
model_id="test_id",
cache_dir=str(tmp_path),
train_dataset=dataset,
checkpoints=checkpoints,
checkpoints_load_func=get_load_state_dict_func("cpu"),
Expand Down

0 comments on commit 3c7baad

Please sign in to comment.