From 87455e7f3228972416904f2a23cc1b67a1d2d855 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 23 Jul 2024 01:41:11 -0400 Subject: [PATCH] swaps to use newer sample packing for mistral (#1773) * swaps to use newer sample packing for mistral * fix multipack patch test * patch the common fa utils * update for refactor of flash attn unpad * remove un-needed drop attn mask for mistral * bump transformers to main to pick up latest mistral fix for 12b and refactor of fa2 * update test --- requirements.txt | 2 +- .../monkeypatch/mistral_attn_hijack_flash.py | 10 +++ src/axolotl/monkeypatch/multipack.py | 32 +++++--- src/axolotl/monkeypatch/unsloth_.py | 79 ++++++++++--------- src/axolotl/utils/models.py | 21 ++--- src/axolotl/utils/trainer.py | 4 +- tests/e2e/patched/test_model_patches.py | 8 +- 7 files changed, 86 insertions(+), 70 deletions(-) diff --git a/requirements.txt b/requirements.txt index 39f6c8a77f..b2aac0dd04 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ packaging==23.2 peft==0.11.1 -transformers==4.42.4 +transformers @ git+https://github.com/huggingface/transformers.git@0fdea8607d7e01eb0e38a1ebeb7feee30a22f0cf tokenizers==0.19.1 bitsandbytes==0.43.1 accelerate==0.32.0 diff --git a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py index c5425dd520..1cbc4278ba 100644 --- a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py @@ -2,6 +2,7 @@ # pylint: disable=duplicate-code import logging +from functools import partial from typing import List, Optional, Tuple, Union import torch @@ -45,6 +46,15 @@ def replace_mistral_attn_with_flash_attn( ) +def patch_mistral_cross_entropy(): + from flash_attn.losses.cross_entropy import CrossEntropyLoss + + LOG.info("patching with flash_attn.losses.cross_entropy") + transformers.models.mistral.modeling_mistral.CrossEntropyLoss = partial( + CrossEntropyLoss, inplace_backward=True + ) + + @torch.jit.script def _make_sliding_window_causal_mask( bsz: int, diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index 017adb2bfd..a2ce0e64fd 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -11,6 +11,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [ "llama", + "mistral", "mixtral", "qwen2", "qwen2_moe", @@ -25,6 +26,19 @@ def patch_for_multipack(model_type, model_name=None): + if model_type == "gemmoe": + patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe") + elif model_type == "deepseek_v2": + patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek") + elif hasattr(transformers, "modeling_flash_attention_utils"): + transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + if model_type == "mixtral" and is_deepspeed_zero3_enabled(): + patch_mixtral_moe_forward_zero3() + return + + # retain for legacy if model_type == "mixtral": transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access get_unpad_data @@ -32,9 +46,15 @@ def patch_for_multipack(model_type, model_name=None): if is_deepspeed_zero3_enabled(): patch_mixtral_moe_forward_zero3() elif model_type == "llama": - transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) + if hasattr(transformers.models.llama.modeling_llama, "_get_unpad_data"): + transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + elif model_type == "mistral": + if hasattr(transformers.models.mistral.modeling_mistral, "_get_unpad_data"): + transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) elif model_type == "qwen2": transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access get_unpad_data @@ -63,12 +83,6 @@ def patch_for_multipack(model_type, model_name=None): transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access get_unpad_data ) - elif model_type == "gemmoe": - patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe") - elif model_type == "jamba": - patch_remote(model_name, ".configuration_jamba", ".modeling_jamba") - elif model_type == "deepseek_v2": - patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek") def patch_remote(model_name, config_name, modeling_name): diff --git a/src/axolotl/monkeypatch/unsloth_.py b/src/axolotl/monkeypatch/unsloth_.py index b1f0bddc00..5b1f0061de 100644 --- a/src/axolotl/monkeypatch/unsloth_.py +++ b/src/axolotl/monkeypatch/unsloth_.py @@ -99,48 +99,51 @@ def check_self_attn_is_patchable() -> bool: return ORIGINAL_QKV_CODE in qkv and ORIGINAL_O_CODE in qkv -def integrate_cross_entropy_loss_patch(): - forward = get_forward_code() - LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access - forward, _ = detab_code(forward) - assert ORIGINAL_CEL_CODE in forward, "Original forward code not found" +def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None: + if model_type == "llama": + forward = get_forward_code() + LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access + forward, _ = detab_code(forward) + assert ORIGINAL_CEL_CODE in forward, "Original forward code not found" + + forward = forward.replace( + "@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", "" + ) + forward = forward.replace( + "@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)", + "", + ) + forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE) + forward = forward.replace( + "def forward(", + "def fast_cross_entropy_loss_forward(", + 1, + ) - forward = forward.replace( - "@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", "" - ) - forward = forward.replace( - "@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)", - "", - ) - forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE) - forward = forward.replace( - "def forward(", - "def fast_cross_entropy_loss_forward(", - 1, - ) + # load imports necessary + import transformers.models.llama.modeling_llama - # load imports necessary - import transformers.models.llama.modeling_llama + items_to_import = [] + for item in dir(transformers.models.llama.modeling_llama): + if item in forward: + items_to_import.append(item) - items_to_import = [] - for item in dir(transformers.models.llama.modeling_llama): - if item in forward: - items_to_import.append(item) - - exec( # pylint: disable=exec-used # nosec B102 - "from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss", - globals(), - ) + exec( # pylint: disable=exec-used # nosec B102 + "from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss", + globals(), + ) - exec( # pylint: disable=exec-used # nosec B102 - "from transformers.models.llama.modeling_llama import (" - + ", ".join(x for x in items_to_import) - + ")", - globals(), - ) - exec(forward, globals()) # pylint: disable=exec-used # nosec B102 - LOG.info("patching unsloth fast_cross_entropy_loss", main_process_only=True) - LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821 + exec( # pylint: disable=exec-used # nosec B102 + "from transformers.models.llama.modeling_llama import (" + + ", ".join(x for x in items_to_import) + + ")", + globals(), + ) + exec(forward, globals()) # pylint: disable=exec-used # nosec B102 + LOG.info("patching unsloth fast_cross_entropy_loss", main_process_only=True) + LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821 + else: + raise ValueError("Unsupported model type") def detab_code(code: str) -> Tuple[str, str]: diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 6185f0102f..339195df79 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -367,7 +367,7 @@ def load_model( integrate_cross_entropy_loss_patch, ) - integrate_cross_entropy_loss_patch() + integrate_cross_entropy_loss_patch(model_type="llama") if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o: from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora @@ -424,7 +424,7 @@ def load_model( if cfg.unsloth_cross_entropy_loss: from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch - integrate_cross_entropy_loss_patch() + integrate_cross_entropy_loss_patch(model_type="llama") if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o: from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora @@ -432,23 +432,12 @@ def load_model( patch_self_attn_lora() # Modify mistral derived models - if ( - cfg.model_config_type == "mistral" - and cfg.flash_attention - and cfg.sample_packing - ): + if cfg.model_config_type == "mistral" and cfg.flash_attn_cross_entropy_loss: from axolotl.monkeypatch.mistral_attn_hijack_flash import ( - replace_mistral_attn_with_flash_attn, + patch_mistral_cross_entropy, ) - LOG.info("patching mistral with flash attention") - replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing) - - if cfg.is_llama_derived_model and cfg.sample_packing and not inference: - from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask - - LOG.info("patching _expand_mask") - hijack_expand_mask() + patch_mistral_cross_entropy() model_kwargs: Dict[str, Any] = {} diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index a16baaae0f..65c2d424e5 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -189,9 +189,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): max_input_len = np.max(get_dataset_lengths(train_dataset)) LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True) - if ( - cfg.is_mistral_derived_model and cfg.flash_attention - ) or cfg.model_config_type == "mamba": + if cfg.model_config_type == "mamba": LOG.info("dropping attention_mask column") train_dataset = train_dataset.remove_columns("attention_mask") if eval_dataset: diff --git a/tests/e2e/patched/test_model_patches.py b/tests/e2e/patched/test_model_patches.py index eecd1b3c11..170c37fd6c 100644 --- a/tests/e2e/patched/test_model_patches.py +++ b/tests/e2e/patched/test_model_patches.py @@ -4,6 +4,8 @@ import unittest +import transformers + from axolotl.common.cli import TrainerCliArgs from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -87,9 +89,9 @@ def test_mistral_multipack(self, temp_dir): normalize_config(cfg) cli_args = TrainerCliArgs() tokenizer = load_tokenizer(cfg) - model, _ = load_model(cfg, tokenizer, inference=cli_args.inference) + load_model(cfg, tokenizer, inference=cli_args.inference) assert ( - "axolotl.monkeypatch.mistral_attn_hijack_flash" - in model.model.layers[0].self_attn.forward.__module__ + "torch.jit" + in transformers.modeling_flash_attention_utils._get_unpad_data.__module__ # pylint: disable=protected-access )