Skip to content

Commit

Permalink
strip out hacky qlora-fsdp workarounds now that qlora-fsdp fixes are …
Browse files Browse the repository at this point in the history
…upstreamed (#1428)
  • Loading branch information
winglian authored Mar 21, 2024
1 parent 7d55607 commit 2a1589f
Show file tree
Hide file tree
Showing 8 changed files with 27 additions and 323 deletions.
8 changes: 7 additions & 1 deletion examples/llama-2/qlora-fsdp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 4
num_epochs: 4
optimizer: paged_adamw_8bit
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.00001

Expand Down Expand Up @@ -66,5 +66,11 @@ weight_decay: 0.0
fsdp:
- full_shard
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: SHARDED_STATE_DICT
special_tokens:
7 changes: 3 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.9.0
transformers @ git+https://github.com/huggingface/transformers.git@f6261d7d81edd036fc53bfede65fe91f01a661aa
transformers @ git+https://github.com/huggingface/transformers.git@73a73b415e36f41481369f6129cb4b62bb127a78
tokenizers==0.15.0
bitsandbytes>=0.43.0
accelerate==0.26.1
bitsandbytes==0.43.0
accelerate==0.28.0
deepspeed==0.13.1
pydantic==2.6.3
addict
Expand Down Expand Up @@ -40,4 +40,3 @@ gcsfs
# adlfs

trl @ git+https://github.com/huggingface/trl.git@304e208f778a5442c30cdda500348226cdc97d90
fastcore>=1.5.29
Empty file.
55 changes: 0 additions & 55 deletions src/axolotl/core/policies/auto_wrap.py

This file was deleted.

52 changes: 5 additions & 47 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import importlib.util
import logging
import math
import os
import sys
from abc import abstractmethod
from collections import defaultdict
Expand All @@ -19,10 +18,7 @@

import torch
import transformers
from accelerate import FullyShardedDataParallelPlugin
from accelerate.utils import str_to_bool
from datasets import Dataset
from torch.distributed.fsdp import MixedPrecision
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import (
Expand All @@ -35,7 +31,6 @@
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer

from axolotl.core.policies.auto_wrap import get_wrapping_policy_factory
from axolotl.loraplus import create_loraplus_optimizer
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
Expand Down Expand Up @@ -591,51 +586,14 @@ def push_to_hub(self, *args, **kwargs) -> str:

@wraps(Trainer.create_accelerator_and_postprocess)
def create_accelerator_and_postprocess(self):
rank = int(os.environ.get("LOCAL_RANK", 0))
res = super().create_accelerator_and_postprocess()

if self.args.qlora is False:
return res

# the rest of this method override is specific to fsdp + qlora (for now)
sync_module_states = (
str_to_bool(os.environ.get("FSDP_SYNC_MODULE_STATES", "True")) == 1
)

mp_policy = None
amp = os.environ["ACCELERATE_MIXED_PRECISION"]
if amp == "fp16":
mp_policy = MixedPrecision(
param_dtype=torch.float32,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
)
elif amp == "bf16":
mp_policy = MixedPrecision(
param_dtype=torch.float32,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
)

# If somehow we figure out how we want to parameterize we want to autocast buffers...
# mp_policy = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.float32)
# load_param_skip_names = ['inv_freq']

if self.is_fsdp_enabled:
wrapping_policy = get_wrapping_policy_factory(self.args.model_type)
fsdp_plugin = FullyShardedDataParallelPlugin(
auto_wrap_policy=wrapping_policy(),
cpu_offload=False,
use_orig_params=False,
limit_all_gathers=True,
param_init_fn=lambda module: module.to_empty(
device=torch.device("cuda"), recurse=False
)
if (rank != 0 and sync_module_states)
else None,
mixed_precision_policy=mp_policy,
)
self.accelerator.state.fsdp_plugin = fsdp_plugin
if (
"limit_all_gathers" in self.args.fsdp_config
and self.args.fsdp_config["limit_all_gathers"]
):
self.accelerator.state.fsdp_plugin.limit_all_gathers = True

return res

Expand Down
Loading

0 comments on commit 2a1589f

Please sign in to comment.