Skip to content

Commit

Permalink
optionally save the final FSDP model as a sharded state dict (#1828)
Browse files Browse the repository at this point in the history
* efficiently save very large llms when using FSDP

* fix parsing and index of sharded chunks

* only save fsdp on main process

* debugging for rename

* save sharded state dict

* remove unused new param

* get state dict directly

* tweak acc merge fsdp to shard the weight files

* sharded_state_dict alongside save_safetensors seems to hang on checkpoint save
  • Loading branch information
winglian authored Aug 19, 2024
1 parent b1d2921 commit e299312
Show file tree
Hide file tree
Showing 3 changed files with 239 additions and 3 deletions.
204 changes: 204 additions & 0 deletions src/axolotl/cli/merge_sharded_fsdp_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
"""
This module provides a CLI to merge sharded FSDP model checkpoints into a single combined checkpoint
"""
import json
import logging
import os
import shutil
from pathlib import Path
from typing import Dict, Union

import fire
import torch
import torch.distributed.checkpoint as dist_cp
import torch.distributed.checkpoint.format_utils as dist_cp_format_utils
import transformers
from accelerate.utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
is_torch_version,
)
from dotenv import load_dotenv
from huggingface_hub import split_torch_state_dict_into_shards
from safetensors.torch import save_file as safe_save_file
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner

from axolotl.cli import load_cfg, print_axolotl_text_art
from axolotl.common.cli import TrainerCliArgs

LOG = logging.getLogger("axolotl.cli.merge_sharded_fsdp_weights")


class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):
"""
A custom planner to cast tensors to bfloat16 on the fly during loading.
"""

def commit_tensor(self, read_item, tensor): # pylint: disable=unused-argument
tensor.copy_(tensor.to(torch.bfloat16))


def _distributed_checkpoint_to_merged_weights(
checkpoint_dir: Union[str, Path],
save_path: str,
safe_serialization: bool = False,
max_shard_size: str = "5GB",
):
"""
Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save`
Will save under `save_path` as either `model.safetensors` or `pytorch_model.bin`.
"""

state_dict: Dict = {}
save_path_ = Path(save_path)
save_path_.mkdir(exist_ok=True)
dist_cp_format_utils._load_state_dict( # pylint: disable=protected-access
state_dict,
storage_reader=dist_cp.FileSystemReader(checkpoint_dir),
planner=BFloat16CastPlanner(), # pylint: disable=protected-access
no_dist=True,
)

# To handle if state is a dict like {model: {...}}
if len(state_dict.keys()) == 1:
state_dict = state_dict[list(state_dict)[0]]

# Ensure all tensors are in bfloat16
for key, value in state_dict.items():
if isinstance(value, torch.Tensor) and value.dtype != torch.bfloat16:
state_dict[key] = value.to(torch.bfloat16)

weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME

filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)
state_dict_split = split_torch_state_dict_into_shards(
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
)
# Save index if sharded
index = None
if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}

# Save the model
filename_to_tensors = state_dict_split.filename_to_tensors.items()

for shard_file, tensors in filename_to_tensors:
shard = {tensor: state_dict[tensor] for tensor in tensors}

if safe_serialization:
safe_save_file(
shard, os.path.join(save_path_, shard_file), metadata={"format": "pt"}
)
else:
torch.save(shard, os.path.join(save_path_, shard_file))

if index is not None:
save_index_file = (
SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
)
save_index_file = os.path.join(save_path_, save_index_file)
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as fout:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
fout.write(content)

return save_path_


def merge_fsdp_weights(
checkpoint_dir: str,
output_path: str,
safe_serialization: bool = False,
remove_checkpoint_dir: bool = False,
):
"""
Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if
`SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}/model.safetensors` if
`safe_serialization` else `pytorch_model.bin`.
Note: this is a CPU-bound process.
Args:
checkpoint_dir (`str`):
The directory containing the FSDP checkpoints (can be either the model or optimizer).
output_path (`str`):
The path to save the merged checkpoint.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the merged weights with safetensors (recommended).
remove_checkpoint_dir (`bool`, *optional*, defaults to `False`):
Whether to remove the checkpoint directory after merging.
"""
checkpoint_dir_ = Path(checkpoint_dir)
from accelerate.state import PartialState

if not is_torch_version(">=", "2.3.0"):
raise ValueError("`merge_fsdp_weights` requires PyTorch >= 2.3.0`")

# Verify that the checkpoint directory exists
if not checkpoint_dir_.exists():
model_path_exists = (checkpoint_dir_ / "pytorch_model_fsdp_0").exists()
optimizer_path_exists = (checkpoint_dir_ / "optimizer_0").exists()
err = f"Tried to load from {checkpoint_dir_} but couldn't find a valid metadata file."
if model_path_exists and optimizer_path_exists:
err += (
" However, potential model and optimizer checkpoint directories exist."
)
err += f"Please pass in either {checkpoint_dir_}/pytorch_model_fsdp_0 or {checkpoint_dir_}/optimizer_0"
err += "instead."
elif model_path_exists:
err += " However, a potential model checkpoint directory exists."
err += (
f"Please try passing in {checkpoint_dir_}/pytorch_model_fsdp_0 instead."
)
elif optimizer_path_exists:
err += " However, a potential optimizer checkpoint directory exists."
err += f"Please try passing in {checkpoint_dir_}/optimizer_0 instead."
raise ValueError(err)

# To setup `save` to work
state = PartialState()
if state.is_main_process:
LOG.info(f"Merging FSDP weights from {checkpoint_dir_}")
save_path = _distributed_checkpoint_to_merged_weights(
checkpoint_dir_, output_path, safe_serialization
)
LOG.info(f"Successfully merged FSDP weights and saved to {save_path}")
if remove_checkpoint_dir:
LOG.info(f"Removing old checkpoint directory {checkpoint_dir_}")
shutil.rmtree(checkpoint_dir_)
state.wait_for_everyone()


def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
parsed_cli_args.merge_lora = True

parsed_cfg = load_cfg(
config,
**kwargs,
)

fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"
merge_fsdp_weights(
checkpoint_dir=str(fsdp_dir),
output_path=str(Path(parsed_cfg.output_dir) / "merged"),
safe_serialization=True,
)


if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)
21 changes: 18 additions & 3 deletions src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import transformers.modelcard
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import save_fsdp_model
from datasets import Dataset
from peft import PeftModel
from pkg_resources import get_distribution # type: ignore
Expand Down Expand Up @@ -194,9 +195,12 @@ def terminate_handler(_, __, model_weakref):
if hasattr(module, "_post_training"):
module._post_training(model, name) # pylint: disable=protected-access

state_dict_type = "FULL_STATE_DICT"
if trainer.is_fsdp_enabled:
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.")
if cfg.fsdp_final_state_dict_type:
state_dict_type = cfg.fsdp_final_state_dict_type
trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type)
LOG.info(f"Set FSDP state dict type to {state_dict_type} for saving.")

if cfg.relora_steps:
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
Expand All @@ -208,7 +212,18 @@ def terminate_handler(_, __, model_weakref):
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.fsdp:
trainer.save_model(cfg.output_dir)
if (
state_dict_type == "SHARDED_STATE_DICT"
and cfg.fsdp_config.fsdp_state_dict_type == "SHARDED_STATE_DICT"
):
save_fsdp_model(
trainer.accelerator.state.fsdp_plugin,
trainer.accelerator,
trainer.model,
cfg.output_dir,
)
elif state_dict_type == "FULL_STATE_DICT":
trainer.save_model(cfg.output_dir)
elif cfg.deepspeed and is_deepspeed_zero3_enabled():
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
trainer.accelerator.wait_for_everyone()
Expand Down
17 changes: 17 additions & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,9 @@ class Config:
deepspeed: Optional[Union[str, Dict[str, Any]]] = None
fsdp: Optional[List[str]] = None
fsdp_config: Optional[Dict[str, Any]] = None
fsdp_final_state_dict_type: Optional[
Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"]
] = None

val_set_size: Optional[float] = Field(default=0.0)

Expand Down Expand Up @@ -1148,6 +1151,20 @@ def check_fsdp_offload_w_8bit_optimizer(cls, data):
)
return data

@model_validator(mode="before")
@classmethod
def check_fsdp_sharded_state_dict_w_safetensors(cls, data):
if (
data.get("fsdp")
and data.get("save_safetensors")
and data.get("fsdp_config")
and data["fsdp_config"].get("fsdp_state_dict_type") == "SHARDED_STATE_DICT"
):
raise ValueError(
"FSDP SHARDED_STATE_DICT not compatible with save_safetensors"
)
return data

@model_validator(mode="before")
@classmethod
def check_causal_lm_evals(cls, data):
Expand Down

0 comments on commit e299312

Please sign in to comment.