Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support PyTorch Lightning: more carefully avoid problematic attribute accesses #31

Merged
merged 2 commits into from
Oct 16, 2024

Conversation

kalekundert
Copy link
Contributor

This PR makes a small change that allows TorchLens to work on PyTorch Lightning modules. If you're not familiar, Lightning is a framework that tries to abstract-away a lot of the training loop code that you'd otherwise have to write. The problem arises because when a Lightning module is used outside the context of a Lightning training loop, it turns out that attempting to access one particular attribute of the module results in an exception being raised. TorchLens ends up triggering this error (in three different places) as it attempts to monkeypatch the module.

This PR attempts to solve the problem by simply skipping any "inaccessible" attributes that are encountered. Some minor comments about the implementation/tests:

  • Since I found three instances of this attribute access pattern in the code, I factored the skipping logic into its own function.
  • While trying to run the test suite, I found that it had a lot of dependencies that weren't included in any of the requirements files. It also had some dependencies that I couldn't figure out how to install (specifically StyleTSS and Stable Diffusion). I added the dependencies I could figure out to requirements.test.txt, and I changed the way the others are imported so that the affected tests xfail but the rest of the test suite still runs.
  • On this topic, I think that pyproject.toml is a better way to specify extra dependencies than requirements.*.txt. The main advantage is that it's more standardized, so it's easier for third-party tools to interact with. If you want, I'd be happy to make a PR that transitions the project from setup.py to pyproject.toml. While I'm at it, I could also make it so that GitHub automatically runs the test suite on each PR.
  • In my hands, the test_varying_loop_noparam2 test fails both in the main branch and in this branch. I'm assuming that this is not related to the changes I made.
  • I'm submitting this PR even though the test suite is still running (it takes a long time). So far all of the tests but the one mentioned above have passed, and I'm pretty confident that my code works, but I'll comment if any more failures happen as the suite keeps running.
  • I didn't run black because it looked like it would change a lot of code I didn't write.

@johnmarktaylor91
Copy link
Owner

Hi Kale, thanks so much for this pull request, and for your careful development and testing here. All your points seem astute and reasonable:

  • test_varying_loop_noparam2 is a very niche edge case for which the TorchLens "loop-finding" algorithm fails. It's in the unit tests to keep it on my long-term radar, but will take a substantial refactor to fix, so unless it causes mischief in a real-life model I might not get to it right away.
  • Some of the extra dependencies you mention are only needed for the unit tests, so it would be good to separate these from the main requirements.

I suspect all your changes are great (and will happily credit you as a contributor on the main page--thanks so much), but can I circle back on this next week? I am in the middle of a grant deadline so that is eating up all my time at the moment.

Also, quick tangential question in the meantime--since you are one of the first (besides me) to do a substantial pull request for TorchLens, how hard has it been to work with the codebase? I am aware that pretty much everything is implemented in the massive model_history.py code brick (since class definitions can't be spread across multiple files to my knowledge), and it makes sense in my own head, but I am curious if it has been a challenge to navigate and figure out what to change.

@kalekundert
Copy link
Contributor Author

As the test suite continued to run, the 97-106th tests failed. I think these are mostly to do with torchvision. I checked 2 of the failing tests—test_fasterrcnn_mobilenet_train and test_fasterrcnn_mobilenet_eval—on the main branch, and they both failed there as well, so I don't think the problem is specific to my PR. After that, the test suite ran out of memory and got killed on the 116th test, which I think is one of the torchvision video tests. There are maybe ≈30 tests that haven't run yet, although at this point I'm satisfied that my code works.

test_varying_loop_noparam2 is a very niche edge case for which the TorchLens "loop-finding" algorithm fails. It's in the unit tests to keep it on my long-term radar, but will take a substantial refactor to fix...

I took the liberty of pushing another commit to mark this test as xfail. That will let other hypothetical contributors know which tests are supposed to pass, and is also a necessary first step towards setting up automated tests.

Some of the extra dependencies you mention are only needed for the unit tests, so it would be good to separate these from the main requirements.

Yes, this is easy to do.

Can I circle back on this next week?

Of course, take your time!

Also, quick tangential question in the meantime--since you are one of the first (besides me) to do a substantial pull request for TorchLens, how hard has it been to work with the codebase?

This change wasn't too hard, but it also didn't really require understanding much about the code. My thought process was basically, "The code seems to be encountering an error while attempting to 'restore model attributes'. Probably by the time it's restoring things, it already has whatever data it needs, so I can just suppress the error and everything will be fine." So I added the try/catch block, ran the code, got another similar stack trace, and repeated that process until it worked. I was actually a bit uneasy submitting a PR despite how little I understand the code, and that's one of the reasons I paid so much attention to the tests.

It was pretty easy for me to figure out which file to look in for what kind of information, although I guess part of that might just be that most of the information is in model_history.py.

I think my biggest challenge with the code base is that I so far haven't been able to build a mental map for how the code works. I know that everything happens in ModelHistory, and I'm able to read individual methods and more-or-less understand what they're doing, but I don't have a sense for the "intermediate" level of detail. In other words: What entities are there, what pieces of information do they share, what's the overall flowchart of the algorithm that ModelHistory implements? Those are the kind of things I feel like I'm missing.

Granted, I haven't put much effort into trying to understand the code. I'm just saying all of this because I know it can be helpful to get the perspective of someone who isn't as familiar with things.


Because I can't help myself, I hope you don't mind if I give some unsolicited opinions on object-oriented software architecture. Broadly, the ideas I want to put forth are that (i) functions are easier to understand than methods, and therefore (ii) the only reason to use a method is when access to private class data is required. So if you're looking to refactor things (which I'm not saying would be a good use of your time), I suspect that a lot of ModelHistory methods could be converted into functions, and it make the code easier to understand.

The reason I think functions are easier to understand than methods is that they usually have simpler inputs and outputs. Methods always have access to self, which tends to complicate things in two ways. First, self provides access to all of the data associated with the class, even if the method only needs access to one or two attributes. From the perspective of someone reading the code, this makes it harder to tell at a glance what kind of information the method actually uses. Second, methods very commonly make changes to self. This means that you can't just look for the return statements to figure out what a method's output is; you really have to read the whole thing—including any other methods it might call—to see if self ever gets mutated.

Without self, standalone functions are more often passed just the information they need, and are more likely to either (i) avoid mutating their arguments or (ii) have mutating their arguments be central to their whole purpose. Of course it's possible to write functions in the same "style" as methods, but I find that functions are generally much easier to understand.

This brings up the question of when methods should be used. As I wrote above, I think the answer is that they should be used when access to private class data is required. This relates to the idea that the purpose of a class is generally to hide some messy private data behind a clean public API. Only methods can do this. However, it's still true that methods are more complicated than functions, perhaps even more so when comparing methods that are responsible for maintaining the integrity of an object's private data with functions that are just accessing an object's public API. This leads to a few rules of thumb that I use when writing classes:

  • Try to minimize the number of methods any class has.
  • Every method should require access to private data. Methods that don't should be turned into functions.
  • Classes without private data (e.g. data classes) shouldn't have any methods.

Bringing this all back to ModelHistory, I think that it would be easier to see the "skeleton" of the code if more of the "fat" was moved into functions, and a more clear distinction between public and private data was made. For example, maybe ModelHistory should just be responsible for providing read/write access to the graph of tensor operations. In fact, maybe you could even use networkx.Graph for this and eliminate ModelHistory altogether. (Probably not, though.) The tasks of instantiating, validating, and visualizing that graph could be moved to standalone functions. I can imagine that instantiating the graph might require keeping track of some extra state (e.g. what current node are we on), so maybe it would be necessary to have another class for that. Of course, splitting the code into functions like this would also allow it to be split into multiple files.

If you made it this far, thanks for reading. I hope I didn't come across as too critical or arrogant.
I know that I might get annoyed if some rando came along to one of my projects and dumped an essay on me about how I should write my code, so I want to emphasize that I don't mean to pressure you into changing anything, or even responding to what I wrote. I just wanted to share some ideas that have been really helpful for me personally, in the hopes that maybe you'll find them helpful too.

@johnmarktaylor91
Copy link
Owner

johnmarktaylor91 commented Oct 9, 2024

Hi Kale,

Apologies for the headaches with the unit tests, so far I've only been using them internally so any snafus are probably errors on my end. Currently the tests err on the side of being exhaustive (checking as many kinds of architectures as I can think of), so they do end up taking awhile. Thanks for all your great suggestions for how to streamline the tests and dependencies, when I have time I will do some tidying up on this front.

And, thanks for all the feedback on the code organization! A second pair of eyes is super helpful after so long working solo on this package:

I think my biggest challenge with the code base is that I so far haven't been able to build a mental map for how the code works. I know that everything happens in ModelHistory, and I'm able to read individual methods and more-or-less understand what they're doing, but I don't have a sense for the "intermediate" level of detail. In other words: What entities are there, what pieces of information do they share, what's the overall flowchart of the algorithm that ModelHistory implements? Those are the kind of things I feel like I'm missing.

Granted, I haven't put much effort into trying to understand the code. I'm just saying all of this because I know it can be helpful to get the perspective of someone who isn't as familiar with things.

All very good point and it would probably be good to write a readme explaining this stuff. The general "flowchart" is:

  1. User feeds in the model and input
  2. TorchLens goes and sticks a special decorator on all the elementary PyTorch functions that operate on tensors.
  3. Forward pass gets executed, and every time a decorated PyTorch function is called, it logs a whole bunch of info about the function call. This is where the action is. I can see how this would be tricky to mentally track, because the code execution alternates between the actual model code, and the TorchLens logging operations. It also doesn't help that there is simply a huge amount of info getting logged.
  4. After forward pass is finished, there are a bunch of postprocessing operations that stitch together the computational graph of the forward pass, do stuff like find recurrent loops, and convert everything to its final format. The graph is also visualized if the user desires.
  5. Then all the torch functions get undecorated and returned to normal.

Because I can't help myself, I hope you don't mind if I give some unsolicited opinions on object-oriented software architecture.

Please don't apologize, I really appreciate the hot takes :] This was my first serious software project and I've actually been wanting a spot-check on my design choices. I totally agree with you that splitting up the code could do a lot for readability and that it could stand to be a lot more faithful to OOP best practices. I will think about all your specific suggestions.

My one mental roadblock is... what if we actually do want to make a lot of changes to self? For me, one of the goals of TorchLens is to be able to fetch any kind of information you want about the model or about a given layer (i.e., all the fields listed in constants.py) just by indexing a single attribute (from either ModelHistory or from a specific layer log. Once that design choice is made, then that requires all the info to be fields in a single class. And once you've decided that, why not just keep all that info stored in the class from the get-go? I guess one could keep all the same data attributes, but convert the methods to functions in which ModelHistory is passed in and tweaked as needed. If this is what you are proposing then I don't think it would be horribly hard to do when I have some time. This would at least allow for splitting the code across multiple files.

@kalekundert
Copy link
Contributor Author

Currently the tests err on the side of being exhaustive (checking as many kinds of architectures as I can think of), so they do end up taking awhile.

You can get the best of both worlds by using @pytest.mark.slow to mark the tests that take a long time to run. That way, developers can easily run just the fast tests on a regular basis, and the slow tests can be run on a more as-needed basis.

For me, one of the goals of TorchLens is to be able to fetch any kind of information you want about the model or about a given layer (i.e., all the fields listed in constants.py) just by indexing a single attribute (from either ModelHistory or from a specific layer log.

I took some time to try to actually understand the code, so I could give some more informed feedback. I agree with the design goal of having everything available via attributes. Without changing that, I have a number of specific suggestions that might make the code easier to understand and/or modify:

Convert redundant attributes into properties.

Are you familiar with @property? It's a decorator that basically disguises a method as an attribute. In other words, to the user it seems like they're just accessing an attribute, but behind the scenes a method is being invoked.

ModelHistory and TensorLogEntry both have a bunch of attributes that seem to present the same information in different ways. These could be reduced to a single attribute with multiple properties, which would be easier to reason about. Some examples:

  • layer_list, layer_dict_main_keys, layer_dict_all_keys, layer_labels, layer_labels_w_pass, layer_labels_no_pass, layer_num_passes, input_layers, output_layers, buffer_layers, etc. If ModelHistory just held a single collection of all the TensorLogEntry objects it knew about, all of the above attributes could be calculated on the fly (possibly with some filtering based on attributes of the layers themselves).

  • parent_layers, has_parents, orig_ancestors, child_layers, has_children, sibling_layers, has_siblings, etc. All of these TensorLogEntry attributes relate to the entry's location in the ModelHistory graph. The only necessary information is the list of edges. Everything else can be derived from that.

  • tensor_fsize, tensor_fsize_nice: This is one example, but there are several number/pretty string pairs of attributes. They have literally the exact same information; there's no reason to store both.

Of all the suggestions I'm making here, I think this one is the best value for effort. It wouldn't be a big change, but it would simplify these classes a lot.

Group related attributes into objects

ModelHistory and TensorLogEntry have around 100 attributes each. This is enough (even after accounting for the redundancy described above) that it's hard to make a mental picture of what each class knows. I suspect that simply dividing these same attributes into a number of smaller classes, e.g. 10 classes with 10 attributes each, would be easier to understand.

For example, all of the LogTensorEntry attributes relating to the label could be moved into a Label object. Likewise with Tensor, Gradient, Function, Params, and Module objects. That alone would eliminate most of the TensorLogEntry attributes. It'd be much easier to look at that group of attributes and understand that TensorLogEntry is a kind-of amalgamation of a function call and its output tensor.

I'll note that the attributes of both classes are already grouped via comments in their respective constructors. I'm basically just advocating for these same groups to be formalized into real data structures.

Separate user-facing and private attributes

Some of the ModelHistory attributes that don't say anything about what the model did, but are just needed by TorchLens to work behind the scenes. I believe this includes at least the following attributes, and maybe more:

  • track_tensors
  • pause_logging
  • activation_postfunc
  • current_function_call_barcode
  • detach_saved_tensors
  • raw_to_final_layer_labels, final_to_raw_layer_labels
  • lookup_keys_to_tensor_num_dict, tensor_num_to_lookup_keys_dict

I think the code would be easier to understand if there was a more clear division between what's meant to be presented to the user, and what's just for internal use. For instance, one thing I'm still not sure about is whether "addresses" (which seem to be recursive data structures used to extract data from arbitrary python objects) exist for internal or external use.

A simple way to make this distinction would be to prefix internal attributes with an underscore (e.g. _pause_logging). However, as I'll explain in the next section, I think a better approach would be to make a separate object responsible for instantiating ModelHistory, and to have ModelHistory itself be a class that does nothing but make data available to the user.

Make a ModelHistory factory class

Right now, the ModelHistory class is responsible both for listening as decorated PyTorch functions are called and presenting information to the end user. I think that splitting these two responsibilities between two separate classes would be helpful.

ModelHistory would still have a lot of attributes (or better: attributes and properties), but it wouldn't need any complex methods. The whole class would basically be a large but completely passive data structure. Overall, I think the mental model of "this class holds all the information about what happened as the model executed, but doesn't do anything" would be pretty easy to work with.

ModelHistoryFactory (just to make up a name) would be responsible for keeping track of things between calls to decorated PyTorch functions. This class would be more conceptually complicated, because to understand how it works, you'd also have to understand how the decorated functions work. But instead of this complexity being contained in a huge class with lots of other responsibilities, it'd be in its own small unit. Plus, this class would never be returned to the user. So as a developer, you would know that everything it does is necessary for its function, and isn't just to collect data that might possibly be of interest to some fraction of users.

Decouple decorating PyTorch functions from running the model

Right now, the code to decorate all the PyTorch functions runs right before the model itself is executed. But this leads to the problem discussed in #18. Being able to decorate the PyTorch functions in advance would solve that problem, and would also help break the code into smaller and more understandable pieces.

In order for this to work, we need a way for the decorated functions to communicate with the ModelHistoryFactory, even though the latter won't exist when the former are created. This can be done using an intermediate object, which I'll call Broadcaster. The decorated functions would each be given a reference to the broadcaster. Later, when a ModelHistoryFactory was created, it would also get a reference to the same broadcaster. When a decorated function is called, all the relevant details (input argument, return values, etc.) would be provided to the broadcaster by the function, and then the broadcaster would relay those same details to the ModelHistoryFactory.

Because this architecture only allows the decorated functions and the ModelHistoryFactory object to communicate through a relatively simple interface, neither would have to know anything about how the other works under the hood. In fact, this architecture might even make it relatively straight-forward to adapt TorchLens to other ML frameworks, like JAX. All you'd need to do is write a different set of decorators; everything else would be unaffected.

Split TensorLogEntry into finer-grained objects

As I understand it, ModelHistory is fundamentally a graph, where the nodes are TensorLogEntry objects. Each TensorLogEntry represents a function call and its resulting output tensor. However, I don't think this data structure matches the actual structure of the execution graph, for reasons I'll describe below. I think a slightly more fine-grained data structure would match better, and would consequently be more flexible and more intuitive for both users and developers.

The problem with the TensorLogEntry graph is that the actual execution graph has more than one kind of node. There are functions and tensors, which is important because a function can return multiple tensors, so it's not quite right for a single entity to represent both. There are also modules and parameters. Modules should reference the functions they contain. Functions should reference their input tensors, input parameters, and output tensors.

Function, Tensor, Module, and Parameter classes would all be easy to understand: they would simply contain as much information as possible about their respective entity. The links between them would also be intuitive, and would better match the reality of how the model is executed. The graph itself should be a DAG, but the nodes should have enough information (e.g. object ids) to piece together where "cycles" occurred.

On this topic, I have a suspicion that the RolledTensorLogEntry class shouldn't exist. I'll start by saying that I still don't understand what exactly this class does, and I know that it's generally a bad idea to call for getting rid of something without knowing why it's there in the first place. But as far as I can tell, RolledTensorEntry is just TensorLogEntry with less information. What can the former possibly do that the latter can't? I get the impression that rolled entries are used to visualize graphs where certain modules have been "collapsed", but surely the visualization code could just ignore the information it doesn't need?

Move visualization code into stand-alone functions

Because I expounded at length on the benefit of moving code into functions in my last post, I figured I should at least briefly mention one example of that. First, let me admit up front that I haven't read the visualization code yet, so I don't know how it works. But I do know that ModelHistory is supposed to provide a complete accounting of what happened during when the model was evaluated, so the visualization code should not need access to anything beyond that.

Moving the visualization code out of ModelHistory would make that class easier to understand. Right now, because I haven't read the visualization methods, I can't be sure that they don't somehow play a role in the process of building the ModelHistory object. I assume they probably don't, but I can't be sure. However, if the visualization code lived in standalone functions in a separate module, I would not have to worry about it at all.

Miscellaneous ideas

  • I didn't quite understand how the code to keep track of modules works. But it doesn't seem to make use of nn.Module.register_forward_hook(), which is what I was expecting. Anecdotally, I've also seen the module assignments be wrong in some rare cases (but only in really big models; I haven't tried to make a reproducible example). Perhaps it would be better to use hooks?

  • Random string barcodes seem like a suboptimal approach. I get that collisions are unlikely, but if they happen, wouldn't they mess up the whole graph? You can use id() to get a unique identifier for any python object, so I'm skeptical that this is a necessary risk to take.


Having written all of this, I should add the caveat that I still don't understand all of the code, so it's definitely possible that some of these ideas are bad/impractical. Specifically, I only carefully read through the process of running the model, minus the post-processing steps. And even the parts I read seemed to have lots of fiddly details that I didn't necessarily fully appreciate.

@johnmarktaylor91
Copy link
Owner

This is all fantastic feedback, thanks so much for this deep dive. I think many of these ideas are no-brainers and would greatly improve the code. I have a K99 grant to push out the door in the next 48 hours, but will respond in detail after ;)

@johnmarktaylor91
Copy link
Owner

johnmarktaylor91 commented Oct 16, 2024

Grant submitted, finally time to respond properly to your great comments :)

You can get the best of both worlds by using @pytest.mark.slow to mark the tests that take a long time to run. That way, developers can easily run just the fast tests on a regular basis, and the slow tests can be run on a more as-needed basis.

Fantastic idea and a no-brainer. This needs to be done.

Are you familiar with @property? It's a decorator that basically disguises a method as an attribute. In other words, to the user it seems like they're just accessing an attribute, but behind the scenes a method is being invoked.

ModelHistory and TensorLogEntry both have a bunch of attributes that seem to present the same information in different ways. These could be reduced to a single attribute with multiple properties, which would be easier to reason about. Some examples:

Also a great suggestion and no-brainer. This will clean up a lot of the "plumbing" code.

ModelHistory and TensorLogEntry have around 100 attributes each. This is enough (even after accounting for the redundancy described above) that it's hard to make a mental picture of what each class knows. I suspect that simply dividing these same attributes into a number of smaller classes, e.g. 10 classes with 10 attributes each, would be easier to understand.

I have thought about more limited versions of this--at the very least, I think having classes for describing the modules and for describing the parameters would be sensible. One thing to think about is the usability tradeoff of adding extra levels of nesting. It saves some keystrokes to only have one level of nesting when indexing ModelHistory or TensorLogEntry, but on the other hand it really is a big list of attributes and this could be clunky to navigate (and it's easy to overlook attributes that a user might find helpful). Here I am inclined to default to whatever is easier from the user standpoint.

I think the code would be easier to understand if there was a more clear division between what's meant to be presented to the user, and what's just for internal use. For instance, one thing I'm still not sure about is whether "addresses" (which seem to be recursive data structures used to extract data from arbitrary python objects) exist for internal or external use.

100% agreed. This would be an easy fix and help to further tidy up the classes.

Right now, the ModelHistory class is responsible both for listening as decorated PyTorch functions are called and presenting information to the end user. I think that splitting these two responsibilities between two separate classes would be helpful.

Indeed, this could be useful. In fact in my original version of TorchLens I did exactly this, with separate data structures for the object that logs the forward pass and the user-facing object. But maybe it would be cleaner to have separate data structures for these. I think this is very sensible.

Right now, the code to decorate all the PyTorch functions runs right before the model itself is executed. But this leads to the problem discussed in #18. Being able to decorate the PyTorch functions in advance would solve that problem, and would also help break the code into smaller and more understandable pieces.

I agree with this. One thought I had was to have torchlens decorate PyTorch upon being imported, but have there be a "toggle" such that the logging functionality is only executed in the context of log_forward_pass (so that PyTorch behaves normally outside of it). Your idea of the BroadCaster is very nice and sensible and I think would be a good way of handling this.

The problem with the TensorLogEntry graph is that the actual execution graph has more than one kind of node. There are functions and tensors, which is important because a function can return multiple tensors, so it's not quite right for a single entity to represent both. There are also modules and parameters. Modules should reference the functions they contain. Functions should reference their input tensors, input parameters, and output tensors.

It's a good point: there is not a one-to-one mapping between function calls and tensors. I made the choice to "conflate" them for a few reasons:

  1. At least to me, the "intuitive" organizational unit for a DNN is a layer: an operation that takes in an input tensor (or tensors), and returns an output tensor (or tensors). For usability, from the get-go I wanted the ability to easily index and pull out any layer you want, e.g. model_history['relu_1']. I very much wanted something that could be understood in ten seconds, versus the user having to first figure out some ontology of data structures for tensors, layers, modules, etc. To achieve this a choice had to be made about the "default" organizational unit when indexing ModelHistory. I'm not opposed to adding extra abstractions "under the hood" but I'm pretty committed to preserving the "learn it in ten seconds" user experience.
  2. There is almost a one-to-one correspondence between PyTorch function calls and output tensors. It is really a pretty small number of PyTorch functions that return multiple tensors. So I wasn't sure whether it was worth adding an extra abstraction that's only relevant in rare occasions. The solution I arrived at was that for functions that return multiple tensors, there's one node per tensor, with a field called iterable_output_index that specifies where in the output the tensor is from (e.g., its position in a list of tensors).

But, these were subjective judgment calls that I'm not 100% wedded to :]

I do think that making modules their own dedicated data structure could be useful (e.g., a ModuleLog data structure). Currently, TorchLens treats modules as the "boxes" in which operations happen and honestly it's clunky to fetch information about them. I think tidying them into their own data structure would be helpful. Parameters should have their own data structure too.

Because I expounded at length on the benefit of moving code into functions in my last post, I figured I should at least briefly mention one example of that. First, let me admit up front that I haven't read the visualization code yet, so I don't know how it works. But I do know that ModelHistory is supposed to provide a complete accounting of what happened during when the model was evaluated, so the visualization code should not need access to anything beyond that.

Moving the visualization code out of ModelHistory would make that class easier to understand. Right now, because I haven't read the visualization methods, I can't be sure that they don't somehow play a role in the process of building the ModelHistory object. I assume they probably don't, but I can't be sure. However, if the visualization code lived in standalone functions in a separate module, I would not have to worry about it at all.

I agree with this, and to answer your question, the visualization code has no role in building ModelHistory. The one thing is that I do want render_graph to be a method of the ModelHistory object. This method would presumably run a visualization function that's imported from another file. But there would be some plumbing required to avoid circular imports (i.e., the visualization code couldn't import the ModelHistory class for type-hinting, etc.). I think this is an easy call to make for dividing the code into separate files, just some (easily solvable) issues to figure out.

On this topic, I have a suspicion that the RolledTensorLogEntry class shouldn't exist. I'll start by saying that I still don't understand what exactly this class does, and I know that it's generally a bad idea to call for getting rid of something without knowing why it's there in the first place. But as far as I can tell, RolledTensorEntry is just TensorLogEntry with less information. What can the former possibly do that the latter can't? I get the impression that rolled entries are used to visualize graphs where certain modules have been "collapsed", but surely the visualization code could just ignore the information it doesn't need?

It exists solely for visualization purposes. The "rolled" graph option corresponds not to collapsing the modules (that's separate), but to reformatting the graph to show any recurrent loops. The issue is this: the visualization code uses the info from each TensorLogEntry to visualize the graph. When you visualize the rolled up graph, each node now corresponds to all passes of a layer, rather than one node per pass. This requires some different fields (i.e., tracking the input and output nodes for each pass). I wanted to be able to re-use the visualization code for both the rolled and unrolled graphs. The solution I arrived at was to created the RolledTensorLogEntry class, with the visualization code taking either TensorLogEntry or RolledTensorLogEntry. If there's an easier way of handling this I am all ears.

I didn't quite understand how the code to keep track of modules works. But it doesn't seem to make use of nn.Module.register_forward_hook(), which is what I was expecting. Anecdotally, I've also seen the module assignments be wrong in some rare cases (but only in really big models; I haven't tried to make a reproducible example). Perhaps it would be better to use hooks?

I actually did this initially until I ran into one vexing design choice: PyTorch forward hooks can't see any input tensors that were fed in as kwargs to the module's forward function (this is rare, but some models do this). I think the good folks at PyTorch are working on fixing this, but until then, I had to take the approach of making a decorator that gets applied to each module's forward function, since decorators can see the kwargs too.

This shouldn't be the source of the mistaken module assignments (since they're just different ways of triggering code to be executed when the forward function is run). It's probably something to do with the logic of how tensors are tracked as they enter and leave modules (there are some subtleties involved when a tensor is fed into multiple modules, tracking internally generated tensors from torch.ones or torch.rand, etc.). If you happen to remember which models had the wrong module assignments I can check them out and see where the bug arose.

Random string barcodes seem like a suboptimal approach. I get that collisions are unlikely, but if they happen, wouldn't they mess up the whole graph? You can use id() to get a unique identifier for any python object, so I'm skeptical that this is a necessary risk to take.

The barcodes are used for a few different purposes, not all of them for labeling objects. They are used for individuating parameters, though, which is where your suggestion is helpful. I will say that the barcodes (8 chars long, any alphanumeric character) have 200 trillion+ possible values, so the odds of a collision are tiny... but you're right that there's nothing to lose from being kosher here.

@johnmarktaylor91
Copy link
Owner

Also, reviewing the pull request today. It looks great so I think I'll be able to merge it shortly. I will credit you as a contributor on the readme :]

@johnmarktaylor91 johnmarktaylor91 self-requested a review October 16, 2024 16:27
Copy link
Owner

@johnmarktaylor91 johnmarktaylor91 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes look great, thanks so much for doing this.

@johnmarktaylor91 johnmarktaylor91 merged commit 3fd8a31 into johnmarktaylor91:main Oct 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants