From f5a828aa20f9373c1c84e2ff628974d261d52857 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 22 Jan 2024 18:24:15 -0500 Subject: [PATCH] Qwen2 (#1166) * qwen2 multipack support * fix qwen derived model check so it doesn't break qwen2 * fixes to ensure qwen2 packing works * bump requirements for qwen2 * requirements typo --- requirements.txt | 4 ++-- src/axolotl/core/trainer_builder.py | 2 +- src/axolotl/monkeypatch/qwen2/__init__.py | 12 ++++++++++++ src/axolotl/utils/config.py | 17 ++++++----------- src/axolotl/utils/models.py | 12 ++++++++++-- 5 files changed, 31 insertions(+), 16 deletions(-) create mode 100644 src/axolotl/monkeypatch/qwen2/__init__.py diff --git a/requirements.txt b/requirements.txt index 4583850326..41cc8e1054 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ packaging==23.2 peft==0.7.0 -transformers @ git+https://github.com/huggingface/transformers.git@3cefac1d974db5e2825a0cb2b842883a628be7a0 +transformers==4.37.0 tokenizers==0.15.0 bitsandbytes>=0.41.1 -accelerate @ git+https://github.com/huggingface/accelerate.git@0d2280dadc6a93413a5496613b7fdda3a4d2551b +accelerate==0.26.1 deepspeed addict fire diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 334948dd9d..b8309c3633 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -905,7 +905,7 @@ def build_collator( ] ] if use_batch_sampler_collator: - if self.cfg.model_config_type == "mixtral": + if self.cfg.model_config_type in ["mixtral", "qwen2"]: collator = V2BatchSamplerDataCollatorForSeq2Seq else: collator = BatchSamplerDataCollatorForSeq2Seq diff --git a/src/axolotl/monkeypatch/qwen2/__init__.py b/src/axolotl/monkeypatch/qwen2/__init__.py new file mode 100644 index 0000000000..40c54d21e9 --- /dev/null +++ b/src/axolotl/monkeypatch/qwen2/__init__.py @@ -0,0 +1,12 @@ +""" +Patches to support multipack for qwen2 +""" +import transformers + +from axolotl.monkeypatch.utils import get_unpad_data + + +def replace_qwen2_attn_with_multipack_flash_attn(): + transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index ca7d037ddc..b045648410 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -142,17 +142,12 @@ def normalize_config(cfg): ) cfg.is_qwen_derived_model = ( - ( - hasattr(model_config, "model_type") - and model_config.model_type - in [ - "qwen", - ] - ) - or cfg.is_qwen_derived_model - or "qwen" in cfg.base_model.lower() - or (cfg.model_type and "qwen" in cfg.model_type.lower()) - ) + hasattr(model_config, "model_type") + and model_config.model_type + in [ + "qwen", + ] + ) or cfg.is_qwen_derived_model if isinstance(cfg.learning_rate, str): cfg.learning_rate = float(cfg.learning_rate) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 9707a4b65c..c883edb375 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -334,6 +334,14 @@ def load_model( LOG.info("patching mixtral with flash attention") replace_mixtral_attn_with_multipack_flash_attn() + if cfg.model_config_type == "qwen2" and cfg.flash_attention and cfg.sample_packing: + from axolotl.monkeypatch.qwen2 import ( + replace_qwen2_attn_with_multipack_flash_attn, + ) + + LOG.info("patching qwen2 with flash attention") + replace_qwen2_attn_with_multipack_flash_attn() + if cfg.is_llama_derived_model and cfg.sample_packing and not inference: from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask @@ -426,14 +434,14 @@ def load_model( cfg.is_llama_derived_model or cfg.is_falcon_derived_model or cfg.is_mistral_derived_model - or model_config.model_type == "mixtral" + or model_config.model_type in ["mixtral", "qwen2"] ): model_kwargs["attn_implementation"] = "flash_attention_2" model_config._attn_implementation = ( # pylint: disable=protected-access "flash_attention_2" ) else: - if model_config.model_type == "mixtral": + if model_config.model_type in ["mixtral", "qwen2"]: model_kwargs["attn_implementation"] = "flash_attention_2" model_config._attn_implementation = ( # pylint: disable=protected-access "flash_attention_2"