diff --git a/deepspeed_configs/zero3_bf16_cpuoffload_all.json b/deepspeed_configs/zero3_bf16_cpuoffload_all.json index 72fde6e5f1..09ca6785b2 100644 --- a/deepspeed_configs/zero3_bf16_cpuoffload_all.json +++ b/deepspeed_configs/zero3_bf16_cpuoffload_all.json @@ -1,4 +1,6 @@ { + "zero_force_ds_cpu_optimizer": false, + "zero_allow_untested_optimizer": true, "zero_optimization": { "stage": 3, "offload_optimizer": { diff --git a/deepspeed_configs/zero3_bf16_cpuoffload_params.json b/deepspeed_configs/zero3_bf16_cpuoffload_params.json index ca051e03ba..41d4a21323 100644 --- a/deepspeed_configs/zero3_bf16_cpuoffload_params.json +++ b/deepspeed_configs/zero3_bf16_cpuoffload_params.json @@ -1,4 +1,6 @@ { + "zero_force_ds_cpu_optimizer": false, + "zero_allow_untested_optimizer": true, "zero_optimization": { "stage": 3, "offload_param": { diff --git a/examples/dbrx/16bit-lora.yaml b/examples/dbrx/16bit-lora.yaml new file mode 100644 index 0000000000..e5e3ea9216 --- /dev/null +++ b/examples/dbrx/16bit-lora.yaml @@ -0,0 +1,81 @@ +base_model: LnL-AI/dbrx-base-converted-v2 +trust_remote_code: true + +load_in_8bit: false +load_in_4bit: false +strict: false + +datasets: + - path: tatsu-lab/alpaca + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./out + +sequence_len: 512 +sample_packing: false +pad_to_sequence_len: false + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +adapter: lora +lora_model_dir: +lora_r: 8 +lora_alpha: 16 +lora_dropout: 0.05 +# w1, w2, & v1 will hang the trainer +lora_target_modules: + - q_proj # attn + - k_proj # attn + - v_proj # attn + - out_proj # attn + - layer # router +# - w1 +# - w2 +# - v1 + +gradient_accumulation_steps: 1 +micro_batch_size: 1 +num_epochs: 1 +optimizer: paged_adamw_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: false # don't use with fsdp_activation_checkpointing +gradient_checkpointing_kwargs: + use_reentrant: false +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 10 +evals_per_epoch: +saves_per_epoch: 1 +debug: +weight_decay: 0.0 +fsdp: + - full_shard + - auto_wrap +fsdp_config: + fsdp_limit_all_gathers: true + fsdp_sync_module_states: true + fsdp_offload_params: false + fsdp_use_orig_params: false + fsdp_cpu_ram_efficient_loading: true + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_transformer_layer_cls_to_wrap: DbrxBlock + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_activation_checkpointing: true diff --git a/examples/dbrx/8bit-lora.yaml b/examples/dbrx/8bit-lora.yaml new file mode 100644 index 0000000000..89e24db058 --- /dev/null +++ b/examples/dbrx/8bit-lora.yaml @@ -0,0 +1,81 @@ +base_model: LnL-AI/dbrx-base-converted-v2 +trust_remote_code: true + +load_in_8bit: true +load_in_4bit: false +strict: false + +datasets: + - path: tatsu-lab/alpaca + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./out + +sequence_len: 512 +sample_packing: false +pad_to_sequence_len: false + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +adapter: lora +lora_model_dir: +lora_r: 8 +lora_alpha: 16 +lora_dropout: 0.05 +# w1, w2, & v1 will hang the trainer +lora_target_modules: + - q_proj # attn + - k_proj # attn + - v_proj # attn + - out_proj # attn + - layer # router +# - w1 +# - w2 +# - v1 + +gradient_accumulation_steps: 1 +micro_batch_size: 1 +num_epochs: 1 +optimizer: paged_adamw_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: false # don't use with fsdp_activation_checkpointing +gradient_checkpointing_kwargs: + use_reentrant: false +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 10 +evals_per_epoch: +saves_per_epoch: 1 +debug: +weight_decay: 0.0 +fsdp: + - full_shard + - auto_wrap +fsdp_config: + fsdp_limit_all_gathers: true + fsdp_sync_module_states: true + fsdp_offload_params: false + fsdp_use_orig_params: false + fsdp_cpu_ram_efficient_loading: true + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_transformer_layer_cls_to_wrap: DbrxBlock + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_activation_checkpointing: true diff --git a/examples/dbrx/README.md b/examples/dbrx/README.md new file mode 100644 index 0000000000..99ff3dd0b7 --- /dev/null +++ b/examples/dbrx/README.md @@ -0,0 +1,26 @@ +# DBRX MoE + +Currently, for LoRA, only the `q_proj`, `k_proj`, `v_proj` `out_proj` and `layer` Linear layers are trainable. + +We are using the "converted" base models based on [this issue](https://huggingface.co/databricks/dbrx-instruct/discussions/10) +where the Experts are fused as an `nn.Parameter` rather than a `nn.Linear` layer. However, the implementation +is still a bit buggy and attempting to train a LoRA adapter over those `w1`, `w2` and `v1` layers +results in the trainer hanging. + + +### FSDP +We've tested using the [`LnL-AI/dbrx-base-converted-v2`](https://huggingface.co/LnL-AI/dbrx-base-converted-v2) model as the base model for FSDP. + +The high memory usage seen w/ FSDP is due to FSDP not supporting 8bit optimizers. + +- 16-bit LoRA w/ FSDP + - ✅ w/o CPU Offload - 8x80GB uses ~80GiB/gpu + - ❌ w/ CPU Offload - `paged_adamw_8bit` optimizer errors from being on cpu +- ✅ 8-bit LoRA w/ FSDP +- ❌ 4-bit QLoRA w/ FSDP - errors w/: `Error an illegal memory access was encountered at line 90 in file /src/csrc/ops.cu` +- ✅ bf16 full finetune w/ FSDP, freezing all but first 8 layers (8x80GB uses ~78GiB/gpu) + + +### Deepspeed + +WIP diff --git a/examples/dbrx/fft-ds-zero3.yaml b/examples/dbrx/fft-ds-zero3.yaml new file mode 100644 index 0000000000..68292707a4 --- /dev/null +++ b/examples/dbrx/fft-ds-zero3.yaml @@ -0,0 +1,56 @@ +base_model: LnL-AI/dbrx-base-converted-v2 +trust_remote_code: true + +load_in_8bit: false +load_in_4bit: false +strict: false + +datasets: + - path: tatsu-lab/alpaca + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./out + +sequence_len: 512 +sample_packing: false +pad_to_sequence_len: false + +unfrozen_parameters: + - transformer.blocks.[0-7]. + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 1 +num_epochs: 1 +optimizer: paged_adamw_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 10 +evals_per_epoch: +saves_per_epoch: 1 +debug: +weight_decay: 0.0 +deepspeed: deepspeed_configs/zero3_bf16.json diff --git a/examples/llama-2/qlora-fsdp.yml b/examples/llama-2/qlora-fsdp.yml index 30916ed45a..93b3b2a60a 100644 --- a/examples/llama-2/qlora-fsdp.yml +++ b/examples/llama-2/qlora-fsdp.yml @@ -65,12 +65,14 @@ deepspeed: weight_decay: 0.0 fsdp: - full_shard + - auto_wrap fsdp_config: fsdp_limit_all_gathers: true fsdp_sync_module_states: true fsdp_offload_params: true fsdp_use_orig_params: false fsdp_cpu_ram_efficient_loading: true + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer - fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_state_dict_type: FULL_STATE_DICT special_tokens: diff --git a/examples/mistral/bigstral-ds-zero3.yaml b/examples/mistral/bigstral-ds-zero3.yaml new file mode 100644 index 0000000000..cc0a44b2a4 --- /dev/null +++ b/examples/mistral/bigstral-ds-zero3.yaml @@ -0,0 +1,63 @@ +base_model: mistral-community/Mixtral-8x22B-v0.1 +model_type: AutoModelForCausalLM +tokenizer_type: LlamaTokenizer +trust_remote_code: true + +load_in_8bit: false +load_in_4bit: false +strict: false + +unfrozen_parameters: + - ^lm_head.weight$ + - ^model.embed_tokens.weight$ + - model.layers.4[4-9]+.block_sparse_moe.gate + - model.layers.4[4-9]+.block_sparse_moe.experts + - model.layers.5[0-5]+.block_sparse_moe.gate + - model.layers.5[0-5]+.block_sparse_moe.experts + +model_config: + output_router_logits: true + +datasets: + - path: tatsu-lab/alpaca + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.05 +output_dir: ./out + +sequence_len: 2048 +sample_packing: true +pad_to_sequence_len: true + +gradient_accumulation_steps: 1 +micro_batch_size: 1 +num_epochs: 3 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0001 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +save_total_limit: 1 +save_steps: +debug: +deepspeed: deepspeed_configs/zero3_bf16_cpuoffload_params.json +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: + eos_token: "<|im_end|>" +tokens: + - "<|im_start|>" diff --git a/examples/mistral/mistral-qlora-fsdp.yml b/examples/mistral/mistral-qlora-fsdp.yml new file mode 100644 index 0000000000..71ac1e701f --- /dev/null +++ b/examples/mistral/mistral-qlora-fsdp.yml @@ -0,0 +1,82 @@ +base_model: mistralai/Mixtral-8x7B-v0.1 +model_type: AutoModelForCausalLM +tokenizer_type: LlamaTokenizer +trust_remote_code: true + +load_in_8bit: false +load_in_4bit: true +strict: false + +datasets: + - path: tatsu-lab/alpaca + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.02 +output_dir: ./qlora-out + +model_config: + output_router_logits: true + +adapter: qlora +lora_model_dir: + +sequence_len: 1024 +sample_packing: false +pad_to_sequence_len: false + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_fan_in_fan_out: + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: paged_adamw_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +loss_watchdog_threshold: 5.0 +loss_watchdog_patience: 3 + +warmup_steps: 10 +evals_per_epoch: 4 +eval_table_size: +eval_max_new_tokens: 128 +saves_per_epoch: 1 +debug: +weight_decay: 0.0 +fsdp: + - full_shard + - auto_wrap +fsdp_config: + fsdp_limit_all_gathers: true + fsdp_sync_module_states: true + fsdp_offload_params: false + fsdp_use_orig_params: false + fsdp_cpu_ram_efficient_loading: false + fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP +special_tokens: diff --git a/examples/mistral/mixtral-8x22b-qlora-fsdp.yml b/examples/mistral/mixtral-8x22b-qlora-fsdp.yml new file mode 100644 index 0000000000..ac80a2a756 --- /dev/null +++ b/examples/mistral/mixtral-8x22b-qlora-fsdp.yml @@ -0,0 +1,81 @@ +base_model: mistral-community/Mixtral-8x22B-v0.1 +model_type: AutoModelForCausalLM +tokenizer_type: LlamaTokenizer + +load_in_8bit: false +load_in_4bit: true +strict: false + +datasets: + - path: tatsu-lab/alpaca + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.02 +output_dir: ./qlora-out + +model_config: + output_router_logits: true + +adapter: qlora +lora_model_dir: + +sequence_len: 1024 +sample_packing: false +pad_to_sequence_len: false + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_fan_in_fan_out: + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_torch +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: true + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +loss_watchdog_threshold: 5.0 +loss_watchdog_patience: 3 + +warmup_steps: 10 +evals_per_epoch: 4 +eval_table_size: +eval_max_new_tokens: 128 +saves_per_epoch: 1 +debug: +weight_decay: 0.0 +fsdp: + - full_shard + - auto_wrap +fsdp_config: + fsdp_limit_all_gathers: true + fsdp_sync_module_states: true + fsdp_offload_params: true + fsdp_use_orig_params: false + fsdp_cpu_ram_efficient_loading: true + fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP +special_tokens: diff --git a/examples/mistral/mixtral-qlora-fsdp.yml b/examples/mistral/mixtral-qlora-fsdp.yml index 32db7073b7..b6a07ae51c 100644 --- a/examples/mistral/mixtral-qlora-fsdp.yml +++ b/examples/mistral/mixtral-qlora-fsdp.yml @@ -39,7 +39,7 @@ wandb_log_model: gradient_accumulation_steps: 4 micro_batch_size: 2 num_epochs: 1 -optimizer: paged_adamw_8bit +optimizer: adamw_torch lr_scheduler: cosine learning_rate: 0.0002 @@ -47,7 +47,7 @@ train_on_inputs: false group_by_length: false bf16: auto fp16: -tf32: false +tf32: true gradient_checkpointing: true early_stopping_patience: @@ -69,6 +69,17 @@ debug: weight_decay: 0.0 fsdp: - full_shard + - auto_wrap fsdp_config: + fsdp_limit_all_gathers: true + fsdp_sync_module_states: true + fsdp_offload_params: true + fsdp_use_orig_params: false + fsdp_cpu_ram_efficient_loading: true fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_sharding_strategy: FULL_SHARD + fsdp_forward_prefetch: false + fsdp_backward_prefetch: BACKWARD_PRE special_tokens: diff --git a/requirements.txt b/requirements.txt index 785ede535e..f707946a02 100644 --- a/requirements.txt +++ b/requirements.txt @@ -41,3 +41,4 @@ gcsfs trl @ git+https://github.com/huggingface/trl.git@0ee349dcd43b0f4b3169449f16751c38ac4a609f zstandard==0.22.0 +fastcore diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 35318b836d..900dcb7887 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -918,10 +918,6 @@ def get_callbacks(self): ): callbacks.append(SaveBetterTransformerModelCallback()) - if self.cfg.use_wandb: - callbacks.append( - SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path) - ) if self.cfg.use_mlflow and is_mlflow_available(): from axolotl.utils.callbacks.mlflow_ import ( SaveAxolotlConfigtoMlflowCallback, diff --git a/src/axolotl/train.py b/src/axolotl/train.py index b6cd24672e..01e07640f9 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -9,6 +9,7 @@ import torch import transformers.modelcard +from accelerate import Accelerator from accelerate.logging import get_logger from datasets import Dataset from peft import PeftModel @@ -81,6 +82,8 @@ def train( if cfg.adapter: msg += " and peft_config..." LOG.debug(msg) + # we wait unitl the last possible moment to setup Accelerator + Accelerator() model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference) model.generation_config.do_sample = True diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 20887dccf8..0fbed08ca3 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -259,6 +259,7 @@ class ModelInputConfig(BaseModel): base_model: str base_model_config: Optional[str] = None + cls_model_config: Optional[str] = None tokenizer_config: Optional[str] = None tokenizer_use_fast: Optional[bool] = None tokenizer_legacy: Optional[bool] = None @@ -971,9 +972,16 @@ def check_val_w_test_datasets(cls, data): @model_validator(mode="before") @classmethod - def check_fsdp_w_8bit_optimizer(cls, data): - if data.get("fsdp") and "bnb" in data.get("optimizer", ""): - raise ValueError(f"FSDP not compatible with {data.get('optimizer')}") + def check_fsdp_offload_w_8bit_optimizer(cls, data): + if ( + data.get("fsdp") + and "8bit" in data.get("optimizer", "") + and data.get("fsdp_config") + and data["fsdp_config"].get("fsdp_offload_params") + ): + raise ValueError( + f"FSDP Offload not compatible with {data.get('optimizer')}" + ) return data @model_validator(mode="before") diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py index 313dd24e8c..ecb1bcc9ec 100644 --- a/src/axolotl/utils/distributed.py +++ b/src/axolotl/utils/distributed.py @@ -4,27 +4,25 @@ import os import pickle # nosec from contextlib import contextmanager +from datetime import timedelta import torch import torch.distributed as dist -from accelerate import Accelerator +from accelerate import PartialState -accelerate = None # pylint: disable=invalid-name - - -def load_accelerate(): - global accelerate # pylint: disable=global-statement - accelerate = Accelerator() +distributed_state = None # pylint: disable=invalid-name def is_distributed(): """ Check if distributed training is initialized. """ - global accelerate # pylint: disable=global-statement - if not accelerate: - accelerate = Accelerator() - return dist.is_available() and dist.is_initialized() + global distributed_state # pylint: disable=global-statement + if not distributed_state: + timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800)) + distributed_state = PartialState(timeout=timedelta(seconds=timeout)) + + return distributed_state.use_distributed and distributed_state.initialized def barrier(): diff --git a/src/axolotl/utils/model_shard_quant.py b/src/axolotl/utils/model_shard_quant.py new file mode 100644 index 0000000000..65f23b9e0f --- /dev/null +++ b/src/axolotl/utils/model_shard_quant.py @@ -0,0 +1,259 @@ +""" +module to handle loading model on cpu/meta device for FSDP +""" +import os +import time +from typing import List, Optional, Type, Union + +import safetensors +import torch +from accelerate import init_empty_weights +from bitsandbytes.nn import Linear4bit, Params4bit +from fastcore.parallel import parallel +from torch import Tensor, nn +from tqdm import tqdm +from transformers import AutoModelForCausalLM +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub + + +def _replace_linear( + model: nn.Module, + linear_replacement: Type[nn.Module], + quant_config: Union[dict, None] = None, + skip_modules=None, + **kwargs, +): + """ + Replace linear modules with a new Linear module. + Parameters: + model (`torch.nn.Module`): + Input model or `torch.nn.Module` as the function is run recursively. + linear_replacement (`torch.nn.Module`): + The linear module that replaces the old one. Only expects standard arguments. + If other arguments need to be passed, use a lambda. + skip_modules (`List[str]`, *optional*, defaults to `lm_head`): + List of modules names not to convert. Defaults to `lm_head`. + """ + if skip_modules is None: + skip_modules = ["lm_head"] + for name, module in model.named_children(): + if len(list(module.children())) > 0: + _replace_linear( + module, linear_replacement, quant_config, skip_modules, **kwargs + ) + + if isinstance(module, torch.nn.Linear) and name not in skip_modules: + if issubclass(linear_replacement, Linear4bit): + model._modules[ # pylint: disable=protected-access + name + ] = linear_replacement( + module.in_features, + module.out_features, + module.bias is not None, + **kwargs, + ) + else: + raise ValueError( + f"Unsupported linear replacement: {type(linear_replacement)}" + ) + return model + + +def load_and_quantize( + module: nn.Module, + name: str, + value: Tensor, + device: torch.device = None, + dtype: torch.dtype = None, + skip_names: Optional[List[str]] = None, + to_cpu: bool = False, + to_meta: bool = False, + verbose: bool = False, + quant_method: str = "bnb", +): + """ + Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`. + + Quantizes `Params4bit` on `device` then places on "cpu" if to_cpu=True or "meta" if to_meta=True. + """ + + if not skip_names: + skip_names = [] + + def place_on_device(value): + if to_meta: + device = "meta" + elif to_cpu: + device = "cpu" + return value.to(device=device, dtype=dtype) + + if any(skip_name in name for skip_name in skip_names): + if verbose: + print(f"Skipping {name} because it is in skip_names") + return + + module_key, _, value_key = name.rpartition(".") + try: + submodule = module.get_submodule(module_key) + except AttributeError as exc: + print(f"Module {module_key} not found:\n{exc}") + return + + try: + if quant_method == "bnb": + param = submodule.get_parameter(value_key) + if isinstance(param, Params4bit): + # With `sync_module_states=True`, a meta device Params4bit needs to be the same + # shape as the quantized Params4bit with an initialized quant_state. However, + # FSDP only syncs parameters and buffers, so the quant_state isn't copied. This + # workaround quantizes Params4bit to initialize quant_state on all ranks, then + # replaces Params4bit's data with a meta tensor to free memory on non-rank 0. + value = type(param)( + value.to(device=device, dtype=dtype).data, **param.__dict__ + ).cuda(device) + if to_meta: + value = type(param)(value.data.to("meta"), **value.__dict__) + elif to_cpu: + value = type(param)(value.data.to("cpu"), **value.__dict__) + else: + value = type(param)(place_on_device(value).data) + + except AttributeError: + # it's a buffer + value = place_on_device(value) + + setattr(submodule, value_key, value) + + +def n_loading_workers(quant_method: str, param_count: float): + devprops = torch.cuda.get_device_properties(torch.cuda.current_device()) + left = int(os.cpu_count() / torch.cuda.device_count()) + model_params_b = 70 + right = int( + (4 if quant_method == "hqq" else 8) + * (devprops.total_memory / 1e9 / 40) + * (model_params_b / (param_count / 1e9)) + ) + return min(left, right) + + +def load_sharded_model( + model_name, + model_config, + cfg, + torch_dtype=torch.bfloat16, + low_memory=True, +): + if (low_memory and cfg.local_rank == 0) or not low_memory: + model = AutoModelForCausalLM.from_pretrained( + model_name, + use_cache=False, + torch_dtype=torch.float32, + _attn_implementation=model_config._attn_implementation, # pylint: disable=protected-access + trust_remote_code=cfg.trust_remote_code, + ) + dtype = torch_dtype if not cfg.float32 else None + model.to(dtype=dtype, device="cpu" if low_memory else cfg.local_rank) + else: + with init_empty_weights(): + model = AutoModelForCausalLM.from_config( + model_config, + torch_dtype=torch_dtype, + trust_remote_code=cfg.trust_remote_code, + ) + return model + + +def load_sharded_model_quant( + model_name, + model_config, + cfg, + compute_dtype=torch.bfloat16, + quant_storage=torch.float32, + low_memory=True, + verbose=False, + loading_workers=2, +): + with init_empty_weights(): + model = AutoModelForCausalLM.from_config( + model_config, + trust_remote_code=cfg.trust_remote_code, + ) + if hasattr(model, "transformer"): + model.transformer = _replace_linear( + model.transformer, + Linear4bit, + compute_dtype=compute_dtype, + quant_type="nf4", + quant_storage=quant_storage, + ) + else: + # this is the more common case with HF transformers + model.model = _replace_linear( + model.model, + Linear4bit, + compute_dtype=compute_dtype, + quant_type="nf4", + quant_storage=quant_storage, + ) + model.is_loaded_in_4bit = True + + # Grab the safetensors files that hold the weights + try: + idx = hub.cached_file(model_name, SAFE_WEIGHTS_INDEX_NAME) + files, _ = hub.get_checkpoint_shard_files(model_name, idx) + except OSError: + try: + # This means the model doesn't have a model.safetensors.index.json because it is not sharded + files = [] + files.append(hub.cached_file(model_name, SAFE_WEIGHTS_NAME)) + except OSError as exc: + # This means the model probably doesn't have a safetensors file + raise exc + + # Load in the weights, using our custom load_and_quantize method which quantizes Params4bit on the fly + # and then places each layer on CPU or meta if using low_memory to minimize GPU memory usage + def load_and_quantize_parallel(name_param, model, **kwargs): + name, param = name_param + load_and_quantize(model, name, param, **kwargs) + + quant_method = "bnb" + param_count = sum((p.numel() for n, p in model.named_parameters())) + + n_workers = ( + n_loading_workers(quant_method, param_count) + if loading_workers == -1 + else loading_workers + ) + if cfg.local_rank == 0 and verbose: + print(f"Using n_workers: {n_workers} for loading") + + start = time.time() + for filename in tqdm( + files, + desc="Loading & Quantizing Model Shards", + disable=cfg.local_rank != 0, + position=0, + ): + weights = safetensors.torch.load_file(filename) + parallel( + load_and_quantize_parallel, + iter(weights.items()), + n_workers=n_workers, + threadpool=True, + model=model, + dtype=quant_storage, + device=cfg.local_rank, + skip_names=[], + to_cpu=(low_memory and cfg.local_rank == 0), + to_meta=(low_memory and cfg.local_rank != 0), + verbose=verbose, + quant_method=quant_method, + ) + + if cfg.local_rank == 0 and verbose: + print(f"Loaded model weights in {time.time()-start:.3f} seconds") + # cleanup any extra memory usage from parallel loading + torch.cuda.empty_cache() + + return model diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 88bf50041b..0b15850518 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -45,10 +45,35 @@ from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import zero_only from axolotl.utils.lora_embeddings import get_linear_embedding_layers +from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant LOG = logging.getLogger("axolotl") +# copied from accelerator.FullyShardedDataParallelPlugin +def get_module_class_from_name(module, name): + """ + Gets a class from a module by its name. + + Args: + module (`torch.nn.Module`): The module to get the class from. + name (`str`): The name of the class. + """ + modules_children = list(module.children()) + if module.__class__.__name__ == name: + return module.__class__ + + if len(modules_children) == 0: + return None + + for child_module in modules_children: + module_class = get_module_class_from_name(child_module, name) + if module_class is not None: + return module_class + + return None + + def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]): quant_config_exists = ( hasattr(model_config, "quantization_config") @@ -459,7 +484,7 @@ def load_model( "bnb_4bit_quant_type": "nf4", "bnb_4bit_quant_storage": torch.bfloat16, } - if not cfg.deepspeed: + if cfg.model_config_type in ["jamba", "qwen2_moe"] and not cfg.deepspeed: # for some reason, this causes the loss to be off by an order of magnitude # but deepspeed needs this still in bfloat16 bnb_config["bnb_4bit_quant_storage"] = torch.float32 @@ -470,6 +495,13 @@ def load_model( model_kwargs["quantization_config"] = BitsAndBytesConfig( **bnb_config, ) + elif cfg.adapter == "lora" and cfg.load_in_8bit: + bnb_config = { + "load_in_8bit": True, + } + model_kwargs["quantization_config"] = BitsAndBytesConfig( + **bnb_config, + ) if cfg.load_in_8bit and cfg.adapter is not None: model_kwargs["load_in_8bit"] = True @@ -517,7 +549,31 @@ def load_model( qlora_fsdp = cfg.fsdp and cfg.adapter == "qlora" try: + skip_move_to_device = False if ( + cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading + ) and not qlora_fsdp: + model = load_sharded_model( + base_model, + model_config, + cfg, + torch_dtype=cfg.torch_dtype, + ) + skip_move_to_device = True + elif ( + qlora_fsdp + and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading + and cfg.model_config_type == "dbrx" + ): + quant_storage = cfg.torch_dtype + model = load_sharded_model_quant( + base_model, + model_config, + cfg, + quant_storage=quant_storage, + ) + skip_move_to_device = True + elif ( model_config.model_type == "llama" and not cfg.trust_remote_code and not cfg.gptq @@ -597,6 +653,11 @@ def load_model( **model_kwargs, ) else: + if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: + skip_move_to_device = True + if "device_map" in model_kwargs: + del model_kwargs["device_map"] + model = AutoModelForCausalLM.from_pretrained( base_model, config=model_config, @@ -670,13 +731,17 @@ def load_model( needs_fa2_dtype = cfg.adapter or cfg.fsdp skip_prepare_model_for_kbit_training = False - if cfg.model_config_type == "mixtral" and is_deepspeed_zero3_enabled(): + if is_deepspeed_zero3_enabled(): from deepspeed.utils import ( # pylint: disable=no-name-in-module set_z3_leaf_modules, ) - from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock - set_z3_leaf_modules(model, [MixtralSparseMoeBlock]) + if cfg.model_config_type == "mixtral": + moe_block = get_module_class_from_name(model, "MixtralSparseMoeBlock") + set_z3_leaf_modules(model, [moe_block]) + elif cfg.model_config_type == "dbrx": + moe_block = get_module_class_from_name(model, "DbrxFFN") + set_z3_leaf_modules(model, [moe_block]) if cfg.model_config_type == "qwen" and cfg.adapter == "lora": # Qwen doesn't play nicely with LoRA if this is enabled @@ -686,7 +751,8 @@ def load_model( if cfg.adapter == "lora" and loftq_bits: skip_prepare_model_for_kbit_training = True - if qlora_fsdp: + if qlora_fsdp or (cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading): + # make sure everything is in the same dtype skip_prepare_model_for_kbit_training = True if cfg.adapter in ["lora", "qlora"]: @@ -727,7 +793,7 @@ def load_model( cfg.ddp and not load_in_8bit and not (cfg.rl and cfg.load_in_4bit) - and not qlora_fsdp + and not skip_move_to_device ): # TODO revaldate this conditional model.to(f"cuda:{cfg.local_rank}") @@ -883,7 +949,12 @@ def load_lora(model, cfg, inference=False, config_only=False): rank = int(os.environ.get("LOCAL_RANK", 0)) - if cfg.fsdp and cfg.adapter == "qlora" and rank != 0: + if ( + cfg.fsdp + and cfg.adapter + and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading + and rank != 0 + ): setup_quantized_meta_for_peft(model) if cfg.lora_model_dir: @@ -908,7 +979,12 @@ def load_lora(model, cfg, inference=False, config_only=False): LOG.warning( "Exception caught during model.print_trainable_parameters(): %s", exc ) - elif cfg.fsdp and cfg.adapter == "qlora": + elif ( + cfg.fsdp + and cfg.adapter + and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading + and rank != 0 + ): setup_quantized_peft_meta_for_training(model) return model, lora_config diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 6625080755..2a8ed216d1 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -306,6 +306,8 @@ def calc_sample_packing_eff_est(estimates: List[float]): def setup_fsdp_envs(cfg): os.environ["ACCELERATE_USE_FSDP"] = "true" + if cfg.fsdp_config.fsdp_activation_checkpointing: + os.environ["FSDP_ACTIVATION_CHECKPOINTING"] = "true" if cfg.fsdp_config.fsdp_offload_params: os.environ["FSDP_OFFLOAD_PARAMS"] = "true" if cfg.fsdp_config.fsdp_sync_module_states: