diff --git a/src/metrics/randomization/model_randomization.py b/src/metrics/randomization/model_randomization.py index ad4fb6c0..dca5bb3b 100644 --- a/src/metrics/randomization/model_randomization.py +++ b/src/metrics/randomization/model_randomization.py @@ -4,7 +4,7 @@ import torch from metrics.base import Metric -from utils.common import _get_parent_module_from_name, make_func +from utils.common import _get_parent_module_from_name from utils.explain_wrapper import ExplainFunc from utils.functions.correlations import ( CorrelationFnLiterals, @@ -51,14 +51,8 @@ def __init__( # or do we want genuinely random models for each call of the metric (keeping seed in the constructor) self.generator = torch.Generator(device=device) self.rand_model = self._randomize_model(model) - self.explain_fn = make_func( - func=explain_fn, - func_kwargs=explain_fn_kwargs, - model=self.rand_model, - model_id=model_id, - cache_dir=cache_dir, - train_dataset=train_dataset, - ) + self.explain_fn = explain_fn + self.results = {"rank_correlations": []} if isinstance(correlation_fn, str) and correlation_fn in correlation_functions: @@ -76,7 +70,14 @@ def update( test_data: torch.Tensor, explanations: torch.Tensor, ): - rand_explanations = self.explain_fn(test_tensor=test_data) + rand_explanations = self.explain_fn( + model=self.rand_model, + model_id=self.model_id, + cache_dir=self.cache_dir, + train_dataset=self.train_dataset, + test_tensor=test_data, + **self.explain_fn_kwargs, + ) corrs = self.correlation_measure(explanations, rand_explanations) self.results["rank_correlations"].append(corrs)