Skip to content

Commit

Permalink
update async handler
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Oct 25, 2024
1 parent 780040e commit e78fe82
Show file tree
Hide file tree
Showing 7 changed files with 265 additions and 226 deletions.
248 changes: 248 additions & 0 deletions paddlenlp/trainer/unified_checkpoint/async_uc_hander.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Asynchronous unified checkpoint handler."""

import multiprocessing
import os
import time
from multiprocessing import shared_memory

import paddle
import paddle.distributed as dist

from paddlenlp.transformers.utils import is_safetensors_available
from paddlenlp.utils.log import logger

if is_safetensors_available():
from safetensors.numpy import save_file as safe_save_file

from .shared_memory_utils import (
_read_state_dict_from_shm,
_traverse_copy_to_shm,
create_meta_dict,
)


class AsyncCheckpointHander:
def __init__(self, args):
# Mainly for asynchronous saving.
self.args = args
self.global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1

self._shm_model_weight = None
self._shm_master_weight = None
self._shm_optimizer_weight = None
self._meta_dict_model = None
self._meta_dict_master_weight = None
self._meta_dict_optim = None
self._process_model_weight = None
self._process_master_weight = None
self._process_optimizer_weight = None
self._lock = None
self._shared_save_model_flag = None
self._shared_save_master_weight_flag = None
self._shared_save_optimizer_flag = None

if "async_save" in self.args.unified_checkpoint_config:
self._lock = multiprocessing.Lock()
self._shared_save_model_path = multiprocessing.Array("c", 100000)
self._shared_save_model_signal_path = multiprocessing.Array("c", 100000)
self._shared_save_master_weight_path = multiprocessing.Array("c", 100000)
self._shared_save_master_weight_signal_path = multiprocessing.Array("c", 100000)
self._shared_save_optimizer_path = multiprocessing.Array("c", 100000)
self._shared_save_optimizer_signal_path = multiprocessing.Array("c", 100000)
self._shared_save_model_flag = multiprocessing.Array("i", 1)
self._shared_save_master_weight_flag = multiprocessing.Array("i", 1)
self._shared_save_optimizer_flag = multiprocessing.Array("i", 1)

def _file_save_async_or_sync(
self, state_dict, path, signal_path=None, is_sync=True, state_dict_type="model_weight"
):
if is_sync:
for k in list(state_dict.keys()):
if isinstance(state_dict[k], paddle.Tensor):
state_dict[k] = state_dict.pop(k).cpu().numpy()
safe_save_file(state_dict, path, metadata={"format": "np"})
else:
if state_dict_type == "model_weight":
if self._shm_model_weight is None:
self._meta_dict_model, buffer_size = create_meta_dict(state_dict)
self._shm_model_weight = shared_memory.SharedMemory(create=True, size=buffer_size)
shm_state_dict = self._shm_model_weight
meta_dict = self._meta_dict_model
shared_save_flag = self._shared_save_model_flag
shared_save_path = self._shared_save_model_path
shared_save_signal_path = self._shared_save_model_signal_path
if self._process_model_weight is None:
self._process_model_weight = multiprocessing.Process(
target=self._save_file_async_in_process,
args=(
meta_dict,
self._shm_model_weight.name,
self._shared_save_model_flag,
self._shared_save_model_path,
self._shared_save_model_signal_path,
self._lock,
state_dict_type,
self.global_rank,
),
)
self._process_model_weight.start()
process = self._process_model_weight
elif state_dict_type == "master_weight":
if self._shm_master_weight is None:
self._meta_dict_master_weight, buffer_size = create_meta_dict(state_dict)
self._shm_master_weight = shared_memory.SharedMemory(create=True, size=buffer_size)
shm_state_dict = self._shm_master_weight
meta_dict = self._meta_dict_master_weight
shared_save_flag = self._shared_save_master_weight_flag
shared_save_path = self._shared_save_master_weight_path
shared_save_signal_path = self._shared_save_master_weight_signal_path
if self._process_master_weight is None:
self._process_master_weight = multiprocessing.Process(
target=self._save_file_async_in_process,
args=(
meta_dict,
self._shm_master_weight.name,
self._shared_save_master_weight_flag,
self._shared_save_master_weight_path,
self._shared_save_master_weight_signal_path,
self._lock,
"model_weight"
if "skip_save_model_weight" in self.args.unified_checkpoint_config
else state_dict_type,
self.global_rank,
),
)
self._process_master_weight.start()
process = self._process_master_weight
elif state_dict_type == "optimizer_weight":
if self._shm_optimizer_weight is None:
self._meta_dict_optim, buffer_size = create_meta_dict(state_dict)
self._shm_optimizer_weight = shared_memory.SharedMemory(create=True, size=buffer_size)
shm_state_dict = self._shm_optimizer_weight
meta_dict = self._meta_dict_optim
shared_save_flag = self._shared_save_optimizer_flag
shared_save_path = self._shared_save_optimizer_path
shared_save_signal_path = self._shared_save_optimizer_signal_path
if self._process_optimizer_weight is None:
self._process_optimizer_weight = multiprocessing.Process(
target=self._save_file_async_in_process,
args=(
meta_dict,
self._shm_optimizer_weight.name,
self._shared_save_optimizer_flag,
self._shared_save_optimizer_path,
self._shared_save_optimizer_signal_path,
self._lock,
state_dict_type,
self.global_rank,
),
)
self._process_optimizer_weight.start()
process = self._process_optimizer_weight

while True: # wait until no process is saving.
flag_value = shared_save_flag[0]
if flag_value == 0:
break
if not process.is_alive():
raise RuntimeError(f"The process that saves {state_dict_type} has been killed unexpectedly.")
time.sleep(0.5)
logger.info(f"Wait for the previous save process to finish saving {state_dict_type}")
# only save model weight or save master weight, we enter this loop.
self._reset_and_update(shared_save_path, path)
self._reset_and_update(shared_save_signal_path, signal_path)
_traverse_copy_to_shm(state_dict, meta_dict, shm_state_dict.buf)
with self._lock:
shared_save_flag[0] = 1

def _save_file_async_in_process(
self,
meta_dict,
shm_name,
shared_save_flag,
shared_save_path,
shared_save_signal_path,
lock,
state_dict_type,
global_rank,
):
shm = shared_memory.SharedMemory(name=shm_name)
while True:
flag_value = shared_save_flag[0] # if process uses `spawn`, cannot read this value.
if flag_value == -1: # stop process
break
if flag_value == 0: # nothing to save
continue
if flag_value == 1: # need to save
path = shared_save_path[:].decode("utf-8").rstrip("\x00")
signal_path = shared_save_signal_path[:].decode("utf-8").rstrip("\x00")
logger.info(f"Start to async save {path}")
state_dict = _read_state_dict_from_shm(meta_dict, shm) # numpy array
safe_save_file(state_dict, path, {"format": "np"})
del state_dict
saved_signal_path = os.path.join(signal_path, f".{state_dict_type}.done.{global_rank}")
paddle.save(global_rank, saved_signal_path)
with lock:
shared_save_flag[0] = 0
time.sleep(0.5)
shm.close()

def _reset_and_update(self, shared_array, new_value):
# clear array
for i in range(len(shared_array)):
shared_array[i] = b"\0"
# update array
encoded_value = new_value.encode("utf-8")
shared_array[: len(encoded_value)] = encoded_value

def unlink_shared_memory(self):
if not ("async_save" in self.args.unified_checkpoint_config):
return

if self._shared_save_model_flag is not None:
while self._shared_save_model_flag[0] > 0: # async process is saving
if not self._process_model_weight.is_alive():
raise RuntimeError("The process that saves model_weight has been killed unexpectedly.")
time.sleep(0.5)
self._shared_save_model_flag[0] = -1
if self._shared_save_master_weight_flag is not None:
while self._shared_save_master_weight_flag[0] > 0:
if not self._process_master_weight.is_alive():
raise RuntimeError("The process that saves master_weight has been killed unexpectedly.")
time.sleep(0.5)
self._shared_save_master_weight_flag[0] = -1
if self._shared_save_optimizer_flag is not None:
while self._shared_save_optimizer_flag[0] > 0:
if not self._process_optimizer_weight.is_alive():
raise RuntimeError("The process that saves optimizer_weight has been killed unexpectedly.")
time.sleep(0.5)
self._shared_save_optimizer_flag[0] = -1

if self._shm_model_weight is not None:
self._shm_model_weight.close()
self._shm_model_weight.unlink()
self._shm_model_weight = None
if self._shm_master_weight is not None:
self._shm_master_weight.close()
self._shm_master_weight.unlink()
self._shm_master_weight = None
if self._shm_optimizer_weight is not None:
self._shm_optimizer_weight.close()
self._shm_optimizer_weight.unlink()
self._shm_optimizer_weight = None

if paddle.distributed.get_world_size() > 1:
dist.barrier()
2 changes: 2 additions & 0 deletions paddlenlp/trainer/unified_checkpoint/check_uc.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
update_master_weight_status,
)

__all__ = ["check_unified_checkpoint", "check_unified_optimizer"]


def check_unified_checkpoint(args, model, resume_from_checkpoint, safe_serialization=False):
index_filename = select_model_weight_index(model, resume_from_checkpoint, safe_serialization, local=False)
Expand Down
4 changes: 3 additions & 1 deletion paddlenlp/trainer/unified_checkpoint/uc_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@
update_master_weight_status,
)

__all__ = ["load_unified_checkpoint_dynamically", "load_unified_optimizer_dynamically"]


def create_send_table(file_keyname_mappings, file_machine_mappings):
send_table = {}
Expand Down Expand Up @@ -258,7 +260,7 @@ def distributed_send_recv(
return state_dict


def load_uc_dynamically(args, model, resume_from_checkpoint, safe_serialization=False):
def load_unified_checkpoint_dynamically(args, model, resume_from_checkpoint, safe_serialization=False):
index_filename = select_model_weight_index(model, resume_from_checkpoint, safe_serialization, local=False)
index_filename = os.path.join(resume_from_checkpoint, index_filename)

Expand Down
2 changes: 2 additions & 0 deletions paddlenlp/trainer/unified_checkpoint/uc_locally_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@
update_master_weight_status,
)

__all__ = ["load_unified_checkpoint_locally", "load_unified_optimizer_locally"]


def load_unified_checkpoint_locally(args, model, resume_from_checkpoint: str, safe_serialization=False):
"""
Expand Down
2 changes: 2 additions & 0 deletions paddlenlp/trainer/unified_checkpoint/uc_sharding_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
mapping_optimizer_tp_actions,
)

__all__ = ["gather_splited_param_for_optimizer", "load_unified_optimizer_split_param"]


def merge_splited_param(
state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, is_master_weights=False
Expand Down
7 changes: 7 additions & 0 deletions paddlenlp/trainer/unified_checkpoint/uc_single_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@
save_model_config,
)

__all__ = [
"load_single_card_checkpoint",
"load_single_card_optimizer",
"save_single_card_checkpoint",
"save_single_card_optimizer",
]


def save_file_sync(state_dict, path):
for k in list(state_dict.keys()):
Expand Down
Loading

0 comments on commit e78fe82

Please sign in to comment.