Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FDSP + QLoRA #1378

Merged
merged 12 commits into from
Mar 8, 2024
70 changes: 70 additions & 0 deletions examples/llama-2/qlora-fsdp.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer

load_in_8bit: false
load_in_4bit: true
strict: false

datasets:
- path: yahma/alpaca-cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./qlora-out

adapter: qlora
lora_model_dir:

sequence_len: 512
sample_packing: false
pad_to_sequence_len: true

lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 4
num_epochs: 4
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.00001

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
- full_shard
fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
special_tokens:
74 changes: 74 additions & 0 deletions examples/mistral/mixtral-qlora-fsdp.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
base_model: mistralai/Mixtral-8x7B-v0.1
model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer
trust_remote_code: true

load_in_8bit: false
load_in_4bit: true
strict: false

datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.02
output_dir: ./qlora-out

model_config:
output_router_logits: true

adapter: qlora
lora_model_dir:

sequence_len: 1024
sample_packing: false
pad_to_sequence_len: false

lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3

warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
weight_decay: 0.0
fsdp:
- full_shard
fsdp_config:
fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
special_tokens:
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ packaging==23.2
peft==0.9.0
transformers==4.38.2
tokenizers==0.15.0
bitsandbytes>=0.41.1
bitsandbytes>=0.43.0
accelerate==0.26.1
deepspeed==0.13.1
pydantic==2.6.3
Expand Down Expand Up @@ -40,3 +40,4 @@ gcsfs
# adlfs

trl>=0.7.9
fastcore>=1.5.29
Empty file.
55 changes: 55 additions & 0 deletions src/axolotl/core/policies/auto_wrap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""module for building the auto wrap policy for FSDP"""
import functools

from peft import PrefixEncoder, PromptEmbedding, PromptEncoder
from torch.distributed.fsdp.wrap import (
_or_policy,
lambda_auto_wrap_policy,
transformer_auto_wrap_policy,
)
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer

SUPPORTED_AUTO_WRAP_MODEL_TYPES = [
"llama",
"mistral",
"mixtral",
]


def get_wrapping_policy_factory(model_type):
if model_type == "llama":
layer_to_wrap = LlamaDecoderLayer
elif model_type == "mistral":
layer_to_wrap = MistralDecoderLayer
elif model_type == "mixtral":
layer_to_wrap = MixtralDecoderLayer

def get_wrapping_policy():
"""This checks for lora layers (has weight and requires_grad)"""

def lambda_policy_fn(module):
return (
len(list(module.named_children())) == 0
and getattr(module, "weight", None) is not None
and module.weight.requires_grad
)

lambda_policy = functools.partial(
lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn
)
transformer_layer_name = layer_to_wrap
transformer_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls=(
PrefixEncoder,
PromptEncoder,
PromptEmbedding,
transformer_layer_name,
),
)
policies = [lambda_policy, transformer_wrap_policy]
return functools.partial(_or_policy, policies=policies)

return get_wrapping_policy
62 changes: 62 additions & 0 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import importlib.util
import logging
import math
import os
import sys
from abc import abstractmethod
from dataclasses import dataclass, field
Expand All @@ -17,7 +18,10 @@

import torch
import transformers
from accelerate import FullyShardedDataParallelPlugin
from accelerate.utils import str_to_bool
from datasets import Dataset
from torch.distributed.fsdp import MixedPrecision
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import (
Expand All @@ -30,6 +34,7 @@
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer

from axolotl.core.policies.auto_wrap import get_wrapping_policy_factory
from axolotl.loraplus import create_loraplus_optimizer
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
Expand Down Expand Up @@ -191,6 +196,10 @@ class AxolotlTrainingArguments(TrainingArguments):
default=1e-6,
metadata={"help": "loraplus learning rate for lora embedding layers."},
)
qlora: bool = field(
default=False,
metadata={"help": "whether this is a qlora training"},
)


class AxolotlTrainer(Trainer):
Expand Down Expand Up @@ -468,6 +477,56 @@ def push_to_hub(self, *args, **kwargs) -> str:

return super().push_to_hub(*args, **kwargs)

@wraps(Trainer.create_accelerator_and_postprocess)
def create_accelerator_and_postprocess(self):
rank = int(os.environ.get("LOCAL_RANK", 0))
res = super().create_accelerator_and_postprocess()

if self.args.qlora is False:
return res

# the rest of this method override is specific to fsdp + qlora (for now)
sync_module_states = (
str_to_bool(os.environ.get("FSDP_SYNC_MODULE_STATES", "True")) == 1
)

mp_policy = None
amp = os.environ["ACCELERATE_MIXED_PRECISION"]
if amp == "fp16":
mp_policy = MixedPrecision(
param_dtype=torch.float32,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
)
elif amp == "bf16":
mp_policy = MixedPrecision(
param_dtype=torch.float32,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
)

# If somehow we figure out how we want to parameterize we want to autocast buffers...
# mp_policy = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.float32)
# load_param_skip_names = ['inv_freq']

if self.is_fsdp_enabled:
wrapping_policy = get_wrapping_policy_factory(self.args.model_type)
fsdp_plugin = FullyShardedDataParallelPlugin(
auto_wrap_policy=wrapping_policy(),
cpu_offload=False,
use_orig_params=False,
limit_all_gathers=True,
param_init_fn=lambda module: module.to_empty(
device=torch.device("cuda"), recurse=False
)
if (rank != 0 and sync_module_states)
else None,
mixed_precision_policy=mp_policy,
)
self.accelerator.state.fsdp_plugin = fsdp_plugin

return res


class AxolotlMambaTrainer(AxolotlTrainer):
"""
Expand Down Expand Up @@ -787,6 +846,9 @@ def build(self, total_num_steps):
if self.cfg.fsdp_config:
training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config)

if self.cfg.adapter == "qlora":
training_arguments_kwargs["qlora"] = True

# deepspeed
if self.cfg.deepspeed:
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/utils/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ def wrapper(*args, **kwargs):
or not torch.cuda.is_available()
or device == "auto"
or torch.device(device).type == "cpu"
or torch.device(device).type == "meta"
):
return default_value

return func(*args, **kwargs)

return wrapper
Expand Down
Loading
Loading