Skip to content

Commit

Permalink
[Unified Checkpoint] Fix checkpoint quant log (PaddlePaddle#9606) (Pa…
Browse files Browse the repository at this point in the history
  • Loading branch information
wtmlon authored Dec 12, 2024
1 parent 05e6f06 commit e473a81
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions paddlenlp/trainer/unified_checkpoint/async_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def _file_save_async_or_sync(
if isinstance(state_dict[k], paddle.Tensor):
state_dict[k] = state_dict.pop(k).cpu().numpy()

state_dict = quant_unified_optimizer(state_dict, state_dict_type, ckpt_quant_stage)
if state_dict_type == "optimizer_weight" and ckpt_quant_stage != "O0":
state_dict = quant_unified_optimizer(state_dict, state_dict_type, ckpt_quant_stage)
safe_save_file(state_dict, path, metadata={"format": "np"})
else:
if len(state_dict.keys()) == 0:
Expand Down Expand Up @@ -206,9 +207,10 @@ def _save_file_async_in_process(
signal_path = shared_save_signal_path[:].decode("utf-8").rstrip("\x00")
logger.info(f"Start to async save {path}")
state_dict = _read_state_dict_from_shm(meta_dict, shm) # numpy array
state_dict = quant_unified_optimizer(
state_dict, state_dict_type, ckpt_quant_stage, async_save=True
) # ckpt quantization
if state_dict_type == "optimizer_weight" and ckpt_quant_stage != "O0":
state_dict = quant_unified_optimizer(
state_dict, state_dict_type, ckpt_quant_stage, async_save=True
) # ckpt quantization
safe_save_file(state_dict, path, {"format": "np"})
del state_dict
saved_signal_path = os.path.join(signal_path, f".{state_dict_type}.done.{global_rank}")
Expand Down

0 comments on commit e473a81

Please sign in to comment.