From 9c229c4b55c1ad4af46ff2799cf0f86709e6f54e Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 22 Jan 2024 17:19:16 -0500 Subject: [PATCH] bump requirements for qwen2 --- requirements.txt | 4 ++-- src/axolotl/utils/models.py | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/requirements.txt b/requirements.txt index 4583850326..d7ac3d4597 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ packaging==23.2 peft==0.7.0 -transformers @ git+https://github.com/huggingface/transformers.git@3cefac1d974db5e2825a0cb2b842883a628be7a0 +transformers==4.37.0 tokenizers==0.15.0 bitsandbytes>=0.41.1 -accelerate @ git+https://github.com/huggingface/accelerate.git@0d2280dadc6a93413a5496613b7fdda3a4d2551b +accelerat==0.26.1 deepspeed addict fire diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 6be11a77d2..c883edb375 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -434,15 +434,14 @@ def load_model( cfg.is_llama_derived_model or cfg.is_falcon_derived_model or cfg.is_mistral_derived_model - or model_config.model_type == "mixtral" - or model_config.model_type == "qwen2" + or model_config.model_type in ["mixtral", "qwen2"] ): model_kwargs["attn_implementation"] = "flash_attention_2" model_config._attn_implementation = ( # pylint: disable=protected-access "flash_attention_2" ) else: - if model_config.model_type == "mixtral": + if model_config.model_type in ["mixtral", "qwen2"]: model_kwargs["attn_implementation"] = "flash_attention_2" model_config._attn_implementation = ( # pylint: disable=protected-access "flash_attention_2"