Skip to content

Commit

Permalink
make kwargs **kwargs again
Browse files Browse the repository at this point in the history
  • Loading branch information
dilyabareeva committed Jun 20, 2024
1 parent 0a477f5 commit 950b8b8
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 34 deletions.
2 changes: 1 addition & 1 deletion src/explainers/base_explainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import List, Optional, Union, Any
from typing import Any, List, Optional, Union

import torch

Expand Down
21 changes: 10 additions & 11 deletions src/explainers/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import torch

from src.explainers.base_explainer import BaseExplainer
from src.explainers.wrappers.captum_influence import CaptumSimilarity


Expand Down Expand Up @@ -34,15 +33,15 @@ def explain_fn_from_explainer(
init_kwargs: Optional[Dict] = {},
explain_kwargs: Optional[Dict] = {},
) -> torch.Tensor:
explainer = explainer_cls(
model=model,
model_id=model_id,
cache_dir=cache_dir,
train_dataset=train_dataset,
device=device,
explainer_kwargs=init_kwargs,
)
return explainer.explain(test=test_tensor, **explain_kwargs)
explainer = explainer_cls(
model=model,
model_id=model_id,
cache_dir=cache_dir,
train_dataset=train_dataset,
device=device,
**init_kwargs,
)
return explainer.explain(test=test_tensor, **explain_kwargs)


def explainer_self_influence_interface(
Expand All @@ -60,7 +59,7 @@ def explainer_self_influence_interface(
cache_dir=cache_dir,
train_dataset=train_dataset,
device=device,
explainer_kwargs=init_kwargs,
**init_kwargs,
)
return explainer.self_influence()

Expand Down
37 changes: 21 additions & 16 deletions src/explainers/wrappers/captum_influence.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Union, Any, Dict
from typing import Any, List, Optional, Union

import torch
from captum.influence import SimilarityInfluence
Expand All @@ -21,11 +21,14 @@ def __init__(
train_dataset: torch.utils.data.Dataset,
device: Union[str, torch.device],
explainer_cls: type,
explain_kwargs: Dict[str, Any],
**kwargs,
**explain_kwargs: Any,
):
super().__init__(
model=model, model_id=model_id, cache_dir=cache_dir, train_dataset=train_dataset, device=device,
model=model,
model_id=model_id,
cache_dir=cache_dir,
train_dataset=train_dataset,
device=device,
)
self.explainer_cls = explainer_cls
self.explain_kwargs = explain_kwargs
Expand Down Expand Up @@ -66,29 +69,31 @@ def __init__(
cache_dir: str,
train_dataset: torch.utils.data.Dataset,
device: Union[str, torch.device],
explainer_kwargs: Dict[str, Any],
**explainer_kwargs: Any,
):
# extract and validate layer from kwargs
self._layer: Union[List[str], str] = None
self._layer: Optional[Union[List[str], str]] = None
self.layer = explainer_kwargs.get("layers", [])

explainer_kwargs = {
"module": model,
"influence_src_dataset": train_dataset,
"activation_dir": cache_dir,
"model_id": model_id,
"similarity_direction": "max",
**explainer_kwargs,
}
# TODO: validate SimilarityInfluence kwargs
explainer_kwargs.update(
{
"module": model,
"influence_src_dataset": train_dataset,
"activation_dir": cache_dir,
"model_id": model_id,
"similarity_direction": "max",
**explainer_kwargs,
}
)

super().__init__(
model=model,
model_id=model_id,
cache_dir=cache_dir,
train_dataset=train_dataset,
device=device,
explainer_cls=SimilarityInfluence,
explain_kwargs=explainer_kwargs,
**explainer_kwargs,
)

@property
Expand Down
6 changes: 2 additions & 4 deletions src/utils/validation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import torch

"""
This is a Python module that contains helper functions for validating input arguments.
The plan is to collect them here and then create a universal validation decorator @validate_args
to check all the input arguments against the expected types specified e.g.
as class attributes.
This module contains utility functions for validation. The plan is to
move the validation logic into a validation decorator at a later point.
"""


Expand Down
2 changes: 1 addition & 1 deletion tests/explainers/test_explainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_explain_stateful(test_id, model, dataset, explanations, test_tensor, te
cache_dir=os.path.join("./cache", "test_id"),
train_dataset=dataset,
device="cpu",
explainer_kwargs=method_kwargs,
**method_kwargs,
)
explanations = explainer.explain(test_tensor)
assert torch.allclose(explanations, explanations_exp), "Training data attributions are not as expected"
7 changes: 6 additions & 1 deletion tests/explainers/test_self_influence.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,12 @@ def test_self_influence(test_id, init_kwargs, request):
)

explainer_obj = CaptumSimilarity(
model=model, model_id="1", cache_dir="temp_captum2", train_dataset=rand_dataset, device="cpu", explainer_kwargs=init_kwargs
model=model,
model_id="1",
cache_dir="temp_captum2",
train_dataset=rand_dataset,
device="cpu",
**init_kwargs,
)
self_influence_rank_stateful = explainer_obj.self_influence()

Expand Down

0 comments on commit 950b8b8

Please sign in to comment.