-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PyTorch wrapper #5
Conversation
Looks great! How about the backward pass? That will be essential for MD. I'm not familiar with how custom PyTorch ops handle types. You've written it so all the arguments are either double or int64. Much more typically they'll be float and int32. Can you make it accept those and process them directly without needing type conversions? |
I'm working on the backward pass. Regarding the types: https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html
|
If we make the arguments have type tensor = tensor.type(torch::kFloat32);
float* data = tensor.data_ptr<float>(); This will be important for the CUDA version, since it will give us a pointer to the data on the GPU, which saves having to copy it to the host and then back again. |
The atomic positions and gradients are as |
I have a working Overall all the wrapper has three components:
PyTorch API favours the functional programming, which is at odds with the object-oriented |
Quick performance benchmark Molecule: 46 atoms
|
I hacked the wrapper to reuse
|
In addition, I have made a mock up of the wrapper, which doesn't do any calculations.
So out of 2 ms, 1.3 ms is PyTorch overhead, 0.5 ms takes the |
The way this is implemented has very high overhead. You construct a new That class should also be independent of TorchANI. Having a separate class that provides easy integration with TorchANI is useful, but that necessarily adds overhead and long term I think our goal is to completely replace TorchANI. So we first want a minimum overhead wrapper that directly exposes this code to PyTorch in the simplest way, and TorchANI integration can be built on top of that. |
The functionality is already directly exposed in PyTorch via |
I have rename the PyTorch module to |
Could you elaborate? Certainly you can implement new calculations using |
Where have you seen |
Modules in pytorch are really just functors that allow you to do perform RAII. The links you provided are for the base class. It is up to you to implement .backward() and .forward() calls. The backward() call is really just a vector jacobian product. It's identical to the backward() signature of what you'd use to implement Function.backward(). |
@proteneer In the autograd documentation, there is only an example with |
End-to-end performance benchmarks of ANI-2x Molecule: 46 atoms ( Forward & backward passes with complete ANI-2x:
Just forward pass with complete ANI-2x:
Forward & backward passes with ANI-2x using just one set of the atomic NNs, not 8:
Just forward pass with ANI-2x using just one set of the atomic NNs, not 8:
|
Looks like the neural net part is now the bottleneck. From the benchmarks in #6, doing both forward and backward passes through the features for a system of 60 atoms is only 0.115 ms, and for a system of 2269 atoms is 1.04 ms. Do you have a sense of what makes the neural net part slow? Can we make it faster from within PyTorch, or do we need a custom kernel for that part too? Also, in the above numbers, how much of the time is spent constructing and destructing CudaANISymmetryFunction objects, and how much is spent in the kernels? |
Let's move the discussion about the NN part to #11. |
@peastman the first iteration of the PyTorch wrapper of NNPOps is done! At the moment, it exposes just Remaining problems:
|
Nice! @proteneer do you have any ideas about how we could make the wrapping more efficient? Is there a way we could create the C++ object just once and use it repeatedly, instead of having to create a new one on every evaluation? This will be even more important for SchNet, since there we'll want to build a neighbor list structure once and use it repeatedly for all the layers within a single evaluation. |
I have tried a few ideas to make the wrapping more efficient, but nothing better came out, except I found and fixed a bug regarding the device management. This is ready for merging. Or do you have more comments, @peastman and @proteneer? |
Great! I'll go ahead and merge it. It would still be good to try to restructure it following the pattern in #10 (comment), since I think that will substantially improve performance. |
More discussions in #3