Skip to content

Commit

Permalink
Qwen2 (#1166)
Browse files Browse the repository at this point in the history
* qwen2 multipack support

* fix qwen derived model check so it doesn't break qwen2

* fixes to ensure qwen2 packing works

* bump requirements for qwen2

* requirements typo
  • Loading branch information
winglian authored Jan 22, 2024
1 parent 3d2b5dd commit 7f1d59b
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 16 deletions.
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.7.0
transformers @ git+https://github.com/huggingface/transformers.git@3cefac1d974db5e2825a0cb2b842883a628be7a0
transformers==4.37.0
tokenizers==0.15.0
bitsandbytes>=0.41.1
accelerate @ git+https://github.com/huggingface/accelerate.git@0d2280dadc6a93413a5496613b7fdda3a4d2551b
accelerate==0.26.1
deepspeed
addict
fire
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,7 +905,7 @@ def build_collator(
]
]
if use_batch_sampler_collator:
if self.cfg.model_config_type == "mixtral":
if self.cfg.model_config_type in ["mixtral", "qwen2"]:
collator = V2BatchSamplerDataCollatorForSeq2Seq
else:
collator = BatchSamplerDataCollatorForSeq2Seq
Expand Down
12 changes: 12 additions & 0 deletions src/axolotl/monkeypatch/qwen2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
Patches to support multipack for qwen2
"""
import transformers

from axolotl.monkeypatch.utils import get_unpad_data


def replace_qwen2_attn_with_multipack_flash_attn():
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
17 changes: 6 additions & 11 deletions src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,12 @@ def normalize_config(cfg):
)

cfg.is_qwen_derived_model = (
(
hasattr(model_config, "model_type")
and model_config.model_type
in [
"qwen",
]
)
or cfg.is_qwen_derived_model
or "qwen" in cfg.base_model.lower()
or (cfg.model_type and "qwen" in cfg.model_type.lower())
)
hasattr(model_config, "model_type")
and model_config.model_type
in [
"qwen",
]
) or cfg.is_qwen_derived_model

if isinstance(cfg.learning_rate, str):
cfg.learning_rate = float(cfg.learning_rate)
Expand Down
12 changes: 10 additions & 2 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,14 @@ def load_model(
LOG.info("patching mixtral with flash attention")
replace_mixtral_attn_with_multipack_flash_attn()

if cfg.model_config_type == "qwen2" and cfg.flash_attention and cfg.sample_packing:
from axolotl.monkeypatch.qwen2 import (
replace_qwen2_attn_with_multipack_flash_attn,
)

LOG.info("patching qwen2 with flash attention")
replace_qwen2_attn_with_multipack_flash_attn()

if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask

Expand Down Expand Up @@ -426,14 +434,14 @@ def load_model(
cfg.is_llama_derived_model
or cfg.is_falcon_derived_model
or cfg.is_mistral_derived_model
or model_config.model_type == "mixtral"
or model_config.model_type in ["mixtral", "qwen2"]
):
model_kwargs["attn_implementation"] = "flash_attention_2"
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
else:
if model_config.model_type == "mixtral":
if model_config.model_type in ["mixtral", "qwen2"]:
model_kwargs["attn_implementation"] = "flash_attention_2"
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
Expand Down

0 comments on commit 7f1d59b

Please sign in to comment.