Skip to content

Commit

Permalink
Merge pull request #54 from dilyabareeva/small_explain_refactoring
Browse files Browse the repository at this point in the history
Merging explainer kwargs + small subset_identification.py changes
  • Loading branch information
gumityolcu authored Jun 24, 2024
2 parents d6e549a + 92dc874 commit 3b484fe
Show file tree
Hide file tree
Showing 10 changed files with 137 additions and 51 deletions.
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ style:
black .
flake8 . --pytest-parametrize-names-type=csv
python -m isort .
python -m flake8 .
python -m mypy src --check-untyped-defs
rm -f .coverage
rm -f .coverage.*
find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf
Expand Down
23 changes: 21 additions & 2 deletions src/downstream_tasks/subclass_identification.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,22 @@ def __init__(
optimizer: Callable,
lr: float,
criterion: torch.nn.modules.loss._Loss,
scheduler: Optional[Callable] = None,
optimizer_kwargs: Optional[dict] = None,
scheduler_kwargs: Optional[dict] = None,
device: str = "cpu",
*args,
**kwargs,
):
self.device = device
self.trainer: Optional[BaseTrainer] = Trainer.from_arguments(
model=model, optimizer=optimizer, lr=lr, criterion=criterion, optimizer_kwargs=optimizer_kwargs
model=model,
optimizer=optimizer,
lr=lr,
scheduler=scheduler,
criterion=criterion,
optimizer_kwargs=optimizer_kwargs,
scheduler_kwargs=scheduler_kwargs,
)

@classmethod
Expand All @@ -39,6 +47,17 @@ def from_pl_module(cls, model: torch.nn.Module, pl_module: L.LightningModule, de
obj.trainer = Trainer.from_lightning_module(model, pl_module)
return obj

@classmethod
def from_trainer(cls, trainer: BaseTrainer, device: str = "cpu", *args, **kwargs):
obj = cls.__new__(cls)
super(SubclassIdentification, obj).__init__()
if isinstance(trainer, BaseTrainer):
obj.trainer = trainer
obj.device = device
else:
raise ValueError("trainer must be an instance of BaseTrainer")
return obj

def evaluate(
self,
train_dataset: torch.utils.data.Dataset,
Expand Down Expand Up @@ -106,8 +125,8 @@ def evaluate(
cache_dir=os.path.join(cache_dir, run_id),
train_dataset=train_dataset,
test_tensor=input,
init_kwargs=explain_kwargs,
device=device,
**explain_kwargs,
)
metric.update(labels, explanations)

Expand Down
15 changes: 12 additions & 3 deletions src/explainers/functional.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Protocol, Union
from typing import Any, List, Optional, Protocol, Union

import torch

Expand All @@ -13,7 +13,16 @@ def __call__(
train_dataset: torch.utils.data.Dataset,
device: Union[str, torch.device],
explanation_targets: Optional[Union[List[int], torch.Tensor]] = None,
init_kwargs: Optional[Dict] = None,
explain_kwargs: Optional[Dict] = None,
**kwargs: Any,
) -> torch.Tensor:
pass


class ExplainFuncMini(Protocol):
def __call__(
self,
test_tensor: torch.Tensor,
explanation_targets: Optional[Union[List[int], torch.Tensor]] = None,
**kwargs: Any,
) -> torch.Tensor:
pass
52 changes: 36 additions & 16 deletions src/explainers/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,24 @@
from typing import Dict, List, Optional, Union
from inspect import signature
from typing import Any, List, Optional, Union

import torch


def _init_explainer(explainer_cls, model, model_id, cache_dir, train_dataset, device, **kwargs):
# Python get explainer_cls expected init keyword arguments
exp_init_kwargs = signature(explainer_cls.__init__)
init_kwargs = {k: v for k, v in kwargs.items() if k in exp_init_kwargs.parameters}
explainer = explainer_cls(
model=model,
model_id=model_id,
cache_dir=cache_dir,
train_dataset=train_dataset,
device=device,
**init_kwargs,
)
return explainer


def explain_fn_from_explainer(
explainer_cls: type,
model: torch.nn.Module,
Expand All @@ -12,20 +28,22 @@ def explain_fn_from_explainer(
train_dataset: torch.utils.data.Dataset,
device: Union[str, torch.device],
targets: Optional[Union[List[int], torch.Tensor]] = None,
init_kwargs: Optional[Dict] = None,
explain_kwargs: Optional[Dict] = None,
**kwargs: Any,
) -> torch.Tensor:
init_kwargs = init_kwargs or {}
explain_kwargs = explain_kwargs or {}

explainer = explainer_cls(
explainer = _init_explainer(
explainer_cls=explainer_cls,
model=model,
model_id=model_id,
cache_dir=cache_dir,
train_dataset=train_dataset,
device=device,
**init_kwargs,
**kwargs,
)

# Python get explainer_cls expected explain keyword arguments
exp_explain_kwargs = signature(explainer.explain)
explain_kwargs = {k: v for k, v in kwargs.items() if k in exp_explain_kwargs.parameters}

return explainer.explain(test=test_tensor, targets=targets, **explain_kwargs)


Expand All @@ -37,18 +55,20 @@ def self_influence_fn_from_explainer(
train_dataset: torch.utils.data.Dataset,
device: Union[str, torch.device],
batch_size: Optional[int] = 32,
init_kwargs: Optional[Dict] = None,
explain_kwargs: Optional[Dict] = None,
**kwargs: Any,
) -> torch.Tensor:
init_kwargs = init_kwargs or {}
explain_kwargs = explain_kwargs or {}

explainer = explainer_cls(
explainer = _init_explainer(
explainer_cls=explainer_cls,
model=model,
model_id=model_id,
cache_dir=cache_dir,
train_dataset=train_dataset,
device=device,
**init_kwargs,
**kwargs,
)
return explainer.self_influence(batch_size=batch_size, **explain_kwargs)

# Python get explainer_cls expected explain keyword arguments
exp_si_kwargs = signature(explainer.self_influence)
si_kwargs = {k: v for k, v in kwargs.items() if k in exp_si_kwargs.parameters}

return explainer.self_influence(batch_size=batch_size, **si_kwargs)
46 changes: 29 additions & 17 deletions src/explainers/wrappers/captum_influence.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, List, Optional, Union

import torch
from captum.influence import SimilarityInfluence # type: ignore
Expand All @@ -8,6 +8,7 @@
explain_fn_from_explainer,
self_influence_fn_from_explainer,
)
from src.utils.functions.similarities import cosine_similarity
from src.utils.validation import validate_1d_tensor_or_int_list


Expand All @@ -25,7 +26,7 @@ def __init__(
train_dataset: torch.utils.data.Dataset,
device: Union[str, torch.device],
explainer_cls: type,
**explain_kwargs: Any,
explain_kwargs: Any,
):
super().__init__(
model=model,
Expand Down Expand Up @@ -64,18 +65,31 @@ def explain(


class CaptumSimilarity(CaptumInfluence):
# TODO: incorporate SimilarityInfluence kwargs into init_kwargs
"""
init_kwargs = signature(SimilarityInfluence.__init__).parameters.items()
init_kwargs.append("replace_nan")
explain_kwargs = signature(SimilarityInfluence.influence)
si_kwargs = signature(SimilarityInfluence.selinfluence)
"""

def __init__(
self,
model: torch.nn.Module,
model_id: str,
cache_dir: str,
train_dataset: torch.utils.data.Dataset,
device: Union[str, torch.device],
layers: Union[str, List[str]],
similarity_metric: Callable = cosine_similarity,
similarity_direction: str = "max",
batch_size: int = 1,
replace_nan: bool = False,
device: Union[str, torch.device] = "cpu",
**explainer_kwargs: Any,
):
# extract and validate layer from kwargs
self._layer: Optional[Union[List[str], str]] = None
self.layer = explainer_kwargs.get("layers", [])
self.layer = layers

# TODO: validate SimilarityInfluence kwargs
explainer_kwargs.update(
Expand All @@ -84,18 +98,23 @@ def __init__(
"influence_src_dataset": train_dataset,
"activation_dir": cache_dir,
"model_id": model_id,
"similarity_direction": "max",
"layers": self.layer,
"similarity_direction": similarity_direction,
"similarity_metric": similarity_metric,
"batch_size": batch_size,
"replace_nan": replace_nan,
**explainer_kwargs,
}
)

super().__init__(
model=model,
model_id=model_id,
cache_dir=cache_dir,
train_dataset=train_dataset,
device=device,
explainer_cls=SimilarityInfluence,
**explainer_kwargs,
explain_kwargs=explainer_kwargs,
)

@property
Expand Down Expand Up @@ -135,12 +154,8 @@ def captum_similarity_explain(
train_dataset: torch.utils.data.Dataset,
device: Union[str, torch.device],
explanation_targets: Optional[Union[List[int], torch.Tensor]] = None,
init_kwargs: Optional[Dict] = None,
explain_kwargs: Optional[Dict] = None,
**kwargs: Any,
) -> torch.Tensor:
init_kwargs = init_kwargs or {}
explain_kwargs = explain_kwargs or {}

return explain_fn_from_explainer(
explainer_cls=CaptumSimilarity,
model=model,
Expand All @@ -150,8 +165,7 @@ def captum_similarity_explain(
targets=explanation_targets,
train_dataset=train_dataset,
device=device,
init_kwargs=init_kwargs,
explain_kwargs=explain_kwargs,
**kwargs,
)


Expand All @@ -160,17 +174,15 @@ def captum_similarity_self_influence(
model_id: str,
cache_dir: Optional[str],
train_dataset: torch.utils.data.Dataset,
init_kwargs: Dict,
device: Union[str, torch.device],
**kwargs: Any,
) -> torch.Tensor:
init_kwargs = init_kwargs or {}

return self_influence_fn_from_explainer(
explainer_cls=CaptumSimilarity,
model=model,
model_id=model_id,
cache_dir=cache_dir,
train_dataset=train_dataset,
device=device,
init_kwargs=init_kwargs,
**kwargs,
)
12 changes: 12 additions & 0 deletions src/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,18 @@ def update(

raise NotImplementedError

def explain_update(
self,
*args,
**kwargs,
):
"""
Used to update the metric with new data.
"""
if hasattr(self, "explain_fn"):
raise NotImplementedError
raise RuntimeError("Explain function not found in explainer.")

@abstractmethod
def compute(self, *args, **kwargs):
"""
Expand Down
26 changes: 19 additions & 7 deletions src/metrics/randomization/model_randomization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from src.explainers.functional import ExplainFunc
from src.metrics.base import Metric
from src.utils.common import _get_parent_module_from_name, make_func
from src.utils.common import get_parent_module_from_name, make_func
from src.utils.functions.correlations import (
CorrelationFnLiterals,
correlation_functions,
Expand All @@ -22,7 +22,6 @@ def __init__(
model: torch.nn.Module,
train_dataset: torch.utils.data.Dataset,
explain_fn: ExplainFunc,
explain_init_kwargs: Optional[dict] = None,
explain_fn_kwargs: Optional[dict] = None,
correlation_fn: Union[Callable, CorrelationFnLiterals] = "spearman",
seed: int = 42,
Expand All @@ -40,7 +39,6 @@ def __init__(
self.model = model
self.train_dataset = train_dataset
self.explain_fn_kwargs = explain_fn_kwargs or {}
self.explain_init_kwargs = explain_init_kwargs or {}
self.seed = seed
self.model_id = model_id
self.cache_dir = cache_dir
Expand All @@ -57,11 +55,10 @@ def __init__(

self.explain_fn = make_func(
func=explain_fn,
init_kwargs=explain_init_kwargs,
explain_kwargs=explain_fn_kwargs,
model_id=self.model_id,
cache_dir=self.cache_dir,
train_dataset=self.train_dataset,
**self.explain_fn_kwargs,
)

self.results: Dict[str, List] = {"scores": []}
Expand All @@ -81,14 +78,29 @@ def update(
self,
test_data: torch.Tensor,
explanations: torch.Tensor,
explanation_targets: torch.Tensor,
explanation_targets: Optional[torch.Tensor] = None,
):
rand_explanations = self.explain_fn(
model=self.rand_model, test_tensor=test_data, explanation_targets=explanation_targets, device=self.device
)
corrs = self.corr_measure(explanations, rand_explanations)
self.results["scores"].append(corrs)

def explain_update(
self,
test_data: torch.Tensor,
explanation_targets: Optional[torch.Tensor] = None,
):
# TODO: add a test
explanations = self.explain_fn(
model=self.model, test_tensor=test_data, explanation_targets=explanation_targets, device=self.device
)
rand_explanations = self.explain_fn(
model=self.rand_model, test_tensor=test_data, explanation_targets=explanation_targets, device=self.device
)
corrs = self.corr_measure(explanations, rand_explanations)
self.results["scores"].append(corrs)

def compute(self):
return torch.cat(self.results["scores"]).mean()

Expand Down Expand Up @@ -121,6 +133,6 @@ def _randomize_model(self, model: torch.nn.Module):
rand_model = copy.deepcopy(model)
for name, param in list(rand_model.named_parameters()):
random_param_tensor = torch.empty_like(param).normal_(generator=self.generator)
parent = _get_parent_module_from_name(rand_model, name)
parent = get_parent_module_from_name(rand_model, name)
parent.__setattr__(name.split(".")[-1], torch.nn.Parameter(random_param_tensor))
return rand_model
Loading

0 comments on commit 3b484fe

Please sign in to comment.