Skip to content

Commit

Permalink
feat: support jax backend (#1671)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Added support for the "jax" backend, allowing users to work with model
files using the new `.savedmodel` suffix.
- **Bug Fixes**
- Updated error messages to include "jax" in the list of supported
backends, improving clarity for users encountering unsupported backend
issues.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
  • Loading branch information
njzjz and coderabbitai[bot] authored Nov 23, 2024
1 parent a5bf532 commit eecac81
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions dpgen/generator/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,13 @@ def _get_model_suffix(jdata) -> str:
"""Return the model suffix based on the backend."""
mlp_engine = jdata.get("mlp_engine", "dp")
if mlp_engine == "dp":
suffix_map = {"tensorflow": ".pb", "pytorch": ".pth"}
suffix_map = {"tensorflow": ".pb", "pytorch": ".pth", "jax": ".savedmodel"}
backend = jdata.get("train_backend", "tensorflow")
if backend in suffix_map:
suffix = suffix_map[backend]
else:
raise ValueError(
f"The backend {backend} is not available. Supported backends are: 'tensorflow', 'pytorch'."
f"The backend {backend} is not available. Supported backends are: 'tensorflow', 'pytorch', 'jax'."
)
return suffix
else:
Expand Down Expand Up @@ -766,6 +766,8 @@ def run_train_dp(iter_index, jdata, mdata):
# assert train_command == "dp", "The 'train_command' should be 'dp'" # the tests should be updated to run this command
if suffix == ".pth":
train_command += " --pt"
elif suffix == ".savedmodel":
train_command += " --jax"

# paths
iter_name = make_iter_name(iter_index)
Expand Down Expand Up @@ -803,6 +805,8 @@ def run_train_dp(iter_index, jdata, mdata):
ckpt_suffix = ".index"
elif suffix == ".pth":
ckpt_suffix = ".pt"
elif suffix == ".savedmodel":
ckpt_suffix = ".jax"
else:
raise RuntimeError(f"Unknown suffix {suffix}")
command = f"{{ if [ ! -f model.ckpt{ckpt_suffix} ]; then {command}{init_flag}; else {command} --restart model.ckpt; fi }}"
Expand Down Expand Up @@ -840,6 +844,10 @@ def run_train_dp(iter_index, jdata, mdata):
]
elif suffix == ".pth":
forward_files += [os.path.join("old", "model.ckpt.pt")]
elif suffix == ".savedmodel":
forward_files += [os.path.join("old", "model.ckpt.jax")]
else:
raise RuntimeError(f"Unknown suffix {suffix}")
elif training_init_frozen_model is not None or training_finetune_model is not None:
forward_files.append(os.path.join("old", f"init{suffix}"))

Expand All @@ -860,6 +868,10 @@ def run_train_dp(iter_index, jdata, mdata):
]
elif suffix == ".pth":
backward_files += ["model.ckpt.pt"]
elif suffix == ".savedmodel":
backward_files += ["model.ckpt.jax"]
else:
raise RuntimeError(f"Unknown suffix {suffix}")

if not jdata.get("one_h5", False):
init_data_sys_ = jdata["init_data_sys"]
Expand Down

0 comments on commit eecac81

Please sign in to comment.