Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quickfix: Accelerate YAML and LoRA Fused Ops #92

Merged
merged 11 commits into from
Oct 14, 2024
2 changes: 1 addition & 1 deletion plugins/accelerated-peft/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Plugin | Description | Depends | Loading | Augmentation | Callbacks


### Key Points
- fix upcasting (resulting in slowdown) issue for `bnb` plugin, originally discovered by inventors of [Unsloth](https://unsloth.ai/blog/mistral-benchmark).
- fix upcasting (resulting in slowdown) issue for `bnb` plugin, originally discovered by inventors of [Unsloth](https://unsloth.ai/blog/mistral-benchmark). **NOTE**: we recommend using *mixed precision* when using 4bit quant for better performance, as per our benchmarks.
- `bnb` properly configured to work with FSDP following [this guide](https://huggingface.co/docs/bitsandbytes/main/en/fsdp_qlora).
- `triton_v2` kernels are not yet properly integrated into huggingface optimum.
- `triton_v2` kernels are [the only 4bit kernels that work for training](https://github.com/AutoGPTQ/AutoGPTQ/issues/633).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,11 @@ def augmentation(
train_args: TrainingArguments,
modifiable_args: Tuple[LoraConfig],
):
# - when using our prepare peft, we will enforce the mixed precision settings
assert (
train_args.bf16 is True or train_args.fp16 is True
), f"{self.__class__} requires mixed precision argument `--fp16` or `--bf16`"

(peft_config,) = modifiable_args # unpack modifiable args

# some assertions
Expand Down
8 changes: 5 additions & 3 deletions plugins/fused-ops-and-kernels/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,20 @@ This library contains fused operations and custom kernels, to be expanded over t

Plugin | Description | Depends | Loading | Augmentation | Callbacks
--|--|--|--|--|--
[fast_quantized_peft](./src/fms_accelerate_foak/framework_plugin_fast_quantized_peft.py) | LoRA fused ops, fast cross-entropy, fast rms, fast RoPE | Contains extracted code | | ✅
[fast_quantized_peft](./src/fms_accelerate_foak/framework_plugin_fast_quantized_peft.py) | LoRA fused ops, fast cross-entropy, fast rms, fast RoPE (**Disabled**) | Contains extracted code | | ✅
[fast_kernels](./src/fms_accelerate_foak/framework_plugin_fast_kernels.py) | Enhanced version of `fast_quantized_peft`, also works for full-FT and non-quant peft | Contains extracted code | | ✅

### Supported DataType Settings
**Compatibility Matrix with Mixed Precision**
torch_dtype | Mixed Precision | Full-FT-FOAK | PEFT-FOAK | QPEFT-FOAK
-- | -- | -- | -- | --
FLOAT16 | - | ✗ Not Allowed | ✗| ✗
FLOAT16 | - | **Compatible** | **Compatible** | ✗
FLOAT16 | FP16 | ValueError: <br>Attempting to <br>unscale FP16 gradients. <br>[See here](https://github.com/huggingface/peft/blob/main/docs/source/developer_guides/troubleshooting.md) | **Compatible** | **Compatible**
BFLOAT16 | - | ✗ | ✗ | ✗
BFLOAT16 | - | **Compatible** | **Compatible** | ✗
BFLOAT16 | BF16 | **Compatible** | **Compatible** | [Less Performant](https://github.com/foundation-model-stack/fms-acceleration/issues/84)

NOTE: this chart is also a good reference for supported types, even for the non-FOAK case.

### Code Extracted from Unsloth


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,14 @@ def augmentation(
train_args: TrainingArguments,
modifiable_args: Tuple[LoraConfig],
):
# assert that plugin requires mixed precision to be set
assert (
train_args.bf16 is True or train_args.fp16 is True
), f"{self.__class__} requires mixed precision argument `--fp16` or `--bf16`"
has_quant = getattr(model, "quantization_method", None)

if has_quant:
# - only in the case where quant case, that we enforce the mixed precision settings
# - this is mostly for the fused-loras
assert (
train_args.bf16 is True or train_args.fp16 is True
), f"{self.__class__} requires mixed precision argument `--fp16` or `--bf16`"

# This is designed to be a passthrough if training scenario is
# full finetuning or standard peft, fused-lora rules (only meant for qpeft)
Expand All @@ -138,7 +142,7 @@ def augmentation(

# some logic to omit terms from the filter if logic precludes
omitted = set()
if getattr(model, "quantization_method", None) is None:
if has_quant is None:
# - fused_lora only required for quant-peft
omitted.add("fused_lora")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,11 @@ def forward(ctx, X : torch.Tensor,

e = matmul_lora(X, gateW, gateW_quant, gateA, gateB, gateS, dropout=dropout_gate)
g = matmul_lora(X, upW, upW_quant, upA, upB, upS, dropout=dropout_up)
e += gate_bias
g += up_bias
if gate_bias is not None: e += gate_bias
if up_bias is not None: g += up_bias
h = _forward_function(e, g)
i = matmul_lora(h, downW, downW_quant, downA, downB, downS, dropout=dropout_down)
i += down_bias
if down_bias is not None: i += down_bias

# Extract post-dropout X for use in backward computation
_dropped_X = []
Expand Down Expand Up @@ -261,9 +261,9 @@ def forward(ctx, X : torch.Tensor,
K = matmul_lora(X, KW, KW_quant, KA, KB, KS, dropout=dropout_K)
V = matmul_lora(X, VW, VW_quant, VA, VB, VS, dropout=dropout_V)

Q += Q_bias
K += K_bias
V += V_bias
if Q_bias is not None: Q += Q_bias
if K_bias is not None: K += K_bias
if V_bias is not None: V += V_bias

# Extract post-dropout X for use in backward computation
_dropped_X = []
Expand Down Expand Up @@ -406,7 +406,7 @@ def forward(ctx, X : torch.Tensor,
W, W_quant, bias, A, B, S, dropout_O):
dtype = X.dtype
XW = matmul_lora(X, W, W_quant, A, B, S, dropout=dropout_O)
XW += bias
if bias is not None: XW += bias

# Extract post-dropout X for use in backward computation
if dropout_O is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,8 @@ def forward(
e = matmul_lora(X, gateW, gateA, gateB, gateS, dropout=dropout_gate)
upW = dequant248(up_qweight, up_scales, up_qzeros, up_g_idx, up_bits)
g = matmul_lora(X, upW, upA, upB, upS, dropout=dropout_up)
e += gate_bias
g += up_bias
if gate_bias is not None: e += gate_bias
if up_bias is not None: g += up_bias
# f = torch.nn.functional.silu(e)
# h = f * g
h = swiglu_fg_kernel(e, g)
Expand All @@ -257,7 +257,7 @@ def forward(
down_qweight, down_scales, down_qzeros, down_g_idx, down_bits
)
i = matmul_lora(h, downW, downA, downB, downS, dropout=dropout_down)
i += down_bias
if down_bias is not None: i += down_bias

ctx.custom_saved_tensors = (
gate_qweight,
Expand Down Expand Up @@ -529,9 +529,9 @@ def forward(
K = matmul_lora(X, KW, KA, KB, KS, dropout=dropout_K)
V = matmul_lora(X, VW, VA, VB, VS, dropout=dropout_V)

Q += Q_bias
K += K_bias
V += V_bias
if Q_bias is not None: Q += Q_bias
if K_bias is not None: K += K_bias
if V_bias is not None: V += V_bias

ctx.custom_saved_tensors = (
Q_qweight,
Expand Down Expand Up @@ -774,7 +774,7 @@ def forward(
):
W = dequant248(O_qweight, O_scales, O_qzeros, O_g_idx, O_bits)
XW = matmul_lora(X, W, A, B, S, dropout=dropout_O)
XW += O_bias
if O_bias is not None: XW += O_bias
del W
ctx.custom_saved_tensors = (
O_qweight,
Expand Down Expand Up @@ -843,6 +843,6 @@ def apply_lora_o(self, X):
# added by [email protected]
# this version can be directly patched on the output linear
def apply_lora_o_v2(self, X):
Oqstate, O_bias, OA, OB, OS, dropout = get_lora_parameters(self.o_proj)
Oqstate, O_bias, OA, OB, OS, dropout = get_lora_parameters(self)
O = LoRA_W.apply(X, *unpack_gptqstate(Oqstate), O_bias, OA, OB, OS, dropout)
return O
3 changes: 2 additions & 1 deletion scripts/benchmarks/accelerate.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP

# this controls the FSDP pipelining
fsdp_backward_prefetch_policy: BACKWARD_PRE # set to BACKWARD_PRE for the most time-efficient pipeline
fsdp_backward_prefetch: BACKWARD_PRE # set to BACKWARD_PRE for the most time-efficient pipeline
# but requires the most memory. BACKWARD_POST is the less
# memory intensive option
fsdp_backward_prefetch_policy: BACKWARD_PRE # for backward compatibility accelerate<1.0

# setting this to true will increase forward memory by prefetching the next FSDP all-gather, while performing
# the current forward pass.
Expand Down
28 changes: 17 additions & 11 deletions scripts/benchmarks/compare_with_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,24 +63,30 @@ def compare_results(df, ref, plot_columns, threshold_ratio=0.1):
ref_series = ref[column].fillna(0)
df_series = df[column].fillna(0)
# Extract outliers base on some threshold % difference on referance
ds = abs(df_series - ref_series) / (ref_series + 1e-9)
outliers = ds.index[ds > threshold_ratio].to_list()
cmp = ref_series.to_frame()
cmp['metric'] = column
cmp = cmp.join(df_series.to_frame(), lsuffix='_ref')
cmp = cmp.rename(columns={f'{column}_ref': 'reference', column: 'new'})
cmp['ds'] = cmp.apply(
lambda x: (
abs(x.reference - x.new) / (x.reference + 1e-9)
), axis=1
)
outliers = cmp[cmp.ds > threshold_ratio]
outliers = outliers.drop('ds', axis=1)

plot_chart(
ax,
ref_series,
df_series,
cmp['reference'],
cmp['new'],
title=f"Metric: {column}",
xlabel="Reference",
ylabel="New",
)
charts.append((ax, f"compare-{column}.jpg"))
total_outliers += [
[column, *outlier, ref_series[outlier].item(), df_series[outlier].item()]
for outlier in outliers
]
outliers_df = pd.DataFrame(
total_outliers, columns=["scenario", *df.index.names, "reference", "new"]
)
total_outliers.append(outliers)

outliers_df = pd.concat(total_outliers)
return outliers_df, outliers, charts


Expand Down
7 changes: 6 additions & 1 deletion scripts/benchmarks/display_bench_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@ def main(
df[c] = constant[c]
kept += 1

df = df.reset_index(drop=True).drop("output_dir", axis=1)
df = df.reset_index(drop=True)
try:
df = df.drop("output_dir", axis=1)
except KeyError:
pass # output_dir not found

df.reindex(sorted(df.columns), axis=1).to_csv(output_filename, index=False)
print("***************** Report Created ******************")
print(f"Total lines: '{len(df)}'")
Expand Down
Loading
Loading