Skip to content

Commit

Permalink
Workaround Low-Mem-Mode Patch for GPTQ-LoRA (#26)
Browse files Browse the repository at this point in the history
* workaround low-mem patch

* resolve conflicts and define patch function

* resolve conflicts and define patch function

* Apply suggestions from code review

Co-authored-by: Yu Chin Fabian Lim <[email protected]>

* revert hack to avoid low memory bug in HF memory metrics calculation

* reversed formatting

* reverse more formatting

---------

Co-authored-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
achew010 and fabianlim authored May 29, 2024
1 parent c3c1cdd commit 25171a0
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,77 @@
# https://spdx.dev/learn/handling-license-info/

# Standard
from typing import Callable, List
from typing import Any, Callable, List
import importlib

# Third Party
from peft import LoraConfig
from peft.tuners.lora.gptq import QuantLinear as LoraLinearGPTQ
import torch


# This function may be moved after merging
# https://github.com/foundation-model-stack/fms-acceleration/pull/25
def _patch_target_module(
to_patch: str,
replace_with: Any,
target_module: str = None,
):
to_patch = to_patch.split(".")
assert len(to_patch) > 1, "must have an object to patch"

to_patch, obj_name_to_patch = to_patch[:-1], to_patch[-1]
to_patch = ".".join(to_patch)
source = importlib.import_module(to_patch)
original_obj = getattr(source, obj_name_to_patch)
setattr(source, obj_name_to_patch, replace_with)

if target_module is not None:
# reload and this should get the patched object
target_module = importlib.import_module(target_module)
importlib.reload(target_module)

# replace it
setattr(source, obj_name_to_patch, original_obj)


def make_sure_no_tensor_in_meta_device(
model,
use_triton: bool,
desc_act: bool,
group_size: int,
bits: int,
disable_exllama: bool,
disable_exllamav2: bool,
use_marlin: bool = False,
use_tritonv2: bool = False,
):
# Third Party
# guarded import
from auto_gptq.utils.import_utils import ( # pylint: disable=import-outside-toplevel,import-error
dynamically_import_QuantLinear,
)

QuantLinear = dynamically_import_QuantLinear(
use_triton,
desc_act,
group_size,
bits=bits,
disable_exllama=disable_exllama,
disable_exllamav2=disable_exllamav2,
use_marlin=use_marlin,
use_tritonv2=use_tritonv2,
)
for _, m in model.named_modules():
bias = getattr(m, "bias", None)
if bias:
if isinstance(m, QuantLinear) and bias.device == torch.device("meta"):
m.register_buffer(
"bias",
torch.zeros((m.outfeatures), dtype=torch.float16, device="cpu"),
)


def replace_module_peft(self, parent_module, child_name, new_module, old_module):

# replace the lora linear
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from peft import LoraConfig, prepare_model_for_kbit_training
from peft.tuners.lora.model import LoraModel
from transformers import AutoModelForCausalLM, TrainingArguments
from transformers.modeling_utils import is_fsdp_enabled
import torch
import torch.distributed

Expand All @@ -48,14 +49,15 @@ def __init__(self, configurations: Dict[str, Dict]):
)

def model_loader(self, model_name: str, **kwargs):

# guarded imports
# Third Party
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig #pylint: disable=import-outside-toplevel,import-error
from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear #pylint: disable=import-outside-toplevel,import-error

# Local
from .autogptq_utils import patch_forward_to_view_attributes_before_call #pylint: disable=import-outside-toplevel
from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
patch_forward_to_view_attributes_before_call,
)

# Currently we allow only a quantized checkpoint to be loaded, we do not
# implement the quantization process here.
Expand Down Expand Up @@ -84,20 +86,6 @@ def model_loader(self, model_name: str, **kwargs):
low_cpu_mem_usage = kwargs.get("low_cpu_mem_usage")
attn_implementation = kwargs.get("attn_implementation")

if low_cpu_mem_usage:
# Note that low_cpu_mem_usage is typically set to
# transformers.modeling_utils.is_fsdp_enabled.
# e.g.,
# https://github.com/huggingface/transformers/blob/a98c41798cf6ed99e1ff17e3792d6e06a2ff2ff3/src/transformers/modeling_utils.py#L2989-L2990
# but not doing that now as AutoGPTQ will call make_sure_no_tensor_in_meta_device
# https://github.com/AutoGPTQ/AutoGPTQ/blob/ea829c7bbe83561c2b1de26795b6592992373ef7/auto_gptq/modeling/_base.py#L982C17-L982C51
# which does not properly check if a QuantLayer has a bias set or not,
# https://github.com/AutoGPTQ/AutoGPTQ/blob/ea829c7bbe83561c2b1de26795b6592992373ef7/auto_gptq/modeling/_utils.py#L514
raise ValueError(
"low_cpu_mem_usage set to True. This may raise error if model has no bias, "
"due to AutoGPTQ bug. Not supporting at the moment."
)

# there are some kwargs that we wont be passed to AutoModel, so we need
# to patch them in
_old_from_config = AutoModelForCausalLM.from_config
Expand All @@ -107,14 +95,40 @@ def model_loader(self, model_name: str, **kwargs):
)
AutoModelForCausalLM.from_config = _from_config # patch

# NOTE: need to set the device map as below as we want to
# use AutoGPTQ for training.
# this is a HF method that checks if the low_cpu_mem mode is enabled
# via HF accelerate
if is_fsdp_enabled():
# Local
from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
_patch_target_module,
make_sure_no_tensor_in_meta_device,
)

# We patch `make_sure_no_tensor_in_meta_device`
# from autogptq to avoid errors on models without bias
_patch_target_module(
to_patch="auto_gptq.modeling._utils.make_sure_no_tensor_in_meta_device",
replace_with=make_sure_no_tensor_in_meta_device,
target_module="auto_gptq.modeling._base",
)
low_cpu_mem_usage = True

# NOTE: need to set the device map as below as we want to use AutoGPTQ for training.
# device_map is for inference only
# ref: https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference
# Thus we set it as below to effectively disable it.
device_map = (
{"": torch.cuda.current_device()} if torch.cuda.is_available() else None
)
# https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference
# For low_cpu_mem_usage = True, we have to set the device map to load checkpoints to "cpu"
# to avoid gpu consumption before train
# This approach will divert consumption to cpu memory,
# a better approach would be to load the checkpoints to meta device
# QLoRA is currently implemented by the former approach and will encounter the same issue.
# see https://github.com/huggingface/transformers/pull/25107#issuecomment-2134833262
device_map = {
"": (
(torch.cuda.current_device() if not low_cpu_mem_usage else "cpu")
if torch.cuda.is_available()
else None
)
}

# currently only enable triton_v2, because the triton kernels are the only ones
# that have backwards
Expand Down Expand Up @@ -204,9 +218,11 @@ def augmentation(
# Third Party
from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear #pylint: disable=import-outside-toplevel,import-error
from auto_gptq.utils.peft_utils import GPTQLoraModel, get_gptq_peft_model #pylint: disable=import-outside-toplevel,import-error

# Local
from .autogptq_utils import create_new_module_peft, replace_module_peft #pylint: disable=import-outside-toplevel
from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
create_new_module_peft,
replace_module_peft,
)

(peft_config,) = modifiable_args # unpack modifiable args

Expand Down
3 changes: 2 additions & 1 deletion scripts/benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ We currently compute the memory values in the report by taking the largest of su
For allocated memory value
```
max([
stage0_mem,
stage0_mem + stage1_allocated_delta,
stage0_mem + stage1_allocated_delta + stage2_allocated_delta,
...
Expand All @@ -173,13 +174,13 @@ max([
For peak memory value
```
max([
stage0_mem,
stage0_mem + stage1_allocated_delta + stage1_peaked_delta,
stage0_mem + stage1_allocated_delta + stage2_allocated_delta + stage2_peaked_delta,
...
])
```
Notice that we do not include `stage0_mem` alone when computing the max value. This is to avoid misleading comparisons between GPTQ-LoRA and others. GPTQ-LoRA + FSDP currently does not support low-memory mode as mentioned [here](https://github.com/foundation-model-stack/fms-acceleration/issues/18). The `stage0_mem` value of GPTQ-LoRA + FSDP will reflect a larger than expected value as it is loaded fully before the trainer is initialized and then subsequently will be sharded internally in `trainer.prepare`. This might cause some misleading comparisons when other variants are loaded in low-memory mode and have smaller `stage0_mem` memory consumption than GPTQ-LoRA + FSDP. Once low-memory mode is supported for GPTQ-LoRA, we will include `stage0_mem` back inside the max computation
We compare memory values between Nvidia-SMI and Torch in this PR - [Memory Benchmarking](https://github.com/foundation-model-stack/fms-acceleration/pull/14).
Expand Down
4 changes: 2 additions & 2 deletions scripts/benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ def extract_gpu_memory_metrics(output_metrics) -> Tuple[float]:
return 0, 0

trainer_stage_order = [
(HF_TRAINER_LOG_GPU_STAGE_BEFORE_INIT, False),
(HF_TRAINER_LOG_GPU_STAGE_INIT, False),
(HF_TRAINER_LOG_GPU_STAGE_BEFORE_INIT, True),
(HF_TRAINER_LOG_GPU_STAGE_INIT, True),
(HF_TRAINER_LOG_GPU_STAGE_TRAIN, True),
]
alloc_running_sum = 0
Expand Down

0 comments on commit 25171a0

Please sign in to comment.