Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integration with probabilistic_model #277

Open
tomsch420 opened this issue Oct 8, 2024 · 4 comments
Open

Integration with probabilistic_model #277

tomsch420 opened this issue Oct 8, 2024 · 4 comments
Labels
question Further information is requested

Comments

@tomsch420
Copy link
Collaborator

Well met!
I am a researcher at the Institute for Artificial Intelligence at the University of Bremen. My area of study is tractable probabilistic cognitive robot plans. For this purpose I have written the package probabilistic_model (https://github.com/tomsch420/probabilistic_model) which contains PGMs in networkx and PCs in networkx, jax and torch.
I talked with Antonio about my progress and a potential integration with your framework.
I would be very happy to do so, since we can drastically reduce the amount of duplicated work while combining our knowledge to build a better framework.
My architecture is inspired by the one that Anji Liu presented (https://arxiv.org/pdf/2406.00766) but bets on sparse tensors. Currently, my work indicates that JAX has the best integration for that. While I do not have an extensive benchmark yet, I testet with ~90k parameters and 200k samples. One forward pass took approximately 50ms with JAX in that situation, while the sum layers were extremely sparse.
Furthermore, my Architecture is completely object oriented. From your code I saw a lot of functional design and factories, which Python is not build for. (Also indicated by #276)
I would be happy to discuss the architectures in detail and assist you in building a framework that is fit for all kinds of users.

Greetings,
Tom

@tomsch420 tomsch420 added the question Further information is requested label Oct 8, 2024
@loreloc
Copy link
Member

loreloc commented Oct 22, 2024

Hi Tom,
I took a look at your library and I found it a very cool work!

One interesting integration with cirkit could be rewriting your code such that it converts a tree-shaped PGM into our symbolic circuit representation. I think this can also be very useful to us to understand if we need to refine or extend our symbolic representation.

I also found very nice how you deal with random variables where, e.g., you can name them. Right now, we only have a scope data structure where variables are instead non-negative integers, which I guess it's too low-level as a representation and instead would benefit from your approach.

For the functional vs object-oriented design choice: I think our library benefits from a functional design as in my view it is a language and a compiler, and since it can be quite hard to implement we care about having code that we can check the correctness more easily. For instance, in our case this is done by having immutable objects that are always consistent up to initialization.

Thanks

@tomsch420
Copy link
Collaborator Author

Hey,
thanks for the nice answer!
I just went through some major refactoring and now it is even nicer than before 😎

The PGMs in my package get converted to networkx circuits for inference anyways. The nx circuits can then be compiled to jax circuits for 4000 times speed up in log-likelihood evaluation.
However, I would like to also convert them with cirkit as a backend. Is there some guide on how to build I cirkit model available? Then I think it could be done quickly. Furthermore I can then test the speed of different architectures.

What I got from your torch implementation that you also have a networkx like structure for edges, am I wrong?

Regarding the scope data structure: I like your approach there. For the computational graph it is not really nescessary to know about metadata of a variable, the column index is enough. I actually copied your approach there to this datastructure (https://probabilistic-model.readthedocs.io/en/latest/autoapi/probabilistic_model/probabilistic_circuit/jax/inner_layer/index.html#probabilistic_model.probabilistic_circuit.jax.inner_layer.InputLayer).

Regarding OOP vs Functional: JAX actually follows the same paradigm there. However, since it is not handy for python developers, the community wrote alot of wrappers around that (equinox, flax, etc.). These are not allowing mutation of objects in the classical sence. I strongly recommend giving that a look.

Hope to hear from you soon!

Best,
Tom

@tomsch420
Copy link
Collaborator Author

Merry Christmas,
I have continued to work on my implementation and up to some precision errors I was able to recreate the functionality of RAT-SPN using JAX sparse layers. While some numerical stability errors remain the speed of sparse vs dense is approximately the same, so perhaps this is also a consideration for cirkit?
Furthermore, I found myself replicating functionality that sounds similar to what your package does, however I cannot clearly tell.

I wanted to re-check since you updated the api-doc. Perhaps a good scheme would be to merge these projects all together?
A good architecture could be to have a way to convert the compiled circuits of cirkit to nx circuits as they are impelmented in my package. Then you can focus on learning and I you get reasoning about any event of the product algebra for free if the converter is supported. This would lift the burden if maintaining and describing inference and different representations of circuits, like BNs or LV-Trees from you and the burden of maintaining parameter estimation from me.

I checked the API doc again, however I am still unsure how to interpret the datastructures that you provide.

I bid you to aid me in the process of parsing your compiled circuits and I offer to implement/maintain/test the parser myself.

I am looking forward towards this integration and am very happy to see the growth of the circuit community through this project.

Best,
Tom

@loreloc
Copy link
Member

loreloc commented Jan 14, 2025

Hi Tom, sorry for the late reply.

It would be very nice to see if a JAX backend would simplify and automate the optimizations we currently do in the PyTorch compiler (e.g., folding and einsums optimizations). That's definitely something worth looking into.

I do not personally have bandwidth for implementing a parser of compiled circuits between the two libs. However, I am interested in understanding what makes hard implementing your inference routines in cirkit (e.g., a poor data structure design that we can improve from our side?).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants