From eecac81231e57fabdb4d9c84bfb08cc7492d76f6 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 23 Nov 2024 02:23:26 -0500 Subject: [PATCH] feat: support jax backend (#1671) ## Summary by CodeRabbit - **New Features** - Added support for the "jax" backend, allowing users to work with model files using the new `.savedmodel` suffix. - **Bug Fixes** - Updated error messages to include "jax" in the list of supported backends, improving clarity for users encountering unsupported backend issues. --------- Signed-off-by: Jinzhe Zeng Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- dpgen/generator/run.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/dpgen/generator/run.py b/dpgen/generator/run.py index 5ee6a6a4d..0dd3fb02f 100644 --- a/dpgen/generator/run.py +++ b/dpgen/generator/run.py @@ -129,13 +129,13 @@ def _get_model_suffix(jdata) -> str: """Return the model suffix based on the backend.""" mlp_engine = jdata.get("mlp_engine", "dp") if mlp_engine == "dp": - suffix_map = {"tensorflow": ".pb", "pytorch": ".pth"} + suffix_map = {"tensorflow": ".pb", "pytorch": ".pth", "jax": ".savedmodel"} backend = jdata.get("train_backend", "tensorflow") if backend in suffix_map: suffix = suffix_map[backend] else: raise ValueError( - f"The backend {backend} is not available. Supported backends are: 'tensorflow', 'pytorch'." + f"The backend {backend} is not available. Supported backends are: 'tensorflow', 'pytorch', 'jax'." ) return suffix else: @@ -766,6 +766,8 @@ def run_train_dp(iter_index, jdata, mdata): # assert train_command == "dp", "The 'train_command' should be 'dp'" # the tests should be updated to run this command if suffix == ".pth": train_command += " --pt" + elif suffix == ".savedmodel": + train_command += " --jax" # paths iter_name = make_iter_name(iter_index) @@ -803,6 +805,8 @@ def run_train_dp(iter_index, jdata, mdata): ckpt_suffix = ".index" elif suffix == ".pth": ckpt_suffix = ".pt" + elif suffix == ".savedmodel": + ckpt_suffix = ".jax" else: raise RuntimeError(f"Unknown suffix {suffix}") command = f"{{ if [ ! -f model.ckpt{ckpt_suffix} ]; then {command}{init_flag}; else {command} --restart model.ckpt; fi }}" @@ -840,6 +844,10 @@ def run_train_dp(iter_index, jdata, mdata): ] elif suffix == ".pth": forward_files += [os.path.join("old", "model.ckpt.pt")] + elif suffix == ".savedmodel": + forward_files += [os.path.join("old", "model.ckpt.jax")] + else: + raise RuntimeError(f"Unknown suffix {suffix}") elif training_init_frozen_model is not None or training_finetune_model is not None: forward_files.append(os.path.join("old", f"init{suffix}")) @@ -860,6 +868,10 @@ def run_train_dp(iter_index, jdata, mdata): ] elif suffix == ".pth": backward_files += ["model.ckpt.pt"] + elif suffix == ".savedmodel": + backward_files += ["model.ckpt.jax"] + else: + raise RuntimeError(f"Unknown suffix {suffix}") if not jdata.get("one_h5", False): init_data_sys_ = jdata["init_data_sys"]