diff --git a/paddlenlp/trainer/unified_checkpoint/async_handler.py b/paddlenlp/trainer/unified_checkpoint/async_handler.py index ffe098808c2f..7e8e11a0e892 100644 --- a/paddlenlp/trainer/unified_checkpoint/async_handler.py +++ b/paddlenlp/trainer/unified_checkpoint/async_handler.py @@ -80,7 +80,8 @@ def _file_save_async_or_sync( if isinstance(state_dict[k], paddle.Tensor): state_dict[k] = state_dict.pop(k).cpu().numpy() - state_dict = quant_unified_optimizer(state_dict, state_dict_type, ckpt_quant_stage) + if state_dict_type == "optimizer_weight" and ckpt_quant_stage != "O0": + state_dict = quant_unified_optimizer(state_dict, state_dict_type, ckpt_quant_stage) safe_save_file(state_dict, path, metadata={"format": "np"}) else: if len(state_dict.keys()) == 0: @@ -206,9 +207,10 @@ def _save_file_async_in_process( signal_path = shared_save_signal_path[:].decode("utf-8").rstrip("\x00") logger.info(f"Start to async save {path}") state_dict = _read_state_dict_from_shm(meta_dict, shm) # numpy array - state_dict = quant_unified_optimizer( - state_dict, state_dict_type, ckpt_quant_stage, async_save=True - ) # ckpt quantization + if state_dict_type == "optimizer_weight" and ckpt_quant_stage != "O0": + state_dict = quant_unified_optimizer( + state_dict, state_dict_type, ckpt_quant_stage, async_save=True + ) # ckpt quantization safe_save_file(state_dict, path, {"format": "np"}) del state_dict saved_signal_path = os.path.join(signal_path, f".{state_dict_type}.done.{global_rank}")