Skip to content

Commit

Permalink
Add Benchmarking Compatibility to PaddingFree Plugin (#66)
Browse files Browse the repository at this point in the history
* add benchmarking on orca-math

Signed-off-by: 1000850000 user <[email protected]>

* modifications to address PR changes

Signed-off-by: 1000850000 user <[email protected]>

* additional fixes to scenarios template

Signed-off-by: 1000850000 user <[email protected]>

* Apply suggestions from code review

Co-authored-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: 1000850000 user <[email protected]>

* renamed scenarios template to specify dataset

Signed-off-by: 1000850000 user <[email protected]>

* added orca benchmarks as ref

Signed-off-by: 1000850000 user <[email protected]>

---------

Signed-off-by: 1000850000 user <[email protected]>
Co-authored-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
achew010 and fabianlim authored Aug 19, 2024
1 parent 09a3104 commit 48426a1
Show file tree
Hide file tree
Showing 16 changed files with 579 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,12 @@ def _flash_attention_forward_with_posids(
and sliding_window is not None
and key_states.shape[1] > sliding_window
)
flash_kwargs = (
{"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
)

# set flash_kwargs only if both use_sliding_window=true and sliding window exist
# otherwise, flash_attn takes window_size = -1 as the default
flash_kwargs = {}
if use_sliding_windows and sliding_window:
flash_kwargs = {"window_size": (sliding_window, sliding_window)}

try:
if is_flash_attn_greater_or_equal("2.4.1"):
Expand Down
28 changes: 27 additions & 1 deletion sample-configurations/CONTENTS.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,30 @@ framework_configs:
- shortname: aadp-padding-free
plugins:
- attention-and-distributed-packing
filename: aadp-padding-free-sample-configuration.yaml
filename: aadp-padding-free-sample-configuration.yaml

- shortname: accelerated-peft-bnb-padding-free
plugins:
- accelerated-peft
- attention-and-distributed-packing
filename: accelerated-peft-bnb-nf4-padding-free-sample-configuration.yaml

- shortname: accelerated-peft-autogptq-padding-free
plugins:
- accelerated-peft
- attention-and-distributed-packing
filename: accelerated-peft-autogptq-padding-free-sample-configuration.yaml

- shortname: accelerated-peft-bnb-foak-padding-free
plugins:
- accelerated-peft
- attention-and-distributed-packing
- fused-ops-and-kernels
filename: accelerated-peft-bnb-nf4-foak-padding-free-sample-configuration.yaml

- shortname: accelerated-peft-autogptq-foak-padding-free
plugins:
- accelerated-peft
- attention-and-distributed-packing
- fused-ops-and-kernels
filename: accelerated-peft-autogptq-foak-padding-free-sample-configuration.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# FMS Acceleration Plugin Configuration.
#
# Each stanza incorporates various configurations for
# different fine-tuning / training tasks.
plugins:
# Configurations to accelerate data packing/padding in training
training:

# attention module configurations
# e.g. padding-free modifications to attention layer
attention:

# this controls the confgurations for padding free computation of flash attention
padding_free:
method: huggingface
peft:

# quantization-releated acceleration
# e.g., kernels for quantized base weights
quantization:

# AutoGPTQ quantized base weights.
auto_gptq:

# Kernel to be used for GPTQ linear laeyer
# NOTE: Not all kernels are suitable for PEFT training; need to use
# kernels that support autograd forward / backward. The best
# recommendation at the moment is "triton_v2".
kernel: triton_v2

# If true, then will already expect quantized checkpoint
# passed into TrainingArguments.model_name_or_path
from_quantized: true

# Setting to false, will create GPTQ-LORA using the local autogptq package.
# if true, will create legacy implementation of GPTQ-LORA using external
# `auto_gptq`. Refer to README for installation instructions
use_external_lib: false
fused_ops_and_kernels:

# load unsloth optimizations for these 4bit base layer weights.
# currently only support "auto_gptq" and "bitsandbytes"
base_layer: auto_gptq

# activate various unsloth optimizations
# NOTE: currently supports only all-or-nothing.

# fused kernels for lora linear layers
fused_lora: true

# fast loss triton kernels
fast_loss: true

# fast rms norm triton kernels
fast_rsm_layernorm: true

# fast RoPE embedding triton kernels
fast_rope_embeddings: true
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# FMS Acceleration Plugin Configuration.
#
# Each stanza incorporates various configurations for
# different fine-tuning / training tasks.
plugins:
# Configurations to accelerate data packing/padding in training
training:

# attention module configurations
# e.g. padding-free modifications to attention layer
attention:

# this controls the confgurations for padding free computation of flash attention
padding_free:
method: huggingface
peft:

# quantization-releated acceleration
# e.g., kernels for quantized base weights
quantization:

# AutoGPTQ quantized base weights.
auto_gptq:

# Kernel to be used for GPTQ linear laeyer
# NOTE: Not all kernels are suitable for PEFT training; need to use
# kernels that support autograd forward / backward. The best
# recommendation at the moment is "triton_v2".
kernel: triton_v2

# If true, then will already expect quantized checkpoint
# passed into TrainingArguments.model_name_or_path
from_quantized: true

# Setting to false, will create GPTQ-LORA using the local autogptq package.
# if true, will create legacy implementation of GPTQ-LORA using external
# `auto_gptq`. Refer to README for installation instructions
use_external_lib: false
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# FMS Acceleration Plugin Configuration.
#
# Each stanza incorporates various configurations for
# different fine-tuning / training tasks.
plugins:
# Configurations to accelerate data packing/padding in training
training:

# attention module configurations
# e.g. padding-free modifications to attention layer
attention:

# this controls the confgurations for padding free computation of flash attention
padding_free:
method: huggingface
peft:

# quantization-releated acceleration
# e.g., kernels for quantized base weights
quantization:

# For loading BitsAndBytes quantized layers
# to serve as 4bit base-weights for LoRA PEFT-tuning.
# NOTE: currently AutoGPTQ is not properly integrated into huggingface /
# bitsandbytes, thus recommended quant_type to be either "nf4"
# or "fp4".
# bitsandbytes:
bitsandbytes:
quant_type: nf4

# If True, then no get_peft_model and prepare_model_for_kbit_training
# will be called.
no_peft_model: false
fused_ops_and_kernels:

# load unsloth optimizations for these 4bit base layer weights.
# currently only support "auto_gptq" and "bitsandbytes"
base_layer: bitsandbytes

# activate various unsloth optimizations
# NOTE: currently supports only all-or-nothing.

# fused kernels for lora linear layers
fused_lora: true

# fast loss triton kernels
fast_loss: true

# fast rms norm triton kernels
fast_rsm_layernorm: true

# fast RoPE embedding triton kernels
fast_rope_embeddings: true
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# FMS Acceleration Plugin Configuration.
#
# Each stanza incorporates various configurations for
# different fine-tuning / training tasks.
plugins:
# Configurations to accelerate data packing/padding in training
training:

# attention module configurations
# e.g. padding-free modifications to attention layer
attention:

# this controls the confgurations for padding free computation of flash attention
padding_free:
method: huggingface
peft:

# quantization-releated acceleration
# e.g., kernels for quantized base weights
quantization:

# For loading BitsAndBytes quantized layers
# to serve as 4bit base-weights for LoRA PEFT-tuning.
# NOTE: currently AutoGPTQ is not properly integrated into huggingface /
# bitsandbytes, thus recommended quant_type to be either "nf4"
# or "fp4".
# bitsandbytes:
bitsandbytes:
quant_type: nf4

# If True, then no get_peft_model and prepare_model_for_kbit_training
# will be called.
no_peft_model: false
35 changes: 26 additions & 9 deletions scripts/benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
KEYWORD_ALLOC_DELTA = "alloc_delta"
HF_ARG_TRAINING_DATA_PATH = "training_data_path"
HF_ARG_RESPONSE_TEMPLATE = "response_template"
HF_ARG_DATASET_TEXT_FIELD = "dataset_text_field"
HF_ARG_SKIP_MEMORY_METRIC = "skip_memory_metrics"
RESULT_FIELD_ALLOCATED_GPU_MEM = "mem_torch_mem_alloc_in_bytes"
RESULT_FIELD_PEAK_ALLOCATED_GPU_MEM = "mem_peak_torch_mem_alloc_in_bytes"
Expand Down Expand Up @@ -164,9 +165,15 @@ def __init__(
input_field: str = "input",
dataset_text_field: str = "output",
chat_template: str = None,
response_template: str = None,
additional_dataset_kwargs: Dict = {},
) -> None:

self.dataset_split = datasets.load_dataset(dataset_name, split=dataset_split)
self.dataset_split = datasets.load_dataset(
dataset_name,
split=dataset_split,
**additional_dataset_kwargs
)

self.kwargs = {
"formatting": formatting,
Expand All @@ -177,6 +184,7 @@ def __init__(
}
self.training_paths = {} # cache to store the training paths
self.data_save_path = data_save_path
self.response_template = response_template

def prepare_dataset(
self,
Expand All @@ -186,6 +194,16 @@ def prepare_dataset(
if model_name in self.training_paths:
return self.training_paths[model_name]

if self.response_template:
if response_template is not None:
warnings.warn(
"Response Template detected in data processing field, "
"overriding response template. "
"*** Old ***\n{response_template}\n"
"*** New ***\n{self.response_template}"
)
response_template = self.response_template

if self.kwargs["tokenize"]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

Expand Down Expand Up @@ -257,8 +275,8 @@ def convert_keyvalue_arguments_to_list(args_dict: Dict):
# otherwise if a regular argument
if val is None:
warnings.warn(
f"Argument '{arg}' is not a true/false argument andhad a 'None' value ",
"and thus will be ignored.",
f"Argument '{arg}' is not a true/false argument and "
"had a 'None' value and thus will be ignored.",
)
continue

Expand Down Expand Up @@ -668,8 +686,10 @@ def prepare_arguments(args, benchmark_dataset: BenchmarkDataset):
print(f"Scenario '{_scn_name}' has matrix '{k}' of len {len(v)}")
scn_factor *= len(v)

# scenario-specific constants should overwrite any similar values in defaults
defaults = {k:v for k, v in defaults.items() if k not in scenario_constants}
# update defaults with scenario constants
constants = {**scenario_constants, **defaults}
constants = {**defaults, **scenario_constants}
# Remove any empty variables and combine matrices to dictionary to cartesian product on
combined_matrices = {**scenario_matrices, **experiment_matrices}
products = ConfigUtils.cartesian_product_on_dict(combined_matrices)
Expand All @@ -684,12 +704,9 @@ def prepare_arguments(args, benchmark_dataset: BenchmarkDataset):
# prepare the dataset
training_path = benchmark_dataset.prepare_dataset(
x["model_name_or_path"],
(
x[HF_ARG_RESPONSE_TEMPLATE]
if HF_ARG_RESPONSE_TEMPLATE in x
else constants.get(HF_ARG_RESPONSE_TEMPLATE)
),
constants.get(HF_ARG_RESPONSE_TEMPLATE),
)

# update
x[HF_ARG_TRAINING_DATA_PATH] = training_path

Expand Down
14 changes: 12 additions & 2 deletions scripts/benchmarks/compare_with_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,19 @@ def main(
help="the acceptable relative difference from the reference value.",
)

parser.add_argument("--indices", default=DEFAULT_INDICES, nargs="+")
parser.add_argument(
"--indices",
default=DEFAULT_INDICES,
nargs="+",
help="list of column names to use as index for merging between old and new benchmark results",
)

parser.add_argument("--plot_columns", default=DEFAULT_PLOT_COLUMNS, nargs="+")
parser.add_argument(
"--plot_columns",
default=DEFAULT_PLOT_COLUMNS,
nargs="+"
help="list of metric names in benchmark results to analyze visually",
)

args = parser.parse_args()
main(
Expand Down
Loading

0 comments on commit 48426a1

Please sign in to comment.