Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable MLP Fused Ops if Not SwiGLU, Depracted Fast Quantized Peft Plugin, Update Benchmarks #106

Merged
merged 7 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def get_callbacks_and_ready_for_train(
# the meta device fix for quantized models is since this transformers version
# or if trl is installed then its only for this version
if _transformers_version >= "4.45" and (
not _trl_installed or (_trl_installed and _trl_version >= "0.12")
not _trl_installed or (_trl_installed and _trl_version >= "0.11.4")
):
# guarded
# NOTE: replace this later with a more specific accelerate version check
Expand Down
1 change: 0 additions & 1 deletion plugins/fused-ops-and-kernels/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ This library contains fused operations and custom kernels, to be expanded over t

Plugin | Description | Depends | Loading | Augmentation | Callbacks
--|--|--|--|--|--
[fast_quantized_peft](./src/fms_accelerate_foak/framework_plugin_fast_quantized_peft.py) | LoRA fused ops, fast cross-entropy, fast rms, fast RoPE (**Disabled**) | Contains extracted code | | ✅
[fast_kernels](./src/fms_accelerate_foak/framework_plugin_fast_kernels.py) | Enhanced version of `fast_quantized_peft`, also works for full-FT and non-quant peft | Contains extracted code | | ✅

### Supported DataType Settings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,3 @@

# Local
from .framework_plugin_fast_kernels import FastKernelsAccelerationPlugin
from .framework_plugin_fast_quantized_peft import FastQuantizedPeftAccelerationPlugin
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,21 @@
from fms_acceleration import AccelerationPlugin, AccelerationPluginConfigError
from peft import LoraConfig
from peft.tuners.lora.layer import LoraLayer
from transformers import TrainingArguments
from transformers import PretrainedConfig, TrainingArguments
import torch

# Local
from .framework_plugin_fast_quantized_peft import lora_adapters_switch_ddp_from_fsdp
from .utils import lora_adapters_switch_ddp_from_fsdp
from .models.utils import filter_mp_rules


# consider rewriting register_foak_model_patch_rules into something
# like this also
def register_foak_model_patch_rules2(base_type: str, filter_endswith: Set[str] = None):
def register_foak_model_patch_rules(
base_type: str,
filter_endswith: Set[str] = None,
config: PretrainedConfig = None,
):

# Third Party
from fms_acceleration.model_patcher import ( # pylint: disable=import-outside-toplevel
Expand All @@ -45,20 +50,21 @@ def register_foak_model_patch_rules2(base_type: str, filter_endswith: Set[str] =
mixtral,
)

# create model specific rules
rules = [
*gpt_bigcode.get_mp_rules(base_type),
*granite.get_mp_rules(base_type),
*granite.get_mp_rules(base_type, config),
*granitemoe.get_mp_rules(base_type),
*llama.get_mp_rules(base_type),
*mistral.get_mp_rules(base_type),
*llama.get_mp_rules(base_type, config),
*mistral.get_mp_rules(base_type, config),
*mixtral.get_mp_rules(base_type),
]

if filter_endswith is not None:
# filter rules
rules = [
r for r in rules if any(r.rule_id.endswith(x) for x in filter_endswith)
]
# for filtering rules that apply regardless of model arch
# - this would be useful for implementing switches for
# turning off rules that affect all models
if filter_endswith:
rules = filter_mp_rules(rules, filter_endswith)

for _rule in rules:
ModelPatcher.register(_rule)
Expand Down Expand Up @@ -151,18 +157,22 @@ def augmentation(

terms = set()
for k, v in self.configurations.items():
if isinstance(v, bool) and v is False:
continue

if k in FILTER_MAP and k not in omitted:
ts = FILTER_MAP[k]
if isinstance(ts, str):
ts = {ts}
if isinstance(v, bool) and v is False:
continue

terms.update(ts)

# wrapper function to register foak patches
# - the base layer setting below will be ignored in non quantized-lora settings
register_foak_model_patch_rules2(
base_type=self.configurations["base_layer"], filter_endswith=terms
register_foak_model_patch_rules(
base_type=self.configurations["base_layer"],
filter_endswith=terms,
config=model.config,
)
return model, modifiable_args

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

# Standard
from functools import partial
import warnings

# Third Party
from fms_acceleration.model_patcher import (
Expand All @@ -22,15 +23,24 @@
combine_functions,
combine_triggers,
)
from transformers import PretrainedConfig

# Local
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops
from .utils import (
KEY_MLP,
KEY_O,
KEY_QKV,
build_lora_fused_ops,
filter_mp_rules,
get_hidden_activation_fn_key,
trigger_fused_ops,
)


def get_mp_rules(base_type: str):
def get_mp_rules(base_type: str, config: PretrainedConfig = None):
"""
Function to access all patch rules in this module.
If it is a forward_builder rule with `base_type` in
Expand All @@ -47,7 +57,7 @@ def get_mp_rules(base_type: str):
except ImportError:
return []

return [
rules = [
# TODO: have a generic version of this rule
# - do regex on RMSNorm class name
# - check on the tensors required for fast_rms_layernorm
Expand Down Expand Up @@ -133,3 +143,15 @@ def get_mp_rules(base_type: str):
),
),
]

# perform model specific filtering
act = get_hidden_activation_fn_key(config)
if config and act != "silu":
warnings.warn(
f"Granite Rules: activation is {act}, "
"thus disabling LoRA fused-op for MLP, since only SwiGLU "
"is supported. This only affects quantized-peft."
)
rules = filter_mp_rules(rules, {"mlp"}, drop=True)

return rules
Loading
Loading