Skip to content

Commit

Permalink
Fix PBCs (#83)
Browse files Browse the repository at this point in the history
periofic boundary flag correctly set in SymmetryFunctions
added tests with a water box with PBCs
  • Loading branch information
sef43 authored Feb 28, 2023
1 parent 3c96f5b commit 16543f9
Show file tree
Hide file tree
Showing 4 changed files with 388 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/pytorch/SymmetryFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,15 @@ class Holder : public torch::CustomClassHolder {
for (const float thetas: ShfZ)
angularFunctions.push_back({eta, rs, zeta, thetas});

bool periodic = cellPtr != nullptr;

if (device.is_cpu()) {
impl = std::make_shared<CpuANISymmetryFunctions>(numAtoms, numSpecies, Rcr, Rca, false, atomSpecies_, radialFunctions, angularFunctions, true);
impl = std::make_shared<CpuANISymmetryFunctions>(numAtoms, numSpecies, Rcr, Rca, periodic, atomSpecies_, radialFunctions, angularFunctions, true);
#ifdef ENABLE_CUDA
} else if (device.is_cuda()) {
// PyTorch allow to chose GPU with "torch.device", but it doesn't set as the default one.
CHECK_CUDA_RESULT(cudaSetDevice(device.index()));
impl = std::make_shared<CudaANISymmetryFunctions>(numAtoms, numSpecies, Rcr, Rca, false, atomSpecies_, radialFunctions, angularFunctions, true);
impl = std::make_shared<CudaANISymmetryFunctions>(numAtoms, numSpecies, Rcr, Rca, periodic, atomSpecies_, radialFunctions, angularFunctions, true);
#endif
} else
throw std::runtime_error("Unsupported device: " + device.str());
Expand Down
34 changes: 34 additions & 0 deletions src/pytorch/TestOptimizedTorchANI.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,40 @@ def test_compare_with_native(deviceString, molFile):
else:
assert grad_error < 5e-3

@pytest.mark.parametrize('deviceString', ['cpu', 'cuda'])
def test_compare_waterbox_pbc_with_native(deviceString):

if deviceString == 'cuda' and not torch.cuda.is_available():
pytest.skip('CUDA is not available')

from NNPOps import OptimizedTorchANI

device = torch.device(deviceString)

mol = mdtraj.load(os.path.join(molecules, 'water.pdb'))
atomicNumbers = torch.tensor([[atom.element.atomic_number for atom in mol.top.atoms]], device=device)
atomicPositions = torch.tensor(mol.xyz * 10, dtype=torch.float32, requires_grad=True, device=device)
cell = mol.unitcell_vectors[0]
cell = torch.tensor(cell, dtype=torch.float32, device=device)*10.0
pbc = torch.tensor([True, True, True], dtype=torch.bool, device=device)

nnp = torchani.models.ANI2x(periodic_table_index=True).to(device)
energy_ref = nnp((atomicNumbers, atomicPositions), cell=cell, pbc=pbc).energies
energy_ref.backward()
grad_ref = atomicPositions.grad.clone()

nnp = OptimizedTorchANI(nnp, atomicNumbers).to(device)
energy = nnp((atomicNumbers, atomicPositions), cell=cell, pbc=pbc).energies
atomicPositions.grad.zero_()
energy.backward()
grad = atomicPositions.grad.clone()

energy_error = torch.abs((energy - energy_ref)/energy_ref)
grad_error = torch.max(torch.abs((grad - grad_ref)/grad_ref))

assert energy_error < 5e-7
assert grad_error < 7e-3

@pytest.mark.parametrize('deviceString', ['cpu', 'cuda'])
@pytest.mark.parametrize('molFile', ['1hvj', '1hvk', '2iuz', '3hkw', '3hky', '3lka', '3o99'])
def test_model_serialization(deviceString, molFile):
Expand Down
38 changes: 38 additions & 0 deletions src/pytorch/TestSymmetryFunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,44 @@ def test_compare_with_native(deviceString, molFile):
else:
assert grad_error < 5e-3


@pytest.mark.parametrize('deviceString', ['cpu', 'cuda'])
def test_compare_waterbox_pbc_with_native(deviceString):

if deviceString == 'cuda' and not torch.cuda.is_available():
pytest.skip('CUDA is not available')

from NNPOps.SymmetryFunctions import TorchANISymmetryFunctions

device = torch.device(deviceString)

mol = mdtraj.load(os.path.join(molecules, 'water.pdb'))
atomicNumbers = torch.tensor([[atom.element.atomic_number for atom in mol.top.atoms]], device=device)
atomicPositions = torch.tensor(mol.xyz * 10, dtype=torch.float32, requires_grad=True, device=device)
cell = mol.unitcell_vectors[0]
cell = torch.tensor(cell, dtype=torch.float32, device=device)*10.0
pbc = torch.tensor([True, True, True], dtype=torch.bool, device=device)

nnp = torchani.models.ANI2x(periodic_table_index=True).to(device)
energy_ref = nnp((atomicNumbers, atomicPositions), cell=cell, pbc=pbc).energies
energy_ref.backward()
grad_ref = atomicPositions.grad.clone()

nnp.aev_computer = TorchANISymmetryFunctions(nnp.species_converter, nnp.aev_computer, atomicNumbers)
energy = nnp((atomicNumbers, atomicPositions), cell=cell, pbc=pbc).energies
atomicPositions.grad.zero_()
energy.backward()
grad = atomicPositions.grad.clone()

energy_error = torch.abs((energy - energy_ref)/energy_ref)
grad_error = torch.max(torch.abs((grad - grad_ref)/grad_ref))

assert energy_error < 5e-7
assert grad_error < 7e-3




@pytest.mark.parametrize('deviceString', ['cpu', 'cuda'])
@pytest.mark.parametrize('molFile', ['1hvj', '1hvk', '2iuz', '3hkw', '3hky', '3lka', '3o99'])
def test_model_serialization(deviceString, molFile):
Expand Down
Loading

0 comments on commit 16543f9

Please sign in to comment.