Skip to content

Commit

Permalink
feat: add liger kernel with fused cross entropy loss (#93)
Browse files Browse the repository at this point in the history
* initial implementation of fused-linear-loss on llama

Signed-off-by: 1000850000 user <[email protected]>
Signed-off-by: Anh Uong <[email protected]>

* syntax fixes and remove unused code

Signed-off-by: Anh Uong <[email protected]>

* add new num_logits_to_keep arg for llama.forward()

Signed-off-by: Anh Uong <[email protected]>

* add mixtral model patch

Signed-off-by: Anh Uong <[email protected]>

* add mistral and granite model patch

Signed-off-by: Anh Uong <[email protected]>

* add benchmark

Signed-off-by: Anh Uong <[email protected]>

* add new liger benchmarks

Signed-off-by: Anh Uong <[email protected]>

* some fixes

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* revise benches

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* refactor to fused_ops

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* fix fmt + lint

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* update full benches and readme

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* fix fast foak configs

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* docs: update foak readme benchmarks

Signed-off-by: Anh Uong <[email protected]>

---------

Signed-off-by: 1000850000 user <[email protected]>
Signed-off-by: Anh Uong <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Co-authored-by: 1000850000 user <[email protected]>
Co-authored-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
3 people authored Dec 2, 2024
1 parent c70ffe0 commit 733992a
Show file tree
Hide file tree
Showing 25 changed files with 1,326 additions and 18 deletions.
2 changes: 1 addition & 1 deletion plugins/framework/src/fms_acceleration/framework_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def _check_config_and_maybe_check_values(
t = list(t.keys())[0] # otherwise take the first value

if t not in values:
if default is None:
if t is not None or default is None:
raise AccelerationPluginConfigError(
f"{self.__class__.__name__}: Value at '{key}' was '{t}'. "
f"Not found in expected set '{values}'."
Expand Down
3 changes: 2 additions & 1 deletion plugins/fused-ops-and-kernels/.isort.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ known_firstparty=
known_localfolder=tuning

# skip code imported from unsloth
skip_glob=**/unsloth*/**
skip_glob=**/unsloth*/**,
**/liger*/**
4 changes: 3 additions & 1 deletion plugins/fused-ops-and-kernels/.pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ ignore=CVS,protobufs
# format. Because '\\' represents the directory delimiter on Windows systems,
# it can't be used as an escape character.
# NOTE: do not lint code imported from unsloth
ignore-paths=.*fused_ops/unsloth_lora.*,.*kernels/unsloth*
ignore-paths=.*fused_ops/unsloth_lora.*,
.*fused_ops/liger_ce.*,
.*kernels/unsloth*,

# Files or directories matching the regular expression patterns are skipped.
# The regex matches against base names, not paths. The default value ignores
Expand Down
15 changes: 14 additions & 1 deletion plugins/fused-ops-and-kernels/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,23 @@ It is realtively easy by following an existing template, in what follows we use
)
```
### Running Liger Kernel Benchmarks
Using the [scenarios-liger.yaml](../../scripts/benchmarks/scenarios-liger.yaml), this will run full fine tuning, lora peft, autoGPTQ lora peft, and bits-and-bytes lora peft with the triton kernels (Fast RMS, RoPE, CrossEnt) as a base and then run with the liger kernel for LigerFusedLinearCrossEntropy as well as Fast RMS, RoPE to compare results. It only runs against mistral and llama models.
The benchmarks were ran separately for each `num_gpu` entry; they can be run together in a single command, but this is more efficient.
```sh
tox -e run-benches -- 1 "4 8 16 32" benchmark_outputs_1 scenarios-liger.yaml none
tox -e run-benches -- 2 "8 16 32 64" benchmark_outputs_2 scenarios-liger.yaml none
tox -e run-benches -- 4 "16 32 64 128" benchmark_outputs_3 scenarios-liger.yaml none
```


## Known Issues

- MixedPrecision `--fp16` or `--bf16` should be used with `fast_lora`.
- `fast_lora` has issues with FSDP V1 with the `peft` style of FSDP wrapping.
* This is because the adapter's forward functions are bypassed in the fused ops.
* For AutoGPTQ/QLoRA this is addressed by distributing the adapters using DDP so they will be unsharded in time for the fused ops.
- `fast_rope_embeddings` does not work with position_ids. Currently `position_ids` are ignored and could give wrong results.
- `fast_rope_embeddings` does not work with `postion_ids`, it seems like HF has depracated passing these ids into the rope embedding methods.
2 changes: 1 addition & 1 deletion plugins/fused-ops-and-kernels/configs/fast_kernels.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ training:
fast_rms_layernorm: True

# fast RoPE embedding triton kernels
fast_rope_embeddings: True
fast_rope_embeddings: True
25 changes: 25 additions & 0 deletions plugins/fused-ops-and-kernels/configs/fast_kernels_liger.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
training:

fused_ops_and_kernels:

# if under training stanza, then putting
# base_layer and fused_lora will be a misnomer
# - this should be in peft.quantized
# However, if it is specified, it will still
# be read. This is useful in use cases where
# the yaml is system generated and not shown
# to a user.

# activate various unsloth optimizations
# there are two versions of the plugin
# - the FastKernel version supports individual kernels
# - the FastQuantized version is all-or-nothing

# fast loss triton kernels
fast_loss: fused_ce_liger

# fast rms norm triton kernels
fast_rms_layernorm: True

# fast RoPE embedding triton kernels
fast_rope_embeddings: True
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# PEFT-related acceleration
peft:

# quantization-releated acceleration
# e.g., kernels for quantized base weights
quantization:

fused_ops_and_kernels:

# load unsloth optimizations for these 4bit base layer weights.
# currently only support "auto_gptq" and "bitsandbytes"
base_layer: auto_gptq

# activate various unsloth optimizations
# there are two versions of the plugin
# - the FastKernel version supports individual kernels
# - the FastQuantized version is all-or-nothing


# fused kernels for lora linear layers
fused_lora: True

# fast loss triton kernels
fast_loss: fused_ce_liger

# fast rms norm triton kernels
fast_rsm_layernorm: True

# fast RoPE embedding triton kernels
fast_rope_embeddings: True
8 changes: 8 additions & 0 deletions plugins/fused-ops-and-kernels/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,11 @@ only-include = ["src/fms_acceleration_foak"]

[tool.hatch.build.targets.wheel.sources]
"src" = ""

[tool.black]
force-exclude = '''
/(
.*unsloth.*
| .*liger.*
)/
'''
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@ def register_foak_model_patch_rules(
# maybe this we should define envvars
FILTER_MAP = {
"fused_lora": {"qkvo", "mlp"},
"fast_loss": "cross-ent",
"fast_loss": {
True: "cross-ent",
"fused_ce_liger": "fused-lce",
},
"fast_rms_layernorm": "rms",
"fast_rope_embeddings": "rope",
}
Expand Down Expand Up @@ -109,19 +112,19 @@ def __init__(self, configurations: Dict[str, Dict]):
key="base_layer", values=["auto_gptq", "bitsandbytes"], default="auto_gptq"
)
self.configurations["fused_lora"] = self._check_config_and_maybe_check_values(
key="fused_lora", values=[False, True], default=True
key="fused_lora", values=[False, True], default=False
)
self.configurations["fast_loss"] = self._check_config_and_maybe_check_values(
key="fast_loss", values=[False, True], default=True
key="fast_loss", values=[False, True, "fused_ce_liger"], default=False
)
self.configurations["fast_rms_layernorm"] = (
self._check_config_and_maybe_check_values(
key="fast_rms_layernorm", values=[False, True], default=True
key="fast_rms_layernorm", values=[False, True], default=False
)
)
self.configurations["fast_rope_embeddings"] = (
self._check_config_and_maybe_check_values(
key="fast_rope_embeddings", values=[False, True], default=True
key="fast_rope_embeddings", values=[False, True], default=False
)
)

Expand Down Expand Up @@ -162,6 +165,8 @@ def augmentation(

if k in FILTER_MAP and k not in omitted:
ts = FILTER_MAP[k]
if isinstance(ts, dict) and v in ts:
ts = ts[v]
if isinstance(ts, str):
ts = {ts}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2024 Byron Hsu & Linkedin team. All rights reserved.
#
# BSD 2-CLAUSE LICENSE
# Copyright 2024 LinkedIn Corporation
# All Rights Reserved.
# Redistribution and use in source and binary forms, with or
# without modification, are permitted provided that the following
# conditions are met:
# 1. Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following
# disclaimer in the documentation and/or other materials provided
# with the distribution.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from .fused_linear_cross_entropy_loss import lce_forward
Loading

0 comments on commit 733992a

Please sign in to comment.