Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch wrapper #5

Merged
merged 33 commits into from
Oct 28, 2020
Merged

PyTorch wrapper #5

merged 33 commits into from
Oct 28, 2020

Conversation

raimis
Copy link
Contributor

@raimis raimis commented Sep 25, 2020

  • Autograd
  • PBC
    • Non-periodic
    • Periodic
  • Devices
    • CPU
    • GPU
  • Integration with TorchANI
    • Model serialization
    • Tests
  • Documentation
    • Installation
    • Usage
    • Docstrings

More discussions in #3

@peastman
Copy link
Member

Looks great! How about the backward pass? That will be essential for MD.

I'm not familiar with how custom PyTorch ops handle types. You've written it so all the arguments are either double or int64. Much more typically they'll be float and int32. Can you make it accept those and process them directly without needing type conversions?

@raimis
Copy link
Contributor Author

raimis commented Sep 28, 2020

I'm working on the backward pass.

Regarding the types: https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html

The TorchScript compiler understands a fixed number of types. Only these types can be used as arguments to your custom operator. Currently these types are: torch::Tensor, torch::Scalar, double, int64_t and std::vector s of these types. Note that only double and not float, and only int64_t and not other integral types such as int, short or long are supported.

@peastman
Copy link
Member

If we make the arguments have type torch::Tensor I believe that should do what we want. You can then cast the tensor if necessary and pull out a pointer to the data. Something like this.

tensor = tensor.type(torch::kFloat32);
float* data = tensor.data_ptr<float>();

This will be important for the CUDA version, since it will give us a pointer to the data on the GPU, which saves having to copy it to the host and then back again.

@raimis
Copy link
Contributor Author

raimis commented Sep 29, 2020

The atomic positions and gradients are as torch::Tensor.

@raimis
Copy link
Contributor Author

raimis commented Sep 30, 2020

I have a working ANISymmetryFunction operation in PyTorch.

Overall all the wrapper has three components:

  • ANISymmetryFunction function, which is exposed in PyTorch.
  • GradANISymmetryFunction class, which implements the autograd interface.
  • CustomANISymmetryFunctions class, which wraps ANISymmetryFunctions class and pass it between the forward and backward passes.

PyTorch API favours the functional programming, which is at odds with the object-oriented ANISymmetryFunctions. At the moment, the ANISymmetryFunctions constructor is called each time PyTorch executes the operation, which isn't optimal.

@raimis
Copy link
Contributor Author

raimis commented Sep 30, 2020

Quick performance benchmark

Molecule: 46 atoms
GPU: GTX 1080 Ti
Execution time is averaged over 10000 consecutive calls

  • TorchANI 2.2 featurizer (pure PyTorch implementation): ~6.5 ms
  • ANISymmetryFunction operation (called via PyTorch): 2 ms

@raimis
Copy link
Contributor Author

raimis commented Sep 30, 2020

I hacked the wrapper to reuse CudaANISymmetryFunctions object. It increases the performance, but breaks the serialization of PyTorch models.

  • ANISymmetryFunction operation (reusing CudaANISymmetryFunctions object): 1.5 ms

@raimis
Copy link
Contributor Author

raimis commented Sep 30, 2020

In addition, I have made a mock up of the wrapper, which doesn't do any calculations.

  • ANISymmetryFunction operation (mock up): 1.3 ms

So out of 2 ms, 1.3 ms is PyTorch overhead, 0.5 ms takes the CudaANISymmetryFunctions constructor, and 0.2 ms is actual calculations.

@peastman
Copy link
Member

peastman commented Oct 6, 2020

The way this is implemented has very high overhead. You construct a new ANISymmetryFunctions object when forward() is called, delete it when backward() is called, and then need to create a new one from scratch on the next time step. We really want a Python ANISymmetryFunctions class that extends torch.nn.Module and does all memory allocation and initialization in its constructor. You should then be able to use it as many times as you want with no further overhead.

That class should also be independent of TorchANI. Having a separate class that provides easy integration with TorchANI is useful, but that necessarily adds overhead and long term I think our goal is to completely replace TorchANI. So we first want a minimum overhead wrapper that directly exposes this code to PyTorch in the simplest way, and TorchANI integration can be built on top of that.

@raimis
Copy link
Contributor Author

raimis commented Oct 7, 2020

The way this is implemented has very high overhead. You construct a new ANISymmetryFunctions object when forward() is called, delete it when backward() is called, and then need to create a new one from scratch on the next time step. We really want a Python ANISymmetryFunctions class that extends torch.nn.Module and does all memory allocation and initialization in its constructor. You should then be able to use it as many times as you want with no further overhead.

  • torch.nn.Module is just an high-level abstraction. The computational graph is implemented and executed in term of PyTorch operations.
  • A PyTorch operation to work with the autograd, the implementation is limited to two functions (for forward and backward pass) with a limited mechanism to pass data (forward -> backward).
  • There is no way to initialised a PyTorch operation in advance or reuse objects. At least, I haven't found. As mentioned PyTorch wrapper #5 (comment), PyTorch favours functional programming. The computational graph is constructed and executed dynamically on-the-fly and after discarded.

That class should also be independent of TorchANI. Having a separate class that provides easy integration with TorchANI is useful, but that necessarily adds overhead and long term I think our goal is to completely replace TorchANI. So we first want a minimum overhead wrapper that directly exposes this code to PyTorch in the simplest way, and TorchANI integration can be built on top of that.

The functionality is already directly exposed in PyTorch via torch.ops.NNPOps.ANISymmetryFunctions (with minimum overhead as far as PyTorch allows) and the TorchANI integration is built on top of that.

@raimis
Copy link
Contributor Author

raimis commented Oct 7, 2020

I have rename the PyTorch module to TorchANISymmetryFunctions to make it clear that is is TorchANI specific. So the rest components are general.

@peastman
Copy link
Member

peastman commented Oct 8, 2020

Could you elaborate? Certainly you can implement new calculations using torch.autograd.Function objects, but that doesn't rule out implementing them with torch.nn.Module objects. Both classes define forward() and backward() methods. The documentation provides examples of writing both in C++. One is a functional API and the other is an object oriented API, depending on which suits your needs better.

@raimis
Copy link
Contributor Author

raimis commented Oct 9, 2020

Both classes define forward() and backward() methods.

Where have you seen torch.nn.Module.backward? It isn't mentioned neither in Python API not C++ API:

@proteneer
Copy link

Modules in pytorch are really just functors that allow you to do perform RAII.

The links you provided are for the base class. It is up to you to implement .backward() and .forward() calls. The backward() call is really just a vector jacobian product. It's identical to the backward() signature of what you'd use to implement Function.backward().

@raimis
Copy link
Contributor Author

raimis commented Oct 9, 2020

@proteneer In the autograd documentation, there is only an example with torch::autograd::Function (https://pytorch.org/tutorials/advanced/cpp_autograd.html#using-custom-autograd-function-in-c). Do you know an equivalent example with torch::nn::Module?

@raimis
Copy link
Contributor Author

raimis commented Oct 9, 2020

End-to-end performance benchmarks of ANI-2x

Molecule: 46 atoms (pytorch/molecules/2iuz_ligand.mol2)
GPU: GTX 1080 Ti

Forward & backward passes with complete ANI-2x:

  • TorchANI with original featurizer: 90 ms
  • TorchANI with our featurizer: 81 ms

Just forward pass with complete ANI-2x:

  • TorchANI with original featurizer: 25 ms
  • TorchANI with our featurizer: 23 ms

Forward & backward passes with ANI-2x using just one set of the atomic NNs, not 8:

  • TorchANI with original featurizer: 11 ms
  • TorchANI with our featurizer: 6.8 ms

Just forward pass with ANI-2x using just one set of the atomic NNs, not 8:

  • TorchANI with original featurizer: 6.3 ms
  • TorchANI with our featurizer: 3.7 ms

@peastman
Copy link
Member

peastman commented Oct 9, 2020

Looks like the neural net part is now the bottleneck. From the benchmarks in #6, doing both forward and backward passes through the features for a system of 60 atoms is only 0.115 ms, and for a system of 2269 atoms is 1.04 ms.

Do you have a sense of what makes the neural net part slow? Can we make it faster from within PyTorch, or do we need a custom kernel for that part too?

Also, in the above numbers, how much of the time is spent constructing and destructing CudaANISymmetryFunction objects, and how much is spent in the kernels?

@raimis
Copy link
Contributor Author

raimis commented Oct 13, 2020

Let's move the discussion about the NN part to #11.

@raimis raimis mentioned this pull request Oct 15, 2020
4 tasks
@raimis raimis marked this pull request as ready for review October 22, 2020 10:15
@raimis
Copy link
Contributor Author

raimis commented Oct 22, 2020

@peastman the first iteration of the PyTorch wrapper of NNPOps is done!

At the moment, it exposes just NNPOps.SymmetryFunctions.TorchANISymmetryFunctions, but it demonstrates how to make a custom PyTorch operations which work the automatic differentiation and model serialisation.

Remaining problems:

  • The wrapper isn't as efficient as it could be.
  • Only one molecule can be computed, i.e. the batched computation isn't supported.
  • Only the 0D or 3D periodic boundary condition are supported.

@peastman
Copy link
Member

Nice!

@proteneer do you have any ideas about how we could make the wrapping more efficient? Is there a way we could create the C++ object just once and use it repeatedly, instead of having to create a new one on every evaluation? This will be even more important for SchNet, since there we'll want to build a neighbor list structure once and use it repeatedly for all the layers within a single evaluation.

@raimis
Copy link
Contributor Author

raimis commented Oct 28, 2020

I have tried a few ideas to make the wrapping more efficient, but nothing better came out, except I found and fixed a bug regarding the device management.

This is ready for merging. Or do you have more comments, @peastman and @proteneer?

@peastman
Copy link
Member

Great! I'll go ahead and merge it. It would still be good to try to restructure it following the pattern in #10 (comment), since I think that will substantially improve performance.

@peastman peastman merged commit 667a282 into openmm:master Oct 28, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants