From 702de2f635284437cba574ff1faca340c9806847 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Wed, 11 Dec 2024 00:55:32 -0500 Subject: [PATCH] feat: Checkpoint utils safetensors (#116) * checkpoint conversion handle non-dcp case Signed-off-by: Yu Chin Fabian Lim * improvements Signed-off-by: Yu Chin Fabian Lim * fix: sharded safetensors save Signed-off-by: Will Johnson * fix: lint Signed-off-by: Will Johnson * fmt Signed-off-by: Will Johnson --------- Signed-off-by: Yu Chin Fabian Lim Signed-off-by: Will Johnson Co-authored-by: Yu Chin Fabian Lim --- .../utils/checkpoint_utils.py | 167 +++++++++++++----- 1 file changed, 124 insertions(+), 43 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index d8d33b18..fb8ab1bc 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -14,14 +14,17 @@ # Standard from collections import defaultdict -from typing import List +from typing import Dict, List, Union import json import os import re +import shutil # Third Party from accelerate.logging import get_logger from accelerate.utils.constants import FSDP_MODEL_NAME, OPTIMIZER_NAME +from huggingface_hub import split_torch_state_dict_into_shards +from safetensors.torch import load_file, save_file from torch.distributed.checkpoint.default_planner import ( DefaultLoadPlanner, DefaultSavePlanner, @@ -29,6 +32,7 @@ from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType from transformers import PretrainedConfig +from transformers.utils import CONFIG_NAME, SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME import torch import torch.distributed.checkpoint as dcp @@ -213,24 +217,10 @@ def _dict_from_json_file(resolved_config_file): return os.path.dirname(result) -# function to get the ScatterMoE state dict from its DCP checkpoint -# - if the original pretrained_model_name_or_path is specified, will use the checkpoint as hints -# to map the ScatterMoE checkpoint to that of the original model. This is useful so that we -# can restore the checkpoint to be loaded by the original architecture. -def recover_original_state_dict_from_dcp_checkpoint( +# function to get the state dict from dcp_checkpoint +def get_state_dict_from_dcp_checkpoint( dcp_checkpoint_dir: str, - pretrained_model_name_or_path: str = None, ): - """ - Parameters: - dcp_checkpoint_dir (str): the DCP to be converted. - pretrained_model_name_or_path (str): Optional, if provided we will - use the hints to remap the - """ - - # reference dcp_to_torch_save from torch.distributed.checkpoint.format_utils.py - # - strategy is to use _EmptyStateDictLoadPlanner to populate the state dict, then we remap - # guarded, load some internal functions # pylint: disable=import-outside-toplevel # Third Party @@ -245,11 +235,46 @@ def recover_original_state_dict_from_dcp_checkpoint( planner=_EmptyStateDictLoadPlanner(), no_dist=True, ) - sd = sd[KEY_MODEL] + return [KEY_MODEL] + + +# function to get state dict from regular checkoint +# - note this assumes sharded safetensors, we do not support +# the non-sharded case for now +def get_state_dict_from_safe_checkpoint( + safe_checkpoint_dir: str, +): + # Load the index + safe_index_file = os.path.join(safe_checkpoint_dir, SAFE_WEIGHTS_INDEX_NAME) + with open(safe_index_file, "r", encoding="utf-8") as f: + index = json.load(f) + + sd = {} + shard_files = list(set(index["weight_map"].values())) + for shard_file in shard_files: + for key, v in load_file(os.path.join(safe_checkpoint_dir, shard_file)).items(): + sd[key] = v + + return sd - # if not provided - if pretrained_model_name_or_path is None: - return sd + +# function to get the ScatterMoE state dict from its DCP checkpoint +# - if the original pretrained_model_name_or_path is specified, will use the checkpoint as hints +# to map the ScatterMoE checkpoint to that of the original model. This is useful so that we +# can restore the checkpoint to be loaded by the original architecture. +def recover_original_state_dict_from_checkpoint( + sd: Dict, + pretrained_model_name_or_path: str = None, +): + """ + Parameters: + dcp_checkpoint_dir (str): the DCP to be converted. + pretrained_model_name_or_path (str): Optional, if provided we will + use the hints to remap the + """ + + # reference dcp_to_torch_save from torch.distributed.checkpoint.format_utils.py + # - strategy is to use _EmptyStateDictLoadPlanner to populate the state dict, then we remap # now do the remap loc = get_resolved_checkpoint_location(pretrained_model_name_or_path) @@ -398,6 +423,37 @@ def _infer_prefixes_and_module_names( return sd +def save_sharded_safetensors( + input_state_dict: Dict, + save_directory: str, + metadata: Dict, + max_shard_size: Union[int, str] = "5GB", +): + filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace( + ".safetensors", "{suffix}.safetensors" + ) + state_dict_split = split_torch_state_dict_into_shards( + input_state_dict, + filename_pattern=filename_pattern, + max_shard_size=max_shard_size, + ) + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + # Save the index + with open( + os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8" + ) as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + + filename_to_tensors = state_dict_split.filename_to_tensors.items() + for shard_file, tensors in filename_to_tensors: + shard = {tensor: input_state_dict[tensor].contiguous() for tensor in tensors} + save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata) + + # --------------------------- SCRIPT ------------------------- @@ -417,8 +473,8 @@ def _infer_prefixes_and_module_names( ) parser.add_argument( - "dcp_checkpoint_dir", - help="Path to the distributed checkpoint.", + "checkpoint_dir", + help="Path to the checkpoint.", ) parser.add_argument( @@ -432,37 +488,62 @@ def _infer_prefixes_and_module_names( "the original pretrained model checkpoint (from which this " "checkpoint is obtained)." ), + default=None, ) args = parser.parse_args() - # search for the checkpint. By the code above, it must + # search for an FSDP checkpoint. If it is an FSDP checkpoint, it must # start with FSDP_MODEL_NAME - if args.dcp_checkpoint_dir.startswith(FSDP_MODEL_NAME): - checkpoint_dir = args.dcp_checkpoint_dir + if args.checkpoint_dir.startswith(FSDP_MODEL_NAME): + checkpoint_dir = args.checkpoint_dir + loader = get_state_dict_from_dcp_checkpoint else: checkpoint_dir = [ x - for x in os.listdir(args.dcp_checkpoint_dir) - if os.path.isdir(os.path.join(args.dcp_checkpoint_dir, x)) + for x in os.listdir(args.checkpoint_dir) + if os.path.isdir(os.path.join(args.checkpoint_dir, x)) and x.startswith(FSDP_MODEL_NAME) ] - if len(checkpoint_dir) > 1: + if len(checkpoint_dir) == 1: + checkpoint_dir = os.path.join(args.checkpoint_dir, checkpoint_dir[0]) + loader = get_state_dict_from_dcp_checkpoint + elif len(checkpoint_dir) > 1: raise ValueError( - f"Found > 1 dirs in dcp checkpoint dir {args.dcp_checkpoint_dir} " + f"Found > 1 dirs in dcp checkpoint dir {args.checkpoint_dir} " f"that starts with {FSDP_MODEL_NAME}. Please spectify the exact dir." ) - if len(checkpoint_dir) == 0: - raise ValueError( - f"Found no dirs in dcp checkpoint dir {args.dcp_checkpoint_dir} " - f"that starts with {FSDP_MODEL_NAME}. Nothing to convert" - ) - checkpoint_dir = os.path.join(args.dcp_checkpoint_dir, checkpoint_dir[0]) - - # get the converted statedict - state_dict = recover_original_state_dict_from_dcp_checkpoint( - checkpoint_dir, args.pretrained_model_name_or_path + else: + # then take it as a safetensors checkpoint + # - do not support .bin checkpoints + checkpoint_dir = args.checkpoint_dir + loader = get_state_dict_from_safe_checkpoint + + # - pretrained model name + _name_or_path = args.pretrained_model_name_or_path + + # assume output directory exists, we do not create it + # - copy the config file if exists + config_file = os.path.join(checkpoint_dir, CONFIG_NAME) + target_config_file = os.path.join(args.output_dir, CONFIG_NAME) + if os.path.exists(config_file): + shutil.copyfile(config_file, target_config_file) + + # try to populate pretrained_model_name_or_path from the config path + # if it was None + if not _name_or_path: + with open(target_config_file, "r", encoding="utf-8") as file: + _name_or_path = json.load(file).get("_name_or_path") + + # get the state_dict + state_dict = loader(checkpoint_dir) + + # recover the original state dict + state_dict = recover_original_state_dict_from_checkpoint(state_dict, _name_or_path) + + # save it as a safetensors file + save_sharded_safetensors( + {k: v.contiguous() for k, v in state_dict.items()}, + args.output_dir, + metadata={"format": "pt"}, ) - - # save it - torch.save(state_dict, args.output_dir)