From 821f4e494622694f5ec84d632e6b8f4ca4dafb3d Mon Sep 17 00:00:00 2001 From: aski02 Date: Sun, 11 Aug 2024 23:41:12 +0200 Subject: [PATCH] Making model_id and cache_dir optional --- quanda/explainers/base.py | 2 +- quanda/explainers/utils.py | 8 ++-- .../explainers/wrappers/captum_influence.py | 38 ++++++++----------- .../wrappers/test_captum_influence.py | 22 ++--------- 4 files changed, 25 insertions(+), 45 deletions(-) diff --git a/quanda/explainers/base.py b/quanda/explainers/base.py index ddc3c888..eabc776a 100644 --- a/quanda/explainers/base.py +++ b/quanda/explainers/base.py @@ -11,10 +11,10 @@ class BaseExplainer(ABC): def __init__( self, model: torch.nn.Module, - model_id: str, cache_dir: Optional[str], train_dataset: torch.utils.data.Dataset, device: Union[str, torch.device], + model_id: Optional[str] = None, **kwargs, ): self.model = model diff --git a/quanda/explainers/utils.py b/quanda/explainers/utils.py index 135e8152..14112aa9 100644 --- a/quanda/explainers/utils.py +++ b/quanda/explainers/utils.py @@ -18,12 +18,12 @@ def _init_explainer(explainer_cls, model, model_id, cache_dir, train_dataset, de def explain_fn_from_explainer( explainer_cls: type, model: torch.nn.Module, - model_id: str, - cache_dir: Optional[str], test_tensor: torch.Tensor, train_dataset: torch.utils.data.Dataset, device: Union[str, torch.device], targets: Optional[Union[List[int], torch.Tensor]] = None, + cache_dir: Optional[str] = None, + model_id: Optional[str] = None, **kwargs: Any, ) -> torch.Tensor: explainer = _init_explainer( @@ -42,11 +42,11 @@ def explain_fn_from_explainer( def self_influence_fn_from_explainer( explainer_cls: type, model: torch.nn.Module, - model_id: str, - cache_dir: Optional[str], train_dataset: torch.utils.data.Dataset, device: Union[str, torch.device], self_influence_kwargs: dict, + cache_dir: Optional[str] = None, + model_id: Optional[str] = None, **kwargs: Any, ) -> torch.Tensor: explainer = _init_explainer( diff --git a/quanda/explainers/wrappers/captum_influence.py b/quanda/explainers/wrappers/captum_influence.py index e1bdda55..d7223181 100644 --- a/quanda/explainers/wrappers/captum_influence.py +++ b/quanda/explainers/wrappers/captum_influence.py @@ -25,12 +25,12 @@ class CaptumInfluence(BaseExplainer, ABC): def __init__( self, model: torch.nn.Module, - model_id: str, - cache_dir: Optional[str], train_dataset: torch.utils.data.Dataset, device: Union[str, torch.device], explainer_cls: type, explain_kwargs: Any, + model_id: Optional[str] = None, + cache_dir: Optional[str] = None, ): super().__init__( model=model, @@ -200,10 +200,8 @@ class CaptumArnoldi(CaptumInfluence): def __init__( self, model: torch.nn.Module, - model_id: str, # TODO Make optional train_dataset: torch.utils.data.Dataset, checkpoint: str, - cache_dir: str, # TODO Make optional loss_fn: Union[torch.nn.Module, Callable] = torch.nn.CrossEntropyLoss(), checkpoints_load_func: Optional[Callable[..., Any]] = None, layers: Optional[List[str]] = None, @@ -220,6 +218,8 @@ def __init__( projection_on_cpu: bool = True, show_progress: bool = False, device: Union[str, torch.device] = "cpu", # TODO Check if gpu works + model_id: Optional[str] = None, + cache_dir: Optional[str] = None, **explainer_kwargs: Any, ): if checkpoints_load_func is None: @@ -287,13 +287,12 @@ def self_influence(self, **kwargs: Any) -> torch.Tensor: def captum_arnoldi_explain( model: torch.nn.Module, - model_id: str, - cache_dir: str, test_tensor: torch.Tensor, train_dataset: torch.utils.data.Dataset, - loss_fn: Union[torch.nn.Module, Callable], device: Union[str, torch.device], explanation_targets: Optional[Union[List[int], torch.Tensor]] = None, + model_id: Optional[str] = None, + cache_dir: Optional[str] = None, **kwargs: Any, ) -> torch.Tensor: return explain_fn_from_explainer( @@ -304,7 +303,6 @@ def captum_arnoldi_explain( test_tensor=test_tensor, targets=explanation_targets, train_dataset=train_dataset, - loss_fn=loss_fn, device=device, **kwargs, ) @@ -312,12 +310,11 @@ def captum_arnoldi_explain( def captum_arnoldi_self_influence( model: torch.nn.Module, - model_id: str, - cache_dir: str, train_dataset: torch.utils.data.Dataset, - loss_fn: Union[torch.nn.Module, Callable], device: Union[str, torch.device], inputs_dataset: Optional[Union[Tuple[Any, ...], torch.utils.data.DataLoader]] = None, + model_id: Optional[str] = None, + cache_dir: Optional[str] = None, **kwargs: Any, ) -> torch.Tensor: self_influence_kwargs = { @@ -329,7 +326,6 @@ def captum_arnoldi_self_influence( model_id=model_id, cache_dir=cache_dir, train_dataset=train_dataset, - loss_fn=loss_fn, device=device, self_influence_kwargs=self_influence_kwargs, **kwargs, @@ -340,16 +336,17 @@ class CaptumTracInCP(CaptumInfluence): def __init__( self, model: torch.nn.Module, - model_id: str, train_dataset: torch.utils.data.Dataset, checkpoints: Union[str, List[str], Iterator], - cache_dir: Optional[str], checkpoints_load_func: Optional[Callable[..., Any]] = None, + layers: Optional[List[str]] = None, loss_fn: Optional[Union[torch.nn.Module, Callable]] = None, batch_size: int = 1, test_loss_fn: Optional[Union[torch.nn.Module, Callable]] = None, sample_wise_grads_per_batch: bool = False, device: Union[str, torch.device] = "cpu", + model_id: Optional[str] = None, + cache_dir: Optional[str] = None, **explainer_kwargs: Any, ): if checkpoints_load_func is None: @@ -369,6 +366,7 @@ def __init__( "train_dataset": train_dataset, "checkpoints": checkpoints, "checkpoints_load_func": checkpoints_load_func, + "layers": layers, "loss_fn": loss_fn, "batch_size": batch_size, "test_loss_fn": test_loss_fn, @@ -410,13 +408,12 @@ def self_influence(self, **kwargs: Any) -> torch.Tensor: def captum_tracincp_explain( model: torch.nn.Module, - model_id: str, - cache_dir: Optional[str], test_tensor: torch.Tensor, train_dataset: torch.utils.data.Dataset, - checkpoints: Union[str, List[str], Iterator], device: Union[str, torch.device], explanation_targets: Optional[Union[List[int], torch.Tensor]] = None, + model_id: Optional[str] = None, + cache_dir: Optional[str] = None, **kwargs: Any, ) -> torch.Tensor: return explain_fn_from_explainer( @@ -427,7 +424,6 @@ def captum_tracincp_explain( test_tensor=test_tensor, targets=explanation_targets, train_dataset=train_dataset, - checkpoints=checkpoints, device=device, **kwargs, ) @@ -435,13 +431,12 @@ def captum_tracincp_explain( def captum_tracincp_self_influence( model: torch.nn.Module, - model_id: str, - cache_dir: Optional[str], train_dataset: torch.utils.data.Dataset, - checkpoints: Union[str, List[str], Iterator], device: Union[str, torch.device], inputs: Optional[Union[Tuple[Any, ...], torch.utils.data.DataLoader]] = None, outer_loop_by_checkpoints: bool = False, + model_id: Optional[str] = None, + cache_dir: Optional[str] = None, **kwargs: Any, ) -> torch.Tensor: self_influence_kwargs = {"inputs": inputs, "outer_loop_by_checkpoints": outer_loop_by_checkpoints} @@ -451,7 +446,6 @@ def captum_tracincp_self_influence( model_id=model_id, cache_dir=cache_dir, train_dataset=train_dataset, - checkpoints=checkpoints, device=device, self_influence_kwargs=self_influence_kwargs, **kwargs, diff --git a/tests/explainers/wrappers/test_captum_influence.py b/tests/explainers/wrappers/test_captum_influence.py index 79e70f53..549275d3 100644 --- a/tests/explainers/wrappers/test_captum_influence.py +++ b/tests/explainers/wrappers/test_captum_influence.py @@ -179,7 +179,7 @@ def test_captum_influence_explain_functional( ], ) def test_captum_arnoldi( - test_id, model, dataset, test_tensor, test_labels, method_kwargs_simple, method_kwargs_complex, request, tmp_path + test_id, model, dataset, test_tensor, test_labels, method_kwargs_simple, method_kwargs_complex, request ): model = request.getfixturevalue(model) dataset = request.getfixturevalue(dataset) @@ -188,8 +188,6 @@ def test_captum_arnoldi( explainer_simple = CaptumArnoldi( model=model, - model_id="test_id", - cache_dir=str(tmp_path), train_dataset=dataset, checkpoint="tests/assets/mnist", device="cpu", @@ -211,8 +209,6 @@ def test_captum_arnoldi( explainer_complex = CaptumArnoldi( model=model, - model_id="test_id", - cache_dir=str(tmp_path), train_dataset=dataset, checkpoint="tests/assets/mnist", device="cpu", @@ -314,8 +310,6 @@ def test_captum_arnoldi_explain_functional( explanations_complex = captum_arnoldi_explain( model=model, - model_id="test_id", - cache_dir=str(tmp_path), test_tensor=test_tensor, train_dataset=dataset, explanation_targets=test_labels, @@ -351,7 +345,7 @@ def test_captum_arnoldi_explain_functional( ), ], ) -def test_captum_arnoldi_self_influence(test_id, model, dataset, method_kwargs, request, tmp_path): +def test_captum_arnoldi_self_influence(test_id, model, dataset, method_kwargs, request): model = request.getfixturevalue(model) dataset = request.getfixturevalue(dataset) @@ -367,8 +361,6 @@ def test_captum_arnoldi_self_influence(test_id, model, dataset, method_kwargs, r explanations = captum_arnoldi_self_influence( model=model, - model_id="test_id", - cache_dir=str(tmp_path), train_dataset=dataset, device="cpu", checkpoint="tests/assets/mnist", @@ -396,7 +388,7 @@ def test_captum_arnoldi_self_influence(test_id, model, dataset, method_kwargs, r ), ], ) -def test_captum_tracincp(test_id, model, dataset, test_tensor, checkpoints, method_kwargs, request, tmp_path): +def test_captum_tracincp(test_id, model, dataset, test_tensor, checkpoints, method_kwargs, request): model = request.getfixturevalue(model) dataset = request.getfixturevalue(dataset) test_tensor = request.getfixturevalue(test_tensor) @@ -413,8 +405,6 @@ def test_captum_tracincp(test_id, model, dataset, test_tensor, checkpoints, meth explainer = CaptumTracInCP( model=model, - model_id="test_id", - cache_dir=str(tmp_path), train_dataset=dataset, checkpoints=checkpoints, checkpoints_load_func=get_load_state_dict_func("cpu"), @@ -466,8 +456,6 @@ def test_captum_tracincp_explain_functional( explanations_simple = captum_tracincp_explain( model=model, - model_id="test_id", - cache_dir=str(tmp_path), train_dataset=dataset, checkpoints=checkpoints, test_tensor=test_tensor, @@ -519,7 +507,7 @@ def test_captum_tracincp_explain_functional( ), ], ) -def test_captum_tracincp_self_influence(test_id, model, dataset, checkpoints, method_kwargs, request, tmp_path): +def test_captum_tracincp_self_influence(test_id, model, dataset, checkpoints, method_kwargs, request): model = request.getfixturevalue(model) dataset = request.getfixturevalue(dataset) checkpoints = request.getfixturevalue(checkpoints) @@ -535,8 +523,6 @@ def test_captum_tracincp_self_influence(test_id, model, dataset, checkpoints, me explanations = captum_tracincp_self_influence( model=model, - model_id="test_id", - cache_dir=str(tmp_path), train_dataset=dataset, checkpoints=checkpoints, checkpoints_load_func=get_load_state_dict_func("cpu"),