From b219ba6fcf60cb491143df5bb94ad0119b28786b Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Fri, 25 Oct 2024 17:41:26 +0800 Subject: [PATCH] update async handler --- .../unified_checkpoint/async_uc_hander.py | 250 ++++++++++++++++++ .../trainer/unified_checkpoint/check_uc.py | 6 +- .../trainer/unified_checkpoint/uc_dynamic.py | 4 +- .../unified_checkpoint/uc_locally_load.py | 2 + .../unified_checkpoint/uc_sharding_v2.py | 2 + .../unified_checkpoint/uc_single_card.py | 7 + .../unified_checkpoint/unified_checkpoint.py | 226 +--------------- 7 files changed, 270 insertions(+), 227 deletions(-) create mode 100644 paddlenlp/trainer/unified_checkpoint/async_uc_hander.py diff --git a/paddlenlp/trainer/unified_checkpoint/async_uc_hander.py b/paddlenlp/trainer/unified_checkpoint/async_uc_hander.py new file mode 100644 index 000000000000..d57386e9591a --- /dev/null +++ b/paddlenlp/trainer/unified_checkpoint/async_uc_hander.py @@ -0,0 +1,250 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Asynchronous unified checkpoint handler.""" + +import multiprocessing +import os +import time +from multiprocessing import shared_memory + +import paddle +import paddle.distributed as dist + +from paddlenlp.transformers.utils import is_safetensors_available +from paddlenlp.utils.log import logger + +if is_safetensors_available(): + from safetensors.numpy import save_file as safe_save_file + +from .shared_memory_utils import ( + _read_state_dict_from_shm, + _traverse_copy_to_shm, + create_meta_dict, +) + +__all__ = ["AsyncCheckpointHander"] + + +class AsyncCheckpointHander: + def __init__(self, args): + # Mainly for asynchronous saving. + self.args = args + self.global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 + + self._shm_model_weight = None + self._shm_master_weight = None + self._shm_optimizer_weight = None + self._meta_dict_model = None + self._meta_dict_master_weight = None + self._meta_dict_optim = None + self._process_model_weight = None + self._process_master_weight = None + self._process_optimizer_weight = None + self._lock = None + self._shared_save_model_flag = None + self._shared_save_master_weight_flag = None + self._shared_save_optimizer_flag = None + + if "async_save" in self.args.unified_checkpoint_config: + self._lock = multiprocessing.Lock() + self._shared_save_model_path = multiprocessing.Array("c", 100000) + self._shared_save_model_signal_path = multiprocessing.Array("c", 100000) + self._shared_save_master_weight_path = multiprocessing.Array("c", 100000) + self._shared_save_master_weight_signal_path = multiprocessing.Array("c", 100000) + self._shared_save_optimizer_path = multiprocessing.Array("c", 100000) + self._shared_save_optimizer_signal_path = multiprocessing.Array("c", 100000) + self._shared_save_model_flag = multiprocessing.Array("i", 1) + self._shared_save_master_weight_flag = multiprocessing.Array("i", 1) + self._shared_save_optimizer_flag = multiprocessing.Array("i", 1) + + def _file_save_async_or_sync( + self, state_dict, path, signal_path=None, is_sync=True, state_dict_type="model_weight" + ): + if is_sync: + for k in list(state_dict.keys()): + if isinstance(state_dict[k], paddle.Tensor): + state_dict[k] = state_dict.pop(k).cpu().numpy() + safe_save_file(state_dict, path, metadata={"format": "np"}) + else: + if state_dict_type == "model_weight": + if self._shm_model_weight is None: + self._meta_dict_model, buffer_size = create_meta_dict(state_dict) + self._shm_model_weight = shared_memory.SharedMemory(create=True, size=buffer_size) + shm_state_dict = self._shm_model_weight + meta_dict = self._meta_dict_model + shared_save_flag = self._shared_save_model_flag + shared_save_path = self._shared_save_model_path + shared_save_signal_path = self._shared_save_model_signal_path + if self._process_model_weight is None: + self._process_model_weight = multiprocessing.Process( + target=self._save_file_async_in_process, + args=( + meta_dict, + self._shm_model_weight.name, + self._shared_save_model_flag, + self._shared_save_model_path, + self._shared_save_model_signal_path, + self._lock, + state_dict_type, + self.global_rank, + ), + ) + self._process_model_weight.start() + process = self._process_model_weight + elif state_dict_type == "master_weight": + if self._shm_master_weight is None: + self._meta_dict_master_weight, buffer_size = create_meta_dict(state_dict) + self._shm_master_weight = shared_memory.SharedMemory(create=True, size=buffer_size) + shm_state_dict = self._shm_master_weight + meta_dict = self._meta_dict_master_weight + shared_save_flag = self._shared_save_master_weight_flag + shared_save_path = self._shared_save_master_weight_path + shared_save_signal_path = self._shared_save_master_weight_signal_path + if self._process_master_weight is None: + self._process_master_weight = multiprocessing.Process( + target=self._save_file_async_in_process, + args=( + meta_dict, + self._shm_master_weight.name, + self._shared_save_master_weight_flag, + self._shared_save_master_weight_path, + self._shared_save_master_weight_signal_path, + self._lock, + "model_weight" + if "skip_save_model_weight" in self.args.unified_checkpoint_config + else state_dict_type, + self.global_rank, + ), + ) + self._process_master_weight.start() + process = self._process_master_weight + elif state_dict_type == "optimizer_weight": + if self._shm_optimizer_weight is None: + self._meta_dict_optim, buffer_size = create_meta_dict(state_dict) + self._shm_optimizer_weight = shared_memory.SharedMemory(create=True, size=buffer_size) + shm_state_dict = self._shm_optimizer_weight + meta_dict = self._meta_dict_optim + shared_save_flag = self._shared_save_optimizer_flag + shared_save_path = self._shared_save_optimizer_path + shared_save_signal_path = self._shared_save_optimizer_signal_path + if self._process_optimizer_weight is None: + self._process_optimizer_weight = multiprocessing.Process( + target=self._save_file_async_in_process, + args=( + meta_dict, + self._shm_optimizer_weight.name, + self._shared_save_optimizer_flag, + self._shared_save_optimizer_path, + self._shared_save_optimizer_signal_path, + self._lock, + state_dict_type, + self.global_rank, + ), + ) + self._process_optimizer_weight.start() + process = self._process_optimizer_weight + + while True: # wait until no process is saving. + flag_value = shared_save_flag[0] + if flag_value == 0: + break + if not process.is_alive(): + raise RuntimeError(f"The process that saves {state_dict_type} has been killed unexpectedly.") + time.sleep(0.5) + logger.info(f"Wait for the previous save process to finish saving {state_dict_type}") + # only save model weight or save master weight, we enter this loop. + self._reset_and_update(shared_save_path, path) + self._reset_and_update(shared_save_signal_path, signal_path) + _traverse_copy_to_shm(state_dict, meta_dict, shm_state_dict.buf) + with self._lock: + shared_save_flag[0] = 1 + + def _save_file_async_in_process( + self, + meta_dict, + shm_name, + shared_save_flag, + shared_save_path, + shared_save_signal_path, + lock, + state_dict_type, + global_rank, + ): + shm = shared_memory.SharedMemory(name=shm_name) + while True: + flag_value = shared_save_flag[0] # if process uses `spawn`, cannot read this value. + if flag_value == -1: # stop process + break + if flag_value == 0: # nothing to save + continue + if flag_value == 1: # need to save + path = shared_save_path[:].decode("utf-8").rstrip("\x00") + signal_path = shared_save_signal_path[:].decode("utf-8").rstrip("\x00") + logger.info(f"Start to async save {path}") + state_dict = _read_state_dict_from_shm(meta_dict, shm) # numpy array + safe_save_file(state_dict, path, {"format": "np"}) + del state_dict + saved_signal_path = os.path.join(signal_path, f".{state_dict_type}.done.{global_rank}") + paddle.save(global_rank, saved_signal_path) + with lock: + shared_save_flag[0] = 0 + time.sleep(0.5) + shm.close() + + def _reset_and_update(self, shared_array, new_value): + # clear array + for i in range(len(shared_array)): + shared_array[i] = b"\0" + # update array + encoded_value = new_value.encode("utf-8") + shared_array[: len(encoded_value)] = encoded_value + + def unlink_shared_memory(self): + if not ("async_save" in self.args.unified_checkpoint_config): + return + + if self._shared_save_model_flag is not None: + while self._shared_save_model_flag[0] > 0: # async process is saving + if not self._process_model_weight.is_alive(): + raise RuntimeError("The process that saves model_weight has been killed unexpectedly.") + time.sleep(0.5) + self._shared_save_model_flag[0] = -1 + if self._shared_save_master_weight_flag is not None: + while self._shared_save_master_weight_flag[0] > 0: + if not self._process_master_weight.is_alive(): + raise RuntimeError("The process that saves master_weight has been killed unexpectedly.") + time.sleep(0.5) + self._shared_save_master_weight_flag[0] = -1 + if self._shared_save_optimizer_flag is not None: + while self._shared_save_optimizer_flag[0] > 0: + if not self._process_optimizer_weight.is_alive(): + raise RuntimeError("The process that saves optimizer_weight has been killed unexpectedly.") + time.sleep(0.5) + self._shared_save_optimizer_flag[0] = -1 + + if self._shm_model_weight is not None: + self._shm_model_weight.close() + self._shm_model_weight.unlink() + self._shm_model_weight = None + if self._shm_master_weight is not None: + self._shm_master_weight.close() + self._shm_master_weight.unlink() + self._shm_master_weight = None + if self._shm_optimizer_weight is not None: + self._shm_optimizer_weight.close() + self._shm_optimizer_weight.unlink() + self._shm_optimizer_weight = None + + if paddle.distributed.get_world_size() > 1: + dist.barrier() diff --git a/paddlenlp/trainer/unified_checkpoint/check_uc.py b/paddlenlp/trainer/unified_checkpoint/check_uc.py index f03bc14e46e4..287abe01f9c0 100644 --- a/paddlenlp/trainer/unified_checkpoint/check_uc.py +++ b/paddlenlp/trainer/unified_checkpoint/check_uc.py @@ -42,6 +42,8 @@ update_master_weight_status, ) +__all__ = ["check_unified_checkpoint", "check_unified_optimizer"] + def check_unified_checkpoint(args, model, resume_from_checkpoint, safe_serialization=False): index_filename = select_model_weight_index(model, resume_from_checkpoint, safe_serialization, local=False) @@ -102,7 +104,7 @@ def check_unified_checkpoint(args, model, resume_from_checkpoint, safe_serializa else: local_resume = False local_resume = paddle.to_tensor([local_resume]) - dist.all_reduce(local_resume, op=dist.ReduceOp.PROD) + dist.all_reduce(local_resume, op=dist.ReduceOp.MIN) local_resume = local_resume.item() return local_resume @@ -226,7 +228,7 @@ def check_dynamic_load(args, weight_map, existed_files, is_master_weights=False, else: local_resume = False local_resume = paddle.to_tensor([local_resume]) - dist.all_reduce(local_resume, op=dist.ReduceOp.PROD) + dist.all_reduce(local_resume, op=dist.ReduceOp.MIN) return local_resume.item() # check whether the optimizer checkpoint files are complete. diff --git a/paddlenlp/trainer/unified_checkpoint/uc_dynamic.py b/paddlenlp/trainer/unified_checkpoint/uc_dynamic.py index 05090189cd47..ee6cbb12dab6 100644 --- a/paddlenlp/trainer/unified_checkpoint/uc_dynamic.py +++ b/paddlenlp/trainer/unified_checkpoint/uc_dynamic.py @@ -55,6 +55,8 @@ update_master_weight_status, ) +__all__ = ["load_unified_checkpoint_dynamically", "load_unified_optimizer_dynamically"] + def create_send_table(file_keyname_mappings, file_machine_mappings): send_table = {} @@ -258,7 +260,7 @@ def distributed_send_recv( return state_dict -def load_uc_dynamically(args, model, resume_from_checkpoint, safe_serialization=False): +def load_unified_checkpoint_dynamically(args, model, resume_from_checkpoint, safe_serialization=False): index_filename = select_model_weight_index(model, resume_from_checkpoint, safe_serialization, local=False) index_filename = os.path.join(resume_from_checkpoint, index_filename) diff --git a/paddlenlp/trainer/unified_checkpoint/uc_locally_load.py b/paddlenlp/trainer/unified_checkpoint/uc_locally_load.py index abd4442dcb84..93ac3b1ae735 100644 --- a/paddlenlp/trainer/unified_checkpoint/uc_locally_load.py +++ b/paddlenlp/trainer/unified_checkpoint/uc_locally_load.py @@ -51,6 +51,8 @@ update_master_weight_status, ) +__all__ = ["load_unified_checkpoint_locally", "load_unified_optimizer_locally"] + def load_unified_checkpoint_locally(args, model, resume_from_checkpoint: str, safe_serialization=False): """ diff --git a/paddlenlp/trainer/unified_checkpoint/uc_sharding_v2.py b/paddlenlp/trainer/unified_checkpoint/uc_sharding_v2.py index ab3f3a7f27b0..1c1602b46133 100644 --- a/paddlenlp/trainer/unified_checkpoint/uc_sharding_v2.py +++ b/paddlenlp/trainer/unified_checkpoint/uc_sharding_v2.py @@ -37,6 +37,8 @@ mapping_optimizer_tp_actions, ) +__all__ = ["gather_splited_param_for_optimizer", "load_unified_optimizer_split_param"] + def merge_splited_param( state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, is_master_weights=False diff --git a/paddlenlp/trainer/unified_checkpoint/uc_single_card.py b/paddlenlp/trainer/unified_checkpoint/uc_single_card.py index 84657c379419..cd16a3866a22 100644 --- a/paddlenlp/trainer/unified_checkpoint/uc_single_card.py +++ b/paddlenlp/trainer/unified_checkpoint/uc_single_card.py @@ -54,6 +54,13 @@ save_model_config, ) +__all__ = [ + "load_single_card_checkpoint", + "load_single_card_optimizer", + "save_single_card_checkpoint", + "save_single_card_optimizer", +] + def save_file_sync(state_dict, path): for k in list(state_dict.keys()): diff --git a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py index 7a229070201a..72b17c33c44c 100644 --- a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py +++ b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py @@ -14,14 +14,10 @@ import copy import json -import multiprocessing import os import sys -import time -from multiprocessing import shared_memory import paddle -import paddle.distributed as dist from paddle.distributed import fleet try: @@ -61,19 +57,13 @@ from paddlenlp.utils.nested import nested_copy if is_safetensors_available(): - from safetensors.numpy import save_file as safe_save_file - if sys.platform.startswith("win"): from safetensors.numpy import load_file else: from paddlenlp.utils.safetensors import fast_load_file as load_file +from .async_uc_hander import AsyncCheckpointHander from .check_uc import check_unified_checkpoint, check_unified_optimizer -from .shared_memory_utils import ( - _read_state_dict_from_shm, - _traverse_copy_to_shm, - create_meta_dict, -) from .uc_dynamic import ( load_unified_checkpoint_dynamically, load_unified_optimizer_dynamically, @@ -107,219 +97,7 @@ save_model_config, ) - -class AsyncCheckpointHander: - def __init__(self, args): - # Mainly for asynchronous saving. - self.args = args - self.global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 - - self._shm_model_weight = None - self._shm_master_weight = None - self._shm_optimizer_weight = None - self._meta_dict_model = None - self._meta_dict_master_weight = None - self._meta_dict_optim = None - self._process_model_weight = None - self._process_master_weight = None - self._process_optimizer_weight = None - self._lock = None - self._shared_save_model_flag = None - self._shared_save_master_weight_flag = None - self._shared_save_optimizer_flag = None - - if "async_save" in self.args.unified_checkpoint_config: - self._lock = multiprocessing.Lock() - self._shared_save_model_path = multiprocessing.Array("c", 100000) - self._shared_save_model_signal_path = multiprocessing.Array("c", 100000) - self._shared_save_master_weight_path = multiprocessing.Array("c", 100000) - self._shared_save_master_weight_signal_path = multiprocessing.Array("c", 100000) - self._shared_save_optimizer_path = multiprocessing.Array("c", 100000) - self._shared_save_optimizer_signal_path = multiprocessing.Array("c", 100000) - self._shared_save_model_flag = multiprocessing.Array("i", 1) - self._shared_save_master_weight_flag = multiprocessing.Array("i", 1) - self._shared_save_optimizer_flag = multiprocessing.Array("i", 1) - - def _file_save_async_or_sync( - self, state_dict, path, signal_path=None, is_sync=True, state_dict_type="model_weight" - ): - if is_sync: - for k in list(state_dict.keys()): - if isinstance(state_dict[k], paddle.Tensor): - state_dict[k] = state_dict.pop(k).cpu().numpy() - safe_save_file(state_dict, path, metadata={"format": "np"}) - else: - if state_dict_type == "model_weight": - if self._shm_model_weight is None: - self._meta_dict_model, buffer_size = create_meta_dict(state_dict) - self._shm_model_weight = shared_memory.SharedMemory(create=True, size=buffer_size) - shm_state_dict = self._shm_model_weight - meta_dict = self._meta_dict_model - shared_save_flag = self._shared_save_model_flag - shared_save_path = self._shared_save_model_path - shared_save_signal_path = self._shared_save_model_signal_path - if self._process_model_weight is None: - self._process_model_weight = multiprocessing.Process( - target=self._save_file_async_in_process, - args=( - meta_dict, - self._shm_model_weight.name, - self._shared_save_model_flag, - self._shared_save_model_path, - self._shared_save_model_signal_path, - self._lock, - state_dict_type, - self.global_rank, - ), - ) - self._process_model_weight.start() - process = self._process_model_weight - elif state_dict_type == "master_weight": - if self._shm_master_weight is None: - self._meta_dict_master_weight, buffer_size = create_meta_dict(state_dict) - self._shm_master_weight = shared_memory.SharedMemory(create=True, size=buffer_size) - shm_state_dict = self._shm_master_weight - meta_dict = self._meta_dict_master_weight - shared_save_flag = self._shared_save_master_weight_flag - shared_save_path = self._shared_save_master_weight_path - shared_save_signal_path = self._shared_save_master_weight_signal_path - if self._process_master_weight is None: - self._process_master_weight = multiprocessing.Process( - target=self._save_file_async_in_process, - args=( - meta_dict, - self._shm_master_weight.name, - self._shared_save_master_weight_flag, - self._shared_save_master_weight_path, - self._shared_save_master_weight_signal_path, - self._lock, - "model_weight" - if "skip_save_model_weight" in self.args.unified_checkpoint_config - else state_dict_type, - self.global_rank, - ), - ) - self._process_master_weight.start() - process = self._process_master_weight - elif state_dict_type == "optimizer_weight": - if self._shm_optimizer_weight is None: - self._meta_dict_optim, buffer_size = create_meta_dict(state_dict) - self._shm_optimizer_weight = shared_memory.SharedMemory(create=True, size=buffer_size) - shm_state_dict = self._shm_optimizer_weight - meta_dict = self._meta_dict_optim - shared_save_flag = self._shared_save_optimizer_flag - shared_save_path = self._shared_save_optimizer_path - shared_save_signal_path = self._shared_save_optimizer_signal_path - if self._process_optimizer_weight is None: - self._process_optimizer_weight = multiprocessing.Process( - target=self._save_file_async_in_process, - args=( - meta_dict, - self._shm_optimizer_weight.name, - self._shared_save_optimizer_flag, - self._shared_save_optimizer_path, - self._shared_save_optimizer_signal_path, - self._lock, - state_dict_type, - self.global_rank, - ), - ) - self._process_optimizer_weight.start() - process = self._process_optimizer_weight - - while True: # wait until no process is saving. - flag_value = shared_save_flag[0] - if flag_value == 0: - break - if not process.is_alive(): - raise RuntimeError(f"The process that saves {state_dict_type} has been killed unexpectedly.") - time.sleep(0.5) - logger.info(f"Wait for the previous save process to finish saving {state_dict_type}") - # only save model weight or save master weight, we enter this loop. - self._reset_and_update(shared_save_path, path) - self._reset_and_update(shared_save_signal_path, signal_path) - _traverse_copy_to_shm(state_dict, meta_dict, shm_state_dict.buf) - with self._lock: - shared_save_flag[0] = 1 - - def _save_file_async_in_process( - self, - meta_dict, - shm_name, - shared_save_flag, - shared_save_path, - shared_save_signal_path, - lock, - state_dict_type, - global_rank, - ): - shm = shared_memory.SharedMemory(name=shm_name) - while True: - flag_value = shared_save_flag[0] # if process uses `spawn`, cannot read this value. - if flag_value == -1: # stop process - break - if flag_value == 0: # nothing to save - continue - if flag_value == 1: # need to save - path = shared_save_path[:].decode("utf-8").rstrip("\x00") - signal_path = shared_save_signal_path[:].decode("utf-8").rstrip("\x00") - logger.info(f"Start to async save {path}") - state_dict = _read_state_dict_from_shm(meta_dict, shm) # numpy array - safe_save_file(state_dict, path, {"format": "np"}) - del state_dict - saved_signal_path = os.path.join(signal_path, f".{state_dict_type}.done.{global_rank}") - paddle.save(global_rank, saved_signal_path) - with lock: - shared_save_flag[0] = 0 - time.sleep(0.5) - shm.close() - - def _reset_and_update(self, shared_array, new_value): - # clear array - for i in range(len(shared_array)): - shared_array[i] = b"\0" - # update array - encoded_value = new_value.encode("utf-8") - shared_array[: len(encoded_value)] = encoded_value - - def unlink_shared_memory(self): - if not ("async_save" in self.args.unified_checkpoint_config): - return - - if self._shared_save_model_flag is not None: - while self._shared_save_model_flag[0] > 0: # async process is saving - if not self._process_model_weight.is_alive(): - raise RuntimeError("The process that saves model_weight has been killed unexpectedly.") - time.sleep(0.5) - self._shared_save_model_flag[0] = -1 - if self._shared_save_master_weight_flag is not None: - while self._shared_save_master_weight_flag[0] > 0: - if not self._process_master_weight.is_alive(): - raise RuntimeError("The process that saves master_weight has been killed unexpectedly.") - time.sleep(0.5) - self._shared_save_master_weight_flag[0] = -1 - if self._shared_save_optimizer_flag is not None: - while self._shared_save_optimizer_flag[0] > 0: - if not self._process_optimizer_weight.is_alive(): - raise RuntimeError("The process that saves optimizer_weight has been killed unexpectedly.") - time.sleep(0.5) - self._shared_save_optimizer_flag[0] = -1 - - if self._shm_model_weight is not None: - self._shm_model_weight.close() - self._shm_model_weight.unlink() - self._shm_model_weight = None - if self._shm_master_weight is not None: - self._shm_master_weight.close() - self._shm_master_weight.unlink() - self._shm_master_weight = None - if self._shm_optimizer_weight is not None: - self._shm_optimizer_weight.close() - self._shm_optimizer_weight.unlink() - self._shm_optimizer_weight = None - - if paddle.distributed.get_world_size() > 1: - dist.barrier() +__all__ = ["UnifiedCheckpointHandler"] class UnifiedCheckpointHandler: