Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Aug 24, 2023
1 parent cf1f538 commit ab9afc0
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 49 deletions.
8 changes: 2 additions & 6 deletions dpgen/simplify/arginfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,8 @@ def general_simplify_arginfo() -> Argument:
doc_model_devi_f_trust_hi = (
"The higher bound of forces for the selection for the model deviation."
)
doc_model_devi_e_trust_lo = (
"The lower bound of energy per atom for the selection for the model deviation. Requires DeePMD-kit version >=2.2.2."
)
doc_model_devi_e_trust_hi = (
"The higher bound of energy per atom for the selection for the model deviation. Requires DeePMD-kit version >=2.2.2."
)
doc_model_devi_e_trust_lo = "The lower bound of energy per atom for the selection for the model deviation. Requires DeePMD-kit version >=2.2.2."
doc_model_devi_e_trust_hi = "The higher bound of energy per atom for the selection for the model deviation. Requires DeePMD-kit version >=2.2.2."

return [
Argument("labeled", bool, optional=True, default=False, doc=doc_labeled),
Expand Down
5 changes: 4 additions & 1 deletion dpgen/simplify/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,10 @@ def post_model_devi(iter_index, jdata, mdata):
subsys = sys_entire[name][idx]
if f_devi >= f_trust_hi or e_devi >= e_trust_hi:
sys_failed.append(subsys)
elif f_trust_lo <= f_devi < f_trust_hi or e_trust_lo <= e_devi < e_trust_hi:
elif (
f_trust_lo <= f_devi < f_trust_hi
or e_trust_lo <= e_devi < e_trust_hi
):
sys_candinate.append(subsys)
elif f_devi < f_trust_lo and e_devi < e_trust_lo:
sys_accurate.append(subsys)
Expand Down
113 changes: 71 additions & 42 deletions tests/simplify/test_post_model_devi.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import sys
import shutil
import os
import shutil
import sys
import unittest
from pathlib import Path

import dpdata
import numpy as np
from pathlib import Path

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
__package__ = "simplify"
Expand All @@ -29,59 +29,88 @@ def setUp(self):
"forces": np.zeros((1, 1, 3), dtype=np.float32),
}
)
self.system.to_deepmd_npy(self.work_path / "data.rest.old" / self.system.formula)
self.system.to_deepmd_npy(
self.work_path / "data.rest.old" / self.system.formula
)
model_devi = np.array([[0, 0.2, 0.1, 0.15, 0.2, 0.1, 0.15, 0.2]])
np.savetxt(self.work_path / "details", model_devi, fmt=["%12d"] + ["%19.6e" for _ in range(7)], header="data.rest.old/" + self.system.formula + "\n step max_devi_v min_devi_v avg_devi_v max_devi_f min_devi_f avg_devi_f devi_e")
np.savetxt(
self.work_path / "details",
model_devi,
fmt=["%12d"] + ["%19.6e" for _ in range(7)],
header="data.rest.old/"
+ self.system.formula
+ "\n step max_devi_v min_devi_v avg_devi_v max_devi_f min_devi_f avg_devi_f devi_e",
)

def tearDown(self):
shutil.rmtree("iter.000001", ignore_errors=True)

def test_post_model_devi_f_candidate(self):
dpgen.simplify.simplify.post_model_devi(1, {
"model_devi_f_trust_lo": 0.15,
"model_devi_f_trust_hi": 0.25,
"model_devi_e_trust_lo": float("inf"),
"model_devi_e_trust_hi": float("inf"),
"iter_pick_number": 1,
}, {})
dpgen.simplify.simplify.post_model_devi(
1,
{
"model_devi_f_trust_lo": 0.15,
"model_devi_f_trust_hi": 0.25,
"model_devi_e_trust_lo": float("inf"),
"model_devi_e_trust_hi": float("inf"),
"iter_pick_number": 1,
},
{},
)
assert (self.work_path / "data.picked" / self.system.formula).exists()

def test_post_model_devi_e_candidate(self):
dpgen.simplify.simplify.post_model_devi(1, {
"model_devi_e_trust_lo": 0.15,
"model_devi_e_trust_hi": 0.25,
"model_devi_f_trust_lo": float("inf"),
"model_devi_f_trust_hi": float("inf"),
"iter_pick_number": 1,
}, {})
dpgen.simplify.simplify.post_model_devi(
1,
{
"model_devi_e_trust_lo": 0.15,
"model_devi_e_trust_hi": 0.25,
"model_devi_f_trust_lo": float("inf"),
"model_devi_f_trust_hi": float("inf"),
"iter_pick_number": 1,
},
{},
)
assert (self.work_path / "data.picked" / self.system.formula).exists()

def test_post_model_devi_f_failed(self):
with self.assertRaises(RuntimeError):
dpgen.simplify.simplify.post_model_devi(1, {
"model_devi_f_trust_lo": 0.0,
"model_devi_f_trust_hi": 0.0,
"model_devi_e_trust_lo": float("inf"),
"model_devi_e_trust_hi": float("inf"),
"iter_pick_number": 1,
}, {})

dpgen.simplify.simplify.post_model_devi(
1,
{
"model_devi_f_trust_lo": 0.0,
"model_devi_f_trust_hi": 0.0,
"model_devi_e_trust_lo": float("inf"),
"model_devi_e_trust_hi": float("inf"),
"iter_pick_number": 1,
},
{},
)

def test_post_model_devi_e_failed(self):
with self.assertRaises(RuntimeError):
dpgen.simplify.simplify.post_model_devi(1, {
"model_devi_e_trust_lo": 0.0,
"model_devi_e_trust_hi": 0.0,
"model_devi_f_trust_lo": float("inf"),
"model_devi_f_trust_hi": float("inf"),
"iter_pick_number": 1,
}, {})
dpgen.simplify.simplify.post_model_devi(
1,
{
"model_devi_e_trust_lo": 0.0,
"model_devi_e_trust_hi": 0.0,
"model_devi_f_trust_lo": float("inf"),
"model_devi_f_trust_hi": float("inf"),
"iter_pick_number": 1,
},
{},
)

def test_post_model_devi_accurate(self):
dpgen.simplify.simplify.post_model_devi(1, {
"model_devi_e_trust_lo": 0.3,
"model_devi_e_trust_hi": 0.4,
"model_devi_f_trust_lo": 0.3,
"model_devi_f_trust_hi": 0.4,
"iter_pick_number": 1,
}, {})
dpgen.simplify.simplify.post_model_devi(
1,
{
"model_devi_e_trust_lo": 0.3,
"model_devi_e_trust_hi": 0.4,
"model_devi_f_trust_lo": 0.3,
"model_devi_f_trust_hi": 0.4,
"iter_pick_number": 1,
},
{},
)
assert (self.work_path / "data.accurate" / self.system.formula).exists()

0 comments on commit ab9afc0

Please sign in to comment.