-
Notifications
You must be signed in to change notification settings - Fork 75
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #134 from peastman/zbl
Implement ZBL potential
- Loading branch information
Showing
8 changed files
with
254 additions
and
41 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
activation: silu | ||
aggr: add | ||
atom_filter: -1 | ||
attn_activation: silu | ||
batch_size: 128 | ||
coord_files: null | ||
cutoff_lower: 0.0 | ||
cutoff_upper: 5.0 | ||
derivative: false | ||
distance_influence: both | ||
early_stopping_patience: 150 | ||
ema_alpha_neg_dy: 1.0 | ||
ema_alpha_y: 1.0 | ||
embed_files: null | ||
embedding_dimension: 256 | ||
energy_files: null | ||
y_weight: 1.0 | ||
force_files: null | ||
neg_dy_weight: 1.0 | ||
inference_batch_size: 128 | ||
load_model: null | ||
lr: 0.0004 | ||
lr_factor: 0.8 | ||
lr_min: 1.0e-07 | ||
lr_patience: 15 | ||
lr_warmup_steps: 10000 | ||
max_num_neighbors: 64 | ||
max_z: 100 | ||
model: equivariant-transformer | ||
neighbor_embedding: true | ||
ngpus: -1 | ||
num_epochs: 3000 | ||
num_heads: 8 | ||
num_layers: 8 | ||
num_nodes: 1 | ||
num_rbf: 64 | ||
num_workers: 6 | ||
output_model: Scalar | ||
precision: 32 | ||
prior_model: | ||
- ZBL: | ||
cutoff_distance: 4.0 | ||
max_num_neighbors: 50 | ||
- Atomref | ||
rbf_type: expnorm | ||
redirect: false | ||
reduce_op: add | ||
save_interval: 10 | ||
splits: null | ||
standardize: false | ||
test_interval: 10 | ||
test_size: null | ||
train_size: 110000 | ||
trainable_rbf: false | ||
val_size: 10000 | ||
weight_decay: 0.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,4 @@ | ||
from torchmdnet.priors.atomref import Atomref | ||
from torchmdnet.priors.zbl import ZBL | ||
|
||
__all__ = ['Atomref', 'ZBL'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import torch | ||
from torchmdnet.priors.base import BasePrior | ||
from torchmdnet.models.utils import Distance, CosineCutoff | ||
|
||
class ZBL(BasePrior): | ||
"""This class implements the Ziegler-Biersack-Littmark (ZBL) potential for screened nuclear repulsion. | ||
Is is described in https://doi.org/10.1007/978-3-642-68779-2_5 (equations 9 and 10 on page 147). It | ||
is an empirical potential that does a good job of describing the repulsion between atoms at very short | ||
distances. | ||
To use this prior, the Dataset must provide the following attributes. | ||
atomic_number: 1D tensor of length max_z. atomic_number[z] is the atomic number of atoms with atom type z. | ||
distance_scale: multiply by this factor to convert coordinates stored in the dataset to meters | ||
energy_scale: multiply by this factor to convert energies stored in the dataset to Joules (*not* J/mol) | ||
""" | ||
def __init__(self, cutoff_distance, max_num_neighbors, atomic_number=None, distance_scale=None, energy_scale=None, dataset=None): | ||
super(ZBL, self).__init__() | ||
if atomic_number is None: | ||
atomic_number = dataset.atomic_number | ||
if distance_scale is None: | ||
distance_scale = dataset.distance_scale | ||
if energy_scale is None: | ||
energy_scale = dataset.energy_scale | ||
atomic_number = torch.as_tensor(atomic_number, dtype=torch.int8) | ||
self.register_buffer("atomic_number", atomic_number) | ||
self.distance = Distance(0, cutoff_distance, max_num_neighbors=max_num_neighbors) | ||
self.cutoff = CosineCutoff(cutoff_upper=cutoff_distance) | ||
self.cutoff_distance = cutoff_distance | ||
self.max_num_neighbors = max_num_neighbors | ||
self.distance_scale = distance_scale | ||
self.energy_scale = energy_scale | ||
|
||
def get_init_args(self): | ||
return {'cutoff_distance': self.cutoff_distance, | ||
'max_num_neighbors': self.max_num_neighbors, | ||
'atomic_number': self.atomic_number, | ||
'distance_scale': self.distance_scale, | ||
'energy_scale': self.energy_scale} | ||
|
||
def reset_parameters(self): | ||
pass | ||
|
||
def post_reduce(self, y, z, pos, batch): | ||
edge_index, distance, _ = self.distance(pos, batch) | ||
atomic_number = self.atomic_number[z[edge_index]] | ||
# 5.29e-11 is the Bohr radius in meters. All other numbers are magic constants from the ZBL potential. | ||
a = 0.8854*5.29177210903e-11/(atomic_number[0]**0.23 + atomic_number[1]**0.23) | ||
d = distance*self.distance_scale/a | ||
f = 0.1818*torch.exp(-3.2*d) + 0.5099*torch.exp(-0.9423*d) + 0.2802*torch.exp(-0.4029*d) + 0.02817*torch.exp(-0.2016*d) | ||
f *= self.cutoff(distance) | ||
# Compute the energy, converting to the dataset's units. Multiply by 0.5 because every atom pair | ||
# appears twice. | ||
return y + 0.5*(2.30707755e-28/self.energy_scale/self.distance_scale)*torch.sum(f*atomic_number[0]*atomic_number[1]/distance, dim=-1) |
Oops, something went wrong.