diff --git a/dpgen/generator/run.py b/dpgen/generator/run.py index fa70d1cc3..5ee6a6a4d 100644 --- a/dpgen/generator/run.py +++ b/dpgen/generator/run.py @@ -1061,7 +1061,9 @@ def find_only_one_key(lmp_lines, key): return found[0] -def revise_lmp_input_model(lmp_lines, task_model_list, trj_freq, deepmd_version="1"): +def revise_lmp_input_model( + lmp_lines, task_model_list, trj_freq, deepmd_version="1", use_ele_temp=0 +): idx = find_only_one_key(lmp_lines, ["pair_style", "deepmd"]) graph_list = " ".join(task_model_list) if Version(deepmd_version) < Version("1"): @@ -1070,13 +1072,22 @@ def revise_lmp_input_model(lmp_lines, task_model_list, trj_freq, deepmd_version= trj_freq, ) else: - lmp_lines[idx] = ( - "pair_style deepmd %s out_freq %d out_file model_devi.out\n" - % ( - graph_list, - trj_freq, + if use_ele_temp == 0: + lmp_lines[idx] = ( + "pair_style deepmd %s out_freq %d out_file model_devi.out\n" + % ( + graph_list, + trj_freq, + ) + ) + elif use_ele_temp == 1: + lmp_lines[idx] = ( + "pair_style deepmd %s out_freq %d out_file model_devi.out fparam ${ELE_TEMP}\n" + % ( + graph_list, + trj_freq, + ) ) - ) return lmp_lines @@ -1340,6 +1351,8 @@ def _make_model_devi_revmat(iter_index, jdata, mdata, conf_systems): sys_idx = expand_idx(cur_job["sys_idx"]) if len(sys_idx) != len(list(set(sys_idx))): raise RuntimeError("system index should be uniq") + + use_ele_temp = jdata.get("use_ele_temp", 0) mass_map = jdata["mass_map"] use_plm = jdata.get("model_devi_plumed", False) use_plm_path = jdata.get("model_devi_plumed_path", False) @@ -1446,6 +1459,7 @@ def _make_model_devi_revmat(iter_index, jdata, mdata, conf_systems): task_model_list, trj_freq, deepmd_version=deepmd_version, + use_ele_temp=use_ele_temp, ) else: if len(lmp_lines[template_pair_deepmd_idx].split()) != ( @@ -1466,6 +1480,7 @@ def _make_model_devi_revmat(iter_index, jdata, mdata, conf_systems): task_model_list, trj_freq, deepmd_version=deepmd_version, + use_ele_temp=use_ele_temp, ) # use revise_lmp_input_model to raise error message if "part_style" or "deepmd" not found else: @@ -1474,6 +1489,7 @@ def _make_model_devi_revmat(iter_index, jdata, mdata, conf_systems): task_model_list, trj_freq, deepmd_version=deepmd_version, + use_ele_temp=use_ele_temp, ) lmp_lines = revise_lmp_input_dump(