From 606d2e954fd17999af40e6fb4f712055ca11b2f0 Mon Sep 17 00:00:00 2001 From: Chansung Park Date: Wed, 8 May 2024 22:08:13 +0900 Subject: [PATCH] Add fsdp+qlora support (#160) --- recipes/accelerate_configs/fsdp_qlora.yaml | 25 ++++++++++++++++++++++ scripts/README.md | 4 ++++ setup.py | 4 ++-- src/alignment/configs.py | 3 +++ src/alignment/model_utils.py | 1 + 5 files changed, 35 insertions(+), 2 deletions(-) create mode 100644 recipes/accelerate_configs/fsdp_qlora.yaml diff --git a/recipes/accelerate_configs/fsdp_qlora.yaml b/recipes/accelerate_configs/fsdp_qlora.yaml new file mode 100644 index 00000000..f28a0f10 --- /dev/null +++ b/recipes/accelerate_configs/fsdp_qlora.yaml @@ -0,0 +1,25 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: FSDP +downcast_bf16: 'no' +fsdp_config: + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_backward_prefetch: BACKWARD_PRE + fsdp_cpu_ram_efficient_loading: true + fsdp_forward_prefetch: false + fsdp_offload_params: true + fsdp_sharding_strategy: FULL_SHARD + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_sync_module_states: true + fsdp_use_orig_params: false +machine_rank: 0 +main_training_function: main +mixed_precision: 'no' +num_machines: 1 +num_processes: 2 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/scripts/README.md b/scripts/README.md index 01efe528..3860e41b 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -7,6 +7,7 @@ In the handbook, we provide three main ways to align LLMs for chat: - Full fine-tuning on a multi-GPU machine with DeepSpeed ZeRO-3 (tested on an 8 x A100 (80GB) node). - LoRA or QLoRA fine-tuning on a single consumer 24GB GPU (tested on an RTX 4090). - LoRA fine-tuning on a multi-GPU machine with DeepSpeed ZeRO-3 (tested on a 2 x A100s (80GB)). +- QLoRA fine-tuning on multi-GPU machine with FSDP (tested on a 2 x A6000s (48GB)). In practice, we find comparable performance for both full and QLoRA fine-tuning, with the latter having the advantage of producing small adapter weights that are fast to upload and download from the Hugging Face Hub. Here are the general commands to fine-tune your models: @@ -22,6 +23,9 @@ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_con # LoRA training with ZeRO-3 on two or more GPUs ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml --num_processes={num_gpus} scripts/run_{task}.py recipes/{model_name}/{task}/config_qlora.yaml --load_in_4bit=false + +# QLoRA training with FSDP on two or more GPUs +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/fsdp+qlora.yaml --num_processes={num_gpus} scripts/run_{task}.py recipes/{model_name}/{task}/config_qlora.yaml --torch_dtype=bfloat16 --bnb_4bit_quant_storage=bfloat16 ``` Here `{task}` refers to the type of training you wish to run. Currently the following tasks are supported: diff --git a/setup.py b/setup.py index 61737130..5ae2312e 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,7 @@ # * If a dependency is fast-moving (e.g. transformers), pin to the exact version _deps = [ "accelerate>=0.29.2", - "bitsandbytes==0.41.2.post2", + "bitsandbytes>=0.43.0", "black==23.1.0", "datasets>=2.18.0", "deepspeed==0.12.2", @@ -57,7 +57,7 @@ "numpy>=1.24.2", "packaging>=23.0", "parameterized>=0.9.0", - "peft==0.7.1", + "peft>=0.9.0", "protobuf<=3.20.2", # Needed to avoid conflicts with `transformers` "pytest", "safetensors>=0.3.3", diff --git a/src/alignment/configs.py b/src/alignment/configs.py index 9881206d..208be0e2 100644 --- a/src/alignment/configs.py +++ b/src/alignment/configs.py @@ -185,6 +185,9 @@ class ModelArguments: default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"} ) use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"}) + bnb_4bit_quant_storage: Optional[str] = field( + default="uint8", metadata={"help": "storage type to pack the quanitzed 4-bit prarams."} + ) def __post_init__(self): if self.load_in_8bit and self.load_in_4bit: diff --git a/src/alignment/model_utils.py b/src/alignment/model_utils.py index 14cd9cbb..fe1ecadb 100644 --- a/src/alignment/model_utils.py +++ b/src/alignment/model_utils.py @@ -51,6 +51,7 @@ def get_quantization_config(model_args: ModelArguments) -> BitsAndBytesConfig | bnb_4bit_compute_dtype=compute_dtype, bnb_4bit_quant_type=model_args.bnb_4bit_quant_type, bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant, + bnb_4bit_quant_storage=model_args.bnb_4bit_quant_storage, ) elif model_args.load_in_8bit: quantization_config = BitsAndBytesConfig(