-
Notifications
You must be signed in to change notification settings - Fork 75
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Batched metrics #351
Batched metrics #351
Conversation
quantus/helpers/utils.py
Outdated
@@ -1015,6 +1034,8 @@ def calculate_auc(values: np.array, dx: int = 1): | |||
np.ndarray | |||
Definite integral of values. | |||
""" | |||
if batched: | |||
return np.trapz(values, dx=dx, axis=1) | |||
return np.trapz(np.array(values), dx=dx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe, this could be simplified to smth like:
axis = 1 if batched else None
return np.trapz(np.asarray(values), dx=dx, axis=axis)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simplified in the latest commit.
# Indices | ||
indices = params["indices"] | ||
|
||
if isinstance(expected, (int, float)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The expected
value is provided by the pytest.mark.parametrize
, and its type is known beforehand. Why do we need this check?
@@ -30,6 +30,11 @@ def input_zeros_2d_3ch_flattened(): | |||
return np.zeros(shape=(3, 224, 224)).flatten() | |||
|
|||
|
|||
@pytest.fixture |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this fixture used only in one place?
If that's the case, please inline it.
x_batch_shape = x_batch.shape | ||
for perturbation_step_index in range(n_perturbations): | ||
# Perturb input by indices of attributions. | ||
a_ix = a_indices[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a_ix is an array with shape (batch_size, n_features*n_perturbations)
, right?
I'd suggest we create a view with shape (batch_size, n_features, n_perturbations)
.
Then we can index each step with [...,perturbation_step_index]
instead of manually calculating offsets into the array
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a_indices
is an array with shape (batch_size, n_features)
and the resulting a_ix
s are of shape (batch_size, self.features_in_step)
. I believe calculating the offsets manually here is the only option.
@@ -118,6 +118,58 @@ def baseline_replacement_by_indices( | |||
return arr_perturbed | |||
|
|||
|
|||
def batch_baseline_replacement_by_indices( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import numpy.typing as npt
def batch_baseline_replacement_by_indices(
arr: np.ndarray,
indices: np.ndarray,
perturb_baseline: npt.ArrayLike,
**kwargs,
) -> np.ndarray:
|
||
# Predict on input. | ||
x_input = model.shape_input( | ||
x_batch, x_batch.shape, channel_first=True, batched=True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
afaik channel_first
is a models parameter, so we should not hardcode it. @annahedstroem could you please help us on that one 🙃
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was hardcoded in the original implementation as well. Is that a bug?
# Randomly mask by subset size. | ||
a_ix = np.stack( | ||
[ | ||
np.random.choice(n_features, self.subset_size, replace=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we mb add fixed PRNG seed for reproducibility?
pred_deltas = np.stack(pred_deltas, axis=1) | ||
att_sums = np.stack(att_sums, axis=1) | ||
|
||
similarity = self.similarity_func(a=att_sums, b=pred_deltas, batched=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't batch_baseline_replacement_by_indices
always batched? Why do we need batched=True
argument?
@annahedstroem have you ever used a different similarity_func
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, the batched=True
argument goes into a similarity function (for example correlation_pearson
), not the batch_baseline_replacement_by_indices
. Similarity functions can be batched and not batched (at the moment at least) so this argument is needed here.
return_shape=( | ||
batch_size, | ||
n_features, | ||
), # TODO. Double-check this over using = (1,). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this TODO
would need a bit more detail
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a relic of the past implementation, accidentally left it in the new one as well. I have deleted the TODO in the latest commit.
if batched: | ||
assert len(a.shape) == 2 and len(b.shape) == 2, "Batched arrays must be 2D" | ||
# No support for axis currently, so just iterating over the batch | ||
return np.array([scipy.stats.kendalltau(a_i, b_i)[0] for a_i, b_i in zip(a, b)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mb we could use np.vectorize
(https://numpy.org/doc/stable/reference/generated/numpy.vectorize.html) for this one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could be used but I also like the simplicity of @davor10105's suggestion!
Codecov ReportAttention: Patch coverage is
❗ Your organization needs to install the Codecov GitHub app to enable full functionality. Additional details and impacted files@@ Coverage Diff @@
## main #351 +/- ##
==========================================
+ Coverage 91.15% 91.29% +0.13%
==========================================
Files 66 66
Lines 3925 4010 +85
==========================================
+ Hits 3578 3661 +83
- Misses 347 349 +2 ☔ View full report in Codecov by Sentry. |
quantus/functions/similarity_func.py
Outdated
@@ -14,7 +14,9 @@ | |||
import skimage | |||
|
|||
|
|||
def correlation_spearman(a: np.array, b: np.array, **kwargs) -> float: | |||
def correlation_spearman( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
super!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really great work @davor10105, looking forward to our chat.
@@ -139,7 +139,7 @@ def __init__( | |||
|
|||
# Save metric-specific attributes. | |||
if perturb_func is None: | |||
perturb_func = baseline_replacement_by_indices |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's discuss: where in the code should we make it explicit for the user that they no longer can use any other perturb function that batch_baseline_replacement_by_indices
.
Do we know why most of the python checks are failing? Thanks |
It seems that the installed versions of |
@leanderweber Hey Leander, I have been working on implementing batched versions of all the metrics present in Quantus and have encountered two questions that @annahedstroem said you might be able to answer:
I appreciate your time and look forward to your clarification on these points. |
Hi @davor10105, Regarding your questions above:
I hope I could clarify these points :) |
Hey @leanderweber, I understand your concerns, but I don’t see a significant issue with using a fixed grid. While it’s true that the "highest attribution" patch may not perfectly align with the grid, the concept of a patch is something we've defined arbitrarily. In my view, maintaining a consistent patch structure across the dataset and methods is essential to ensure equal testing conditions for all attribution methods. If we adopt option (1) or (2), each perturbation step could potentially affect a different number of pixels, making some steps less impactful than others. Additionally, if a user specifies a certain number of patches, they won’t know what percentage of the input will ultimately be perturbed. This could lead to significant variations in the final perturbed area across images, complicating result comparison. I do agree with your point about option (3); it wouldn’t make much of a difference. Looking at the paper again, there is mention of a predefined grid being used. Looking forward to your opinion on this! |
@davor10105 I think your solution is sound and would be very happy with your suggested update. @leanderweber let me know if you object! @leanderweber if you any time over today or tomorrow for a general view of this PR, I would be grateful to have your second pair of eyes! otherwise, I'll try to go for a merge tomorrow :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ready for merge.
Just a final question @davor10105, where are the test results for the remaining faithfulness metrics (they are not included in the METRICS_PAIRS
in the testing_utils/batched_tests/batch_implementation_verification.py
)?
I don't find them in the results.pickle
file, let me know!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
N/A
Hi @davor10105 @annahedstroem, sorry for the late reply! Region Perturbation implicitly puts more emphasis on the patches that are removed earlier. I.e., for a faithful attribution method and MORF order, there is this assumption (1) that the first removed patches will lead to the largest change in model output, since measured attribution faithfulness is related to the resulting curve. At the same time, we also assume (2) that more relevance = more change in model output, so if the attribution is faithful, a patch with a larger sum of relevance should lead to a larger change. However, as you stated, there are several drawbacks to the current implementation as well. After thinking about it, the "correct" way to implement this may be to not remove overlapping patches, and instead consider all possible patches. Potentially recomputing attribution sums after each patch removal? Not sure about that last one. Maybe we could evaluate how much variability across patch sizes and datasets is introduced in practice when using a grid? E.g., using MetaQuantus? We can also set up a meeting to discuss this in the discord, if you want. |
Hey @annahedstroem , sorry for the delay, here is the updated speed-up visualization (average speed-up being approximately 25x): Moving onto the actual results, here is the output of the implementation validation procedure:
pixel_flipping, monotonicity, faithfulness_estimate, sensitivity_n and sufficiency are all valid. region_perturbation, selectivity and infidelity were skipped during the validation process. region_perturbation and selectivity were skipped due to the dynamic / fixed grid discussion, and therefore the results are expected to be different when compared to the old implementation. Regarding infidelity, I forgot to mention it on our last call, but I have found a potential issue with the old implementation. Looking at the paper, infidelity is defined as (focusing on the difference to baseline and explanation multiplication, notice that both are 1D arrays): Regarding the remaining metrics, both monotonicity_correlation and faithfulness_correlation produced results similar to the initial validation run, despite increasing the number of samples from 10 to 20. However, I have still not found any reason to believe that something is wrong with the implementation. For the road metric, around 90% of the means are consistent with the previous implementation. However, irof shows some variation, as it does not involve stochasticity, yet the results differ from the old implementation. I verified that identical segmentation masks are being used in both implementations, so the source of the discrepancy is unclear. While the results for the 32 examples are generally close between the two implementations, there are still noticeable differences:
I would appreciate any guidance you have on this. If you have any additional questions, please ask! Finally, @leanderweber , I agree that the arbitrary choice of a particular fixed grid changes the resulting curve, but it does so for every evaluated attribution method. In my opinion, the goal of the metric is not to provide a particular "optimal" score, but a score that is obtained on a level playing-field between different attribution methods so that they can be fairly compared. In other words, a score of a single attribution method does not really matter, what matters is the ranking between the scores of a set of attribution methods. For the scores to be comparable, the metric has to perform consistent steps, and that seems tricky to do if the grid is not fixed. Furthermore, a better attribution map should outperform the worse one, no matter the actual underlying grid. |
Thanks @davor10105 for this elaborate discussion on the remaining results! First, the discrepancy in the Second, is there a possibility to keep both alternatives for patching logic for (Thanks also @leanderweber for all your input on the reasoning/ thought behind the implementation! V appreciated.) Also, thank you for highlighting the bug in the As a final request before merging, can you make a short list of changes, separated by:
So that we can add it to the release notes and thus track back any discrepancies to that? Million thanks again @davor10105, really awesome work! |
@annahedstroem Thank you for the feedback! I agree that giving users the option to choose between different patching procedures is a solid approach, and I'll make sure to incorporate it in an upcoming commit. Here’s the requested list of changes: BatchIntroduced batched processing to the following metrics:
Bug fixes
Misc
|
53f97c2
into
understandable-machine-intelligence-lab:main
Description
Implemented changes
evaluate_instance
method inPixelFlipping
,Monotonicity
,MonotonicityCorrelation
,FaithfulnessCorrelation
andFaithfulnessEstimate
classes and replaced the existingevaluate_batch
methods with their "true" batch implementationbatched
parameter tocorrelation_spearman
,correlation_pearson
andcorrelation_kendall_tau
similarity functions to support batch processing,batched
parameter toget_baseline_dict
to support batched baseline creation, and similarly added the same parameter tocalculate_auc
Implementation validity
np.allclose
check was made andPixelFlipping
,Monotonicity
andFaithfulnessEstimate
were verified as valid.MonotonicityCorrelation
andFaithfulnessCorrelation
did not pass this test, as they include stochastic elements in their calculations. To verify their validity, a two-way t-test was utilized over the 30 runs for each sample of the respective implementations. The resulting p-values can be seen below:batched_tests
directory containsbatch_implementation_verification.py
script which runs the validation tests utilizing a copy of the repo (in thequantus
directory also contained within the zip file) that has both the batched and the old implementation versions. Results of the runs mentioned above can be found in theresults.pickle
. This file is used bytest_visualization.py
to display the box visualization and check the validity of the batch implementation as described above.Minimum acceptance criteria
@annahedstroem