diff --git a/examples/dbrx/16bit-lora.yaml b/examples/dbrx/16bit-lora.yaml index 55754191c6..5e0faa5477 100644 --- a/examples/dbrx/16bit-lora.yaml +++ b/examples/dbrx/16bit-lora.yaml @@ -51,7 +51,7 @@ bf16: auto fp16: tf32: false -gradient_checkpointing: true +gradient_checkpointing: false # don't use with fsdp_activation_checkpointing gradient_checkpointing_kwargs: use_reentrant: false early_stopping_patience: @@ -78,3 +78,4 @@ fsdp_config: fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP fsdp_transformer_layer_cls_to_wrap: DbrxBlock fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_activation_checkpointing: true diff --git a/examples/dbrx/8bit-lora.yaml b/examples/dbrx/8bit-lora.yaml new file mode 100644 index 0000000000..5ed20c93a7 --- /dev/null +++ b/examples/dbrx/8bit-lora.yaml @@ -0,0 +1,81 @@ +base_model: LnL-AI/dbrx-base-converted-v2 +trust_remote_code: true + +load_in_8bit: true +load_in_4bit: false +strict: false + +datasets: + - path: tatsu-lab/alpaca + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./out + +sequence_len: 512 +sample_packing: false +pad_to_sequence_len: false + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +adapter: lora +lora_model_dir: +lora_r: 8 +lora_alpha: 16 +lora_dropout: 0.05 +# w1, w2, & v1 will hang the trainer +lora_target_modules: + - q_proj # attn + - k_proj # attn + - v_proj # attn + - out_proj # attn + - layer # router +# - w1 +# - w2 +# - v1 + +gradient_accumulation_steps: 1 +micro_batch_size: 1 +num_epochs: 1 +optimizer: paged_adamw_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: false # don't use with fsdp_activation_checkpointing +gradient_checkpointing_kwargs: + use_reentrant: false +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 10 +evals_per_epoch: +saves_per_epoch: 1 +debug: +weight_decay: 0.0 +fsdp: + - full_shard + - auto_wrap +fsdp_config: + fsdp_limit_all_gathers: true + fsdp_sync_module_states: true + fsdp_offload_params: false + fsdp_use_orig_params: false + fsdp_cpu_ram_efficient_loading: true + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_transformer_layer_cls_to_wrap: DbrxBlock + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_activation_checkpointing: true diff --git a/examples/dbrx/README.md b/examples/dbrx/README.md index c0f69d328c..99ff3dd0b7 100644 --- a/examples/dbrx/README.md +++ b/examples/dbrx/README.md @@ -16,7 +16,7 @@ The high memory usage seen w/ FSDP is due to FSDP not supporting 8bit optimizers - 16-bit LoRA w/ FSDP - ✅ w/o CPU Offload - 8x80GB uses ~80GiB/gpu - ❌ w/ CPU Offload - `paged_adamw_8bit` optimizer errors from being on cpu -- ❓ 8-bit LoRA w/ FSDP - WIP, need to handle loading 8-bit quantized weights +- ✅ 8-bit LoRA w/ FSDP - ❌ 4-bit QLoRA w/ FSDP - errors w/: `Error an illegal memory access was encountered at line 90 in file /src/csrc/ops.cu` - ✅ bf16 full finetune w/ FSDP, freezing all but first 8 layers (8x80GB uses ~78GiB/gpu) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index b6cd24672e..01e07640f9 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -9,6 +9,7 @@ import torch import transformers.modelcard +from accelerate import Accelerator from accelerate.logging import get_logger from datasets import Dataset from peft import PeftModel @@ -81,6 +82,8 @@ def train( if cfg.adapter: msg += " and peft_config..." LOG.debug(msg) + # we wait unitl the last possible moment to setup Accelerator + Accelerator() model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference) model.generation_config.do_sample = True diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 877b289b5c..83f2251570 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -468,6 +468,13 @@ def load_model( model_kwargs["quantization_config"] = BitsAndBytesConfig( **bnb_config, ) + elif cfg.adapter == "lora" and cfg.load_in_8bit: + bnb_config = { + "load_in_8bit": True, + } + model_kwargs["quantization_config"] = BitsAndBytesConfig( + **bnb_config, + ) if cfg.load_in_8bit and cfg.adapter is not None: model_kwargs["load_in_8bit"] = True @@ -517,10 +524,8 @@ def load_model( try: skip_move_to_device = False if ( - cfg.fsdp - and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading - and not qlora_fsdp - ): + cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading + ) and not qlora_fsdp: model = load_sharded_model( base_model, model_config, @@ -704,7 +709,8 @@ def load_model( if cfg.adapter == "lora" and loftq_bits: skip_prepare_model_for_kbit_training = True - if qlora_fsdp: + if qlora_fsdp or (cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading): + # make sure everything is in the same dtype skip_prepare_model_for_kbit_training = True if cfg.adapter in ["lora", "qlora"]: @@ -750,13 +756,6 @@ def load_model( # TODO revaldate this conditional model.to(f"cuda:{cfg.local_rank}") - if ( - cfg.fsdp - and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading - and cfg.local_rank != 0 - ): - setup_quantized_peft_meta_for_training(model) - if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1: setattr(model, "is_parallelizable", True) setattr(model, "model_parallel", True)