Skip to content

Commit

Permalink
fix: fix checkpoint filename for the PyTorch backend (#1585)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **Bug Fixes**
- Improved handling of model checkpoints during training to ensure
compatibility with different file suffixes (`.pb` and `.pth`).

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Jul 13, 2024
1 parent 7f14124 commit 13dc51f
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion dpgen/generator/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,13 @@ def run_train_dp(iter_index, jdata, mdata):
elif training_finetune_model is not None:
init_flag = f" --finetune old/init{suffix}"
command = f"{train_command} train {train_input_file}{extra_flags}"
command = f"{{ if [ ! -f model.ckpt.index ]; then {command}{init_flag}; else {command} --restart model.ckpt; fi }}"
if suffix == ".pb":
ckpt_suffix = ".index"
elif suffix == ".pth":
ckpt_suffix = ".pt"
else:
raise RuntimeError(f"Unknown suffix {suffix}")
command = f"{{ if [ ! -f model.ckpt{ckpt_suffix} ]; then {command}{init_flag}; else {command} --restart model.ckpt; fi }}"
command = f"/bin/sh -c {shlex.quote(command)}"
commands.append(command)
command = f"{train_command} freeze"
Expand Down

0 comments on commit 13dc51f

Please sign in to comment.