Skip to content

Commit

Permalink
merge #1662 to devel (#1669)
Browse files Browse the repository at this point in the history
#1662 was wrongly merged into master.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
- Enhanced input generation for LAMMPS to accommodate electron
temperature settings.
- Added a new parameter, `use_ele_temp`, for improved flexibility in
input handling.

- **Bug Fixes**
- Ensured correct retrieval and passing of the `use_ele_temp` value
during model deviation tasks.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
njzjz authored Nov 23, 2024
2 parents b72b25d + 6943db5 commit a5bf532
Showing 1 changed file with 23 additions and 7 deletions.
30 changes: 23 additions & 7 deletions dpgen/generator/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,7 +1061,9 @@ def find_only_one_key(lmp_lines, key):
return found[0]


def revise_lmp_input_model(lmp_lines, task_model_list, trj_freq, deepmd_version="1"):
def revise_lmp_input_model(
lmp_lines, task_model_list, trj_freq, deepmd_version="1", use_ele_temp=0
):
idx = find_only_one_key(lmp_lines, ["pair_style", "deepmd"])
graph_list = " ".join(task_model_list)
if Version(deepmd_version) < Version("1"):
Expand All @@ -1070,13 +1072,22 @@ def revise_lmp_input_model(lmp_lines, task_model_list, trj_freq, deepmd_version=
trj_freq,
)
else:
lmp_lines[idx] = (
"pair_style deepmd %s out_freq %d out_file model_devi.out\n"
% (
graph_list,
trj_freq,
if use_ele_temp == 0:
lmp_lines[idx] = (
"pair_style deepmd %s out_freq %d out_file model_devi.out\n"
% (
graph_list,
trj_freq,
)
)
elif use_ele_temp == 1:
lmp_lines[idx] = (
"pair_style deepmd %s out_freq %d out_file model_devi.out fparam ${ELE_TEMP}\n"
% (
graph_list,
trj_freq,
)
)
)
return lmp_lines


Expand Down Expand Up @@ -1340,6 +1351,8 @@ def _make_model_devi_revmat(iter_index, jdata, mdata, conf_systems):
sys_idx = expand_idx(cur_job["sys_idx"])
if len(sys_idx) != len(list(set(sys_idx))):
raise RuntimeError("system index should be uniq")

use_ele_temp = jdata.get("use_ele_temp", 0)
mass_map = jdata["mass_map"]
use_plm = jdata.get("model_devi_plumed", False)
use_plm_path = jdata.get("model_devi_plumed_path", False)
Expand Down Expand Up @@ -1446,6 +1459,7 @@ def _make_model_devi_revmat(iter_index, jdata, mdata, conf_systems):
task_model_list,
trj_freq,
deepmd_version=deepmd_version,
use_ele_temp=use_ele_temp,
)
else:
if len(lmp_lines[template_pair_deepmd_idx].split()) != (
Expand All @@ -1466,6 +1480,7 @@ def _make_model_devi_revmat(iter_index, jdata, mdata, conf_systems):
task_model_list,
trj_freq,
deepmd_version=deepmd_version,
use_ele_temp=use_ele_temp,
)
# use revise_lmp_input_model to raise error message if "part_style" or "deepmd" not found
else:
Expand All @@ -1474,6 +1489,7 @@ def _make_model_devi_revmat(iter_index, jdata, mdata, conf_systems):
task_model_list,
trj_freq,
deepmd_version=deepmd_version,
use_ele_temp=use_ele_temp,
)

lmp_lines = revise_lmp_input_dump(
Expand Down

0 comments on commit a5bf532

Please sign in to comment.