Skip to content

Commit

Permalink
dprc: add restart_from_iter option (#1322)
Browse files Browse the repository at this point in the history
Add the `restart_from_iter` option in each iteration of DPRc simulations
to restart from a previous iteration instead of initial structures.

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
njzjz and pre-commit-ci[bot] authored Sep 4, 2023
1 parent 1f0505d commit 317c674
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 8 deletions.
4 changes: 4 additions & 0 deletions dpgen/generator/arginfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,7 @@ def model_devi_amber_args() -> list[Argument]:
)
doc_model_devi_f_trust_lo = "Lower bound of forces for the selection. If dict, should be set for each index in sys_idx, respectively."
doc_model_devi_f_trust_hi = "Upper bound of forces for the selection. If dict, should be set for each index in sys_idx, respectively."
doc_restart_from_iter = "The iteration index to restart the simulation from. If not given, the simulation is restarted from `sys_configs`."

return [
# make model devi args
Expand All @@ -497,6 +498,9 @@ def model_devi_amber_args() -> list[Argument]:
sub_fields=[
Argument("sys_idx", list, optional=False, doc=doc_sys_idx),
Argument("trj_freq", int, optional=False, doc=doc_trj_freq),
Argument(
"restart_from_iter", int, optional=True, doc=doc_restart_from_iter
),
],
),
Argument("low_level", str, optional=False, doc=doc_low_level),
Expand Down
21 changes: 17 additions & 4 deletions dpgen/generator/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -1775,10 +1775,23 @@ def _make_model_devi_amber(
create_path(task_path)
# link restart file
loc_conf_name = "init.rst7"
os.symlink(
os.path.join(os.path.join("..", "confs"), conf_name + ".rst7"),
os.path.join(task_path, loc_conf_name),
)
if cur_job.get("restart_from_iter") is None:
os.symlink(
os.path.join(os.path.join("..", "confs"), conf_name + ".rst7"),
os.path.join(task_path, loc_conf_name),
)
else:
restart_from_iter = cur_job["restart_from_iter"]
restart_iter_name = make_iter_name(restart_from_iter)
os.symlink(
os.path.relpath(
os.path.join(
restart_iter_name, model_devi_name, task_name, "rc.rst7"
),
task_path,
),
os.path.join(task_path, loc_conf_name),
)
cwd_ = os.getcwd()
# chdir to task path
os.chdir(task_path)
Expand Down
Empty file added tests/generator/amber/MON.parm7
Empty file.
Empty file.
2 changes: 1 addition & 1 deletion tests/generator/param-amber.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"numb_models": 4,
"mdin_prefix": "amber",
"parm7_prefix": "amber",
"sys_prefix": "amber",
"sys_configs_prefix": "amber",
"disang_prefix": "amber",
"sys_configs": [
[
Expand Down
60 changes: 57 additions & 3 deletions tests/generator/test_make_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,12 +533,66 @@ def test_make_model_devi(self):
jdata = json.load(fp)
with open(machine_file) as fp:
mdata = json.load(fp)
jdata["sys_prefix"] = os.path.abspath(jdata["sys_prefix"])
# TODO: these should be normalized in the main program
jdata["sys_configs_prefix"] = os.path.abspath(jdata["sys_configs_prefix"])
jdata["disang_prefix"] = os.path.abspath(jdata["disang_prefix"])
jdata["mdin_prefix"] = os.path.abspath(jdata["mdin_prefix"])
jdata["parm7_prefix"] = os.path.abspath(jdata["mdin_prefix"])
_make_fake_models(0, jdata["numb_models"])
make_model_devi(0, jdata, mdata)
_check_pb(self, 0)
_check_confs(self, 0, jdata)
_check_traj_dir(self, 0)
self._check_input(0)

def test_restart_from_iter(self):
if os.path.isdir("iter.000000"):
shutil.rmtree("iter.000000")
if os.path.isdir("iter.000001"):
shutil.rmtree("iter.000001")
with open(param_amber_file) as fp:
jdata = json.load(fp)
with open(machine_file) as fp:
mdata = json.load(fp)
jdata["model_devi_jobs"].append(
{
"sys_idx": [0],
"restart_from_iter": 0,
}
)
jdata["sys_configs_prefix"] = os.path.abspath(jdata["sys_configs_prefix"])
jdata["disang_prefix"] = os.path.abspath(jdata["disang_prefix"])
jdata["mdin_prefix"] = os.path.abspath(jdata["mdin_prefix"])
jdata["parm7_prefix"] = os.path.abspath(jdata["mdin_prefix"])
_make_fake_models(0, jdata["numb_models"])
make_model_devi(0, jdata, mdata)
_check_pb(self, 0)
self._check_input(0)
restart_text = "This is the fake restart file to test `restart_from_iter`"
with open(
os.path.join(
"iter.%06d" % 0, "01.model_devi", "task.000.000000", "rc.rst7"
),
"w",
) as fw:
fw.write(restart_text)
_make_fake_models(1, jdata["numb_models"])
make_model_devi(1, jdata, mdata)
_check_pb(self, 1)
self._check_input(1)
with open(
os.path.join(
"iter.%06d" % 1, "01.model_devi", "task.000.000000", "init.rst7"
)
) as f:
assert f.read() == restart_text

def _check_input(self, iter_idx: int):
md_dir = os.path.join("iter.%06d" % iter_idx, "01.model_devi")
assert os.path.isfile(os.path.join(md_dir, "init0.mdin"))
assert os.path.isfile(os.path.join(md_dir, "qmmm0.parm7"))
tasks = glob.glob(os.path.join(md_dir, "task.*"))
for tt in tasks:
assert os.path.isfile(os.path.join(tt, "init.rst7"))
assert os.path.isfile(os.path.join(tt, "TEMPLATE.disang"))


if __name__ == "__main__":
Expand Down

0 comments on commit 317c674

Please sign in to comment.