Skip to content

Commit

Permalink
add test_target to explainer_fn parameters, because most methods need…
Browse files Browse the repository at this point in the history
… an output neuron to explain
  • Loading branch information
gumityolcu committed Jun 13, 2024
1 parent 2b2ac6e commit e5682f0
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
12 changes: 10 additions & 2 deletions src/explainers/aggregators/self_influence.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,13 @@ def get_self_influence_ranking(
size = len(training_data)
self_inf = torch.zeros((size,))
for i, (x, y) in enumerate(training_data):
self_inf[i] = explain_fn(model, model_id, cache_dir, training_data, i, **explain_fn_kwargs)
return self_inf.argsort()
self_inf[i] = explain_fn(
model=model,
model_id=model_id,
cache_dir=cache_dir,
test_tensor=x[None],
test_label=y[None],
train_dataset=training_data,
train_ids=[i],
**explain_fn_kwargs,
)
3 changes: 2 additions & 1 deletion src/utils/explain_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ def explain(
model_id: str,
cache_dir: str,
method: str,
test_tensor: torch.Tensor,
train_dataset: torch.utils.data.Dataset,
test_tensor: torch.Tensor,
test_target: Optional[torch.Tensor] = None,
train_ids: Optional[Union[List[int], torch.Tensor]] = None,
**kwargs,
) -> torch.Tensor:
Expand Down

0 comments on commit e5682f0

Please sign in to comment.