Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Qwen2 #1166

Merged
merged 5 commits into from
Jan 22, 2024
Merged

Qwen2 #1166

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.7.0
transformers @ git+https://github.com/huggingface/transformers.git@3cefac1d974db5e2825a0cb2b842883a628be7a0
transformers==4.37.0
tokenizers==0.15.0
bitsandbytes>=0.41.1
accelerate @ git+https://github.com/huggingface/accelerate.git@0d2280dadc6a93413a5496613b7fdda3a4d2551b
accelerate==0.26.1
deepspeed
addict
fire
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,7 +905,7 @@ def build_collator(
]
]
if use_batch_sampler_collator:
if self.cfg.model_config_type == "mixtral":
if self.cfg.model_config_type in ["mixtral", "qwen2"]:
collator = V2BatchSamplerDataCollatorForSeq2Seq
else:
collator = BatchSamplerDataCollatorForSeq2Seq
Expand Down
12 changes: 12 additions & 0 deletions src/axolotl/monkeypatch/qwen2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
Patches to support multipack for qwen2
"""
import transformers

from axolotl.monkeypatch.utils import get_unpad_data


def replace_qwen2_attn_with_multipack_flash_attn():
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
17 changes: 6 additions & 11 deletions src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,12 @@ def normalize_config(cfg):
)

cfg.is_qwen_derived_model = (
(
hasattr(model_config, "model_type")
and model_config.model_type
in [
"qwen",
]
)
or cfg.is_qwen_derived_model
or "qwen" in cfg.base_model.lower()
or (cfg.model_type and "qwen" in cfg.model_type.lower())
)
hasattr(model_config, "model_type")
and model_config.model_type
in [
"qwen",
]
) or cfg.is_qwen_derived_model

if isinstance(cfg.learning_rate, str):
cfg.learning_rate = float(cfg.learning_rate)
Expand Down
12 changes: 10 additions & 2 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,14 @@ def load_model(
LOG.info("patching mixtral with flash attention")
replace_mixtral_attn_with_multipack_flash_attn()

if cfg.model_config_type == "qwen2" and cfg.flash_attention and cfg.sample_packing:
from axolotl.monkeypatch.qwen2 import (
replace_qwen2_attn_with_multipack_flash_attn,
)

LOG.info("patching qwen2 with flash attention")
replace_qwen2_attn_with_multipack_flash_attn()

if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask

Expand Down Expand Up @@ -426,14 +434,14 @@ def load_model(
cfg.is_llama_derived_model
or cfg.is_falcon_derived_model
or cfg.is_mistral_derived_model
or model_config.model_type == "mixtral"
or model_config.model_type in ["mixtral", "qwen2"]
):
model_kwargs["attn_implementation"] = "flash_attention_2"
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
else:
if model_config.model_type == "mixtral":
if model_config.model_type in ["mixtral", "qwen2"]:
model_kwargs["attn_implementation"] = "flash_attention_2"
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
Expand Down