diff --git a/examples/llama-2/qlora-fsdp.yml b/examples/llama-2/qlora-fsdp.yml index da6c06020f..30916ed45a 100644 --- a/examples/llama-2/qlora-fsdp.yml +++ b/examples/llama-2/qlora-fsdp.yml @@ -36,7 +36,7 @@ wandb_log_model: gradient_accumulation_steps: 4 micro_batch_size: 4 num_epochs: 4 -optimizer: paged_adamw_8bit +optimizer: adamw_torch lr_scheduler: cosine learning_rate: 0.00001 @@ -66,5 +66,11 @@ weight_decay: 0.0 fsdp: - full_shard fsdp_config: + fsdp_limit_all_gathers: true + fsdp_sync_module_states: true + fsdp_offload_params: true + fsdp_use_orig_params: false + fsdp_cpu_ram_efficient_loading: true fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer + fsdp_state_dict_type: SHARDED_STATE_DICT special_tokens: diff --git a/requirements.txt b/requirements.txt index aaa27c547b..75ce7a0d8a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ packaging==23.2 peft==0.9.0 -transformers @ git+https://github.com/huggingface/transformers.git@f6261d7d81edd036fc53bfede65fe91f01a661aa +transformers @ git+https://github.com/huggingface/transformers.git@73a73b415e36f41481369f6129cb4b62bb127a78 tokenizers==0.15.0 -bitsandbytes>=0.43.0 -accelerate==0.26.1 +bitsandbytes==0.43.0 +accelerate==0.28.0 deepspeed==0.13.1 pydantic==2.6.3 addict @@ -40,4 +40,3 @@ gcsfs # adlfs trl @ git+https://github.com/huggingface/trl.git@304e208f778a5442c30cdda500348226cdc97d90 -fastcore>=1.5.29 diff --git a/src/axolotl/core/policies/__init__.py b/src/axolotl/core/policies/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/axolotl/core/policies/auto_wrap.py b/src/axolotl/core/policies/auto_wrap.py deleted file mode 100644 index d42b62ee08..0000000000 --- a/src/axolotl/core/policies/auto_wrap.py +++ /dev/null @@ -1,55 +0,0 @@ -"""module for building the auto wrap policy for FSDP""" -import functools - -from peft import PrefixEncoder, PromptEmbedding, PromptEncoder -from torch.distributed.fsdp.wrap import ( - _or_policy, - lambda_auto_wrap_policy, - transformer_auto_wrap_policy, -) -from transformers.models.llama.modeling_llama import LlamaDecoderLayer -from transformers.models.mistral.modeling_mistral import MistralDecoderLayer -from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer - -SUPPORTED_AUTO_WRAP_MODEL_TYPES = [ - "llama", - "mistral", - "mixtral", -] - - -def get_wrapping_policy_factory(model_type): - if model_type == "llama": - layer_to_wrap = LlamaDecoderLayer - elif model_type == "mistral": - layer_to_wrap = MistralDecoderLayer - elif model_type == "mixtral": - layer_to_wrap = MixtralDecoderLayer - - def get_wrapping_policy(): - """This checks for lora layers (has weight and requires_grad)""" - - def lambda_policy_fn(module): - return ( - len(list(module.named_children())) == 0 - and getattr(module, "weight", None) is not None - and module.weight.requires_grad - ) - - lambda_policy = functools.partial( - lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn - ) - transformer_layer_name = layer_to_wrap - transformer_wrap_policy = functools.partial( - transformer_auto_wrap_policy, - transformer_layer_cls=( - PrefixEncoder, - PromptEncoder, - PromptEmbedding, - transformer_layer_name, - ), - ) - policies = [lambda_policy, transformer_wrap_policy] - return functools.partial(_or_policy, policies=policies) - - return get_wrapping_policy diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 53f6cca903..c2d622ceec 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -8,7 +8,6 @@ import importlib.util import logging import math -import os import sys from abc import abstractmethod from collections import defaultdict @@ -19,10 +18,7 @@ import torch import transformers -from accelerate import FullyShardedDataParallelPlugin -from accelerate.utils import str_to_bool from datasets import Dataset -from torch.distributed.fsdp import MixedPrecision from torch.optim.lr_scheduler import OneCycleLR from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from transformers import ( @@ -35,7 +31,6 @@ from transformers.utils import is_sagemaker_mp_enabled from trl import DPOTrainer -from axolotl.core.policies.auto_wrap import get_wrapping_policy_factory from axolotl.loraplus import create_loraplus_optimizer from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler @@ -591,51 +586,14 @@ def push_to_hub(self, *args, **kwargs) -> str: @wraps(Trainer.create_accelerator_and_postprocess) def create_accelerator_and_postprocess(self): - rank = int(os.environ.get("LOCAL_RANK", 0)) res = super().create_accelerator_and_postprocess() - if self.args.qlora is False: - return res - - # the rest of this method override is specific to fsdp + qlora (for now) - sync_module_states = ( - str_to_bool(os.environ.get("FSDP_SYNC_MODULE_STATES", "True")) == 1 - ) - - mp_policy = None - amp = os.environ["ACCELERATE_MIXED_PRECISION"] - if amp == "fp16": - mp_policy = MixedPrecision( - param_dtype=torch.float32, - reduce_dtype=torch.float32, - buffer_dtype=torch.float32, - ) - elif amp == "bf16": - mp_policy = MixedPrecision( - param_dtype=torch.float32, - reduce_dtype=torch.float32, - buffer_dtype=torch.float32, - ) - - # If somehow we figure out how we want to parameterize we want to autocast buffers... - # mp_policy = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.float32) - # load_param_skip_names = ['inv_freq'] - if self.is_fsdp_enabled: - wrapping_policy = get_wrapping_policy_factory(self.args.model_type) - fsdp_plugin = FullyShardedDataParallelPlugin( - auto_wrap_policy=wrapping_policy(), - cpu_offload=False, - use_orig_params=False, - limit_all_gathers=True, - param_init_fn=lambda module: module.to_empty( - device=torch.device("cuda"), recurse=False - ) - if (rank != 0 and sync_module_states) - else None, - mixed_precision_policy=mp_policy, - ) - self.accelerator.state.fsdp_plugin = fsdp_plugin + if ( + "limit_all_gathers" in self.args.fsdp_config + and self.args.fsdp_config["limit_all_gathers"] + ): + self.accelerator.state.fsdp_plugin.limit_all_gathers = True return res diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 40090a07c0..41fd471e65 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -5,16 +5,14 @@ import math import os import types -from typing import Any, Dict, List, Optional, Tuple, Type, Union # noqa: F401 +from typing import Any, Dict, Optional, Tuple, Union # noqa: F401 import addict import bitsandbytes as bnb -import safetensors import torch import transformers from accelerate import init_empty_weights -from bitsandbytes.nn import Linear4bit, Params4bit -from fastcore.parallel import parallel +from bitsandbytes.nn import Params4bit from peft import ( LoftQConfig, PeftConfig, @@ -23,7 +21,7 @@ prepare_model_for_kbit_training, ) from peft.tuners.lora import QuantLinear -from torch import Tensor, nn +from torch import nn from transformers import ( # noqa: F401 AddedToken, AutoConfig, @@ -35,9 +33,7 @@ PreTrainedTokenizerBase, ) from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled -from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub -from axolotl.core.policies.auto_wrap import SUPPORTED_AUTO_WRAP_MODEL_TYPES from axolotl.models.mamba import fix_mamba_attn_for_loss from axolotl.monkeypatch.multipack import ( SUPPORTED_MULTIPACK_MODEL_TYPES, @@ -272,117 +268,6 @@ def load_tokenizer(cfg): return tokenizer -def replace_linear( - model: nn.Module, - linear_replacement: Type[nn.Module], - quant_config: Union[dict, None] = None, - skip_modules=None, - **kwargs, -): - """ - Replace linear modules with a new Linear module. - Parameters: - model (`torch.nn.Module`): - Input model or `torch.nn.Module` as the function is run recursively. - linear_replacement (`torch.nn.Module`): - The linear module that replaces the old one. Only expects standard arguments. - If other arguments need to be passed, use a lambda. - skip_modules (`List[str]`, *optional*, defaults to `lm_head`): - List of modules names not to convert. Defaults to `lm_head`. - """ - if skip_modules is None: - skip_modules = ["lm_head"] - for name, module in model.named_children(): - if len(list(module.children())) > 0: - replace_linear( - module, linear_replacement, quant_config, skip_modules, **kwargs - ) - - if isinstance(module, torch.nn.Linear) and name not in skip_modules: - if issubclass(linear_replacement, Linear4bit): - model._modules[ # pylint: disable=protected-access - name - ] = linear_replacement( - module.in_features, - module.out_features, - module.bias is not None, - **kwargs, - ) - else: - raise ValueError( - f"Unsupported linear replacement: {type(linear_replacement)}" - ) - return model - - -def load_and_quantize( - module: nn.Module, - name: str, - value: Tensor, - device: torch.device = None, - dtype: torch.dtype = None, - skip_names: Optional[List[str]] = None, - is_meta_rank: bool = False, - low_memory: bool = True, - verbose: bool = False, - quant_method: str = "bnb", -): - """ - Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`. - - Quantizes `Params4bit` on `device` then places on "cpu" if low_memory=True or "meta" if is_meta_rank=True. - """ - - if skip_names is None: - skip_names = [] - - def place_on_device(value): - if is_meta_rank: - device = "meta" - elif low_memory: - device = "cpu" - else: - device = "cuda" - return value.to(device=device, dtype=dtype) - - if any(skip_name in name for skip_name in skip_names): - if verbose: - print(f"Skipping {name} because it is in skip_names") - return - - module_key, _, value_key = name.rpartition(".") - try: - submodule = module.get_submodule(module_key) - except AttributeError as exc: - print(f"Module {module_key} not found:\n{exc}") - return - - try: - if quant_method == "bnb": - param = submodule.get_parameter(value_key) - if isinstance(param, Params4bit): - # With `sync_module_states=True`, a meta device Params4bit needs to be the same - # shape as the quantized Params4bit with an initialized quant_state. However, - # FSDP only syncs parameters and buffers, so the quant_state isn't copied. This - # workaround quantizes Params4bit to initialize quant_state on all ranks, then - # replaces Params4bit's data with a meta tensor to free memory on non-rank 0. - value = type(param)( - value.to(device=device, dtype=dtype).data, **param.__dict__ - ).cuda(device) - if is_meta_rank: - value = type(param)(value.data.to("meta"), **value.__dict__) - elif low_memory: - value = type(param)(value.data.to("cpu"), **value.__dict__) - else: - value = type(param)(place_on_device(value).data) - - except AttributeError: - # it's a buffer - value = place_on_device(value) - - setattr(submodule, value_key, value) - - def load_model( cfg: DictDefault, tokenizer: PreTrainedTokenizerBase, @@ -568,6 +453,7 @@ def load_model( "bnb_4bit_compute_dtype": cfg.torch_dtype, "bnb_4bit_use_double_quant": True, "bnb_4bit_quant_type": "nf4", + "bnb_4bit_quant_storage": torch.bfloat16, } if cfg.bnb_config_kwargs: @@ -617,78 +503,10 @@ def load_model( model_kwargs["attn_implementation"] = "eager" model_config._attn_implementation = "eager" # pylint: disable=protected-access - qlora_fsdp = ( - cfg.fsdp - and cfg.adapter == "qlora" - and model_config.model_type in SUPPORTED_AUTO_WRAP_MODEL_TYPES - ) + qlora_fsdp = cfg.fsdp and cfg.adapter == "qlora" try: - if qlora_fsdp: - if cfg.bf16 or cfg.bfloat16: - torch_dtype, compute_dtype = torch.float32, torch.bfloat16 - elif cfg.fp16 or cfg.float16: - torch_dtype, compute_dtype = torch.float32, torch.float16 - else: - torch_dtype, compute_dtype = torch.float32, torch.float16 - - with init_empty_weights(): - LOG.info("Loading model with empty weights.") - model = AutoModelForCausalLM.from_config(model_config) - model.model = replace_linear( - model.model, - Linear4bit, - compute_dtype=compute_dtype, - quant_type="nf4", - quant_storage=torch_dtype, - ) - - model.is_loaded_in_4bit = True - - # Grab the safetensors files that hold the weights - try: - idx = hub.cached_file(base_model, SAFE_WEIGHTS_INDEX_NAME) - files, _ = hub.get_checkpoint_shard_files(base_model, idx) - except OSError: - try: - # This means the model doesn't have a model.safetensors.index.json because it is not sharded - files = [] - files.append(hub.cached_file(base_model, SAFE_WEIGHTS_NAME)) - except OSError as exc: - # This means the model probably doesn't have a safetensors file - raise exc - - # Load in the weights, using our custom load_and_quantize method which quantizes Params4bit on the fly - # and then places each layer on CPU or meta if using low_memory to minimize GPU memory usage - def load_and_quantize_parallel(name_param, model, **kwargs): - name, param = name_param - load_and_quantize(model, name, param, **kwargs) - - param_count = sum((p.numel() for n, p in model.named_parameters())) - for filename in files: - weights = safetensors.torch.load_file(filename) - quant_method = "bnb" - devprops = torch.cuda.get_device_properties(torch.cuda.current_device()) - left = int(os.cpu_count() / torch.cuda.device_count()) - right = int( - 8 * (devprops.total_memory / 1e9 / 40) * (70 / (param_count / 1e9)) - ) - n_workers = min(left, right) - parallel( - load_and_quantize_parallel, - weights.items(), - n_workers=n_workers, - threadpool=True, - model=model, - dtype=torch_dtype, - device=cfg.local_rank, - skip_names=[], - is_meta_rank=(cfg.local_rank != 0), - verbose=False, - quant_method=quant_method, - ) - - elif ( + if ( model_config.model_type == "llama" and not cfg.trust_remote_code and not cfg.gptq @@ -715,32 +533,6 @@ def load_and_quantize_parallel(name_param, model, **kwargs): if cfg.flash_attn_fuse_qkv: LOG.info("patching with fused QKV") replace_llama_qkv_with_fused(model) - # elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention: - # This is a WIP, still an issue with the backward pass - # RuntimeError: grad can be implicitly created only for scalar outputs - # TODO: try config.sequence_parallel = False - # # https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/tests/models/test_gpt_neox.py#L12 - # # https://github.com/HazyResearch/flash-attention/tree/main/training#model-components - # # add `**kwargs` to https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/flash_attn/models/gpt.py#L442 - # from flash_attn.utils.pretrained import state_dict_from_pretrained - # from flash_attn.models.gpt import GPTLMHeadModel - # from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox, gpt_neox_config_to_gpt2_config - # from transformers import GPTNeoXConfig - # config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(base_model)) - # config.use_flash_attn = True - # config.fused_bias_fc = True - # config.fused_mlp = True # GPT-NeoX-20B uses "gelu_fast" - # config.activation_function = "gelu_fast" - # config.fused_dropout_add_ln = True - # # config.residual_in_fp32 = True - # - # model: GPTLMHeadModel = GPTLMHeadModel.from_pretrained( - # base_model, - # config, - # dtype=torch_dtype, - # device=cfg.device, - # ) - # model.train() # sets to train instead of eval mode elif model_type == "MambaLMHeadModel": # FIXME this is janky at best and hacked together to make it work MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index e52f35ccca..380264a7ac 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -304,6 +304,10 @@ def setup_fsdp_envs(cfg): os.environ["FSDP_OFFLOAD_PARAMS"] = "true" if cfg.fsdp_config.fsdp_sync_module_states: os.environ["FSDP_SYNC_MODULE_STATES"] = "true" + if cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: + os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = "true" + if cfg.fsdp_config.fsdp_use_orig_params: + os.environ["FSDP_USE_ORIG_PARAMS"] = "true" if cfg.fsdp_config.fsdp_state_dict_type: os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.fsdp_state_dict_type if cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap: diff --git a/tests/e2e/test_mixtral.py b/tests/e2e/test_mixtral.py index ee6f06d875..68afcdec4a 100644 --- a/tests/e2e/test_mixtral.py +++ b/tests/e2e/test_mixtral.py @@ -77,7 +77,7 @@ def test_qlora_w_fa2(self, temp_dir): model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) assert ( model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype - == torch.uint8 + == torch.float32 ) assert (Path(temp_dir) / "adapter_model.bin").exists() @@ -131,7 +131,7 @@ def test_qlora_wo_fa2(self, temp_dir): model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) assert ( model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype - == torch.uint8 + == torch.float32 ) assert (Path(temp_dir) / "adapter_model.bin").exists()