-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature request: callback to control which nodes are expanded/collapsed #24
Comments
Thanks for this great suggestion, and for describing in such detail! I hope TorchLens has been helpful for you. Responses: I think a better API would be to modify ModelHistory.render_graph() to accept a function that will be called for each node, and return a boolean value indicating whether or not that node should be collapsed or expanded. This sounds like just the right interface and I like it a lot. There is just one complication: currently in TorchLens the “first class objects” are tensor operations, not modules, where modules are just the “containers” where the operations happen. So, right now the information about modules is not as easy to fetch as the information about tensor operations. It has been on my to-do list to log the module information more nicely. But, with your comment I’ll bump it to the top of my list. A more aggressive version of this API might be to instead have the callback control all aspects of node formatting, e.g.: This already exists actually, just added a few weeks ago. The CoLab has some examples, but you can pass in a dictionary of functions to override the visualization options according to different aspects of the model (e.g., colorizing based on the runtime or storage of the model). But, this doesn’t include specifying how to collapse nodes as you suggest above, which would be a great addition. Thanks again for the feedback, it helps a lot to know people’s use-cases so I know what to focus on. |
Thanks for such a quick reply!
Yeah, I'm not surprised that there are some practical considerations that I didn't appreciate. Let me know if you'd be open to a PR for this, or if you'd rather do it yourself along with some more extensive refactoring. I'm not guaranteeing that I'll be able to get to a PR any time soon, or at all, but this project has been very helpful to me and I like the idea of giving something back.
That's definitely good to know. It's interesting to me how the current implementation is similar to but different than my "aggressive" proposal. The former is a dictionary with values that may be functions, and the latter is a function that returns a dictionary. Basically the difference is just whether the dictionary or the function is the "outermost" entity. This ship has probably already sailed, but I have to at least comment that the function-returns-dict API seems slightly better to me. I think it would make it easier to calculate multiple related attributes. For example, you might want to control whether the text color is white or black based on the fill color. (Here's some code I wrote to do exactly this in one of my projects. It's a standard algorithm, but I find it rather interesting.) It'd be easiest to write a single function that calculates both colors. With the current API, you'd basically have to calculate the fill color twice. That said, probably the best way to integrate the expand/collapse function into the current API would be to add an "expand" key to the |
That’s a very good point regarding “dict of functions” vs. “function that returns a dict.” You’re totally right about the potential interdependence of options. I just added this feature recently so I don’t think it’s too entrenched to fix. Lemme think about this. And, thanks so much for the offer to do a pull request—so glad to hear TorchLens has helped you out! I think this refactor might make more sense for me to handle, since it’s going to involve restructuring how modules are treated in the code (I’ll probably make a ModuleLog class or some such thing with all the module info). But I’ll bump to the top of the to-do list. |
When using torchlens to visualize big models, I often wish there was an easier way to hide all of the elementary operations for certain layers. Later on I'll propose an API that would allow this, but I want to start by giving a motivating example. I was just working with one of the U-Net models from https://github.com/lucidrains/denoising-diffusion-pytorch, and my goal was to see how the size of the latent representation changes throughout the model.
Here's the visualization that torchlens produces for this model. You can see that it doesn't make it easy to track the size of the latent representation. Most of the complexity comes from the fact that the model contains lots of
ResnetBlock
andLinearAttention
blocks, which each contain a lot of internal complexity. However, neither of these blocks changes the size of its input, so for the purpose of tracking sizes, whatever happens within them is unimportant. If these blocks were each represented as a single node, the whole graph would be much easier to understand.I'm aware of the existing
vis_nesting_depth
option, but I don't think it satisfactorily addresses this issue. First, not all of theResnetBlock
andLinearAttention
block are necessarily at the same depth, so this setting isn't always capable of collapsing only the nodes I want. Second, it's not easy to know what the depth of each block is, especially in a big model with lots of residual connections. To find a good visualization, you basically have to guess-and-check different depth cutoffs (and be careful to check that nothing important was collapsed).I think a better API would be to modify
ModelHistory.render_graph()
to accept a function that will be called for each node, and return a boolean value indicating whether or not that node should be collapsed or expanded. The signature of this function might look something like this:This would allow the user to determine which nodes to collapse based on any property of those nodes. I'm of course interested in doing this based on the name/class of the corresponding layer. But I could imagine this also being useful for expanding only nodes that use lots of memory, or take a long time to evaluate. This API could also replace the existing API, since it also allows for collapsing based on depth (assuming that the nodes have some sort of depth attribute), but it's probably not worth breaking backwards compatibility over.
A more aggressive version of this API might be to instead have the callback control all aspects of node formatting, e.g.:
The dictionary returned by this function would describe how to format the node. Certain keys like
expand
would be specially extracted and interpreted by torchlens. Any others would just be passed along to graphviz (e.g.color
,style
,shape
, etc.). This seems very elegant to me, but it's definitely a bigger change, and there could be practicalities that I'm overlooking that would complicate things.The text was updated successfully, but these errors were encountered: