From c5465b3f8c24aae8cc40ee75d263cd8b43a48175 Mon Sep 17 00:00:00 2001 From: dilyabareeva Date: Mon, 24 Jun 2024 19:09:00 +0200 Subject: [PATCH 1/2] explainer kwargs + small subclass_identification.py changes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Galip Ümit Yolcu --- Makefile | 2 + .../subclass_identification.py | 23 +++++++- src/explainers/functional.py | 15 ++++-- src/explainers/utils.py | 52 +++++++++++++------ src/explainers/wrappers/captum_influence.py | 41 +++++++++------ src/metrics/base.py | 12 +++++ .../randomization/model_randomization.py | 26 +++++++--- src/utils/common.py | 2 +- .../wrappers/test_captum_influence.py | 4 +- tests/metrics/test_randomization_metrics.py | 6 +-- 10 files changed, 134 insertions(+), 49 deletions(-) diff --git a/Makefile b/Makefile index 1688b5f0..dd5c6b4b 100644 --- a/Makefile +++ b/Makefile @@ -7,6 +7,8 @@ style: black . flake8 . --pytest-parametrize-names-type=csv python -m isort . + python -m flake8 . + python -m mypy src --check-untyped-defs rm -f .coverage rm -f .coverage.* find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf diff --git a/src/downstream_tasks/subclass_identification.py b/src/downstream_tasks/subclass_identification.py index fe1cc074..c2d542a2 100644 --- a/src/downstream_tasks/subclass_identification.py +++ b/src/downstream_tasks/subclass_identification.py @@ -21,14 +21,22 @@ def __init__( optimizer: Callable, lr: float, criterion: torch.nn.modules.loss._Loss, + scheduler: Optional[Callable] = None, optimizer_kwargs: Optional[dict] = None, + scheduler_kwargs: Optional[dict] = None, device: str = "cpu", *args, **kwargs, ): self.device = device self.trainer: Optional[BaseTrainer] = Trainer.from_arguments( - model=model, optimizer=optimizer, lr=lr, criterion=criterion, optimizer_kwargs=optimizer_kwargs + model=model, + optimizer=optimizer, + lr=lr, + scheduler=scheduler, + criterion=criterion, + optimizer_kwargs=optimizer_kwargs, + scheduler_kwargs=scheduler_kwargs, ) @classmethod @@ -39,6 +47,17 @@ def from_pl_module(cls, model: torch.nn.Module, pl_module: L.LightningModule, de obj.trainer = Trainer.from_lightning_module(model, pl_module) return obj + @classmethod + def from_trainer(cls, trainer: BaseTrainer, device: str = "cpu", *args, **kwargs): + obj = cls.__new__(cls) + super(SubclassIdentification, obj).__init__() + if isinstance(trainer, BaseTrainer): + obj.trainer = trainer + obj.device = device + else: + raise ValueError("trainer must be an instance of BaseTrainer") + return obj + def evaluate( self, train_dataset: torch.utils.data.Dataset, @@ -106,8 +125,8 @@ def evaluate( cache_dir=os.path.join(cache_dir, run_id), train_dataset=train_dataset, test_tensor=input, - init_kwargs=explain_kwargs, device=device, + **explain_kwargs, ) metric.update(labels, explanations) diff --git a/src/explainers/functional.py b/src/explainers/functional.py index 5b2cb0ae..e00569de 100644 --- a/src/explainers/functional.py +++ b/src/explainers/functional.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Protocol, Union +from typing import Any, List, Optional, Protocol, Union import torch @@ -13,7 +13,16 @@ def __call__( train_dataset: torch.utils.data.Dataset, device: Union[str, torch.device], explanation_targets: Optional[Union[List[int], torch.Tensor]] = None, - init_kwargs: Optional[Dict] = None, - explain_kwargs: Optional[Dict] = None, + **kwargs: Any, + ) -> torch.Tensor: + pass + + +class ExplainFuncMini(Protocol): + def __call__( + self, + test_tensor: torch.Tensor, + explanation_targets: Optional[Union[List[int], torch.Tensor]] = None, + **kwargs: Any, ) -> torch.Tensor: pass diff --git a/src/explainers/utils.py b/src/explainers/utils.py index e9a6b3a8..a6cdd867 100644 --- a/src/explainers/utils.py +++ b/src/explainers/utils.py @@ -1,8 +1,24 @@ -from typing import Dict, List, Optional, Union +from inspect import signature +from typing import Any, List, Optional, Union import torch +def _init_explainer(explainer_cls, model, model_id, cache_dir, train_dataset, device, **kwargs): + # Python get explainer_cls expected init keyword arguments + exp_init_kwargs = signature(explainer_cls.__init__) + init_kwargs = {k: v for k, v in kwargs.items() if k in exp_init_kwargs.parameters} + explainer = explainer_cls( + model=model, + model_id=model_id, + cache_dir=cache_dir, + train_dataset=train_dataset, + device=device, + **init_kwargs, + ) + return explainer + + def explain_fn_from_explainer( explainer_cls: type, model: torch.nn.Module, @@ -12,20 +28,22 @@ def explain_fn_from_explainer( train_dataset: torch.utils.data.Dataset, device: Union[str, torch.device], targets: Optional[Union[List[int], torch.Tensor]] = None, - init_kwargs: Optional[Dict] = None, - explain_kwargs: Optional[Dict] = None, + **kwargs: Any, ) -> torch.Tensor: - init_kwargs = init_kwargs or {} - explain_kwargs = explain_kwargs or {} - - explainer = explainer_cls( + explainer = _init_explainer( + explainer_cls=explainer_cls, model=model, model_id=model_id, cache_dir=cache_dir, train_dataset=train_dataset, device=device, - **init_kwargs, + **kwargs, ) + + # Python get explainer_cls expected explain keyword arguments + exp_explain_kwargs = signature(explainer.explain) + explain_kwargs = {k: v for k, v in kwargs.items() if k in exp_explain_kwargs.parameters} + return explainer.explain(test=test_tensor, targets=targets, **explain_kwargs) @@ -37,18 +55,20 @@ def self_influence_fn_from_explainer( train_dataset: torch.utils.data.Dataset, device: Union[str, torch.device], batch_size: Optional[int] = 32, - init_kwargs: Optional[Dict] = None, - explain_kwargs: Optional[Dict] = None, + **kwargs: Any, ) -> torch.Tensor: - init_kwargs = init_kwargs or {} - explain_kwargs = explain_kwargs or {} - - explainer = explainer_cls( + explainer = _init_explainer( + explainer_cls=explainer_cls, model=model, model_id=model_id, cache_dir=cache_dir, train_dataset=train_dataset, device=device, - **init_kwargs, + **kwargs, ) - return explainer.self_influence(batch_size=batch_size, **explain_kwargs) + + # Python get explainer_cls expected explain keyword arguments + exp_si_kwargs = signature(explainer.self_influence) + si_kwargs = {k: v for k, v in kwargs.items() if k in exp_si_kwargs.parameters} + + return explainer.self_influence(batch_size=batch_size, **si_kwargs) diff --git a/src/explainers/wrappers/captum_influence.py b/src/explainers/wrappers/captum_influence.py index 58b41e0b..d0dbb0a6 100644 --- a/src/explainers/wrappers/captum_influence.py +++ b/src/explainers/wrappers/captum_influence.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, List, Optional, Union import torch from captum.influence import SimilarityInfluence # type: ignore @@ -8,6 +8,7 @@ explain_fn_from_explainer, self_influence_fn_from_explainer, ) +from src.utils.functions.similarities import cosine_similarity from src.utils.validation import validate_1d_tensor_or_int_list @@ -64,18 +65,31 @@ def explain( class CaptumSimilarity(CaptumInfluence): + # TODO: incorporate SimilarityInfluence kwargs into init_kwargs + """ + init_kwargs = signature(SimilarityInfluence.__init__).parameters.items() + init_kwargs.append("replace_nan") + explain_kwargs = signature(SimilarityInfluence.influence) + si_kwargs = signature(SimilarityInfluence.selinfluence) + """ + def __init__( self, model: torch.nn.Module, model_id: str, cache_dir: str, train_dataset: torch.utils.data.Dataset, - device: Union[str, torch.device], + layers: Union[str, List[str]], + similarity_metric: Callable = cosine_similarity, + similarity_direction: str = "max", + batch_size: int = 1, + replace_nan: bool = False, + device: Union[str, torch.device] = "cpu", **explainer_kwargs: Any, ): # extract and validate layer from kwargs self._layer: Optional[Union[List[str], str]] = None - self.layer = explainer_kwargs.get("layers", []) + self.layer = layers # TODO: validate SimilarityInfluence kwargs explainer_kwargs.update( @@ -84,7 +98,11 @@ def __init__( "influence_src_dataset": train_dataset, "activation_dir": cache_dir, "model_id": model_id, - "similarity_direction": "max", + "layers": self.layer, + "similarity_direction": similarity_direction, + "similarity_metric": similarity_metric, + "batch_size": batch_size, + "replace_nan": replace_nan, **explainer_kwargs, } ) @@ -135,12 +153,8 @@ def captum_similarity_explain( train_dataset: torch.utils.data.Dataset, device: Union[str, torch.device], explanation_targets: Optional[Union[List[int], torch.Tensor]] = None, - init_kwargs: Optional[Dict] = None, - explain_kwargs: Optional[Dict] = None, + **kwargs: Any, ) -> torch.Tensor: - init_kwargs = init_kwargs or {} - explain_kwargs = explain_kwargs or {} - return explain_fn_from_explainer( explainer_cls=CaptumSimilarity, model=model, @@ -150,8 +164,7 @@ def captum_similarity_explain( targets=explanation_targets, train_dataset=train_dataset, device=device, - init_kwargs=init_kwargs, - explain_kwargs=explain_kwargs, + **kwargs, ) @@ -160,11 +173,9 @@ def captum_similarity_self_influence( model_id: str, cache_dir: Optional[str], train_dataset: torch.utils.data.Dataset, - init_kwargs: Dict, device: Union[str, torch.device], + **kwargs: Any, ) -> torch.Tensor: - init_kwargs = init_kwargs or {} - return self_influence_fn_from_explainer( explainer_cls=CaptumSimilarity, model=model, @@ -172,5 +183,5 @@ def captum_similarity_self_influence( cache_dir=cache_dir, train_dataset=train_dataset, device=device, - init_kwargs=init_kwargs, + **kwargs, ) diff --git a/src/metrics/base.py b/src/metrics/base.py index 8e1f5b1d..837857a7 100644 --- a/src/metrics/base.py +++ b/src/metrics/base.py @@ -29,6 +29,18 @@ def update( raise NotImplementedError + def explain_update( + self, + *args, + **kwargs, + ): + """ + Used to update the metric with new data. + """ + if hasattr(self, "explain_fn"): + raise NotImplementedError + raise RuntimeError("Explain function not found in explainer.") + @abstractmethod def compute(self, *args, **kwargs): """ diff --git a/src/metrics/randomization/model_randomization.py b/src/metrics/randomization/model_randomization.py index 49ab460f..918e974c 100644 --- a/src/metrics/randomization/model_randomization.py +++ b/src/metrics/randomization/model_randomization.py @@ -5,7 +5,7 @@ from src.explainers.functional import ExplainFunc from src.metrics.base import Metric -from src.utils.common import _get_parent_module_from_name, make_func +from src.utils.common import get_parent_module_from_name, make_func from src.utils.functions.correlations import ( CorrelationFnLiterals, correlation_functions, @@ -22,7 +22,6 @@ def __init__( model: torch.nn.Module, train_dataset: torch.utils.data.Dataset, explain_fn: ExplainFunc, - explain_init_kwargs: Optional[dict] = None, explain_fn_kwargs: Optional[dict] = None, correlation_fn: Union[Callable, CorrelationFnLiterals] = "spearman", seed: int = 42, @@ -40,7 +39,6 @@ def __init__( self.model = model self.train_dataset = train_dataset self.explain_fn_kwargs = explain_fn_kwargs or {} - self.explain_init_kwargs = explain_init_kwargs or {} self.seed = seed self.model_id = model_id self.cache_dir = cache_dir @@ -57,11 +55,10 @@ def __init__( self.explain_fn = make_func( func=explain_fn, - init_kwargs=explain_init_kwargs, - explain_kwargs=explain_fn_kwargs, model_id=self.model_id, cache_dir=self.cache_dir, train_dataset=self.train_dataset, + **self.explain_fn_kwargs, ) self.results: Dict[str, List] = {"scores": []} @@ -81,7 +78,7 @@ def update( self, test_data: torch.Tensor, explanations: torch.Tensor, - explanation_targets: torch.Tensor, + explanation_targets: Optional[torch.Tensor] = None, ): rand_explanations = self.explain_fn( model=self.rand_model, test_tensor=test_data, explanation_targets=explanation_targets, device=self.device @@ -89,6 +86,21 @@ def update( corrs = self.corr_measure(explanations, rand_explanations) self.results["scores"].append(corrs) + def explain_update( + self, + test_data: torch.Tensor, + explanation_targets: Optional[torch.Tensor] = None, + ): + # TODO: add a test + explanations = self.explain_fn( + model=self.model, test_tensor=test_data, explanation_targets=explanation_targets, device=self.device + ) + rand_explanations = self.explain_fn( + model=self.rand_model, test_tensor=test_data, explanation_targets=explanation_targets, device=self.device + ) + corrs = self.corr_measure(explanations, rand_explanations) + self.results["scores"].append(corrs) + def compute(self): return torch.cat(self.results["scores"]).mean() @@ -121,6 +133,6 @@ def _randomize_model(self, model: torch.nn.Module): rand_model = copy.deepcopy(model) for name, param in list(rand_model.named_parameters()): random_param_tensor = torch.empty_like(param).normal_(generator=self.generator) - parent = _get_parent_module_from_name(rand_model, name) + parent = get_parent_module_from_name(rand_model, name) parent.__setattr__(name.split(".")[-1], torch.nn.Parameter(random_param_tensor)) return rand_model diff --git a/src/utils/common.py b/src/utils/common.py index b07849ac..05931099 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -11,7 +11,7 @@ def _get_module_from_name(model: torch.nn.Module, layer_name: str) -> Any: return reduce(getattr, layer_name.split("."), model) -def _get_parent_module_from_name(model: torch.nn.Module, layer_name: str) -> Any: +def get_parent_module_from_name(model: torch.nn.Module, layer_name: str) -> Any: return reduce(getattr, layer_name.split(".")[:-1], model) diff --git a/tests/explainers/wrappers/test_captum_influence.py b/tests/explainers/wrappers/test_captum_influence.py index 9a7167b0..1acfb922 100644 --- a/tests/explainers/wrappers/test_captum_influence.py +++ b/tests/explainers/wrappers/test_captum_influence.py @@ -43,8 +43,8 @@ def test_self_influence(test_id, init_kwargs, tmp_path): model_id="0", cache_dir=str(tmp_path), train_dataset=rand_dataset, - init_kwargs=init_kwargs, device="cpu", + **init_kwargs, ).argsort() # TODO: ...this is test 2, unless we want to compare that the outputs are the same. @@ -136,7 +136,7 @@ def test_explain_functional(test_id, model, dataset, test_tensor, test_labels, m test_tensor=test_tensor, explanation_targets=test_labels, train_dataset=dataset, - init_kwargs=method_kwargs, device="cpu", + **method_kwargs, ) assert torch.allclose(explanations, explanations_exp), "Training data attributions are not as expected" diff --git a/tests/metrics/test_randomization_metrics.py b/tests/metrics/test_randomization_metrics.py index 7feb2f8a..0457b421 100644 --- a/tests/metrics/test_randomization_metrics.py +++ b/tests/metrics/test_randomization_metrics.py @@ -10,7 +10,7 @@ @pytest.mark.randomization_metrics @pytest.mark.parametrize( - "test_id, model, dataset, test_data, batch_size, explain, explain_init_kwargs, explanations, test_labels", + "test_id, model, dataset, test_data, batch_size, explain, explain_fn_kwargs, explanations, test_labels", [ ( "mnist", @@ -29,7 +29,7 @@ ], ) def test_randomization_metric( - test_id, model, dataset, test_data, batch_size, explain, explain_init_kwargs, explanations, test_labels, request + test_id, model, dataset, test_data, batch_size, explain, explain_fn_kwargs, explanations, test_labels, request ): model = request.getfixturevalue(model) test_data = request.getfixturevalue(test_data) @@ -40,7 +40,7 @@ def test_randomization_metric( model=model, train_dataset=dataset, explain_fn=explain, - explain_init_kwargs=explain_init_kwargs, + explain_fn_kwargs=explain_fn_kwargs, correlation_fn="spearman", seed=42, device="cpu", From 92dc874aca0e3d3451e406e3083049055d345174 Mon Sep 17 00:00:00 2001 From: dilyabareeva Date: Mon, 24 Jun 2024 20:51:28 +0200 Subject: [PATCH 2/2] fix a bug --- src/explainers/wrappers/captum_influence.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/explainers/wrappers/captum_influence.py b/src/explainers/wrappers/captum_influence.py index d0dbb0a6..77de9c3e 100644 --- a/src/explainers/wrappers/captum_influence.py +++ b/src/explainers/wrappers/captum_influence.py @@ -26,7 +26,7 @@ def __init__( train_dataset: torch.utils.data.Dataset, device: Union[str, torch.device], explainer_cls: type, - **explain_kwargs: Any, + explain_kwargs: Any, ): super().__init__( model=model, @@ -109,11 +109,12 @@ def __init__( super().__init__( model=model, + model_id=model_id, cache_dir=cache_dir, train_dataset=train_dataset, device=device, explainer_cls=SimilarityInfluence, - **explainer_kwargs, + explain_kwargs=explainer_kwargs, ) @property