Skip to content

Commit

Permalink
Merge pull request #134 from peastman/zbl
Browse files Browse the repository at this point in the history
Implement ZBL potential
  • Loading branch information
peastman authored Nov 9, 2022
2 parents da617ea + 9066935 commit 9e0fef3
Show file tree
Hide file tree
Showing 8 changed files with 254 additions and 41 deletions.
56 changes: 56 additions & 0 deletions tests/priors.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
activation: silu
aggr: add
atom_filter: -1
attn_activation: silu
batch_size: 128
coord_files: null
cutoff_lower: 0.0
cutoff_upper: 5.0
derivative: false
distance_influence: both
early_stopping_patience: 150
ema_alpha_neg_dy: 1.0
ema_alpha_y: 1.0
embed_files: null
embedding_dimension: 256
energy_files: null
y_weight: 1.0
force_files: null
neg_dy_weight: 1.0
inference_batch_size: 128
load_model: null
lr: 0.0004
lr_factor: 0.8
lr_min: 1.0e-07
lr_patience: 15
lr_warmup_steps: 10000
max_num_neighbors: 64
max_z: 100
model: equivariant-transformer
neighbor_embedding: true
ngpus: -1
num_epochs: 3000
num_heads: 8
num_layers: 8
num_nodes: 1
num_rbf: 64
num_workers: 6
output_model: Scalar
precision: 32
prior_model:
- ZBL:
cutoff_distance: 4.0
max_num_neighbors: 50
- Atomref
rbf_type: expnorm
redirect: false
reduce_op: add
save_interval: 10
splits: null
standardize: false
test_interval: 10
test_size: null
train_size: 110000
trainable_rbf: false
val_size: 10000
weight_decay: 0.0
67 changes: 65 additions & 2 deletions tests/test_priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
import torch
import pytorch_lightning as pl
from torchmdnet import models
from torchmdnet.models.model import create_model
from torchmdnet.priors import Atomref
from torchmdnet.models.model import create_model, create_prior_models
from torchmdnet.module import LNNP
from torchmdnet.priors import Atomref, ZBL
from torch_scatter import scatter
from utils import load_example_args, create_example_batch, DummyDataset
from os.path import dirname, join
import tempfile


@mark.parametrize("model_name", models.__all__)
Expand All @@ -31,3 +34,63 @@ def test_atomref(model_name):
# check if the output of both models differs by the expected atomref contribution
expected_offset = scatter(dataset.get_atomref().squeeze()[z], batch).unsqueeze(1)
torch.testing.assert_allclose(x_atomref, x_no_atomref + expected_offset)

def test_zbl():
pos = torch.tensor([[1.0, 0.0, 0.0], [2.5, 0.0, 0.0], [1.0, 1.0, 0.0], [0.0, 0.0, -1.0]], dtype=torch.float32) # Atom positions in Bohr
types = torch.tensor([0, 1, 2, 1], dtype=torch.long) # Atom types
atomic_number = torch.tensor([1, 6, 8], dtype=torch.int8) # Mapping of atom types to atomic numbers
distance_scale = 5.29177210903e-11 # Convert Bohr to meters
energy_scale = 1000.0/6.02214076e23 # Convert kJ/mol to Joules

# Use the ZBL class to compute the energy.

zbl = ZBL(10.0, 5, atomic_number, distance_scale=distance_scale, energy_scale=energy_scale)
energy = zbl.post_reduce(torch.zeros((1,)), types, pos, torch.zeros_like(types))[0]

# Compare to the expected value.

def compute_interaction(pos1, pos2, z1, z2):
delta = pos1-pos2
r = torch.sqrt(torch.dot(delta, delta))
x = r / (0.8854/(z1**0.23 + z2**0.23))
phi = 0.1818*torch.exp(-3.2*x) + 0.5099*torch.exp(-0.9423*x) + 0.2802*torch.exp(-0.4029*x) + 0.02817*torch.exp(-0.2016*x)
cutoff = 0.5*(torch.cos(r*torch.pi/10.0) + 1.0)
return cutoff*phi*(138.935/5.29177210903e-2)*z1*z2/r

expected = 0
for i in range(len(pos)):
for j in range(i):
expected += compute_interaction(pos[i], pos[j], atomic_number[types[i]], atomic_number[types[j]])
torch.testing.assert_allclose(expected, energy)

def test_multiple_priors():
# Create a model from a config file.

dataset = DummyDataset(has_atomref=True)
config_file = join(dirname(__file__), 'priors.yaml')
args = load_example_args('equivariant-transformer', config_file=config_file)
prior_models = create_prior_models(args, dataset)
args['prior_args'] = [p.get_init_args() for p in prior_models]
model = LNNP(args, prior_model=prior_models)
priors = model.model.prior_model

# Make sure the priors were created correctly.

assert len(priors) == 2
assert isinstance(priors[0], ZBL)
assert isinstance(priors[1], Atomref)
assert priors[0].cutoff_distance == 4.0
assert priors[0].max_num_neighbors == 50

# Save and load a checkpoint, and make sure the priors are correct.

with tempfile.NamedTemporaryFile() as f:
torch.save(model, f)
f.seek(0)
model2 = torch.load(f)
priors2 = model2.model.prior_model
assert len(priors2) == 2
assert isinstance(priors2[0], ZBL)
assert isinstance(priors2[1], Atomref)
assert priors2[0].cutoff_distance == priors[0].cutoff_distance
assert priors2[0].max_num_neighbors == priors[0].max_num_neighbors
9 changes: 7 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
from torch_geometric.data import Dataset, Data


def load_example_args(model_name, remove_prior=False, **kwargs):
with open(join(dirname(dirname(__file__)), "examples", "ET-QM9.yaml"), "r") as f:
def load_example_args(model_name, remove_prior=False, config_file=None, **kwargs):
if config_file is None:
config_file = join(dirname(dirname(__file__)), "examples", "ET-QM9.yaml")
with open(config_file, "r") as f:
args = yaml.load(f, Loader=yaml.FullLoader)
args["model"] = model_name
args["seed"] = 1234
Expand Down Expand Up @@ -69,6 +71,9 @@ def _get_atomref(self):
return self.atomref

DummyDataset.get_atomref = _get_atomref
self.atomic_number = torch.arange(max(atom_types)+1)
self.distance_scale = 1.0
self.energy_scale = 1.0

def get(self, idx):
features = dict(z=self.z[idx].clone(), pos=self.pos[idx].clone())
Expand Down
33 changes: 20 additions & 13 deletions torchmdnet/datasets/hdf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from torch_geometric.data import Dataset, Data
import h5py
import numpy as np


class HDF5(Dataset):
Expand All @@ -27,7 +28,12 @@ def __init__(self, filename, **kwargs):
files = [h5py.File(f, "r") for f in self.filename.split(";")]
for file in files:
for group_name in file:
self.num_molecules += len(file[group_name]["energy"])
if group_name == '_metadata':
group = file[group_name]
for name in group:
setattr(self, name, torch.tensor(np.array(group[name])))
else:
self.num_molecules += len(file[group_name]["energy"])
file.close()

def setup_index(self):
Expand All @@ -36,18 +42,19 @@ def setup_index(self):
self.index = []
for file in files:
for group_name in file:
group = file[group_name]
types = group["types"]
pos = group["pos"]
energy = group["energy"]
if "forces" in group:
self.has_forces = True
forces = group["forces"]
for i in range(len(energy)):
self.index.append((types, pos, energy, forces, i))
else:
for i in range(len(energy)):
self.index.append((types, pos, energy, i))
if group_name != '_metadata':
group = file[group_name]
types = group["types"]
pos = group["pos"]
energy = group["energy"]
if "forces" in group:
self.has_forces = True
forces = group["forces"]
for i in range(len(energy)):
self.index.append((types, pos, energy, forces, i))
else:
for i in range(len(energy)):
self.index.append((types, pos, energy, i))

assert self.num_molecules == len(self.index), (
"Mismatch between previously calculated "
Expand Down
59 changes: 45 additions & 14 deletions torchmdnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,8 @@ def create_model(args, prior_model=None, mean=None, std=None):

# prior model
if args["prior_model"] and prior_model is None:
assert "prior_args" in args, (
f"Requested prior model {args['prior_model']} but the "
f'arguments are lacking the key "prior_args".'
)
assert hasattr(priors, args["prior_model"]), (
f'Unknown prior model {args["prior_model"]}. '
f'Available models are {", ".join(priors.__all__)}'
)
# instantiate prior model if it was not passed to create_model (i.e. when loading a model)
prior_model = getattr(priors, args["prior_model"])(**args["prior_args"])
prior_model = create_prior_models(args)

# create output network
output_prefix = "Equivariant" if is_equivariant else ""
Expand Down Expand Up @@ -113,6 +105,40 @@ def load_model(filepath, args=None, device="cpu", **kwargs):
return model.to(device)


def create_prior_models(args, dataset=None):
"""Parse the prior_model configuration option and create the prior models."""
prior_models = []
if args['prior_model']:
prior_model = args['prior_model']
prior_names = []
prior_args = []
if not isinstance(prior_model, list):
prior_model = [prior_model]
for prior in prior_model:
if isinstance(prior, dict):
for key, value in prior.items():
prior_names.append(key)
if value is None:
prior_args.append({})
else:
prior_args.append(value)
else:
prior_names.append(prior)
prior_args.append({})
if 'prior_args' in args:
prior_args = args['prior_args']
if not isinstance(prior_args):
prior_args = [prior_args]
for name, arg in zip(prior_names, prior_args):
assert hasattr(priors, name), (
f"Unknown prior model {name}. "
f"Available models are {', '.join(priors.__all__)}"
)
# initialize the prior model
prior_models.append(getattr(priors, name)(dataset=dataset, **arg))
return prior_models


class TorchMD_Net(nn.Module):
def __init__(
self,
Expand All @@ -127,15 +153,17 @@ def __init__(
self.representation_model = representation_model
self.output_model = output_model

self.prior_model = prior_model
if not output_model.allow_prior_model and prior_model is not None:
self.prior_model = None
prior_model = None
rank_zero_warn(
(
"Prior model was given but the output model does "
"not allow prior models. Dropping the prior model."
)
)
if isinstance(prior_model, priors.base.BasePrior):
prior_model = [prior_model]
self.prior_model = None if prior_model is None else torch.nn.ModuleList(prior_model)

self.derivative = derivative

Expand All @@ -150,7 +178,8 @@ def reset_parameters(self):
self.representation_model.reset_parameters()
self.output_model.reset_parameters()
if self.prior_model is not None:
self.prior_model.reset_parameters()
for prior in self.prior_model:
prior.reset_parameters()

def forward(
self,
Expand Down Expand Up @@ -179,7 +208,8 @@ def forward(

# apply atom-wise prior model
if self.prior_model is not None:
x = self.prior_model.pre_reduce(x, z, pos, batch)
for prior in self.prior_model:
x = prior.pre_reduce(x, z, pos, batch)

# aggregate atoms
x = self.output_model.reduce(x, batch)
Expand All @@ -193,7 +223,8 @@ def forward(

# apply molecular-wise prior model
if self.prior_model is not None:
y = self.prior_model.post_reduce(y, z, pos, batch)
for prior in self.prior_model:
y = prior.post_reduce(y, z, pos, batch)

# compute gradients with respect to coordinates
if self.derivative:
Expand Down
3 changes: 3 additions & 0 deletions torchmdnet/priors/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
from torchmdnet.priors.atomref import Atomref
from torchmdnet.priors.zbl import ZBL

__all__ = ['Atomref', 'ZBL']
54 changes: 54 additions & 0 deletions torchmdnet/priors/zbl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch
from torchmdnet.priors.base import BasePrior
from torchmdnet.models.utils import Distance, CosineCutoff

class ZBL(BasePrior):
"""This class implements the Ziegler-Biersack-Littmark (ZBL) potential for screened nuclear repulsion.
Is is described in https://doi.org/10.1007/978-3-642-68779-2_5 (equations 9 and 10 on page 147). It
is an empirical potential that does a good job of describing the repulsion between atoms at very short
distances.
To use this prior, the Dataset must provide the following attributes.
atomic_number: 1D tensor of length max_z. atomic_number[z] is the atomic number of atoms with atom type z.
distance_scale: multiply by this factor to convert coordinates stored in the dataset to meters
energy_scale: multiply by this factor to convert energies stored in the dataset to Joules (*not* J/mol)
"""
def __init__(self, cutoff_distance, max_num_neighbors, atomic_number=None, distance_scale=None, energy_scale=None, dataset=None):
super(ZBL, self).__init__()
if atomic_number is None:
atomic_number = dataset.atomic_number
if distance_scale is None:
distance_scale = dataset.distance_scale
if energy_scale is None:
energy_scale = dataset.energy_scale
atomic_number = torch.as_tensor(atomic_number, dtype=torch.int8)
self.register_buffer("atomic_number", atomic_number)
self.distance = Distance(0, cutoff_distance, max_num_neighbors=max_num_neighbors)
self.cutoff = CosineCutoff(cutoff_upper=cutoff_distance)
self.cutoff_distance = cutoff_distance
self.max_num_neighbors = max_num_neighbors
self.distance_scale = distance_scale
self.energy_scale = energy_scale

def get_init_args(self):
return {'cutoff_distance': self.cutoff_distance,
'max_num_neighbors': self.max_num_neighbors,
'atomic_number': self.atomic_number,
'distance_scale': self.distance_scale,
'energy_scale': self.energy_scale}

def reset_parameters(self):
pass

def post_reduce(self, y, z, pos, batch):
edge_index, distance, _ = self.distance(pos, batch)
atomic_number = self.atomic_number[z[edge_index]]
# 5.29e-11 is the Bohr radius in meters. All other numbers are magic constants from the ZBL potential.
a = 0.8854*5.29177210903e-11/(atomic_number[0]**0.23 + atomic_number[1]**0.23)
d = distance*self.distance_scale/a
f = 0.1818*torch.exp(-3.2*d) + 0.5099*torch.exp(-0.9423*d) + 0.2802*torch.exp(-0.4029*d) + 0.02817*torch.exp(-0.2016*d)
f *= self.cutoff(distance)
# Compute the energy, converting to the dataset's units. Multiply by 0.5 because every atom pair
# appears twice.
return y + 0.5*(2.30707755e-28/self.energy_scale/self.distance_scale)*torch.sum(f*atomic_number[0]*atomic_number[1]/distance, dim=-1)
Loading

0 comments on commit 9e0fef3

Please sign in to comment.