-
Notifications
You must be signed in to change notification settings - Fork 0
/
ase_calculator.py
286 lines (251 loc) · 9 KB
/
ase_calculator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
import numpy as np
import logging
# GemNet imports
from gemnet.model.gemnet import GemNet
from gemnet.training.data_container import DataContainer
# ASE imports
from ase.md import MDLogger
from ase.md.velocitydistribution import (
MaxwellBoltzmannDistribution,
Stationary,
)
from ase.md.verlet import VelocityVerlet
from ase.md.langevin import Langevin
from ase import units, Atoms
from ase.io.trajectory import Trajectory
from ase.calculators.calculator import Calculator, all_changes
from ase.calculators.emt import EMT
from ase.calculators.lj import LennardJones
class Molecule(DataContainer):
"""
Implements the DataContainer but for a single molecule. Requires custom init method.
"""
def __init__(self, R, Z, cutoff, int_cutoff, triplets_only=False, bias_id=False):
self.index_keys = [
"batch_seg",
"id_undir",
"id_swap",
"id_c",
"id_a",
"id3_expand_ba",
"id3_reduce_ca",
"Kidx3",
]
if not triplets_only:
self.index_keys += [
"id4_int_b",
"id4_int_a",
"id4_reduce_ca",
"id4_expand_db",
"id4_reduce_cab",
"id4_expand_abd",
"Kidx4",
"id4_reduce_intm_ca",
"id4_expand_intm_db",
"id4_reduce_intm_ab",
"id4_expand_intm_ab",
]
self.triplets_only = triplets_only
self.cutoff = cutoff
self.int_cutoff = int_cutoff
self.keys = ["N", "Z", "R", "F", "E"]
self.bias_id=bias_id
assert R.shape == (len(Z), 3)
self.R = R
self.Z = Z
self.N = np.array([len(Z)], dtype=np.int32)
self.E = np.zeros(1, dtype=np.float32).reshape(1, 1)
self.F = np.zeros((len(Z), 3), dtype=np.float32)
self.bias_id = False
self.delta=False
self.N_cumsum = np.concatenate([[0], np.cumsum(self.N)])
self.addID = False
self.dtypes, dtypes2 = self.get_dtypes()
self.dtypes.update(dtypes2) # merge all dtypes in single dict
self.device = "cpu"
def get(self):
"""
Get the molecule representation in the expected format for the GemNet model.
"""
data = self.__getitem__(0)
for var in ["E", "F"]:
data.pop(var) # not needed i.e.e not kown -> want to calculate this
# push to the selected device
for key in data:
data[key] = data[key].to(self.device)
return data
def update(self, R):
"""
Update the position of the atoms.
Graph representation of the molecule might change if the atom positions are updated.
Parameters
----------
R: torch.Tensor (nAtoms, 3)
Positions of the atoms in A°.
"""
assert self.R.shape == R.shape
self.R = R
def to(self, device):
"""
Changes the device of the returned tensors in the .get() method.
"""
self.device = device
class GNNCalculator(Calculator):
"""
A custom ase calculator that computes energy and forces acting on atoms of a molecule using GNNs,
e.g. GemNet.
Parameters
----------
molecule
Captures data of all atoms. Contains indices etc.
model
The trained GemNet model.
atoms: ase.Atoms
ASE atoms instance.
restart: str
Prefix for restart file. May contain a directory. Default is None: don't restart.
label: str
Name used for all files.
"""
implemented_properties = ["energy", "forces"]
def __init__(
self,
molecule,
model,
atoms=None,
restart=None,
add_atom_energies=False,
label="gemnet_calc", # ase settings
**kwargs,
):
super().__init__(restart=restart, label=label, atoms=atoms, **kwargs)
self.molecule = molecule
self.model = model
# atom energies: EPBE0_atom (in eV) from QM7-X
self.add_atom_energies = add_atom_energies
self.atom_energies = {
1: -13.641404161,
6: -1027.592489146,
7: -1484.274819088,
8: -2039.734879322,
16: -10828.707468187,
17: -12516.444619523,
}
def calculate(
self, atoms=None, properties=["energy", "forces"], system_changes=all_changes
):
super().calculate(atoms, properties, system_changes)
# atoms.positions changes in each time step
# -> need to recompute indices
self.molecule.update(R=atoms.positions)
# get new indices etc.
inputs = self.molecule.get()
# predict the energy and forces
energy, forces = self.model.predict(inputs)
# uncomment to add atomic reference energies
energy = float(energy) # to scalar
if self.add_atom_energies:
energy += np.sum([self.atom_energies[z] for z in atoms.numbers])
# store energy and forces in the calculator dictionary
self.results["energy"] = energy
self.results["forces"] = forces.numpy()
class MDSimulator:
"""
Runs a MD simulation on the Atoms object created from data and perform MD simulation for max_steps
Parameters
----------
molecule
Captures data of all atoms.
model
The trained GemNet model.
dynamics: str
Name of the MD integrator. Implemented: 'langevin' or 'verlet'.
max_steps: int
Maximum number of simulation steps.
time: float
Integration time step for Newton's law in femtoseconds.
temperature: float
The temperature in Kelvin.
langevin_friction: float
Only used when dynamics are 'langevin'. A friction coefficient, typically 1e-4 to 1e-2.
interval: int
Write only every <interval> time step to trajectory file.
traj_path: str
Path of the file where to save the calculated trajectory.
vel: N-array, default=None
If set, then atoms have been initialized with these velocties.
logfile: str
File name or open file, where to log md simulation. “-” refers to standard output.
"""
def __init__(
self,
molecule,
model,
dynamics: str = "langevin",
max_steps: int = 100, # max_steps * time is total time length of trajectory
time: float = 0.5, # in fs
temperature: float = 300, # in K
langevin_friction: float = 0.002,
interval: int = 10,
traj_path="md_sim.traj",
vel=None,
logfile="-",
calculator="model"
):
self.max_steps = max_steps
atoms = Atoms(
positions=molecule.R, numbers=molecule.Z
) # positions in A, numbers in integers (1=H, etc.)
if calculator == "model":
atoms.calc = GNNCalculator(molecule, model=model, atoms=atoms)
elif calculator == "EMT":
atoms.calc = EMT() #Effective Medium Theory
elif calculator == "LJ":
atoms.calc = LennardJones() # Lennard Jones
else:
raise UserWarning(
f"Unkown MD calculator."
)
# Initializes velocities
#TODO: Implement a check for that switch
if vel is not None:
atoms.set_velocities(vel)
else:
# Set the momenta to a Maxwell-Boltzmann distribution
MaxwellBoltzmannDistribution(
atoms,
temp=temperature * units.kB, # kB: Boltzmann constant, eV/K
# temperature_K = temperature # only works in newer ase versions
)
# Set the center-of-mass momentum to zero
Stationary(atoms)
self.dyn = None
# Select MD simulation
if dynamics.lower() == "verlet":
logging.info("Selected MD integrator: Verlet")
# total energy will always be constant
self.dyn = VelocityVerlet(atoms, timestep=time * units.fs)
elif dynamics.lower() == "langevin":
logging.info("Selected MD integrator: Langevin")
# each atom is coupled to a heat bath through a fluctuating force and a friction term
self.dyn = Langevin(
atoms,
timestep=time * units.fs,
temperature=temperature * units.kB, # kB: Boltzmann constant, eV/K
# temperature_K = temperature, # only works in newer ase versions
friction=langevin_friction,
)
else:
raise UserWarning(
f"Unkown MD integrator. I only know 'verlet' and 'langevin' but {dynamics} was given."
)
logging.info(f"Save trajectory to {traj_path}")
self.traj = Trajectory(traj_path, "w", atoms)
self.dyn.attach(self.traj.write, interval=interval)
self.dyn.attach(
MDLogger(self.dyn, atoms, logfile, peratom=False, mode="a"),
interval=1,
)
def run(self):
self.dyn.run(self.max_steps)
self.traj.close()