Skip to content

Commit

Permalink
feat: Checkpoint utils safetensors (#116)
Browse files Browse the repository at this point in the history
* checkpoint conversion handle non-dcp case

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* improvements

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* fix: sharded safetensors save

Signed-off-by: Will Johnson <[email protected]>

* fix: lint

Signed-off-by: Will Johnson <[email protected]>

* fmt

Signed-off-by: Will Johnson <[email protected]>

---------

Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Will Johnson <[email protected]>
Co-authored-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
willmj and fabianlim authored Dec 11, 2024
1 parent 733992a commit 702de2f
Showing 1 changed file with 124 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,25 @@

# Standard
from collections import defaultdict
from typing import List
from typing import Dict, List, Union
import json
import os
import re
import shutil

# Third Party
from accelerate.logging import get_logger
from accelerate.utils.constants import FSDP_MODEL_NAME, OPTIMIZER_NAME
from huggingface_hub import split_torch_state_dict_into_shards
from safetensors.torch import load_file, save_file
from torch.distributed.checkpoint.default_planner import (
DefaultLoadPlanner,
DefaultSavePlanner,
)
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
from transformers import PretrainedConfig
from transformers.utils import CONFIG_NAME, SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
import torch
import torch.distributed.checkpoint as dcp

Expand Down Expand Up @@ -213,24 +217,10 @@ def _dict_from_json_file(resolved_config_file):
return os.path.dirname(result)


# function to get the ScatterMoE state dict from its DCP checkpoint
# - if the original pretrained_model_name_or_path is specified, will use the checkpoint as hints
# to map the ScatterMoE checkpoint to that of the original model. This is useful so that we
# can restore the checkpoint to be loaded by the original architecture.
def recover_original_state_dict_from_dcp_checkpoint(
# function to get the state dict from dcp_checkpoint
def get_state_dict_from_dcp_checkpoint(
dcp_checkpoint_dir: str,
pretrained_model_name_or_path: str = None,
):
"""
Parameters:
dcp_checkpoint_dir (str): the DCP to be converted.
pretrained_model_name_or_path (str): Optional, if provided we will
use the hints to remap the
"""

# reference dcp_to_torch_save from torch.distributed.checkpoint.format_utils.py
# - strategy is to use _EmptyStateDictLoadPlanner to populate the state dict, then we remap

# guarded, load some internal functions
# pylint: disable=import-outside-toplevel
# Third Party
Expand All @@ -245,11 +235,46 @@ def recover_original_state_dict_from_dcp_checkpoint(
planner=_EmptyStateDictLoadPlanner(),
no_dist=True,
)
sd = sd[KEY_MODEL]
return [KEY_MODEL]


# function to get state dict from regular checkoint
# - note this assumes sharded safetensors, we do not support
# the non-sharded case for now
def get_state_dict_from_safe_checkpoint(
safe_checkpoint_dir: str,
):
# Load the index
safe_index_file = os.path.join(safe_checkpoint_dir, SAFE_WEIGHTS_INDEX_NAME)
with open(safe_index_file, "r", encoding="utf-8") as f:
index = json.load(f)

sd = {}
shard_files = list(set(index["weight_map"].values()))
for shard_file in shard_files:
for key, v in load_file(os.path.join(safe_checkpoint_dir, shard_file)).items():
sd[key] = v

return sd

# if not provided
if pretrained_model_name_or_path is None:
return sd

# function to get the ScatterMoE state dict from its DCP checkpoint
# - if the original pretrained_model_name_or_path is specified, will use the checkpoint as hints
# to map the ScatterMoE checkpoint to that of the original model. This is useful so that we
# can restore the checkpoint to be loaded by the original architecture.
def recover_original_state_dict_from_checkpoint(
sd: Dict,
pretrained_model_name_or_path: str = None,
):
"""
Parameters:
dcp_checkpoint_dir (str): the DCP to be converted.
pretrained_model_name_or_path (str): Optional, if provided we will
use the hints to remap the
"""

# reference dcp_to_torch_save from torch.distributed.checkpoint.format_utils.py
# - strategy is to use _EmptyStateDictLoadPlanner to populate the state dict, then we remap

# now do the remap
loc = get_resolved_checkpoint_location(pretrained_model_name_or_path)
Expand Down Expand Up @@ -398,6 +423,37 @@ def _infer_prefixes_and_module_names(
return sd


def save_sharded_safetensors(
input_state_dict: Dict,
save_directory: str,
metadata: Dict,
max_shard_size: Union[int, str] = "5GB",
):
filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)
state_dict_split = split_torch_state_dict_into_shards(
input_state_dict,
filename_pattern=filename_pattern,
max_shard_size=max_shard_size,
)
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
# Save the index
with open(
os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8"
) as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)

filename_to_tensors = state_dict_split.filename_to_tensors.items()
for shard_file, tensors in filename_to_tensors:
shard = {tensor: input_state_dict[tensor].contiguous() for tensor in tensors}
save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)


# --------------------------- SCRIPT -------------------------


Expand All @@ -417,8 +473,8 @@ def _infer_prefixes_and_module_names(
)

parser.add_argument(
"dcp_checkpoint_dir",
help="Path to the distributed checkpoint.",
"checkpoint_dir",
help="Path to the checkpoint.",
)

parser.add_argument(
Expand All @@ -432,37 +488,62 @@ def _infer_prefixes_and_module_names(
"the original pretrained model checkpoint (from which this "
"checkpoint is obtained)."
),
default=None,
)

args = parser.parse_args()

# search for the checkpint. By the code above, it must
# search for an FSDP checkpoint. If it is an FSDP checkpoint, it must
# start with FSDP_MODEL_NAME
if args.dcp_checkpoint_dir.startswith(FSDP_MODEL_NAME):
checkpoint_dir = args.dcp_checkpoint_dir
if args.checkpoint_dir.startswith(FSDP_MODEL_NAME):
checkpoint_dir = args.checkpoint_dir
loader = get_state_dict_from_dcp_checkpoint
else:
checkpoint_dir = [
x
for x in os.listdir(args.dcp_checkpoint_dir)
if os.path.isdir(os.path.join(args.dcp_checkpoint_dir, x))
for x in os.listdir(args.checkpoint_dir)
if os.path.isdir(os.path.join(args.checkpoint_dir, x))
and x.startswith(FSDP_MODEL_NAME)
]
if len(checkpoint_dir) > 1:
if len(checkpoint_dir) == 1:
checkpoint_dir = os.path.join(args.checkpoint_dir, checkpoint_dir[0])
loader = get_state_dict_from_dcp_checkpoint
elif len(checkpoint_dir) > 1:
raise ValueError(
f"Found > 1 dirs in dcp checkpoint dir {args.dcp_checkpoint_dir} "
f"Found > 1 dirs in dcp checkpoint dir {args.checkpoint_dir} "
f"that starts with {FSDP_MODEL_NAME}. Please spectify the exact dir."
)
if len(checkpoint_dir) == 0:
raise ValueError(
f"Found no dirs in dcp checkpoint dir {args.dcp_checkpoint_dir} "
f"that starts with {FSDP_MODEL_NAME}. Nothing to convert"
)
checkpoint_dir = os.path.join(args.dcp_checkpoint_dir, checkpoint_dir[0])

# get the converted statedict
state_dict = recover_original_state_dict_from_dcp_checkpoint(
checkpoint_dir, args.pretrained_model_name_or_path
else:
# then take it as a safetensors checkpoint
# - do not support .bin checkpoints
checkpoint_dir = args.checkpoint_dir
loader = get_state_dict_from_safe_checkpoint

# - pretrained model name
_name_or_path = args.pretrained_model_name_or_path

# assume output directory exists, we do not create it
# - copy the config file if exists
config_file = os.path.join(checkpoint_dir, CONFIG_NAME)
target_config_file = os.path.join(args.output_dir, CONFIG_NAME)
if os.path.exists(config_file):
shutil.copyfile(config_file, target_config_file)

# try to populate pretrained_model_name_or_path from the config path
# if it was None
if not _name_or_path:
with open(target_config_file, "r", encoding="utf-8") as file:
_name_or_path = json.load(file).get("_name_or_path")

# get the state_dict
state_dict = loader(checkpoint_dir)

# recover the original state dict
state_dict = recover_original_state_dict_from_checkpoint(state_dict, _name_or_path)

# save it as a safetensors file
save_sharded_safetensors(
{k: v.contiguous() for k, v in state_dict.items()},
args.output_dir,
metadata={"format": "pt"},
)

# save it
torch.save(state_dict, args.output_dir)

0 comments on commit 702de2f

Please sign in to comment.