diff --git a/src/metrics/unnamed/top_k_overlap.py b/src/metrics/unnamed/top_k_overlap.py index 7c47aabd..95fcb91c 100644 --- a/src/metrics/unnamed/top_k_overlap.py +++ b/src/metrics/unnamed/top_k_overlap.py @@ -23,7 +23,7 @@ def update( **kwargs, ): top_k_indices = torch.topk(explanations, self.top_k).indices - self.all_top_k_examples += top_k_indices + self.all_top_k_examples.append(top_k_indices) def compute(self, *args, **kwargs): return len(set(self.all_top_k_examples))