diff --git a/src/axolotl/cli/merge_sharded_fsdp_weights.py b/src/axolotl/cli/merge_sharded_fsdp_weights.py new file mode 100644 index 0000000000..25408fd57e --- /dev/null +++ b/src/axolotl/cli/merge_sharded_fsdp_weights.py @@ -0,0 +1,204 @@ +""" +This module provides a CLI to merge sharded FSDP model checkpoints into a single combined checkpoint +""" +import json +import logging +import os +import shutil +from pathlib import Path +from typing import Dict, Union + +import fire +import torch +import torch.distributed.checkpoint as dist_cp +import torch.distributed.checkpoint.format_utils as dist_cp_format_utils +import transformers +from accelerate.utils import ( + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, + is_torch_version, +) +from dotenv import load_dotenv +from huggingface_hub import split_torch_state_dict_into_shards +from safetensors.torch import save_file as safe_save_file +from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner + +from axolotl.cli import load_cfg, print_axolotl_text_art +from axolotl.common.cli import TrainerCliArgs + +LOG = logging.getLogger("axolotl.cli.merge_sharded_fsdp_weights") + + +class BFloat16CastPlanner(_EmptyStateDictLoadPlanner): + """ + A custom planner to cast tensors to bfloat16 on the fly during loading. + """ + + def commit_tensor(self, read_item, tensor): # pylint: disable=unused-argument + tensor.copy_(tensor.to(torch.bfloat16)) + + +def _distributed_checkpoint_to_merged_weights( + checkpoint_dir: Union[str, Path], + save_path: str, + safe_serialization: bool = False, + max_shard_size: str = "5GB", +): + """ + Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save` + + Will save under `save_path` as either `model.safetensors` or `pytorch_model.bin`. + """ + + state_dict: Dict = {} + save_path_ = Path(save_path) + save_path_.mkdir(exist_ok=True) + dist_cp_format_utils._load_state_dict( # pylint: disable=protected-access + state_dict, + storage_reader=dist_cp.FileSystemReader(checkpoint_dir), + planner=BFloat16CastPlanner(), # pylint: disable=protected-access + no_dist=True, + ) + + # To handle if state is a dict like {model: {...}} + if len(state_dict.keys()) == 1: + state_dict = state_dict[list(state_dict)[0]] + + # Ensure all tensors are in bfloat16 + for key, value in state_dict.items(): + if isinstance(value, torch.Tensor) and value.dtype != torch.bfloat16: + state_dict[key] = value.to(torch.bfloat16) + + weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME + + filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace( + ".safetensors", "{suffix}.safetensors" + ) + state_dict_split = split_torch_state_dict_into_shards( + state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size + ) + # Save index if sharded + index = None + if state_dict_split.is_sharded: + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + + # Save the model + filename_to_tensors = state_dict_split.filename_to_tensors.items() + + for shard_file, tensors in filename_to_tensors: + shard = {tensor: state_dict[tensor] for tensor in tensors} + + if safe_serialization: + safe_save_file( + shard, os.path.join(save_path_, shard_file), metadata={"format": "pt"} + ) + else: + torch.save(shard, os.path.join(save_path_, shard_file)) + + if index is not None: + save_index_file = ( + SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME + ) + save_index_file = os.path.join(save_path_, save_index_file) + # Save the index as well + with open(save_index_file, "w", encoding="utf-8") as fout: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + fout.write(content) + + return save_path_ + + +def merge_fsdp_weights( + checkpoint_dir: str, + output_path: str, + safe_serialization: bool = False, + remove_checkpoint_dir: bool = False, +): + """ + Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if + `SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}/model.safetensors` if + `safe_serialization` else `pytorch_model.bin`. + + Note: this is a CPU-bound process. + + Args: + checkpoint_dir (`str`): + The directory containing the FSDP checkpoints (can be either the model or optimizer). + output_path (`str`): + The path to save the merged checkpoint. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the merged weights with safetensors (recommended). + remove_checkpoint_dir (`bool`, *optional*, defaults to `False`): + Whether to remove the checkpoint directory after merging. + """ + checkpoint_dir_ = Path(checkpoint_dir) + from accelerate.state import PartialState + + if not is_torch_version(">=", "2.3.0"): + raise ValueError("`merge_fsdp_weights` requires PyTorch >= 2.3.0`") + + # Verify that the checkpoint directory exists + if not checkpoint_dir_.exists(): + model_path_exists = (checkpoint_dir_ / "pytorch_model_fsdp_0").exists() + optimizer_path_exists = (checkpoint_dir_ / "optimizer_0").exists() + err = f"Tried to load from {checkpoint_dir_} but couldn't find a valid metadata file." + if model_path_exists and optimizer_path_exists: + err += ( + " However, potential model and optimizer checkpoint directories exist." + ) + err += f"Please pass in either {checkpoint_dir_}/pytorch_model_fsdp_0 or {checkpoint_dir_}/optimizer_0" + err += "instead." + elif model_path_exists: + err += " However, a potential model checkpoint directory exists." + err += ( + f"Please try passing in {checkpoint_dir_}/pytorch_model_fsdp_0 instead." + ) + elif optimizer_path_exists: + err += " However, a potential optimizer checkpoint directory exists." + err += f"Please try passing in {checkpoint_dir_}/optimizer_0 instead." + raise ValueError(err) + + # To setup `save` to work + state = PartialState() + if state.is_main_process: + LOG.info(f"Merging FSDP weights from {checkpoint_dir_}") + save_path = _distributed_checkpoint_to_merged_weights( + checkpoint_dir_, output_path, safe_serialization + ) + LOG.info(f"Successfully merged FSDP weights and saved to {save_path}") + if remove_checkpoint_dir: + LOG.info(f"Removing old checkpoint directory {checkpoint_dir_}") + shutil.rmtree(checkpoint_dir_) + state.wait_for_everyone() + + +def do_cli(config: Path = Path("examples/"), **kwargs): + # pylint: disable=duplicate-code + print_axolotl_text_art() + parser = transformers.HfArgumentParser((TrainerCliArgs)) + parsed_cli_args, _ = parser.parse_args_into_dataclasses( + return_remaining_strings=True + ) + parsed_cli_args.merge_lora = True + + parsed_cfg = load_cfg( + config, + **kwargs, + ) + + fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0" + merge_fsdp_weights( + checkpoint_dir=str(fsdp_dir), + output_path=str(Path(parsed_cfg.output_dir) / "merged"), + safe_serialization=True, + ) + + +if __name__ == "__main__": + load_dotenv() + fire.Fire(do_cli) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index b8890d4f7a..b21b0b269c 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -12,6 +12,7 @@ import transformers.modelcard from accelerate import Accelerator from accelerate.logging import get_logger +from accelerate.utils import save_fsdp_model from datasets import Dataset from peft import PeftModel from pkg_resources import get_distribution # type: ignore @@ -194,9 +195,12 @@ def terminate_handler(_, __, model_weakref): if hasattr(module, "_post_training"): module._post_training(model, name) # pylint: disable=protected-access + state_dict_type = "FULL_STATE_DICT" if trainer.is_fsdp_enabled: - trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") - LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.") + if cfg.fsdp_final_state_dict_type: + state_dict_type = cfg.fsdp_final_state_dict_type + trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type) + LOG.info(f"Set FSDP state dict type to {state_dict_type} for saving.") if cfg.relora_steps: if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit): @@ -208,7 +212,18 @@ def terminate_handler(_, __, model_weakref): # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file if cfg.fsdp: - trainer.save_model(cfg.output_dir) + if ( + state_dict_type == "SHARDED_STATE_DICT" + and cfg.fsdp_config.fsdp_state_dict_type == "SHARDED_STATE_DICT" + ): + save_fsdp_model( + trainer.accelerator.state.fsdp_plugin, + trainer.accelerator, + trainer.model, + cfg.output_dir, + ) + elif state_dict_type == "FULL_STATE_DICT": + trainer.save_model(cfg.output_dir) elif cfg.deepspeed and is_deepspeed_zero3_enabled(): # Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading trainer.accelerator.wait_for_everyone() diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 5e690bb88e..dcc902c8c6 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -628,6 +628,9 @@ class Config: deepspeed: Optional[Union[str, Dict[str, Any]]] = None fsdp: Optional[List[str]] = None fsdp_config: Optional[Dict[str, Any]] = None + fsdp_final_state_dict_type: Optional[ + Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] + ] = None val_set_size: Optional[float] = Field(default=0.0) @@ -1148,6 +1151,20 @@ def check_fsdp_offload_w_8bit_optimizer(cls, data): ) return data + @model_validator(mode="before") + @classmethod + def check_fsdp_sharded_state_dict_w_safetensors(cls, data): + if ( + data.get("fsdp") + and data.get("save_safetensors") + and data.get("fsdp_config") + and data["fsdp_config"].get("fsdp_state_dict_type") == "SHARDED_STATE_DICT" + ): + raise ValueError( + "FSDP SHARDED_STATE_DICT not compatible with save_safetensors" + ) + return data + @model_validator(mode="before") @classmethod def check_causal_lm_evals(cls, data):