diff --git a/dpgen/generator/run.py b/dpgen/generator/run.py index 6d7183387..f8c887838 100644 --- a/dpgen/generator/run.py +++ b/dpgen/generator/run.py @@ -811,7 +811,13 @@ def run_train_dp(iter_index, jdata, mdata): elif training_finetune_model is not None: init_flag = f" --finetune old/init{suffix}" command = f"{train_command} train {train_input_file}{extra_flags}" - command = f"{{ if [ ! -f model.ckpt.index ]; then {command}{init_flag}; else {command} --restart model.ckpt; fi }}" + if suffix == ".pb": + ckpt_suffix = ".index" + elif suffix == ".pth": + ckpt_suffix = ".pt" + else: + raise RuntimeError(f"Unknown suffix {suffix}") + command = f"{{ if [ ! -f model.ckpt{ckpt_suffix} ]; then {command}{init_flag}; else {command} --restart model.ckpt; fi }}" command = f"/bin/sh -c {shlex.quote(command)}" commands.append(command) command = f"{train_command} freeze"