From e5682f0a55b7f60ca47cca12b310cd0c083d56ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Thu, 13 Jun 2024 19:08:35 +0200 Subject: [PATCH] add test_target to explainer_fn parameters, because most methods need an output neuron to explain --- src/explainers/aggregators/self_influence.py | 12 ++++++++++-- src/utils/explain_wrapper.py | 3 ++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/explainers/aggregators/self_influence.py b/src/explainers/aggregators/self_influence.py index 45146dc5..b0f83a14 100644 --- a/src/explainers/aggregators/self_influence.py +++ b/src/explainers/aggregators/self_influence.py @@ -19,5 +19,13 @@ def get_self_influence_ranking( size = len(training_data) self_inf = torch.zeros((size,)) for i, (x, y) in enumerate(training_data): - self_inf[i] = explain_fn(model, model_id, cache_dir, training_data, i, **explain_fn_kwargs) - return self_inf.argsort() + self_inf[i] = explain_fn( + model=model, + model_id=model_id, + cache_dir=cache_dir, + test_tensor=x[None], + test_label=y[None], + train_dataset=training_data, + train_ids=[i], + **explain_fn_kwargs, + ) diff --git a/src/utils/explain_wrapper.py b/src/utils/explain_wrapper.py index ebe49ec4..43607047 100644 --- a/src/utils/explain_wrapper.py +++ b/src/utils/explain_wrapper.py @@ -26,8 +26,9 @@ def explain( model_id: str, cache_dir: str, method: str, - test_tensor: torch.Tensor, train_dataset: torch.utils.data.Dataset, + test_tensor: torch.Tensor, + test_target: Optional[torch.Tensor] = None, train_ids: Optional[Union[List[int], torch.Tensor]] = None, **kwargs, ) -> torch.Tensor: