diff --git a/dpgen/generator/run.py b/dpgen/generator/run.py index 847cdb594..d773cb861 100644 --- a/dpgen/generator/run.py +++ b/dpgen/generator/run.py @@ -95,6 +95,7 @@ normalize, sepline, set_directory, + setup_ele_temp, ) from .arginfo import run_jdata_arginfo @@ -4024,10 +4025,6 @@ def post_fp_vasp(iter_index, jdata, rfailed=None): _sys = dpdata.LabeledSystem() dlog.info("Failed fp path: %s" % oo.replace("OUTCAR", "")) if len(_sys) == 1: - if all_sys is None: - all_sys = _sys - else: - all_sys.append(_sys) # save ele_temp, if any if os.path.exists(oo.replace("OUTCAR", "job.json")): with open(oo.replace("OUTCAR", "job.json")) as fp: @@ -4036,6 +4033,27 @@ def post_fp_vasp(iter_index, jdata, rfailed=None): assert use_ele_temp ele_temp = job_data["ele_temp"] all_te.append(ele_temp) + if use_ele_temp == 0: + raise RuntimeError( + "should not get ele temp at setting: use_ele_temp == 0" + ) + elif use_ele_temp == 1: + _sys.data["fparam"] = np.array(ele_temp).reshape(1, 1) + elif use_ele_temp == 2: + tile_te = np.tile(ele_temp, [_sys.get_natoms()]) + _sys.data["aparam"] = tile_te.reshape( + 1, _sys.get_natoms(), 1 + ) + else: + raise RuntimeError( + "invalid setting of use_ele_temp " + str(use_ele_temp) + ) + # check if ele_temp shape is correct + _sys.check_data() + if all_sys is None: + all_sys = _sys + else: + all_sys.append(_sys) elif len(_sys) >= 2: raise RuntimeError("The vasp parameter NSW should be set as 1") else: @@ -4045,29 +4063,6 @@ def post_fp_vasp(iter_index, jdata, rfailed=None): sys_data_path = os.path.join(work_path, "data.%s" % ss) all_sys.to_deepmd_raw(sys_data_path) all_sys.to_deepmd_npy(sys_data_path, set_size=len(sys_outcars)) - if all_te.size > 0: - assert len(all_sys) == all_sys.get_nframes() - assert len(all_sys) == all_te.size - all_te = np.reshape(all_te, [-1, 1]) - if use_ele_temp == 0: - raise RuntimeError( - "should not get ele temp at setting: use_ele_temp == 0" - ) - elif use_ele_temp == 1: - np.savetxt(os.path.join(sys_data_path, "fparam.raw"), all_te) - np.save( - os.path.join(sys_data_path, "set.000", "fparam.npy"), all_te - ) - elif use_ele_temp == 2: - tile_te = np.tile(all_te, [1, all_sys.get_natoms()]) - np.savetxt(os.path.join(sys_data_path, "aparam.raw"), tile_te) - np.save( - os.path.join(sys_data_path, "set.000", "aparam.npy"), tile_te - ) - else: - raise RuntimeError( - "invalid setting of use_ele_temp " + str(use_ele_temp) - ) if tcount == 0: rfail = 0.0 @@ -4512,6 +4507,13 @@ def run_iter(param_file, machine_file): update_mass_map(jdata) + # set up electron temperature + use_ele_temp = jdata.get("use_ele_temp", 0) + if use_ele_temp == 1: + setup_ele_temp(False) + elif use_ele_temp == 2: + setup_ele_temp(True) + if jdata.get("pretty_print", False): from monty.serialization import dumpfn diff --git a/dpgen/simplify/simplify.py b/dpgen/simplify/simplify.py index c392adbf5..089545026 100644 --- a/dpgen/simplify/simplify.py +++ b/dpgen/simplify/simplify.py @@ -44,7 +44,7 @@ train_task_fmt, ) from dpgen.remote.decide_machine import convert_mdata -from dpgen.util import expand_sys_str, load_file, normalize, sepline +from dpgen.util import expand_sys_str, load_file, normalize, sepline, setup_ele_temp from .arginfo import simplify_jdata_arginfo @@ -519,6 +519,13 @@ def run_iter(param_file, machine_file): jdata_arginfo = simplify_jdata_arginfo() jdata = normalize(jdata_arginfo, jdata) + # set up electron temperature + use_ele_temp = jdata.get("use_ele_temp", 0) + if use_ele_temp == 1: + setup_ele_temp(False) + elif use_ele_temp == 2: + setup_ele_temp(True) + if mdata.get("handlers", None): if mdata["handlers"].get("smtp", None): que = queue.Queue(-1) diff --git a/dpgen/util.py b/dpgen/util.py index 896a3d504..cd38d1473 100644 --- a/dpgen/util.py +++ b/dpgen/util.py @@ -9,7 +9,9 @@ import dpdata import h5py +import numpy as np from dargs import Argument +from dpdata.data_type import Axis, DataType from dpgen import dlog @@ -215,3 +217,30 @@ def load_file(filename: Union[str, os.PathLike]) -> dict: else: raise ValueError(f"Unsupported file format: {filename}") return data + + +def setup_ele_temp(atomic: bool): + """Set electronic temperature as required input data. + + Parameters + ---------- + atomic : bool + Whether to use atomic temperature or frame temperature + """ + if atomic: + ele_temp_data_type = DataType( + "aparam", + np.ndarray, + shape=(Axis.NFRAMES, Axis.NATOMS, 1), + required=False, + ) + else: + ele_temp_data_type = DataType( + "fparam", + np.ndarray, + shape=(Axis.NFRAMES, 1), + required=False, + ) + + dpdata.System.register_data_type(ele_temp_data_type) + dpdata.LabeledSystem.register_data_type(ele_temp_data_type) diff --git a/pyproject.toml b/pyproject.toml index b89858ba6..2e80d1e61 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ classifiers = [ ] dependencies = [ 'numpy>=1.14.3', - 'dpdata>=0.2.6,!=0.2.11', + 'dpdata>=0.2.16', 'pymatgen>=2022.11.1', 'ase', 'monty>2.0.0', diff --git a/tests/generator/context.py b/tests/generator/context.py index 6da7932bc..7a6f1f2d3 100644 --- a/tests/generator/context.py +++ b/tests/generator/context.py @@ -17,6 +17,7 @@ _parse_calypso_input, # noqa: F401 ) from dpgen.generator.run import * # noqa: F403 +from dpgen.util import setup_ele_temp # noqa: F401 param_file = "param-mg-vasp.json" param_file_merge_traj = "param-mg-vasp_merge_traj.json" diff --git a/tests/generator/test_post_fp.py b/tests/generator/test_post_fp.py index 240ec99d6..a4f6adc5f 100644 --- a/tests/generator/test_post_fp.py +++ b/tests/generator/test_post_fp.py @@ -24,6 +24,7 @@ param_siesta_file, post_fp, post_fp_vasp, + setup_ele_temp, setUpModule, # noqa: F401 ) @@ -59,14 +60,21 @@ def setUp(self): self.ref_e = np.array(self.ref_e) self.ref_f = np.array(self.ref_f) self.ref_v = np.array(self.ref_v) + # backup dpdata system data type + self._system_dtypes = dpdata.System.DTYPES + self._labeled_system_dtypes = dpdata.LabeledSystem.DTYPES def tearDown(self): shutil.rmtree("iter.000000") + # recover + dpdata.System.DTYPES = self._system_dtypes + dpdata.LabeledSystem.DTYPES = self._labeled_system_dtypes def test_post_fp_vasp_0(self): with open(param_file) as fp: jdata = json.load(fp) jdata["use_ele_temp"] = 2 + setup_ele_temp(True) post_fp_vasp(0, jdata, rfailed=0.3) sys = dpdata.LabeledSystem("iter.000000/02.fp/data.000/", fmt="deepmd/raw") @@ -117,6 +125,7 @@ def test_post_fp_vasp_1(self): with open(param_file) as fp: jdata = json.load(fp) jdata["use_ele_temp"] = 1 + setup_ele_temp(False) post_fp_vasp(0, jdata, rfailed=0.3) sys = dpdata.LabeledSystem("iter.000000/02.fp/data.001/", fmt="deepmd/raw")