From 9e48f9ffa04e80a712730039e6442f4b897d1e2f Mon Sep 17 00:00:00 2001 From: DimensionSTP Date: Fri, 17 May 2024 12:01:37 +0900 Subject: [PATCH] feat: a detail of ckpt name changed --- src/pipelines/pipeline.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/pipelines/pipeline.py b/src/pipelines/pipeline.py index 3140a2f..9624ac0 100644 --- a/src/pipelines/pipeline.py +++ b/src/pipelines/pipeline.py @@ -100,7 +100,9 @@ def train( or config.strategy == "deepspeed_stage_3_offload" ): for epoch in range(config.epoch): - ckpt_path = f"{config.callbacks.model_checkpoint.dirpath}/epoch{epoch}.ckpt" + ckpt_path = ( + f"{config.callbacks.model_checkpoint.dirpath}/epoch={epoch}.ckpt" + ) if os.path.exists(ckpt_path) and os.path.isdir(ckpt_path): convert_zero_checkpoint_to_fp32_state_dict( ckpt_path,