Skip to content

Commit

Permalink
get rid of functoolss.partial because it is very hard to use just to …
Browse files Browse the repository at this point in the history
…make 1 line smaller in each metric
  • Loading branch information
gumityolcu committed Jun 12, 2024
1 parent 6e63659 commit c5fb56b
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions src/metrics/randomization/model_randomization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch

from metrics.base import Metric
from utils.common import _get_parent_module_from_name, make_func
from utils.common import _get_parent_module_from_name
from utils.explain_wrapper import ExplainFunc
from utils.functions.correlations import (
CorrelationFnLiterals,
Expand Down Expand Up @@ -51,14 +51,8 @@ def __init__(
# or do we want genuinely random models for each call of the metric (keeping seed in the constructor)
self.generator = torch.Generator(device=device)
self.rand_model = self._randomize_model(model)
self.explain_fn = make_func(
func=explain_fn,
func_kwargs=explain_fn_kwargs,
model=self.rand_model,
model_id=model_id,
cache_dir=cache_dir,
train_dataset=train_dataset,
)
self.explain_fn = explain_fn

self.results = {"rank_correlations": []}

if isinstance(correlation_fn, str) and correlation_fn in correlation_functions:
Expand All @@ -76,7 +70,14 @@ def update(
test_data: torch.Tensor,
explanations: torch.Tensor,
):
rand_explanations = self.explain_fn(test_tensor=test_data)
rand_explanations = self.explain_fn(
model=self.rand_model,
model_id=self.model_id,
cache_dir=self.cache_dir,
train_dataset=self.train_dataset,
test_tensor=test_data,
**self.explain_fn_kwargs,
)
corrs = self.correlation_measure(explanations, rand_explanations)
self.results["rank_correlations"].append(corrs)

Expand Down

0 comments on commit c5fb56b

Please sign in to comment.