From 370e5facc726b4180aa08f77eab49fb4de3ac931 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 22 Jan 2024 17:00:53 -0500 Subject: [PATCH] fixes to ensure qwen2 packing works --- src/axolotl/core/trainer_builder.py | 2 +- src/axolotl/utils/models.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 334948dd9d..b8309c3633 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -905,7 +905,7 @@ def build_collator( ] ] if use_batch_sampler_collator: - if self.cfg.model_config_type == "mixtral": + if self.cfg.model_config_type in ["mixtral", "qwen2"]: collator = V2BatchSamplerDataCollatorForSeq2Seq else: collator = BatchSamplerDataCollatorForSeq2Seq diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 2132a45636..6be11a77d2 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -435,6 +435,7 @@ def load_model( or cfg.is_falcon_derived_model or cfg.is_mistral_derived_model or model_config.model_type == "mixtral" + or model_config.model_type == "qwen2" ): model_kwargs["attn_implementation"] = "flash_attention_2" model_config._attn_implementation = ( # pylint: disable=protected-access