Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use torchlens with ESCNN model #18

Open
kalekundert opened this issue Nov 16, 2023 · 11 comments
Open

Use torchlens with ESCNN model #18

kalekundert opened this issue Nov 16, 2023 · 11 comments

Comments

@kalekundert
Copy link
Contributor

I just tried using torchlens on a model built using a library called ESCNN, and I ran into some errors that I'm hoping you can help me with. ESCNN is a relatively niche library for geometric deep learning; think CNNs where the filters are matched in all possible orientations, in addition to all possible locations.

Here's a script that creates a single convolutional layer and tries to visualize it with torchlens:

import torch
import torchlens as tl

from escnn.nn import FieldType, R2Conv, GeometricTensor
from escnn.gspaces import rot2dOnR2

gspace = rot2dOnR2(4)
so2 = gspace.fibergroup

in_type = FieldType(gspace, [so2.trivial_representation])
out_type = FieldType(gspace, [so2.regular_representation])

conv = R2Conv(in_type, out_type, kernel_size=3)
conv.eval()

x = GeometricTensor(
        tensor=torch.randn(1, 1, 5, 5),
        type=in_type,
)

log = tl.log_forward_pass(conv, [x])
print(log)

Just to explain the example a little bit, this is a 2D convolution where any 90° rotation of the filters should be matched. ESCNN requires a more sophisticated concept of "channels" than a normal CNN, and that's what the in_type and out_type variables establish. The input to the convolutional layer is not a normal tensor, but a tensor wrapped in a GeometricTensor object that also keeps track of the associated "channels".

I think it might be helpful to include (and briefly describe) the source code for the R2Conv.forward() method:

def forward(self, input: GeometricTensor):
        assert input.type == self.in_type
        
        if not self.training:
            _filter = self.filter
            _bias = self.expanded_bias
        else:
            # retrieve the filter and the bias
            _filter, _bias = self.expand_parameters()
        
        if self.padding_mode == 'zeros':
            output = conv2d(input.tensor, _filter,
                            stride=self.stride,
                            padding=self.padding,
                            dilation=self.dilation,
                            groups=self.groups,
                            bias=_bias)
        else:
            output = conv2d(pad(input.tensor, self._reversed_padding_repeated_twice, self.padding_mode),
                            _filter,
                            stride=self.stride,
                            dilation=self.dilation,
                            groups=self.groups,
                            bias=_bias)
        
        return GeometricTensor(output, self.out_type, coords=None)

If the model is in training mode, some complicated calculations are performed to get the necessary filter and bias tensors. But if the model is in evaluation mode, as it is in the above example, then these calculations have already happened. Either way, the actual convolution is just done using torch.nn.functional.conv2d. There's nothing really fancy going on under the hood.


When I run the above script, it gets stuck on the following line:

Traceback (most recent call last):
  File "/home/kale/hacking/bugs/torchlens_escnn/torchlens_escnn.py", line 21, in <module>
    log = tl.log_forward_pass(conv, x)
  File "/home/kale/hacking/forks/torchlens/torchlens/user_funcs.py", line 98, in log_forward_pass
    model_history = run_model_and_save_specified_activations(
  File "/home/kale/hacking/forks/torchlens/torchlens/model_history.py", line 7001, in run_model_and_save_specified_activations
    model_history.run_and_log_inputs_through_model(
  File "/home/kale/hacking/forks/torchlens/torchlens/model_history.py", line 1762, in run_and_log_inputs_through_model
    input_args = [copy.deepcopy(arg) for arg in input_args]

This line never completes, and keeps allocating memory until my machine runs out (>16 GB). It seems that deepcopy() is getting caught in an infinite loop, probably while trying to copy in_type (which is a relatively complicated object). My first thought was that there might be a reference cycle, but it seems that deepcopy() automatically handles reference cycles, so that's probably not it. The obvious solution would be to somehow modify the ESCNN objects so that they can be deep-copied, but that might not be a trivial change, and it seems to me that torchlens shouldn't require accommodations by libraries such as this if at all possible.

I tried side-stepping this problem by just removing the deepcopy() call. That results in the following stack trace, which I wasn't able to make any sense of. I don't know if this is just the consequence of removing the deep copy, or indicative of some other problem:

Traceback (most recent call last):
  File "/home/kale/hacking/bugs/torchlens_escnn/torchlens_escnn.py", line 21, in <module>
    log = tl.log_forward_pass(conv, [x])
  File "/home/kale/hacking/forks/torchlens/torchlens/user_funcs.py", line 98, in log_forward_pass
    model_history = run_model_and_save_specified_activations(
  File "/home/kale/hacking/forks/torchlens/torchlens/model_history.py", line 7001, in run_model_and_save_specified_activations
    model_history.run_and_log_inputs_through_model(
  File "/home/kale/hacking/forks/torchlens/torchlens/model_history.py", line 1818, in run_and_log_inputs_through_model
    raise e
  File "/home/kale/hacking/forks/torchlens/torchlens/model_history.py", line 1802, in run_and_log_inputs_through_model
    self.output_layers.append(t.tl_tensor_label_raw)
AttributeError: 'Tensor' object has no attribute 'tl_tensor_label_raw'

I'd really like to be able to use torchlens, so I'd appreciate any help you can offer.

@kalekundert
Copy link
Contributor Author

kalekundert commented Nov 16, 2023

Also, if you want to install ESCNN to try running this example yourself, there's a gotcha to be aware of. A dependency called lie_learn has to be installed from GitHub instead of PyPI, for modern versions of python. I think the following commands should work, but let me know if they don't:

$ pip install git+https://github.com/AMLab-Amsterdam/lie_learn
$ pip install escnn

@kalekundert
Copy link
Contributor Author

I figured out the infinite loop. It was my fault; there was a typo in the script I was running. I was calling tl.log_forward_pass(conv, x) instead of tl.log_forward_pass(conv, [x]). You can even see this by looking at the stack traces I posted. It turns out that GeometricTensor objects are considered iterable because they implement __getitem__(), but they do so in such a way that the iteration never ends. Sorry for the confusion. I was playing with multiple versions of my test script, and got things mixed up a bit.

So there's nothing wrong with the deep copy, but I'm still running into the 'Tensor' object has no attribute 'tl_tensor_label_raw' stack trace.

@johnmarktaylor91
Copy link
Owner

Thanks so much for describing this so carefully! Looks like an interesting case I didn’t think of from the armchair, I’ll check it out pronto.

@johnmarktaylor91
Copy link
Owner

Okay, I think I found the issue (bear with me): in r2convolution.py, there is the following seemingly innocuous line:

from torch.nn.functional import conv2d, pad

Later on in that file, it calls conv2d and pad under those names, instead of calling them as torch.nn.functional.conv2d. This ends up mattering because TorchLens works by replacing all the functions in the PyTorch namespace with modified versions of themselves, such that they log their results whenever they get called. But, if you just import a function as (e.g.) conv2d, TorchLens isn't able to modify that instance of the function (since it's "dangling", no longer attached to the torch namespace), so it doesn't log its results as it should.

This is a subtle issue that will require a bit of refactoring to fix. In the meantime, I think the easiest fix would be to remove from torch.nn.functional import conv2d, pad and replace it with import torch.nn.functional as f, then later on call f.conv2d and f.pad.

Apologies for the headache over something so silly--I'll bump this up my priority list to fix more thoroughly, and let me know if the stopgap solution doesn't work.

@kalekundert
Copy link
Contributor Author

Thanks for looking into this so quickly! I'm traveling today, but I'll check if the quick fix works for me as soon as I get the chance.

kalekundert added a commit to kalekundert/escnn that referenced this issue Nov 19, 2023
…unctional as F`

These two expressions are almost identical, but it turns out that the
latter is (currently) necessary to visualize models using torchlens.
See johnmarktaylor91/torchlens#18.  While this is a pretty minor
benefit, it's also a pretty minor change.

Signed-off-by: Kale Kundert <[email protected]>
kalekundert added a commit to kalekundert/escnn that referenced this issue Nov 21, 2023
kalekundert added a commit to kalekundert/escnn that referenced this issue Nov 21, 2023
@kalekundert
Copy link
Contributor Author

Sorry it took me a few days to get back to you on this, but I can confirm that the quick-fix you suggested works for me. Thanks again for such a quick response! I'll leave the issue open in case you're planning to do some refactoring to accommodate from torch.nn.functional import ... imports (which seems like it would be really difficult to me), but feel free to close the issue if you're not.

@johnmarktaylor91
Copy link
Owner

Delighted to hear that it worked :) I haven’t figured out a totally general fix yet, and indeed it looks like it’s not going to be easy, but I’ll brainstorm some more…

@johnmarktaylor91
Copy link
Owner

I think this issue is impossible to solve in the totally general case, so closing this issue unless someone has an idea.

@kalekundert
Copy link
Contributor Author

I can imagine a few ways to wrap the pytorch functions regardless of how they're imported. I don't know how torchlens works at all, so I don't know if any of these approaches would really solve the problem, but I figure they're worth mentioning:

  • Provide a function that the user must call before importing any libraries that would import torch. This function would import torch itself and replace the necessary pytorch functions with wrappers that can be further manipulated later. The nice thing about this approach is that it is explicit, but it requires the user to understand the need to call such a function.

  • Setup an import hook to perform the necessary monkeypatching as soon as torch is first imported. See here for some ways to do this. (I think the MetaPathFinder approach is better than the __import__ approach, for what it's worth.) This is more "magical" than having the user call a function, but it just requires torchlens to be imported before torch, and it would even be possible to issue a warning if the user imports the libraries in the wrong order.

  • The nuclear option: Use gc.get_referrers() to get the list of all objects that hold a reference to each pytorch function, and then replace each of those references. This would be very hard to do reliably, because different kinds of references (e.g. dicts, lists, sets, etc.) would have to be replaced differently. But it might not be too hard to cover the most common use cases. There's an out-of-date library called pyjack that tries to do this; it doesn't support python3, but you could look at its source to see how this kind of thing is done more specifically. At the very least, it might be worth using gc.get_referrers() to warn the user if it looks like there are copies of the pytorch functions that can't be wrapped.

@johnmarktaylor91
Copy link
Owner

Wait, you are a genius :] I hadn't thought of any of these options, thanks so much for laying them out and describing them so carefully. For reference, the way Torchlens works is by attaching a rather elaborate decorator to all the functions in the PyTorch namespace such that their results get logged every time they're called. Currently, this decoration is done when TorchLens function (e.g., log_forward_pass) is called, not when TorchLens is imported, so that the functions are in their "clean" original states outside of the TorchLens function calls.

The question is, which of these options makes for the minimum confusion from the user's perspective and wouldn't slow down performance too much. What about something like this: have the decoration occur when TorchLens is imported, and do this in an "exhaustive" way (potentially involving gc.get_referrers; e.g., get the references, and check for any references outside of the torch namespace), but make the decorator "silent" unless some attribute attached to the function is toggled on or off? This way, the decoration step only has to occur once (on import), it should work no matter the order of the imports, and all the torch functions should behave normally outside of the TorchLens function calls. If torch hasn't been imported yet then all the functions will get decorated so all subsequent imports will get the decorated version, and if torch has already been imported it'll catch all the references.

If this sounds reasonable I'll make this a priority in the next TorchLens update. Thanks again for the great suggestions, there's no way I would have thought of these on my own.

@kalekundert
Copy link
Contributor Author

Sorry for the slow reply, I kinda let this issue fall off my radar. The big question (as I see it) is whether or not to do the decoration automatically, or to require the user to call a function that does it. Either way, I imagine that the decorator would work pretty much like you described. Some pros/cons:

Do it automatically:

  • Doesn't require the user to know about any esoteric internal details.
  • I can't think of any case where this wouldn't "just work", although it might be good to warn the user if torch is imported before torchlens.

Do it manually:

  • Monkeypatching third-party packages is definitely a dark art, and has the potential to cause weird issues. Making it more visible to end-users might be worth the extra hassle, just so they have some idea what's happening if any weirdness crops up.
  • It might be possible to display a warning telling the user when they need to do this, which reduces the degree that end-users need to really understand the whole import issue.

I'm not sure which approach I prefer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants