Skip to content

Commit

Permalink
qlora-fsdp ram efficient loading with hf trainer (#1791)
Browse files Browse the repository at this point in the history
* fix 405b with lower cpu ram requirements

* make sure to use doouble quant and only skip output embeddings

* set model attributes

* more fixes for sharded fsdp loading

* update the base model in example to use pre-quantized nf4-bf16 weights

* upstream fixes  for qlora+fsdp
  • Loading branch information
winglian authored Jul 30, 2024
1 parent dbf8fb5 commit 3ebf224
Show file tree
Hide file tree
Showing 10 changed files with 52 additions and 14 deletions.
1 change: 0 additions & 1 deletion docker/Dockerfile-cloud
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ FROM winglian/axolotl:$BASE_TAG

ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
ENV HF_HUB_ENABLE_HF_TRANSFER="1"

Expand Down
1 change: 0 additions & 1 deletion docker/Dockerfile-cloud-no-tmux
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ FROM winglian/axolotl:$BASE_TAG

ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
ENV HF_HUB_ENABLE_HF_TRANSFER="1"

Expand Down
7 changes: 4 additions & 3 deletions examples/llama-3/qlora-fsdp-405b.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
base_model: meta-llama/Meta-Llama-3.1-405B
base_model: hugging-quants/Meta-Llama-3.1-405B-BNB-NF4-BF16
tokenizer_type: AutoTokenizer

load_in_4bit: true
Expand All @@ -10,10 +10,11 @@ datasets:
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out/qlora-llama3_1-405b
save_safetensors: true

adapter: qlora

sequence_len: 1024
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true

Expand All @@ -25,7 +26,7 @@ lora_target_linear: true

gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 4
num_epochs: 2
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.00001
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.11.1
transformers==4.43.3
transformers @ git+https://github.com/huggingface/transformers.git@026a173a64372e9602a16523b8fae9de4b0ff428
tokenizers==0.19.1
bitsandbytes==0.43.1
bitsandbytes==0.43.3
accelerate==0.32.0
deepspeed==0.14.4
pydantic==2.6.3
Expand Down
4 changes: 3 additions & 1 deletion src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.models import load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import prepare_optim_env
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars

project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
Expand Down Expand Up @@ -382,6 +382,8 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):

prepare_optim_env(cfg)

prepare_opinionated_env(cfg)

normalize_config(cfg)

normalize_cfg_datasets(cfg)
Expand Down
4 changes: 3 additions & 1 deletion src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,7 +1243,9 @@ def build(self, total_num_steps):
if self.cfg.fsdp:
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
if self.cfg.fsdp_config:
training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config)
training_arguments_kwargs["fsdp_config"] = {
k.lstrip("fsdp_"): v for k, v in dict(self.cfg.fsdp_config).items()
}

if self.cfg.adapter == "qlora":
training_arguments_kwargs["qlora"] = True
Expand Down
8 changes: 8 additions & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,12 @@ class LoraConfig(BaseModel):
peft_use_rslora: Optional[bool] = None
peft_layer_replication: Optional[List[Tuple[int, int]]] = None

qlora_sharded_model_loading: Optional[bool] = Field(
default=False,
metadata={
"help": "load qlora model in sharded format for FSDP using answer.ai technique."
},
)
lora_on_cpu: Optional[bool] = None
gptq: Optional[bool] = None
bnb_config_kwargs: Optional[Dict[str, Any]] = None
Expand Down Expand Up @@ -939,6 +945,8 @@ def check_evals(cls, data):
@model_validator(mode="before")
@classmethod
def check_eval_packing(cls, data):
# TODO also should check test_datasets and val_set_size as we can skip
# if there are no eval datasets/splits
if (
data.get("sample_packing")
and data.get("eval_table_size")
Expand Down
18 changes: 18 additions & 0 deletions src/axolotl/utils/model_shard_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch import Tensor, nn
from tqdm import tqdm
from transformers import AutoModelForCausalLM
from transformers.quantizers import AutoHfQuantizer
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub


Expand Down Expand Up @@ -173,6 +174,7 @@ def load_sharded_model_quant(
low_memory=True,
verbose=False,
loading_workers=2,
quantization_config=None,
):
with init_empty_weights():
model = AutoModelForCausalLM.from_config(
Expand All @@ -186,15 +188,26 @@ def load_sharded_model_quant(
compute_dtype=compute_dtype,
quant_type="nf4",
quant_storage=quant_storage,
compress_statistics=True, # bnb_4bit_use_double_quant
skip_modules=[
"lm_head",
"embed_out",
],
)
else:
# this is the more common case with HF transformers
# TODO can we detect the model arch and dynamically set skip_modules
model.model = _replace_linear(
model.model,
Linear4bit,
compute_dtype=compute_dtype,
quant_type="nf4",
quant_storage=quant_storage,
compress_statistics=True, # bnb_4bit_use_double_quant
skip_modules=[
"lm_head",
"embed_out",
],
)
model.is_loaded_in_4bit = True

Expand Down Expand Up @@ -251,6 +264,11 @@ def load_and_quantize_parallel(name_param, model, **kwargs):
quant_method=quant_method,
)

# these attributes are needed to inform transformers/peft of the quantization
model.is_quantized = True
model.quantization_method = "bitsandbytes"
model.hf_quantizer = AutoHfQuantizer.from_config(quantization_config)

if cfg.local_rank == 0 and verbose:
print(f"Loaded model weights in {time.time()-start:.3f} seconds")
# cleanup any extra memory usage from parallel loading
Expand Down
9 changes: 8 additions & 1 deletion src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,14 +624,21 @@ def load_model(
elif (
qlora_fsdp
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
and cfg.model_config_type == "dbrx"
and (cfg.model_config_type == "dbrx" or cfg.qlora_sharded_model_loading)
):
quant_storage = cfg.torch_dtype
quantization_config = hasattr(
model_config, "quantization_config"
) and getattr(model_config, "quantization_config")
quantization_config = (
quantization_config or model_kwargs["quantization_config"]
)
model = load_sharded_model_quant(
base_model,
model_config,
cfg,
quant_storage=quant_storage,
quantization_config=quantization_config,
)
skip_move_to_device = True
elif (
Expand Down
10 changes: 6 additions & 4 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,10 +393,6 @@ def calc_sample_packing_eff_est(estimates: List[float]):
def setup_deepspeed_env(cfg, stage=None):
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
if cfg.bf16:
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
elif cfg.fp16:
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"
if stage:
os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage)
if stage == 3:
Expand Down Expand Up @@ -444,6 +440,12 @@ def prepare_optim_env(cfg):
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"


def prepare_opinionated_env(cfg):
if cfg.qlora_sharded_model_loading:
# model loading is forked after the tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"


def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]:
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
Expand Down

0 comments on commit 3ebf224

Please sign in to comment.