Skip to content

Commit

Permalink
add bnb
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
fabianlim committed May 30, 2024
1 parent 3f03ce4 commit 97f013c
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 18 deletions.
2 changes: 1 addition & 1 deletion plugins/fused-ops-and-kernels/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Notes on the extraction of code from [unsloth](https://github.com/unslothai/unsl
Path | Description | Extracted From | Modifications | Date
--|--|--|--|--
[fused_ops/unsloth_lora](./src/fms_acceleration_foak/fused_ops/unsloth_lora) | QLoRA fast dequant, activation kernels | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | | 28 Jan 2024
[fused_ops/unsloth_lora/bnb](./src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb) | BNB fast lora | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | | 28 Jan 2024
[fused_ops/unsloth_lora/bnb](./src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb) | BNB fast lora | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | `fast_lora.py` | 28 Jan 2024
[fused_ops/unsloth_lora/gptq](./src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq) | GPTQ fast dequant (triton_v2) | `jeromeku/main` @ [2839d39](https://github.com/jeromeku/unsloth/commit/2839d390ef3bb318904289bfb9a7751a782c4e44) | `fast_lora.py`<br>`triton/layers.py` | 6 Feb 2024
[kernels/unsloth](./src/fms_acceleration_foak/kernels/unsloth) | Fast RMS, RoPE, CrossEnt kernels | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | `cross_entropy_loss.py`<br>`rms_layernorm.py` | 28 Jan 2024
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ def __init__(self, configurations: Dict[str, Dict]):
self._base_layer = self._check_config_and_maybe_check_values(
key="peft.quantization.fused_ops_and_kernels.base_layer",
values=[
"auto_gptq",
# "bitsandbytes" # enable later when we have BNB implemented
"auto_gptq", "bitsandbytes"
],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,3 +394,10 @@ def apply_lora_o(self, X):
O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS)
return O
pass

# added by [email protected]
# this will be patchable on the actual module
def apply_lora_o_v2(self, X):
OW, OW_quant, OA, OB, OS = get_lora_parameters(self)
O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS)
return O
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,19 @@
import torch
import os

# NOTE: the default activation is swiglu in both cases
from ..fused_ops.unsloth_lora.bnb.fast_lora import (
apply_lora_qkv as fused_op_qkv_bnb,
apply_lora_o_v2 as fused_op_o_bnb,
apply_lora_mlp_swiglu as fused_op_mlp_bnb,
)

from ..fused_ops.unsloth_lora.gptq.fast_lora import (
apply_lora_qkv as fused_op_qkv_gptq,
apply_lora_o_v2 as fused_op_o_gptq,
apply_lora_mlp as fused_op_mlp_gptq,
)
from .model_patcher import ModelPatcherTrigger
from functools import partial


# simple utility function to guess if its lora layer
def _is_loralayer(module: torch.nn.Module, names: List[str] = None):
if names is None:
names = ["lora_A", "lora_B", "base_layer"]
return all(hasattr(module, x) for x in names)

KEY_QKV = 'qkv'
KEY_O = 'o'
Expand All @@ -29,9 +28,22 @@ def _is_loralayer(module: torch.nn.Module, names: List[str] = None):
KEY_QKV: fused_op_qkv_gptq,
KEY_O: fused_op_o_gptq,
KEY_MLP: fused_op_mlp_gptq
},
'bitsandbytes': {
KEY_QKV: fused_op_qkv_bnb,
KEY_O: fused_op_o_bnb,
KEY_MLP: fused_op_mlp_bnb
}
}

from functools import partial

# simple utility function to guess if its lora layer
def _is_loralayer(module: torch.nn.Module, names: List[str] = None):
if names is None:
names = ["lora_A", "lora_B", "base_layer"]
return all(hasattr(module, x) for x in names)

# builds a triple of forward functions, that each can be attached
# on a series of QKV's, where if the first one is called, will call the
# fused op
Expand Down Expand Up @@ -99,7 +111,7 @@ def build_lora_fused_ops(
# get the fused op
fused_operation = FUSED_OPS[base_type][fused_op]

# handle the QKVs
# handle casting issues
if base_type == "auto_gptq":

# this is required due to this FSDP fix
Expand Down Expand Up @@ -144,11 +156,6 @@ def build_lora_fused_ops(
is_method_forward=False,
)

else:
raise NotImplementedError(
f"Cannot build fused ops for base type '{base_type}'."
)

if fused_op == KEY_QKV:
return [
(ModelPatcherTrigger(check=_is_loralayer, module_name=name), forward)
Expand Down
8 changes: 7 additions & 1 deletion sample-configurations/CONTENTS.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,10 @@ framework_configs:
plugins:
- accelerated-peft
- fused-ops-and-kernels
filename: accelerated-peft-autogptq-foak-sample-configuration.yaml
filename: accelerated-peft-autogptq-foak-sample-configuration.yaml

- shortname: accelerated-peft-bnb-foak
plugins:
- accelerated-peft
- fused-ops-and-kernels
filenabnb-nf4me: accelerated-peft-bnb-nf4-foak-sample-configuration.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# FMS Acceleration Plugin Configuration.
#
# Each stanza incorporates various configurations for
# different fine-tuning / training tasks.
plugins:
# PEFT-related acceleration
peft:

# quantization-releated acceleration
# e.g., kernels for quantized base weights
quantization:

# For loading BitsAndBytes quantized layers
# to serve as 4bit base-weights for LoRA PEFT-tuning.
# NOTE: currently AutoGPTQ is not properly integrated into huggingface /
# bitsandbytes, thus recommended quant_type to be either "nf4"
# or "fp4".
# bitsandbytes:
bitsandbytes:
quant_type: nf4

# If True, then no get_peft_model and prepare_model_for_kbit_training
# will be called.
no_peft_model: false
fused_ops_and_kernels:

# load unsloth optimizations for these 4bit base layer weights.
# currently only support "auto_gptq" and "bitsandbytes"
base_layer: bitsandbytes

# activate various unsloth optimizations
# NOTE: currently supports only all-or-nothing.

# fused kernels for lora linear layers
fused_lora: true

# fast loss triton kernels
fast_loss: true

# fast rms norm triton kernels
fast_rsm_layernorm: true

# fast RoPE embedding triton kernels
fast_rope_embeddings: true
1 change: 1 addition & 0 deletions scripts/benchmarks/scenarios.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ scenarios:
- name: accelerated-peft-bnb
framework_config:
- accelerated-peft-bnb
- accelerated-peft-autogptq-foak
arguments:
fp16: True
learning_rate: 2e-4
Expand Down
6 changes: 6 additions & 0 deletions scripts/generate_sample_configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def read_configuration(path: str) -> Dict:
KEY_BNB_NF4 = "bnb-nf4"
KEY_BNB_NF4_BASELINE = "baseline-bnb-nf4"
KEY_AUTO_GPTQ_FOAK = "auto-gptq-foak"
KEY_BNB_NF4_FOAK = "bnb-nf4-foak"

CONFIGURATIONS = {
KEY_AUTO_GPTQ: "plugins/accelerated-peft/configs/autogptq.yaml",
Expand All @@ -161,6 +162,10 @@ def read_configuration(path: str) -> Dict:
"plugins/fused-ops-and-kernels/configs/fast_quantized_peft.yaml",
[("peft.quantization.fused_ops_and_kernels.base_layer", "auto_gptq")],
),
KEY_BNB_NF4_FOAK: (
"plugins/fused-ops-and-kernels/configs/fast_quantized_peft.yaml",
[("peft.quantization.fused_ops_and_kernels.base_layer", "bitsandbytes")],
),
}

# list of (tag, combi) tuples
Expand All @@ -173,6 +178,7 @@ def read_configuration(path: str) -> Dict:
("accelerated-peft-bnb-nf4", (KEY_BNB_NF4,)),
("baseline-peft-bnb-nf4", (KEY_BNB_NF4_BASELINE,)),
("accelerated-peft-autogptq-foak", (KEY_AUTO_GPTQ, KEY_AUTO_GPTQ_FOAK)),
("accelerated-peft-bnb-nf4-foak", (KEY_BNB_NF4, KEY_BNB_NF4_FOAK)),
]

# TODO: throw error if merge conflicts
Expand Down

0 comments on commit 97f013c

Please sign in to comment.