diff --git a/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/flash_attn.py b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/flash_attn.py index 26e26d01..782145cc 100644 --- a/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/flash_attn.py +++ b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/flash_attn.py @@ -138,9 +138,12 @@ def _flash_attention_forward_with_posids( and sliding_window is not None and key_states.shape[1] > sliding_window ) - flash_kwargs = ( - {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} - ) + + # set flash_kwargs only if both use_sliding_window=true and sliding window exist + # otherwise, flash_attn takes window_size = -1 as the default + flash_kwargs = {} + if use_sliding_windows and sliding_window: + flash_kwargs = {"window_size": (sliding_window, sliding_window)} try: if is_flash_attn_greater_or_equal("2.4.1"): diff --git a/sample-configurations/CONTENTS.yaml b/sample-configurations/CONTENTS.yaml index e2eccbc1..fec756de 100644 --- a/sample-configurations/CONTENTS.yaml +++ b/sample-configurations/CONTENTS.yaml @@ -36,4 +36,30 @@ framework_configs: - shortname: aadp-padding-free plugins: - attention-and-distributed-packing - filename: aadp-padding-free-sample-configuration.yaml \ No newline at end of file + filename: aadp-padding-free-sample-configuration.yaml + + - shortname: accelerated-peft-bnb-padding-free + plugins: + - accelerated-peft + - attention-and-distributed-packing + filename: accelerated-peft-bnb-nf4-padding-free-sample-configuration.yaml + + - shortname: accelerated-peft-autogptq-padding-free + plugins: + - accelerated-peft + - attention-and-distributed-packing + filename: accelerated-peft-autogptq-padding-free-sample-configuration.yaml + + - shortname: accelerated-peft-bnb-foak-padding-free + plugins: + - accelerated-peft + - attention-and-distributed-packing + - fused-ops-and-kernels + filename: accelerated-peft-bnb-nf4-foak-padding-free-sample-configuration.yaml + + - shortname: accelerated-peft-autogptq-foak-padding-free + plugins: + - accelerated-peft + - attention-and-distributed-packing + - fused-ops-and-kernels + filename: accelerated-peft-autogptq-foak-padding-free-sample-configuration.yaml \ No newline at end of file diff --git a/sample-configurations/accelerated-peft-autogptq-foak-padding-free-sample-configuration.yaml b/sample-configurations/accelerated-peft-autogptq-foak-padding-free-sample-configuration.yaml new file mode 100644 index 00000000..a331154e --- /dev/null +++ b/sample-configurations/accelerated-peft-autogptq-foak-padding-free-sample-configuration.yaml @@ -0,0 +1,58 @@ +# FMS Acceleration Plugin Configuration. +# +# Each stanza incorporates various configurations for +# different fine-tuning / training tasks. +plugins: + # Configurations to accelerate data packing/padding in training + training: + + # attention module configurations + # e.g. padding-free modifications to attention layer + attention: + + # this controls the confgurations for padding free computation of flash attention + padding_free: + method: huggingface + peft: + + # quantization-releated acceleration + # e.g., kernels for quantized base weights + quantization: + + # AutoGPTQ quantized base weights. + auto_gptq: + + # Kernel to be used for GPTQ linear laeyer + # NOTE: Not all kernels are suitable for PEFT training; need to use + # kernels that support autograd forward / backward. The best + # recommendation at the moment is "triton_v2". + kernel: triton_v2 + + # If true, then will already expect quantized checkpoint + # passed into TrainingArguments.model_name_or_path + from_quantized: true + + # Setting to false, will create GPTQ-LORA using the local autogptq package. + # if true, will create legacy implementation of GPTQ-LORA using external + # `auto_gptq`. Refer to README for installation instructions + use_external_lib: false + fused_ops_and_kernels: + + # load unsloth optimizations for these 4bit base layer weights. + # currently only support "auto_gptq" and "bitsandbytes" + base_layer: auto_gptq + + # activate various unsloth optimizations + # NOTE: currently supports only all-or-nothing. + + # fused kernels for lora linear layers + fused_lora: true + + # fast loss triton kernels + fast_loss: true + + # fast rms norm triton kernels + fast_rsm_layernorm: true + + # fast RoPE embedding triton kernels + fast_rope_embeddings: true diff --git a/sample-configurations/accelerated-peft-autogptq-padding-free-sample-configuration.yaml b/sample-configurations/accelerated-peft-autogptq-padding-free-sample-configuration.yaml new file mode 100644 index 00000000..f3054449 --- /dev/null +++ b/sample-configurations/accelerated-peft-autogptq-padding-free-sample-configuration.yaml @@ -0,0 +1,38 @@ +# FMS Acceleration Plugin Configuration. +# +# Each stanza incorporates various configurations for +# different fine-tuning / training tasks. +plugins: + # Configurations to accelerate data packing/padding in training + training: + + # attention module configurations + # e.g. padding-free modifications to attention layer + attention: + + # this controls the confgurations for padding free computation of flash attention + padding_free: + method: huggingface + peft: + + # quantization-releated acceleration + # e.g., kernels for quantized base weights + quantization: + + # AutoGPTQ quantized base weights. + auto_gptq: + + # Kernel to be used for GPTQ linear laeyer + # NOTE: Not all kernels are suitable for PEFT training; need to use + # kernels that support autograd forward / backward. The best + # recommendation at the moment is "triton_v2". + kernel: triton_v2 + + # If true, then will already expect quantized checkpoint + # passed into TrainingArguments.model_name_or_path + from_quantized: true + + # Setting to false, will create GPTQ-LORA using the local autogptq package. + # if true, will create legacy implementation of GPTQ-LORA using external + # `auto_gptq`. Refer to README for installation instructions + use_external_lib: false diff --git a/sample-configurations/accelerated-peft-bnb-nf4-foak-padding-free-sample-configuration.yaml b/sample-configurations/accelerated-peft-bnb-nf4-foak-padding-free-sample-configuration.yaml new file mode 100644 index 00000000..32d077ae --- /dev/null +++ b/sample-configurations/accelerated-peft-bnb-nf4-foak-padding-free-sample-configuration.yaml @@ -0,0 +1,53 @@ +# FMS Acceleration Plugin Configuration. +# +# Each stanza incorporates various configurations for +# different fine-tuning / training tasks. +plugins: + # Configurations to accelerate data packing/padding in training + training: + + # attention module configurations + # e.g. padding-free modifications to attention layer + attention: + + # this controls the confgurations for padding free computation of flash attention + padding_free: + method: huggingface + peft: + + # quantization-releated acceleration + # e.g., kernels for quantized base weights + quantization: + + # For loading BitsAndBytes quantized layers + # to serve as 4bit base-weights for LoRA PEFT-tuning. + # NOTE: currently AutoGPTQ is not properly integrated into huggingface / + # bitsandbytes, thus recommended quant_type to be either "nf4" + # or "fp4". + # bitsandbytes: + bitsandbytes: + quant_type: nf4 + + # If True, then no get_peft_model and prepare_model_for_kbit_training + # will be called. + no_peft_model: false + fused_ops_and_kernels: + + # load unsloth optimizations for these 4bit base layer weights. + # currently only support "auto_gptq" and "bitsandbytes" + base_layer: bitsandbytes + + # activate various unsloth optimizations + # NOTE: currently supports only all-or-nothing. + + # fused kernels for lora linear layers + fused_lora: true + + # fast loss triton kernels + fast_loss: true + + # fast rms norm triton kernels + fast_rsm_layernorm: true + + # fast RoPE embedding triton kernels + fast_rope_embeddings: true diff --git a/sample-configurations/accelerated-peft-bnb-nf4-padding-free-sample-configuration.yaml b/sample-configurations/accelerated-peft-bnb-nf4-padding-free-sample-configuration.yaml new file mode 100644 index 00000000..bad3f571 --- /dev/null +++ b/sample-configurations/accelerated-peft-bnb-nf4-padding-free-sample-configuration.yaml @@ -0,0 +1,33 @@ +# FMS Acceleration Plugin Configuration. +# +# Each stanza incorporates various configurations for +# different fine-tuning / training tasks. +plugins: + # Configurations to accelerate data packing/padding in training + training: + + # attention module configurations + # e.g. padding-free modifications to attention layer + attention: + + # this controls the confgurations for padding free computation of flash attention + padding_free: + method: huggingface + peft: + + # quantization-releated acceleration + # e.g., kernels for quantized base weights + quantization: + + # For loading BitsAndBytes quantized layers + # to serve as 4bit base-weights for LoRA PEFT-tuning. + # NOTE: currently AutoGPTQ is not properly integrated into huggingface / + # bitsandbytes, thus recommended quant_type to be either "nf4" + # or "fp4". + # bitsandbytes: + bitsandbytes: + quant_type: nf4 + + # If True, then no get_peft_model and prepare_model_for_kbit_training + # will be called. + no_peft_model: false diff --git a/scripts/benchmarks/benchmark.py b/scripts/benchmarks/benchmark.py index 91d52601..a43f34c8 100644 --- a/scripts/benchmarks/benchmark.py +++ b/scripts/benchmarks/benchmark.py @@ -79,6 +79,7 @@ KEYWORD_ALLOC_DELTA = "alloc_delta" HF_ARG_TRAINING_DATA_PATH = "training_data_path" HF_ARG_RESPONSE_TEMPLATE = "response_template" +HF_ARG_DATASET_TEXT_FIELD = "dataset_text_field" HF_ARG_SKIP_MEMORY_METRIC = "skip_memory_metrics" RESULT_FIELD_ALLOCATED_GPU_MEM = "mem_torch_mem_alloc_in_bytes" RESULT_FIELD_PEAK_ALLOCATED_GPU_MEM = "mem_peak_torch_mem_alloc_in_bytes" @@ -164,9 +165,15 @@ def __init__( input_field: str = "input", dataset_text_field: str = "output", chat_template: str = None, + response_template: str = None, + additional_dataset_kwargs: Dict = {}, ) -> None: - self.dataset_split = datasets.load_dataset(dataset_name, split=dataset_split) + self.dataset_split = datasets.load_dataset( + dataset_name, + split=dataset_split, + **additional_dataset_kwargs + ) self.kwargs = { "formatting": formatting, @@ -177,6 +184,7 @@ def __init__( } self.training_paths = {} # cache to store the training paths self.data_save_path = data_save_path + self.response_template = response_template def prepare_dataset( self, @@ -186,6 +194,16 @@ def prepare_dataset( if model_name in self.training_paths: return self.training_paths[model_name] + if self.response_template: + if response_template is not None: + warnings.warn( + "Response Template detected in data processing field, " + "overriding response template. " + "*** Old ***\n{response_template}\n" + "*** New ***\n{self.response_template}" + ) + response_template = self.response_template + if self.kwargs["tokenize"]: tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -257,8 +275,8 @@ def convert_keyvalue_arguments_to_list(args_dict: Dict): # otherwise if a regular argument if val is None: warnings.warn( - f"Argument '{arg}' is not a true/false argument andhad a 'None' value ", - "and thus will be ignored.", + f"Argument '{arg}' is not a true/false argument and " + "had a 'None' value and thus will be ignored.", ) continue @@ -668,8 +686,10 @@ def prepare_arguments(args, benchmark_dataset: BenchmarkDataset): print(f"Scenario '{_scn_name}' has matrix '{k}' of len {len(v)}") scn_factor *= len(v) + # scenario-specific constants should overwrite any similar values in defaults + defaults = {k:v for k, v in defaults.items() if k not in scenario_constants} # update defaults with scenario constants - constants = {**scenario_constants, **defaults} + constants = {**defaults, **scenario_constants} # Remove any empty variables and combine matrices to dictionary to cartesian product on combined_matrices = {**scenario_matrices, **experiment_matrices} products = ConfigUtils.cartesian_product_on_dict(combined_matrices) @@ -684,12 +704,9 @@ def prepare_arguments(args, benchmark_dataset: BenchmarkDataset): # prepare the dataset training_path = benchmark_dataset.prepare_dataset( x["model_name_or_path"], - ( - x[HF_ARG_RESPONSE_TEMPLATE] - if HF_ARG_RESPONSE_TEMPLATE in x - else constants.get(HF_ARG_RESPONSE_TEMPLATE) - ), + constants.get(HF_ARG_RESPONSE_TEMPLATE), ) + # update x[HF_ARG_TRAINING_DATA_PATH] = training_path diff --git a/scripts/benchmarks/compare_with_reference.py b/scripts/benchmarks/compare_with_reference.py index 953ead5c..71c0e57a 100644 --- a/scripts/benchmarks/compare_with_reference.py +++ b/scripts/benchmarks/compare_with_reference.py @@ -172,9 +172,19 @@ def main( help="the acceptable relative difference from the reference value.", ) - parser.add_argument("--indices", default=DEFAULT_INDICES, nargs="+") + parser.add_argument( + "--indices", + default=DEFAULT_INDICES, + nargs="+", + help="list of column names to use as index for merging between old and new benchmark results", + ) - parser.add_argument("--plot_columns", default=DEFAULT_PLOT_COLUMNS, nargs="+") + parser.add_argument( + "--plot_columns", + default=DEFAULT_PLOT_COLUMNS, + nargs="+" + help="list of metric names in benchmark results to analyze visually", + ) args = parser.parse_args() main( diff --git a/scripts/benchmarks/refs_orca/benchmarks.csv b/scripts/benchmarks/refs_orca/benchmarks.csv new file mode 100644 index 00000000..c055de88 --- /dev/null +++ b/scripts/benchmarks/refs_orca/benchmarks.csv @@ -0,0 +1,41 @@ +fp16,framework_config,learning_rate,lora_alpha,lora_dropout,mem_nvidia_mem_reserved,mem_peak_torch_mem_alloc_in_bytes,mem_torch_mem_alloc_in_bytes,model_name_or_path,num_gpus,peft_method,per_device_train_batch_size,r,target_modules,torch_dtype,train_loss,train_runtime,train_samples_per_second,train_steps_per_second,train_tokens_per_second +,none,2e-5,,,77527.0,72468863488,43468103168,mistralai/Mistral-7B-v0.1,1,,4,,,float16,0.36955297470092774,362.6052,5.516,0.689,2383.937 +,none,2e-5,,,54982.0,38899449344,28984259072,mistralai/Mistral-7B-v0.1,2,,2,,,float16,0.3714465112686157,320.0386,6.249,0.781,1102.661 +,none,2e-5,,,76911.0,72465051648,43467904512,mistralai/Mistral-7B-v0.1,1,,8,,,float16,0.3604792728424072,356.7945,5.605,0.35,2933.722 +,none,2e-5,,,58821.0,42812754432,28984268288,mistralai/Mistral-7B-v0.1,2,,4,,,float16,0.3584610557556152,231.2066,8.65,0.541,1930.1 +,aadp-padding-free,2e-5,,,71665.0,72470858752,43468621312,mistralai/Mistral-7B-v0.1,1,,4,,,float16,0.3722459201812744,291.8317,6.853,0.857,1874.076 +,aadp-padding-free,2e-5,,,53231.0,38670022656,28984259072,mistralai/Mistral-7B-v0.1,2,,2,,,float16,0.37080725002288817,305.9316,6.537,0.817,886.777 +,aadp-padding-free,2e-5,,,75107.0,72452382720,43467883008,mistralai/Mistral-7B-v0.1,1,,8,,,float16,0.365696475982666,213.0649,9.387,0.587,2566.894 +,aadp-padding-free,2e-5,,,54301.0,39207462400,28984429056,mistralai/Mistral-7B-v0.1,2,,4,,,float16,0.36816050434112546,176.9167,11.305,0.707,1584.933 +True,accelerated-peft-bnb,2e-4,16,0.1,13927.0,10322870272,4306494976,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.3517864627838135,341.8255,5.851,0.731,2528.858 +True,accelerated-peft-bnb,2e-4,16,0.1,7789.0,6435678720,2244413952,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.3503192720413208,767.1497,2.607,0.326,460.007 +True,accelerated-peft-bnb,2e-4,16,0.1,23271.0,16366288896,4306296320,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.3807634754180908,317.0715,6.308,0.394,3301.262 +True,accelerated-peft-bnb,2e-4,16,0.1,11944.0,9927788032,2244423168,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.3727553825378418,392.5799,5.095,0.318,1136.716 +True,accelerated-peft-bnb-padding-free,2e-4,16,0.1,7905.0,6284557312,4306259456,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.3511854820251465,342.8506,5.833,0.729,1595.199 +True,accelerated-peft-bnb-padding-free,2e-4,16,0.1,6277.0,4901515776,2244413952,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.353749755859375,772.1829,2.59,0.324,351.333 +True,accelerated-peft-bnb-padding-free,2e-4,16,0.1,11113.0,6883823104,4307159552,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.3688219118118286,210.9553,9.481,0.593,2592.564 +True,accelerated-peft-bnb-padding-free,2e-4,16,0.1,6972.0,5302929408,2244420096,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.37023702239990236,398.2732,5.022,0.314,704.042 +True,accelerated-peft-bnb-foak,2e-4,16,0.1,12751.0,9080075776,4306494976,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.3516845245361328,314.1839,6.366,0.796,2751.344 +True,accelerated-peft-bnb-foak,2e-4,16,0.1,7943.0,6377985024,2244413952,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.3619532585144043,348.4855,5.739,0.717,1012.65 +True,accelerated-peft-bnb-foak,2e-4,16,0.1,21411.0,13907961856,4306296320,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.36872802734375,272.9149,7.328,0.458,3835.394 +True,accelerated-peft-bnb-foak,2e-4,16,0.1,11558.0,9761232384,2244423168,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.3788349189758301,190.089,10.521,0.658,2347.595 +True,accelerated-peft-bnb-foak-padding-free,2e-4,16,0.1,7039.0,5898254848,4306259456,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.5028951711654663,245.3348,8.152,1.019,2229.26 +True,accelerated-peft-bnb-foak-padding-free,2e-4,16,0.1,6407.0,4856763904,2244413952,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.4248905544281006,307.9769,6.494,0.812,880.888 +True,accelerated-peft-bnb-foak-padding-free,2e-4,16,0.1,9223.0,6381574656,4306274816,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.7981027908325196,167.693,11.927,0.745,3261.406 +True,accelerated-peft-bnb-foak-padding-free,2e-4,16,0.1,6950.0,5258480640,2244616704,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.6221811065673828,155.5657,12.856,0.804,1802.46 +True,accelerated-peft-autogptq,2e-4,16,0.1,13269.0,10353179648,4336804352,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.3607017431259155,312.6812,6.396,0.8,2764.567 +True,accelerated-peft-autogptq,2e-4,16,0.1,8276.0,6452275200,2261091840,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.3964675521850586,767.574,2.606,0.326,459.752 +True,accelerated-peft-autogptq,2e-4,16,0.1,23229.0,16396598272,4336605696,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.37328079795837404,305.2517,6.552,0.409,3429.091 +True,accelerated-peft-autogptq,2e-4,16,0.1,12347.0,9945317888,2261101056,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.37490358924865724,388.2367,5.151,0.322,1149.433 +True,accelerated-peft-autogptq-padding-free,2e-4,16,0.1,7307.0,6314883072,4336585216,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.4573829107284546,304.2352,6.574,0.822,1797.672 +True,accelerated-peft-autogptq-padding-free,2e-4,16,0.1,6791.0,4916809216,2261091840,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.4595833406448364,751.4294,2.662,0.333,361.036 +True,accelerated-peft-autogptq-padding-free,2e-4,16,0.1,11083.0,6914132480,4337468928,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.5659423522949218,194.3331,10.292,0.643,2814.318 +True,accelerated-peft-autogptq-padding-free,2e-4,16,0.1,7347.0,5320475648,2261114368,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.38012803745269774,386.4168,5.176,0.323,725.644 +True,accelerated-peft-autogptq-foak,2e-4,16,0.1,12825.0,9110761472,4337180672,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.47571081733703613,287.4231,6.958,0.87,3007.51 +True,accelerated-peft-autogptq-foak,2e-4,16,0.1,8359.0,6395514880,2261091840,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.42341567993164064,359.3769,5.565,0.696,981.961 +True,accelerated-peft-autogptq-foak,2e-4,16,0.1,21441.0,13938271232,4336605696,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.3734574165344238,260.8468,7.667,0.479,4012.837 +True,accelerated-peft-autogptq-foak,2e-4,16,0.1,12080.0,9778762240,2261101056,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.4276853837966919,186.9887,10.696,0.668,2386.519 +True,accelerated-peft-autogptq-foak-padding-free,2e-4,16,0.1,6895.0,5930145792,4336568832,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.5187568559646606,225.1688,8.882,1.11,2428.911 +True,accelerated-peft-autogptq-foak-padding-free,2e-4,16,0.1,6816.0,4874150400,2261091840,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.43749408531188966,311.8489,6.413,0.802,869.95 +True,accelerated-peft-autogptq-foak-padding-free,2e-4,16,0.1,9135.0,6411245056,4337468928,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.7783932838439941,157.2511,12.719,0.795,3477.972 +True,accelerated-peft-autogptq-foak-padding-free,2e-4,16,0.1,7477.0,5276182528,2261097984,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.6172291069030762,158.2566,12.638,0.79,1771.812 diff --git a/scripts/benchmarks/refs_orca/requirements.txt b/scripts/benchmarks/refs_orca/requirements.txt new file mode 100644 index 00000000..b021ef7c --- /dev/null +++ b/scripts/benchmarks/refs_orca/requirements.txt @@ -0,0 +1,125 @@ +accelerate==0.33.0 +aiohappyeyeballs==2.3.5 +aiohttp==3.10.1 +aiosignal==1.3.1 +asttokens==2.4.1 +async-timeout==4.0.3 +attrs==24.2.0 +auto_gptq @ git+https://github.com/AutoGPTQ/AutoGPTQ.git@caf343b1826301c15f90e2e119cabd0347acfcdf +bitsandbytes==0.43.3 +certifi==2024.7.4 +charset-normalizer==3.3.2 +coloredlogs==15.0.1 +comm==0.2.2 +contourpy==1.2.1 +cramjam==2.8.3 +cycler==0.12.1 +datasets==2.20.0 +debugpy==1.8.5 +decorator==5.1.1 +dill==0.3.8 +docstring_parser==0.16 +einops==0.8.0 +exceptiongroup==1.2.2 +executing==2.0.1 +fastparquet==2024.5.0 +filelock==3.15.4 +fire==0.6.0 +flash-attn==2.6.3 +-e git+https://github.com/foundation-model-stack/fms-acceleration.git@0fe0867656a01c9e030d77d8007c70fa775e5668#egg=fms_acceleration&subdirectory=plugins/framework +-e git+https://github.com/foundation-model-stack/fms-acceleration.git@0fe0867656a01c9e030d77d8007c70fa775e5668#egg=fms_acceleration_aadp&subdirectory=plugins/attention-and-distributed-packing +-e git+https://github.com/foundation-model-stack/fms-acceleration.git@0fe0867656a01c9e030d77d8007c70fa775e5668#egg=fms_acceleration_foak&subdirectory=plugins/fused-ops-and-kernels +-e git+https://github.com/foundation-model-stack/fms-acceleration.git@0fe0867656a01c9e030d77d8007c70fa775e5668#egg=fms_acceleration_peft&subdirectory=plugins/accelerated-peft +-e git+https://github.com/foundation-model-stack/fms-hf-tuning.git@a8ab68ffaa0d3b49aeb6753bccfdf807672eba69#egg=fms_hf_tuning +fonttools==4.53.1 +frozenlist==1.4.1 +fsspec==2024.5.0 +gekko==1.2.1 +huggingface-hub==0.24.5 +humanfriendly==10.0 +idna==3.7 +ipykernel==6.29.5 +ipython==8.26.0 +jedi==0.19.1 +Jinja2==3.1.4 +joblib==1.4.2 +jupyter_client==8.6.2 +jupyter_core==5.7.2 +kiwisolver==1.4.5 +markdown-it-py==3.0.0 +MarkupSafe==2.1.5 +matplotlib==3.9.1.post1 +matplotlib-inline==0.1.7 +mdurl==0.1.2 +mpmath==1.3.0 +multidict==6.0.5 +multiprocess==0.70.16 +nest-asyncio==1.6.0 +networkx==3.3 +ninja==1.11.1.1 +numpy==1.26.4 +nvidia-cublas-cu12==12.1.3.1 +nvidia-cuda-cupti-cu12==12.1.105 +nvidia-cuda-nvrtc-cu12==12.1.105 +nvidia-cuda-runtime-cu12==12.1.105 +nvidia-cudnn-cu12==9.1.0.70 +nvidia-cufft-cu12==11.0.2.54 +nvidia-curand-cu12==10.3.2.106 +nvidia-cusolver-cu12==11.4.5.107 +nvidia-cusparse-cu12==12.1.0.106 +nvidia-nccl-cu12==2.20.5 +nvidia-nvjitlink-cu12==12.6.20 +nvidia-nvtx-cu12==12.1.105 +packaging==24.1 +pandas==2.2.2 +parso==0.8.4 +peft==0.12.0 +pexpect==4.9.0 +pillow==10.4.0 +platformdirs==4.2.2 +prompt_toolkit==3.0.47 +protobuf==5.27.3 +psutil==6.0.0 +ptyprocess==0.7.0 +pure_eval==0.2.3 +pyarrow==17.0.0 +pyarrow-hotfix==0.6 +Pygments==2.18.0 +pyparsing==3.1.2 +python-dateutil==2.9.0.post0 +pytz==2024.1 +PyYAML==6.0.2 +pyzmq==26.1.0 +regex==2024.7.24 +requests==2.32.3 +rich==13.7.1 +rouge==1.0.1 +ruamel.yaml==0.18.6 +ruamel.yaml.clib==0.2.8 +safetensors==0.4.4 +scikit-learn==1.5.1 +scipy==1.14.0 +sentencepiece==0.2.0 +shtab==1.7.1 +simpleeval==0.9.13 +six==1.16.0 +stack-data==0.6.3 +sympy==1.13.1 +termcolor==2.4.0 +threadpoolctl==3.5.0 +tokenizers==0.19.1 +torch==2.4.0 +tornado==6.4.1 +tqdm==4.66.5 +traitlets==5.14.3 +transformers==4.42.4 +triton==3.0.0 +trl==0.9.6 +typing_extensions==4.12.2 +tyro==0.8.5 +tzdata==2024.1 +urllib3==2.2.2 +wcwidth==0.2.13 +websockets==12.0 +xxhash==3.4.1 +yarl==1.9.4 diff --git a/scripts/benchmarks/scenarios-orca.yaml b/scripts/benchmarks/scenarios-orca.yaml new file mode 100644 index 00000000..5c435852 --- /dev/null +++ b/scripts/benchmarks/scenarios-orca.yaml @@ -0,0 +1,109 @@ +# This file holds a sample full-finetuning scenario and +# demonstrates various pretokenization scenarios + +# the data_processing stanza is optional +# - if it is missing, then the defaults is to use alpaca +# with instruct formatting and no tokenization + +# - this is an older style method which does not rely on +# chat templates, this will also do instruct formatting +# - but if tokenize = True, this works only if +# sft_trainer accepts pretokenized dataset +# data_processing: +# dataset_name: yahma/alpaca-cleaned +# formatting: "instruct" +# tokenize: True +# input_field: input + +# - this is the new style, with the chat templates for formatting +# - this is the best approach to keep things flexible and +# allows to configure many different datasets +# - there is an option of setting tokenize is True or False + +data_processing: + dataset_name: microsoft/orca-math-word-problems-200k + chat_template: | + {%- for message in messages %} + USER: + {{ message['question'] }} + + ASSISTANT: + {{ message['answer'] }} + {%- endfor %} + dataset_split: "train[:2000]" + tokenize: True + response_template: "\n\nASSISTANT:" + +# scenarios +scenarios: + - name: full-finetuning + arguments: + learning_rate: 2e-5 + torch_dtype: float16 + gradient_accumulation_steps: 2 + max_steps: null + packing: False + model_name_or_path: + - 'mistralai/Mistral-7B-v0.1' + response_template: null + dataset_text_field: null + + - name: padding-free + framework_config: + - aadp-padding-free + arguments: + learning_rate: 2e-5 + torch_dtype: float16 + gradient_accumulation_steps: 2 + max_steps: null + packing: False + model_name_or_path: + - 'mistralai/Mistral-7B-v0.1' + response_template: null + dataset_text_field: null + + - name: accelerated-peft-bnb + framework_config: + - accelerated-peft-bnb + - accelerated-peft-bnb-padding-free + - accelerated-peft-bnb-foak + - accelerated-peft-bnb-foak-padding-free + arguments: + fp16: True + learning_rate: 2e-4 + torch_dtype: float16 + peft_method: lora + r: 16 + lora_alpha: 16 + lora_dropout: 0.1 + target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"] + max_steps: null + gradient_accumulation_steps: 2 + packing: False + model_name_or_path: + - 'mistralai/Mistral-7B-v0.1' + response_template: null + dataset_text_field: null + + - name: accelerated-peft-gptq + framework_config: + - accelerated-peft-autogptq + - accelerated-peft-autogptq-padding-free + - accelerated-peft-autogptq-foak + - accelerated-peft-autogptq-foak-padding-free + arguments: + learning_rate: 2e-4 + fp16: True + torch_dtype: float16 + peft_method: lora + r: 16 + lora_alpha: 16 + lora_dropout: 0.1 + target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"] + max_steps: null + gradient_accumulation_steps: 2 + packing: False + model_name_or_path: + - 'TheBloke/Mistral-7B-v0.1-GPTQ' + response_template: null + dataset_text_field: null diff --git a/scripts/benchmarks/scenarios-pretok.yaml b/scripts/benchmarks/scenarios-pretok.yaml deleted file mode 100644 index b7c9a442..00000000 --- a/scripts/benchmarks/scenarios-pretok.yaml +++ /dev/null @@ -1,62 +0,0 @@ -# This file holds a sample full-finetuning scenario and -# demonstrates various pretokenization scenarios - -# the data_processing stanza is optional -# - if it is missing, then the defaults is to use alpaca -# with instruct formatting and no tokenization - -# - this is an older style method which does not rely on -# chat templates, this will also do instruct formatting -# - but if tokenize = True, this works only if -# sft_trainer accepts pretokenized dataset -# data_processing: -# dataset_name: yahma/alpaca-cleaned -# formatting: "instruct" -# tokenize: True -# input_field: input - -# - this is the new style, with the chat templates for formatting -# - this is the best approach to keep things flexible and -# allows to configure many different datasets -# - there is an option of setting tokenize is True or False -data_processing: - dataset_name: yahma/alpaca-cleaned - chat_template: | - {%- for message in messages %} - {% if message['input'] != '' %} - Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. - - {% else %} - Below is an instruction that describes a task. Write a response that appropriately completes the request. - - {% endif %} - ### Instruction: - {{ message['instruction'] }} - - {% if message['input'] != '' %} - ### Input: - {{ message['input'] }} - - {% endif %} - ### Response: - {{ message['output'] + eos_token }} - {% endfor %} - tokenize: True - -# scenarios -scenarios: - - name: full-finetuning - arguments: - learning_rate: 2e-5 - model_name_or_path: - - 'mistralai/Mistral-7B-v0.1' - torch_dtype: float16 - - - name: padding-free - framework_config: - - ilab-padding-free - arguments: - learning_rate: 2e-5 - model_name_or_path: - - 'mistralai/Mistral-7B-v0.1' - torch_dtype: float16 \ No newline at end of file diff --git a/scripts/benchmarks/scenarios.yaml b/scripts/benchmarks/scenarios.yaml index ecf2ec8c..2eb22872 100644 --- a/scripts/benchmarks/scenarios.yaml +++ b/scripts/benchmarks/scenarios.yaml @@ -8,6 +8,33 @@ # multiple arguments. # - So anything that is critical for the scenario MUST be specified here # and not in the defaults, e.g. fp16 + +# This stanza will be used in future to replace the custom processing functions in data_processing.py +# data_processing: +# dataset_name: yahma/alpaca-cleaned +# chat_template: | +# {%- for message in messages %} +# {% if message['input'] != '' %} +# Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. + +# {% else %} +# Below is an instruction that describes a task. Write a response that appropriately completes the request. + +# {% endif %} +# ### Instruction: +# {{ message['instruction'] }} + +# {% if message['input'] != '' %} +# ### Input: +# {{ message['input'] }} + +# {% endif %} +# ### Response: +# {{ message['output'] + eos_token }} +# {% endfor %} +# tokenize: True + + scenarios: - name: full-finetuning arguments: diff --git a/scripts/generate_sample_configurations.py b/scripts/generate_sample_configurations.py index 3dd80b92..b778c3c7 100644 --- a/scripts/generate_sample_configurations.py +++ b/scripts/generate_sample_configurations.py @@ -21,6 +21,7 @@ import logging import os import re +from copy import deepcopy # Third Party from ruamel.yaml import YAML @@ -182,6 +183,10 @@ def read_configuration(path: str) -> Dict: ("accelerated-peft-autogptq-foak", (KEY_AUTO_GPTQ, KEY_AUTO_GPTQ_FOAK)), ("accelerated-peft-bnb-nf4-foak", (KEY_BNB_NF4, KEY_BNB_NF4_FOAK)), ("aadp-padding-free", (KEY_AADP_PADDING_FREE,)), + ("accelerated-peft-autogptq-padding-free", (KEY_AADP_PADDING_FREE,KEY_AUTO_GPTQ)), + ("accelerated-peft-bnb-nf4-padding-free", (KEY_AADP_PADDING_FREE,KEY_BNB_NF4)), + ("accelerated-peft-autogptq-foak-padding-free", (KEY_AADP_PADDING_FREE,KEY_AUTO_GPTQ, KEY_AUTO_GPTQ_FOAK)), + ("accelerated-peft-bnb-nf4-foak-padding-free", (KEY_AADP_PADDING_FREE,KEY_BNB_NF4, KEY_BNB_NF4_FOAK)), ] @@ -265,9 +270,9 @@ def _merge(result: Dict, new_contents: Dict): # now merge contents in CONFIGURATIONS to form the final # sample configuration for combi_tag, combi in COMBINATIONS: - # merging the configuration contents for this particular combination - config = merge_configs([CONFIGURATIONS[tag] for tag in combi]) + # if keys are not the same separate the merges + config = merge_configs([deepcopy(CONFIGURATIONS[tag]) for tag in combi]) indent_yaml(config) # add the indent # writing the configuration contents diff --git a/scripts/run_benchmarks.sh b/scripts/run_benchmarks.sh index 5fb83b99..71c68fc0 100644 --- a/scripts/run_benchmarks.sh +++ b/scripts/run_benchmarks.sh @@ -105,6 +105,7 @@ if [ ! "$NO_OVERWRITE" = "true" ]; then fi # run the bench +PYTHONPATH=. \ python $WORKING_DIR/benchmark.py \ --num_gpus $NUM_GPUS_MATRIX \ --scenarios_config_path $SCENARIOS_CONFIG \ @@ -137,9 +138,18 @@ PYTHONPATH=. \ 'error_messages' \ 'acceleration_framework_config_file' -if [ "$DRY_RUN" = "true" ]; then - echo "DRY_RUN=True, will skip compare with reference logic" -else - PYTHONPATH=. \ - python $WORKING_DIR/compare_with_reference.py --result_dir $RESULT_DIR -fi + +# For every new benchmark run, it is good practice to perform a regression check +# against a previous known set of benchmark results. This repo provides a convenient comparison +# tool that analyses the differences of metrics like loss and throughput between an old and new set +# of benchmark results. +# To use this tool simply run the following python command +# PYTHONPATH=. \ +# python $WORKING_DIR/compare_with_reference.py +# The following arguments can be used to further configure the analysis, otherwise it uses default values +# arguments: +# --result_dir +# --reference_benchmark_filepath +# --threshold_ratio +# --indices +# --plot_columns \ No newline at end of file diff --git a/tox.ini b/tox.ini index 5c7c596f..f512639c 100644 --- a/tox.ini +++ b/tox.ini @@ -38,6 +38,7 @@ commands = # NOTE: when there are more plugins install here python -m fms_acceleration.cli install -e {toxinidir}/plugins/accelerated-peft python -m fms_acceleration.cli install -e {toxinidir}/plugins/fused-ops-and-kernels + python -m fms_acceleration.cli install -e {toxinidir}/plugins/attention_and_distributed_packing # run the benchmark script bash scripts/run_benchmarks.sh {posargs:"1 2" benchmark_outputs}