-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use torchlens with ESCNN model #18
Comments
Also, if you want to install ESCNN to try running this example yourself, there's a gotcha to be aware of. A dependency called
|
I figured out the infinite loop. It was my fault; there was a typo in the script I was running. I was calling So there's nothing wrong with the deep copy, but I'm still running into the |
Thanks so much for describing this so carefully! Looks like an interesting case I didn’t think of from the armchair, I’ll check it out pronto. |
Okay, I think I found the issue (bear with me): in r2convolution.py, there is the following seemingly innocuous line:
Later on in that file, it calls This is a subtle issue that will require a bit of refactoring to fix. In the meantime, I think the easiest fix would be to remove Apologies for the headache over something so silly--I'll bump this up my priority list to fix more thoroughly, and let me know if the stopgap solution doesn't work. |
Thanks for looking into this so quickly! I'm traveling today, but I'll check if the quick fix works for me as soon as I get the chance. |
…unctional as F` These two expressions are almost identical, but it turns out that the latter is (currently) necessary to visualize models using torchlens. See johnmarktaylor91/torchlens#18. While this is a pretty minor benefit, it's also a pretty minor change. Signed-off-by: Kale Kundert <[email protected]>
Signed-off-by: Kale Kundert <[email protected]>
Sorry it took me a few days to get back to you on this, but I can confirm that the quick-fix you suggested works for me. Thanks again for such a quick response! I'll leave the issue open in case you're planning to do some refactoring to accommodate |
Delighted to hear that it worked :) I haven’t figured out a totally general fix yet, and indeed it looks like it’s not going to be easy, but I’ll brainstorm some more… |
I think this issue is impossible to solve in the totally general case, so closing this issue unless someone has an idea. |
I can imagine a few ways to wrap the pytorch functions regardless of how they're imported. I don't know how torchlens works at all, so I don't know if any of these approaches would really solve the problem, but I figure they're worth mentioning:
|
Wait, you are a genius :] I hadn't thought of any of these options, thanks so much for laying them out and describing them so carefully. For reference, the way Torchlens works is by attaching a rather elaborate decorator to all the functions in the PyTorch namespace such that their results get logged every time they're called. Currently, this decoration is done when TorchLens function (e.g., The question is, which of these options makes for the minimum confusion from the user's perspective and wouldn't slow down performance too much. What about something like this: have the decoration occur when TorchLens is imported, and do this in an "exhaustive" way (potentially involving gc.get_referrers; e.g., get the references, and check for any references outside of the torch namespace), but make the decorator "silent" unless some attribute attached to the function is toggled on or off? This way, the decoration step only has to occur once (on import), it should work no matter the order of the imports, and all the torch functions should behave normally outside of the TorchLens function calls. If torch hasn't been imported yet then all the functions will get decorated so all subsequent imports will get the decorated version, and if torch has already been imported it'll catch all the references. If this sounds reasonable I'll make this a priority in the next TorchLens update. Thanks again for the great suggestions, there's no way I would have thought of these on my own. |
Sorry for the slow reply, I kinda let this issue fall off my radar. The big question (as I see it) is whether or not to do the decoration automatically, or to require the user to call a function that does it. Either way, I imagine that the decorator would work pretty much like you described. Some pros/cons: Do it automatically:
Do it manually:
I'm not sure which approach I prefer. |
I just tried using torchlens on a model built using a library called ESCNN, and I ran into some errors that I'm hoping you can help me with. ESCNN is a relatively niche library for geometric deep learning; think CNNs where the filters are matched in all possible orientations, in addition to all possible locations.
Here's a script that creates a single convolutional layer and tries to visualize it with torchlens:
Just to explain the example a little bit, this is a 2D convolution where any 90° rotation of the filters should be matched. ESCNN requires a more sophisticated concept of "channels" than a normal CNN, and that's what the
in_type
andout_type
variables establish. The input to the convolutional layer is not a normal tensor, but a tensor wrapped in aGeometricTensor
object that also keeps track of the associated "channels".I think it might be helpful to include (and briefly describe) the source code for the
R2Conv.forward()
method:If the model is in training mode, some complicated calculations are performed to get the necessary filter and bias tensors. But if the model is in evaluation mode, as it is in the above example, then these calculations have already happened. Either way, the actual convolution is just done using
torch.nn.functional.conv2d
. There's nothing really fancy going on under the hood.When I run the above script, it gets stuck on the following line:
This line never completes, and keeps allocating memory until my machine runs out (>16 GB). It seems that
deepcopy()
is getting caught in an infinite loop, probably while trying to copyin_type
(which is a relatively complicated object). My first thought was that there might be a reference cycle, but it seems thatdeepcopy()
automatically handles reference cycles, so that's probably not it. The obvious solution would be to somehow modify the ESCNN objects so that they can be deep-copied, but that might not be a trivial change, and it seems to me that torchlens shouldn't require accommodations by libraries such as this if at all possible.I tried side-stepping this problem by just removing the
deepcopy()
call. That results in the following stack trace, which I wasn't able to make any sense of. I don't know if this is just the consequence of removing the deep copy, or indicative of some other problem:I'd really like to be able to use torchlens, so I'd appreciate any help you can offer.
The text was updated successfully, but these errors were encountered: