Skip to content

Commit

Permalink
Merge pull request #13 from raimis/batchedNN
Browse files Browse the repository at this point in the history
Batched NNs for TorchANI
  • Loading branch information
Raimondas Galvelis authored Oct 1, 2021
2 parents d4190d8 + 1b4acae commit 44e4282
Show file tree
Hide file tree
Showing 8 changed files with 401 additions and 36 deletions.
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ enable_testing()

add_library(${LIBRARY} SHARED src/ani/CpuANISymmetryFunctions.cpp
src/ani/CudaANISymmetryFunctions.cu
src/pytorch/BatchedNN.cpp
src/pytorch/SymmetryFunctions.cpp
src/schnet/CpuCFConv.cpp
src/schnet/CudaCFConv.cu)
Expand All @@ -29,8 +30,10 @@ foreach(TEST_PATH ${TEST_PATHS})
endforeach()

add_test(TestSymmetryFunctions pytest ${CMAKE_SOURCE_DIR}/src/pytorch/TestSymmetryFunctions.py)
add_test(TestBatchedNN pytest ${CMAKE_SOURCE_DIR}/src/pytorch/TestBatchedNN.py)

install(TARGETS ${LIBRARY} DESTINATION ${Python_SITEARCH}/${NAME})
install(FILES src/pytorch/__init__.py
src/pytorch/BatchedNN.py
src/pytorch/SymmetryFunctions.py
DESTINATION ${Python_SITEARCH}/${NAME})
36 changes: 36 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,40 @@ $ make install
- Run the tests
```bash
$ ctest
```

## Usage

Accelerated [*TorchANI*](https://aiqm.github.io/torchani/) operations:
- [`torchani.AEVComputer`](https://aiqm.github.io/torchani/api.html?highlight=speciesaev#torchani.AEVComputer)
- [`torchani.neurochem.NeuralNetwork`](https://aiqm.github.io/torchani/api.html#module-torchani.neurochem)

### Example

```python
import mdtraj
import torch
import torchani

from NNPOps.SymmetryFunctions import TorchANISymmetryFunctions
from NNPOps.BatchedNN import TorchANIBatchedNN

device = torch.device('cuda')

# Load a molecule
molecule = mdtraj.load('molecule.mol2')
species = torch.tensor([[atom.element.atomic_number for atom in molecule.top.atoms]], device=device)
positions = torch.tensor(molecule.xyz * 10, dtype=torch.float32, requires_grad=True, device=device)

# Construct ANI-2x and replace its operations with the optimized ones
nnp = torchani.models.ANI2x(periodic_table_index=True).to(device)
nnp.aev_computer = TorchANISymmetryFunctions(nnp.aev_computer).to(device)
nnp.neural_networks = TorchANIBatchedNN(nnp.species_converter, nnp.neural_networks, species).to(device)

# Compute energy and forces
energy = nnp((species, positions)).energies
energy.backward()
forces = -positions.grad.clone()

print(energy, forces)
```
50 changes: 50 additions & 0 deletions src/pytorch/BatchedNN.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/**
* Copyright (c) 2020 Acellera
* Authors: Raimondas Galvelis
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/

#include <torch/script.h>

using Context = torch::autograd::AutogradContext;
using Tensor = torch::Tensor;
using tensor_list = torch::autograd::tensor_list;

class BatchedLinearFunction : public torch::autograd::Function<BatchedLinearFunction> {
public:
static Tensor forward(Context* ctx, const Tensor& vectors, const Tensor& weights, const Tensor& biases) {
ctx->save_for_backward({weights});
return torch::matmul(weights, vectors) + biases;
};
static tensor_list backward(Context *ctx, const tensor_list& grads) {
const Tensor grad_in = grads[0].squeeze(-1).unsqueeze(-2);
const Tensor weights = ctx->get_saved_variables()[0];
const Tensor grad_out = torch::matmul(grad_in, weights).squeeze(-2).unsqueeze(-1);
return {grad_out, torch::Tensor(), torch::Tensor()};
};
};

static Tensor BatchedLinear(const Tensor& vector, const Tensor& weights, const Tensor& biases) {
return BatchedLinearFunction::apply(vector, weights, biases);
}

TORCH_LIBRARY(NNPOpsBatchedNN, m) {
m.def("BatchedLinear", BatchedLinear);
}
104 changes: 104 additions & 0 deletions src/pytorch/BatchedNN.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
#
# Copyright (c) 2020 Acellera
# Authors: Raimondas Galvelis
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#

import os
import torch
from torch import nn
from torch import Tensor
from torch.nn import functional as F
import torchani
from torchani.nn import ANIModel, Ensemble, SpeciesConverter, SpeciesEnergies
from typing import List, Optional, Tuple, Union

torch.ops.load_library(os.path.join(os.path.dirname(__file__), 'libNNPOpsPyTorch.so'))
batchedLinear = torch.ops.NNPOpsBatchedNN.BatchedLinear


class TorchANIBatchedNN(torch.nn.Module):

def __init__(self, converter: SpeciesConverter, ensemble: Union[ANIModel, Ensemble], atomicNumbers: Tensor):

super().__init__()

# Convert atomic numbers to a list of species
species_list = converter((atomicNumbers, torch.empty(0))).species[0].tolist()

# Handle the case when the ensemble is just one model
ensemble = [ensemble] if type(ensemble) == ANIModel else ensemble

# Convert models to the list of linear layers
models = [list(model.values()) for model in ensemble]

# Extract the weihts and biases of the linear layers
for ilayer in [0, 2, 4, 6]:
layers = [[model[species][ilayer] for species in species_list] for model in models]
weights, biases = self.batchLinearLayers(layers)
self.register_parameter(f'layer{ilayer}_weights', weights)
self.register_parameter(f'layer{ilayer}_biases', biases)

# Disable autograd for the parameters
for parameter in self.parameters():
parameter.requires_grad = False

@staticmethod
def batchLinearLayers(layers: List[List[nn.Linear]]) -> Tuple[nn.Parameter, nn.Parameter]:

num_models = len(layers)
num_atoms = len(layers[0])

# Note: different elements have different size linear layers, so we just find maximum sizes
# and pad with zeros.
max_out = max(layer.out_features for layer in sum(layers, []))
max_in = max(layer.in_features for layer in sum(layers, []))

# Copy weights and biases
weights = torch.zeros((1, num_atoms, num_models, max_out, max_in), dtype=torch.float32)
biases = torch.zeros((1, num_atoms, num_models, max_out, 1), dtype=torch.float32)
for imodel, sublayers in enumerate(layers):
for iatom, layer in enumerate(sublayers):
num_out, num_in = layer.weight.shape
weights[0, iatom, imodel, :num_out, :num_in] = layer.weight
biases [0, iatom, imodel, :num_out, 0] = layer.bias

return nn.Parameter(weights), nn.Parameter(biases)

def forward(self, species_aev: Tuple[Tensor, Tensor]) -> SpeciesEnergies:

species, aev = species_aev

# Reshape: [num_mols, num_atoms, num_features] --> [num_mols, num_atoms, 1, num_features, 1]
vectors = aev.unsqueeze(-2).unsqueeze(-1)

vectors = batchedLinear(vectors, self.layer0_weights, self.layer0_biases) # Linear 0
vectors = F.celu(vectors, alpha=0.1) # CELU 1
vectors = batchedLinear(vectors, self.layer2_weights, self.layer2_biases) # Linear 2
vectors = F.celu(vectors, alpha=0.1) # CELU 3
vectors = batchedLinear(vectors, self.layer4_weights, self.layer4_biases) # Linear 4
vectors = F.celu(vectors, alpha=0.1) # CELU 5
vectors = batchedLinear(vectors, self.layer6_weights, self.layer6_biases) # Linear 6

# Sum: [num_mols, num_atoms, num_models, 1, 1] --> [num_mols, num_models]
# Mean: [num_mols, num_models] --> [num_mols]
energies = torch.mean(torch.sum(vectors, (1, 3, 4)), 1)

return SpeciesEnergies(species, energies)
99 changes: 99 additions & 0 deletions src/pytorch/BenchmarkBatchedNN.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#
# Copyright (c) 2020 Acellera
# Authors: Raimondas Galvelis
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#

import mdtraj
import time
import torch
import torchani

# from NNPOps.SymmetryFunctions import TorchANISymmetryFunctions
from NNPOps.BatchedNN import TorchANIBatchedNN

device = torch.device('cuda')

mol = mdtraj.load('molecules/2iuz_ligand.mol2')
species = torch.tensor([[atom.element.atomic_number for atom in mol.top.atoms]], device=device)
positions = torch.tensor(mol.xyz, dtype=torch.float32, requires_grad=True, device=device)

nnp = torchani.models.ANI2x(periodic_table_index=True, model_index=None).to(device)
print(nnp)

energy_ref = nnp((species, positions)).energies
energy_ref.backward()
grad_ref = positions.grad.clone()

N = 3000
start = time.time()
for _ in range(N):
energy_ref = nnp((species, positions)).energies
delta = time.time() - start
print(f'ANI-2x (forward pass)')
print(f' Duration: {delta} s')
print(f' Speed: {delta/N*1000} ms/it')

N = 1000
start = time.time()
for _ in range(N):
energy_ref = nnp((species, positions)).energies
positions.grad.zero_()
energy_ref.backward()
delta = time.time() - start
print(f'ANI-2x (forward & backward pass)')
print(f' Duration: {delta} s')
print(f' Speed: {delta/N*1000} ms/it')

# nnp.aev_computer = TorchANISymmetryFunctions(nnp.aev_computer).to(device)
nnp.neural_networks = TorchANIBatchedNN(nnp.species_converter, nnp.neural_networks, species).to(device)
print(nnp)

# nnp = torch.jit.script(nnp)
# nnp.save('nnp.pt')
# npp = torch.jit.load('nnp.pt').to(device)

energy = nnp((species, positions)).energies
positions.grad.zero_()
energy.backward()
grad = positions.grad.clone()

N = 15000
start = time.time()
for _ in range(N):
energy = nnp((species, positions)).energies
delta = time.time() - start
print(f'ANI-2x with BatchedNN (forward pass)')
print(f' Duration: {delta} s')
print(f' Speed: {delta/N*1000} ms/it')

N = 7500
start = time.time()
for _ in range(N):
energy = nnp((species, positions)).energies
positions.grad.zero_()
energy.backward()
delta = time.time() - start
print(f'ANI-2x with BatchedNN (forward & backward pass)')
print(f' Duration: {delta} s')
print(f' Speed: {delta/N*1000} ms/it')

# print(float(energy_ref), float(energy), float(energy_ref - energy))
# print(float(torch.max(torch.abs((grad - grad_ref)/grad_ref))))
35 changes: 0 additions & 35 deletions src/pytorch/README.md

This file was deleted.

Loading

0 comments on commit 44e4282

Please sign in to comment.