Skip to content

Commit

Permalink
Optimize SpeciesConverter of TorchANI (#39)
Browse files Browse the repository at this point in the history
* Implement TorchANISpeciesConverter

* Add test for TorchANISpeciesConverter

* Update README.md

* Remove the TorchANI import from the top level (#44)
  • Loading branch information
Raimondas Galvelis authored Dec 20, 2021
1 parent cacb4f8 commit 7041172
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,14 @@ foreach(TEST_PATH ${TEST_PATHS})
endforeach()

add_test(TestBatchedNN pytest -v ${CMAKE_SOURCE_DIR}/src/pytorch/TestBatchedNN.py)
add_test(TestSpeciesConverter pytest -v ${CMAKE_SOURCE_DIR}/src/pytorch/TestSpeciesConverter.py)
add_test(TestEnergyShifter pytest -v ${CMAKE_SOURCE_DIR}/src/pytorch/TestEnergyShifter.py)
add_test(TestSymmetryFunctions pytest -v ${CMAKE_SOURCE_DIR}/src/pytorch/TestSymmetryFunctions.py)

install(TARGETS ${LIBRARY} DESTINATION ${Python_SITEARCH}/${NAME})
install(FILES src/pytorch/__init__.py
src/pytorch/BatchedNN.py
src/pytorch/SpeciesConverter.py
src/pytorch/EnergyShifter.py
src/pytorch/SymmetryFunctions.py
DESTINATION ${Python_SITEARCH}/${NAME})
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ import mdtraj
import torch
import torchani

from NNPOps.SpeciesConverter import TorchANISpeciesConverter
from NNPOps.SymmetryFunctions import TorchANISymmetryFunctions
from NNPOps.BatchedNN import TorchANIBatchedNN
from NNPOps.EnergyShifter import TorchANIEnergyShifter
Expand All @@ -88,6 +89,7 @@ positions = torch.tensor(molecule.xyz * 10, dtype=torch.float32, requires_grad=T

# Construct ANI-2x and replace its operations with the optimized ones
nnp = torchani.models.ANI2x(periodic_table_index=True).to(device)
nnp.species_converter = TorchANISpeciesConverter(nnp.species_converter, species).to(device)
nnp.aev_computer = TorchANISymmetryFunctions(nnp.aev_computer).to(device)
nnp.neural_networks = TorchANIBatchedNN(nnp.species_converter, nnp.neural_networks, species).to(device)
nnp.energy_shifter = TorchANIEnergyShifter(nnp.species_converter, nnp.energy_shifter, species).to(device)
Expand Down
40 changes: 40 additions & 0 deletions src/pytorch/SpeciesConverter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#

import torch
from torch import Tensor
from typing import Optional, Tuple

class TorchANISpeciesConverter(torch.nn.Module):

from torchani.nn import SpeciesConverter

def __init__(self, converter: SpeciesConverter, atomicNumbers: Tensor) -> None:

super().__init__()

# Convert atomic numbers to a list of species
species = converter((atomicNumbers, torch.empty(0))).species
self.register_buffer('species', species)

self.conv_tensor = converter.conv_tensor # Just to make TorchScript happy :)

def forward(self, species_coordinates: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:

_, coordinates = species_coordinates

return self.species, coordinates
105 changes: 105 additions & 0 deletions src/pytorch/TestSpeciesConverter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#
# Copyright (c) 2020-2021 Acellera
# Authors: Raimondas Galvelis
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#

import mdtraj
import os
import pytest
import tempfile
import torch
import torchani

molecules = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'molecules')

def test_import():
import NNPOps
import NNPOps.SpeciesConverter

@pytest.mark.parametrize('deviceString', ['cpu', 'cuda'])
@pytest.mark.parametrize('molFile', ['1hvj', '1hvk', '2iuz', '3hkw', '3hky', '3lka', '3o99'])
def test_compare_with_native(deviceString, molFile):

if deviceString == 'cuda' and not torch.cuda.is_available():
pytest.skip('CUDA is not available')

from NNPOps.SpeciesConverter import TorchANISpeciesConverter

device = torch.device(deviceString)

mol = mdtraj.load(os.path.join(molecules, f'{molFile}_ligand.mol2'))
atomicNumbers = torch.tensor([[atom.element.atomic_number for atom in mol.top.atoms]], device=device)
atomicPositions = torch.tensor(mol.xyz * 10, dtype=torch.float32, requires_grad=True, device=device)

nnp = torchani.models.ANI2x(periodic_table_index=True).to(device)
energy_ref = nnp((atomicNumbers, atomicPositions)).energies
energy_ref.backward()
grad_ref = atomicPositions.grad.clone()

nnp.species_converter = TorchANISpeciesConverter(nnp.species_converter, atomicNumbers).to(device)
energy = nnp((atomicNumbers, atomicPositions)).energies
atomicPositions.grad.zero_()
energy.backward()
grad = atomicPositions.grad.clone()

energy_error = torch.abs((energy - energy_ref)/energy_ref)
grad_error = torch.max(torch.abs((grad - grad_ref)/grad_ref))

assert energy_error < 5e-7
assert grad_error < 5e-3

@pytest.mark.parametrize('deviceString', ['cpu', 'cuda'])
@pytest.mark.parametrize('molFile', ['1hvj', '1hvk', '2iuz', '3hkw', '3hky', '3lka', '3o99'])
def test_model_serialization(deviceString, molFile):

if deviceString == 'cuda' and not torch.cuda.is_available():
pytest.skip('CUDA is not available')

from NNPOps.SpeciesConverter import TorchANISpeciesConverter

device = torch.device(deviceString)

mol = mdtraj.load(os.path.join(molecules, f'{molFile}_ligand.mol2'))
atomicNumbers = torch.tensor([[atom.element.atomic_number for atom in mol.top.atoms]], device=device)
atomicPositions = torch.tensor(mol.xyz * 10, dtype=torch.float32, requires_grad=True, device=device)

nnp_ref = torchani.models.ANI2x(periodic_table_index=True).to(device)
nnp_ref.species_converter = TorchANISpeciesConverter(nnp_ref.species_converter, atomicNumbers).to(device)

energy_ref = nnp_ref((atomicNumbers, atomicPositions)).energies
energy_ref.backward()
grad_ref = atomicPositions.grad.clone()

with tempfile.NamedTemporaryFile() as fd:

torch.jit.script(nnp_ref).save(fd.name)
nnp = torch.jit.load(fd.name)

energy = nnp((atomicNumbers, atomicPositions)).energies
atomicPositions.grad.zero_()
energy.backward()
grad = atomicPositions.grad.clone()

energy_error = torch.abs((energy - energy_ref)/energy_ref)
grad_error = torch.max(torch.abs((grad - grad_ref)/grad_ref))

assert energy_error < 5e-7
assert grad_error < 5e-3

0 comments on commit 7041172

Please sign in to comment.