Skip to content

Commit

Permalink
ExplanationsAggregator and get_self_influence_ranking
Browse files Browse the repository at this point in the history
  • Loading branch information
gumityolcu committed Jun 9, 2024
1 parent 7e34a3d commit 85f3b4b
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 47 deletions.
24 changes: 24 additions & 0 deletions src/utils/aggregators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from abc import ABC

import torch


class ExplanationsAggregator(ABC):
def __init__(self, training_size: int, *args, **kwargs):
self.scores = torch.zeros(training_size)

def update(self, explanations: torch.Tensor):
raise NotImplementedError

def get_global_ranking(self) -> torch.Tensor:
return self.scores.argsort()


class SumAggregator(ExplanationsAggregator):
def update(self, explanations: torch.Tensor) -> torch.Tensor:
self.scores += explanations.sum(dim=0)


class AbsSumAggregator(ExplanationsAggregator):
def update(self, explanations: torch.Tensor) -> torch.Tensor:
self.scores += explanations.abs().sum(dim=0)
21 changes: 20 additions & 1 deletion src/utils/common.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import functools
from functools import reduce
from typing import Any, Callable, Mapping
from typing import Any, Callable, Mapping, Optional

import torch
import torch.utils
import torch.utils.data

from utils.explain_wrapper import SelfInfluenceFunction


def _get_module_from_name(model: torch.nn.Module, layer_name: str) -> Any:
Expand All @@ -22,3 +26,18 @@ def make_func(func: Callable, func_kwargs: Mapping[str, ...] | None, **kwargs) -
func_kwargs = kwargs

return functools.partial(func, **func_kwargs)


def get_self_influence_ranking(
model: torch.nn.Module,
model_id: str,
cache_dir: Optional[str],
training_data: torch.utils.data.Dataset,
self_influence_fn: SelfInfluenceFunction,
self_influence_fn_kwargs: Optional[dict] = None,
) -> torch.Tensor:
size = len(training_data)
self_inf = torch.zeros((size,))
for i, (x, y) in enumerate(training_data):
self_inf[i] = self_influence_fn(model, model_id, cache_dir, training_data, i)
return self_inf.argsort()
12 changes: 12 additions & 0 deletions src/utils/explain_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,18 @@ def __call__(
pass


class SelfInfluenceFunction(Protocol):
def __call__(
self,
model: torch.nn.Module,
model_id: str,
cache_dir: Optional[str],
train_dataset: torch.utils.data.Dataset,
id: int,
) -> torch.Tensor:
pass


def explain(
model: torch.nn.Module,
model_id: str,
Expand Down
Empty file.
11 changes: 0 additions & 11 deletions src/utils/globalization/base.py

This file was deleted.

27 changes: 0 additions & 27 deletions src/utils/globalization/from_explainer.py

This file was deleted.

8 changes: 0 additions & 8 deletions src/utils/globalization/from_explanations.py

This file was deleted.

0 comments on commit 85f3b4b

Please sign in to comment.