diff --git a/src/pipelines/pipeline.py b/src/pipelines/pipeline.py index 3140a2f..9624ac0 100644 --- a/src/pipelines/pipeline.py +++ b/src/pipelines/pipeline.py @@ -100,7 +100,9 @@ def train( or config.strategy == "deepspeed_stage_3_offload" ): for epoch in range(config.epoch): - ckpt_path = f"{config.callbacks.model_checkpoint.dirpath}/epoch{epoch}.ckpt" + ckpt_path = ( + f"{config.callbacks.model_checkpoint.dirpath}/epoch={epoch}.ckpt" + ) if os.path.exists(ckpt_path) and os.path.isdir(ckpt_path): convert_zero_checkpoint_to_fp32_state_dict( ckpt_path,