Skip to content

Commit

Permalink
Merge branch 'devel' into simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
thangckt authored Jul 10, 2024
2 parents d6d1d12 + 2f39d12 commit 314d159
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 18 deletions.
1 change: 0 additions & 1 deletion .git_archival.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
node: $Format:%H$
node-date: $Format:%cI$
describe-name: $Format:%(describe:tags=true,match=*[0-9]*)$
ref-names: $Format:%D$
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ repos:

# Python
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.5
rev: v0.5.1
hooks:
- id: ruff
args: ["--fix"]
Expand Down
32 changes: 27 additions & 5 deletions dpgen/generator/arginfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,14 @@ def data_args() -> list[Argument]:
# Training


def training_args() -> list[Argument]:
def training_args_common() -> list[Argument]:
doc_numb_models = "Number of models to be trained in 00.train. 4 is recommend."
return [
Argument("numb_models", int, optional=False, doc=doc_numb_models),
]


def training_args_dp() -> list[Argument]:
"""Traning arguments.
Returns
Expand All @@ -90,7 +97,6 @@ def training_args() -> list[Argument]:
doc_train_backend = (
"The backend of the training. Currently only support tensorflow and pytorch."
)
doc_numb_models = "Number of models to be trained in 00.train. 4 is recommend."
doc_training_iter0_model_path = "The model used to init the first iter training. Number of element should be equal to numb_models."
doc_training_init_model = "Iteration > 0, the model parameters will be initilized from the model trained at the previous iteration. Iteration == 0, the model parameters will be initialized from training_iter0_model_path."
doc_default_training_param = "Training parameters for deepmd-kit in 00.train. You can find instructions from `DeePMD-kit documentation <https://docs.deepmodeling.org/projects/deepmd/>`_."
Expand Down Expand Up @@ -133,7 +139,6 @@ def training_args() -> list[Argument]:
default="tensorflow",
doc=doc_train_backend,
),
Argument("numb_models", int, optional=False, doc=doc_numb_models),
Argument(
"training_iter0_model_path",
list[str],
Expand Down Expand Up @@ -224,6 +229,19 @@ def training_args() -> list[Argument]:
]


def training_args() -> Variant:
doc_mlp_engine = "Machine learning potential engine. Currently, only DeePMD-kit (defualt) is supported."
doc_dp = "DeePMD-kit."
return Variant(
"mlp_engine",
[
Argument("dp", dict, training_args_dp(), doc=doc_dp),
],
default_tag="dp",
doc=doc_mlp_engine,
)


# Exploration
def model_devi_jobs_template_args() -> Argument:
doc_template = (
Expand Down Expand Up @@ -987,7 +1005,11 @@ def run_jdata_arginfo() -> Argument:
return Argument(
"run_jdata",
dict,
sub_fields=basic_args() + data_args() + training_args() + fp_args(),
sub_variants=model_devi_args() + [fp_style_variant_type_args()],
sub_fields=basic_args() + data_args() + training_args_common() + fp_args(),
sub_variants=[
training_args(),
*model_devi_args(),
fp_style_variant_type_args(),
],
doc=doc_run_jdata,
)
44 changes: 36 additions & 8 deletions dpgen/generator/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,19 @@

def _get_model_suffix(jdata) -> str:
"""Return the model suffix based on the backend."""
suffix_map = {"tensorflow": ".pb", "pytorch": ".pth"}
backend = jdata.get("train_backend", "tensorflow")
if backend in suffix_map:
suffix = suffix_map[backend]
mlp_engine = jdata.get("mlp_engine", "dp")
if mlp_engine == "dp":
suffix_map = {"tensorflow": ".pb", "pytorch": ".pth"}
backend = jdata.get("train_backend", "tensorflow")
if backend in suffix_map:
suffix = suffix_map[backend]
else:
raise ValueError(
f"The backend {backend} is not available. Supported backends are: 'tensorflow', 'pytorch'."
)
return suffix
else:
raise ValueError(
f"The backend {backend} is not available. Supported backends are: 'tensorflow', 'pytorch'."
)
return suffix
raise ValueError(f"Unsupported engine: {mlp_engine}")


def get_job_names(jdata):
Expand Down Expand Up @@ -270,6 +274,14 @@ def dump_to_deepmd_raw(dump, deepmd_raw, type_map, fmt="gromacs/gro", charge=Non


def make_train(iter_index, jdata, mdata):
mlp_engine = jdata.get("mlp_engine", "dp")
if mlp_engine == "dp":
return make_train_dp(iter_index, jdata, mdata)
else:
raise ValueError(f"Unsupported engine: {mlp_engine}")


def make_train_dp(iter_index, jdata, mdata):
# load json param
# train_param = jdata['train_param']
train_input_file = default_train_input_file
Expand Down Expand Up @@ -714,6 +726,14 @@ def get_nframes(system):


def run_train(iter_index, jdata, mdata):
mlp_engine = jdata.get("mlp_engine", "dp")
if mlp_engine == "dp":
return make_train_dp(iter_index, jdata, mdata)
else:
raise ValueError(f"Unsupported engine: {mlp_engine}")


def run_train_dp(iter_index, jdata, mdata):
# print("debug:run_train:mdata", mdata)
# load json param
numb_models = jdata["numb_models"]
Expand Down Expand Up @@ -899,6 +919,14 @@ def run_train(iter_index, jdata, mdata):


def post_train(iter_index, jdata, mdata):
mlp_engine = jdata.get("mlp_engine", "dp")
if mlp_engine == "dp":
return post_train_dp(iter_index, jdata, mdata)
else:
raise ValueError(f"Unsupported engine: {mlp_engine}")


def post_train_dp(iter_index, jdata, mdata):
# load json param
numb_models = jdata["numb_models"]
# paths
Expand Down
4 changes: 3 additions & 1 deletion dpgen/simplify/arginfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
fp_style_siesta_args,
fp_style_vasp_args,
training_args,
training_args_common,
)


Expand Down Expand Up @@ -201,10 +202,11 @@ def simplify_jdata_arginfo() -> Argument:
*data_args(),
*general_simplify_arginfo(),
# simplify use the same training method as run
*training_args(),
*training_args_common(),
*fp_args(),
],
sub_variants=[
training_args(),
fp_style_variant_type_args(),
],
doc=doc_run_jdata,
Expand Down
8 changes: 8 additions & 0 deletions dpgen/simplify/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,14 @@ def get_multi_system(path: Union[str, list[str]], jdata: dict) -> dpdata.MultiSy


def init_model(iter_index, jdata, mdata):
mlp_engine = jdata.get("mlp_engine", "dp")
if mlp_engine == "dp":
init_model_dp(iter_index, jdata, mdata)
else:
raise TypeError(f"unsupported engine {mlp_engine}")


def init_model_dp(iter_index, jdata, mdata):
training_init_model = jdata.get("training_init_model", False)
if not training_init_model:
return
Expand Down
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dependencies = [
'paramiko',
'custodian',
'GromacsWrapper>=0.8.0',
'GromacsWrapper>=0.9.0; python_version >= "3.12"',
'dpdispatcher>=0.3.11',
'netCDF4',
'dargs>=0.4.0',
Expand Down Expand Up @@ -64,8 +65,9 @@ test = [
"dpgui",
"coverage",
"pymatgen-analysis-defects<2023.08.22",
# To be fixed: https://github.com/Becksteinlab/GromacsWrapper/issues/263
'setuptools; python_version >= "3.12"',
# https://github.com/materialsproject/pymatgen/issues/3882
# https://github.com/kuelumbus/rdkit-pypi/issues/102
"numpy<2",
]
gui = [
"dpgui",
Expand Down

0 comments on commit 314d159

Please sign in to comment.