Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract the most relevant parameters from each layer #9

Open
mattiacampana opened this issue Nov 20, 2024 · 0 comments
Open

Extract the most relevant parameters from each layer #9

mattiacampana opened this issue Nov 20, 2024 · 0 comments

Comments

@mattiacampana
Copy link

mattiacampana commented Nov 20, 2024

I want to use TorchLRP to extract the most relevant parameters from each layer of a given model.
To do so, I’m working on extending the explain_mnist.py, and I extract the parameters relevance map using the following code snippet:

lrp.trace.enable_and_clean()
y_hat.backward()
layer_wise_relevance = lrp.trace.collect_and_disable()

The goal is to convert layer_wise_relevance into a binary mask to extract the most relevant parameters of each layer. This approach works well for nn.Linear layers since the relevance map produced by lrp.trace.collect_and_disable() matches the shape of the layer's weight tensor.

However, I’m encountering challenges with nn.Conv2D layers. For instance, consider the following layer:

conv2l = torch.nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

Its weight tensor has the shape:

conv2l.weight.shape
# Output: torch.Size([64, 32, 3, 3])

But the corresponding relevance map extracted from lrp.trace.collect_and_disable() has a different shape:

layer_wise_relevance[2][0].shape
# Output: torch.Size([32, 28, 28])

How can I interpret or transform the relevance map for Conv2D layers so that it aligns with the shape of the layer's weight tensor? Any suggestions or resources would be greatly appreciated!

Additionally, I propose extending the trace.collect_and_disable() function to return metadata, such as the name or ID of the layer associated with each relevance map. This enhancement would make it easier to map relevance data to specific model parameters, streamlining the analysis process.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant