Skip to content

Commit

Permalink
Add fsdp+qlora support (#160)
Browse files Browse the repository at this point in the history
  • Loading branch information
deep-diver authored May 8, 2024
1 parent 84f8c92 commit 606d2e9
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 2 deletions.
25 changes: 25 additions & 0 deletions recipes/accelerate_configs/fsdp_qlora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: false
fsdp_offload_params: true
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sync_module_states: true
fsdp_use_orig_params: false
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
4 changes: 4 additions & 0 deletions scripts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ In the handbook, we provide three main ways to align LLMs for chat:
- Full fine-tuning on a multi-GPU machine with DeepSpeed ZeRO-3 (tested on an 8 x A100 (80GB) node).
- LoRA or QLoRA fine-tuning on a single consumer 24GB GPU (tested on an RTX 4090).
- LoRA fine-tuning on a multi-GPU machine with DeepSpeed ZeRO-3 (tested on a 2 x A100s (80GB)).
- QLoRA fine-tuning on multi-GPU machine with FSDP (tested on a 2 x A6000s (48GB)).

In practice, we find comparable performance for both full and QLoRA fine-tuning, with the latter having the advantage of producing small adapter weights that are fast to upload and download from the Hugging Face Hub. Here are the general commands to fine-tune your models:

Expand All @@ -22,6 +23,9 @@ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_con

# LoRA training with ZeRO-3 on two or more GPUs
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml --num_processes={num_gpus} scripts/run_{task}.py recipes/{model_name}/{task}/config_qlora.yaml --load_in_4bit=false

# QLoRA training with FSDP on two or more GPUs
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/fsdp+qlora.yaml --num_processes={num_gpus} scripts/run_{task}.py recipes/{model_name}/{task}/config_qlora.yaml --torch_dtype=bfloat16 --bnb_4bit_quant_storage=bfloat16
```

Here `{task}` refers to the type of training you wish to run. Currently the following tasks are supported:
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
# * If a dependency is fast-moving (e.g. transformers), pin to the exact version
_deps = [
"accelerate>=0.29.2",
"bitsandbytes==0.41.2.post2",
"bitsandbytes>=0.43.0",
"black==23.1.0",
"datasets>=2.18.0",
"deepspeed==0.12.2",
Expand All @@ -57,7 +57,7 @@
"numpy>=1.24.2",
"packaging>=23.0",
"parameterized>=0.9.0",
"peft==0.7.1",
"peft>=0.9.0",
"protobuf<=3.20.2", # Needed to avoid conflicts with `transformers`
"pytest",
"safetensors>=0.3.3",
Expand Down
3 changes: 3 additions & 0 deletions src/alignment/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ class ModelArguments:
default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"}
)
use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"})
bnb_4bit_quant_storage: Optional[str] = field(
default="uint8", metadata={"help": "storage type to pack the quanitzed 4-bit prarams."}
)

def __post_init__(self):
if self.load_in_8bit and self.load_in_4bit:
Expand Down
1 change: 1 addition & 0 deletions src/alignment/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def get_quantization_config(model_args: ModelArguments) -> BitsAndBytesConfig |
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
bnb_4bit_quant_storage=model_args.bnb_4bit_quant_storage,
)
elif model_args.load_in_8bit:
quantization_config = BitsAndBytesConfig(
Expand Down

0 comments on commit 606d2e9

Please sign in to comment.