diff --git a/dpgen/generator/arginfo.py b/dpgen/generator/arginfo.py index 33dd64c88..0ad63f22e 100644 --- a/dpgen/generator/arginfo.py +++ b/dpgen/generator/arginfo.py @@ -485,6 +485,7 @@ def model_devi_amber_args() -> list[Argument]: ) doc_model_devi_f_trust_lo = "Lower bound of forces for the selection. If dict, should be set for each index in sys_idx, respectively." doc_model_devi_f_trust_hi = "Upper bound of forces for the selection. If dict, should be set for each index in sys_idx, respectively." + doc_restart_from_iter = "The iteration index to restart the simulation from. If not given, the simulation is restarted from `sys_configs`." return [ # make model devi args @@ -497,6 +498,9 @@ def model_devi_amber_args() -> list[Argument]: sub_fields=[ Argument("sys_idx", list, optional=False, doc=doc_sys_idx), Argument("trj_freq", int, optional=False, doc=doc_trj_freq), + Argument( + "restart_from_iter", int, optional=True, doc=doc_restart_from_iter + ), ], ), Argument("low_level", str, optional=False, doc=doc_low_level), diff --git a/dpgen/generator/run.py b/dpgen/generator/run.py index 46eea3f58..847cdb594 100644 --- a/dpgen/generator/run.py +++ b/dpgen/generator/run.py @@ -1775,10 +1775,23 @@ def _make_model_devi_amber( create_path(task_path) # link restart file loc_conf_name = "init.rst7" - os.symlink( - os.path.join(os.path.join("..", "confs"), conf_name + ".rst7"), - os.path.join(task_path, loc_conf_name), - ) + if cur_job.get("restart_from_iter") is None: + os.symlink( + os.path.join(os.path.join("..", "confs"), conf_name + ".rst7"), + os.path.join(task_path, loc_conf_name), + ) + else: + restart_from_iter = cur_job["restart_from_iter"] + restart_iter_name = make_iter_name(restart_from_iter) + os.symlink( + os.path.relpath( + os.path.join( + restart_iter_name, model_devi_name, task_name, "rc.rst7" + ), + task_path, + ), + os.path.join(task_path, loc_conf_name), + ) cwd_ = os.getcwd() # chdir to task path os.chdir(task_path) diff --git a/tests/generator/amber/MON.parm7 b/tests/generator/amber/MON.parm7 new file mode 100644 index 000000000..e69de29bb diff --git a/tests/generator/amber/init_-1.20.disang b/tests/generator/amber/init_-1.20.disang new file mode 100644 index 000000000..e69de29bb diff --git a/tests/generator/param-amber.json b/tests/generator/param-amber.json index 6dc97aee9..905af9118 100644 --- a/tests/generator/param-amber.json +++ b/tests/generator/param-amber.json @@ -12,7 +12,7 @@ "numb_models": 4, "mdin_prefix": "amber", "parm7_prefix": "amber", - "sys_prefix": "amber", + "sys_configs_prefix": "amber", "disang_prefix": "amber", "sys_configs": [ [ diff --git a/tests/generator/test_make_md.py b/tests/generator/test_make_md.py index 165475939..c6138ee52 100644 --- a/tests/generator/test_make_md.py +++ b/tests/generator/test_make_md.py @@ -533,12 +533,66 @@ def test_make_model_devi(self): jdata = json.load(fp) with open(machine_file) as fp: mdata = json.load(fp) - jdata["sys_prefix"] = os.path.abspath(jdata["sys_prefix"]) + # TODO: these should be normalized in the main program + jdata["sys_configs_prefix"] = os.path.abspath(jdata["sys_configs_prefix"]) + jdata["disang_prefix"] = os.path.abspath(jdata["disang_prefix"]) + jdata["mdin_prefix"] = os.path.abspath(jdata["mdin_prefix"]) + jdata["parm7_prefix"] = os.path.abspath(jdata["mdin_prefix"]) _make_fake_models(0, jdata["numb_models"]) make_model_devi(0, jdata, mdata) _check_pb(self, 0) - _check_confs(self, 0, jdata) - _check_traj_dir(self, 0) + self._check_input(0) + + def test_restart_from_iter(self): + if os.path.isdir("iter.000000"): + shutil.rmtree("iter.000000") + if os.path.isdir("iter.000001"): + shutil.rmtree("iter.000001") + with open(param_amber_file) as fp: + jdata = json.load(fp) + with open(machine_file) as fp: + mdata = json.load(fp) + jdata["model_devi_jobs"].append( + { + "sys_idx": [0], + "restart_from_iter": 0, + } + ) + jdata["sys_configs_prefix"] = os.path.abspath(jdata["sys_configs_prefix"]) + jdata["disang_prefix"] = os.path.abspath(jdata["disang_prefix"]) + jdata["mdin_prefix"] = os.path.abspath(jdata["mdin_prefix"]) + jdata["parm7_prefix"] = os.path.abspath(jdata["mdin_prefix"]) + _make_fake_models(0, jdata["numb_models"]) + make_model_devi(0, jdata, mdata) + _check_pb(self, 0) + self._check_input(0) + restart_text = "This is the fake restart file to test `restart_from_iter`" + with open( + os.path.join( + "iter.%06d" % 0, "01.model_devi", "task.000.000000", "rc.rst7" + ), + "w", + ) as fw: + fw.write(restart_text) + _make_fake_models(1, jdata["numb_models"]) + make_model_devi(1, jdata, mdata) + _check_pb(self, 1) + self._check_input(1) + with open( + os.path.join( + "iter.%06d" % 1, "01.model_devi", "task.000.000000", "init.rst7" + ) + ) as f: + assert f.read() == restart_text + + def _check_input(self, iter_idx: int): + md_dir = os.path.join("iter.%06d" % iter_idx, "01.model_devi") + assert os.path.isfile(os.path.join(md_dir, "init0.mdin")) + assert os.path.isfile(os.path.join(md_dir, "qmmm0.parm7")) + tasks = glob.glob(os.path.join(md_dir, "task.*")) + for tt in tasks: + assert os.path.isfile(os.path.join(tt, "init.rst7")) + assert os.path.isfile(os.path.join(tt, "TEMPLATE.disang")) if __name__ == "__main__":