Skip to content

Commit

Permalink
bump requirements for qwen2
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Jan 22, 2024
1 parent 370e5fa commit 9c229c4
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.7.0
transformers @ git+https://github.com/huggingface/transformers.git@3cefac1d974db5e2825a0cb2b842883a628be7a0
transformers==4.37.0
tokenizers==0.15.0
bitsandbytes>=0.41.1
accelerate @ git+https://github.com/huggingface/accelerate.git@0d2280dadc6a93413a5496613b7fdda3a4d2551b
accelerat==0.26.1
deepspeed
addict
fire
Expand Down
5 changes: 2 additions & 3 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,15 +434,14 @@ def load_model(
cfg.is_llama_derived_model
or cfg.is_falcon_derived_model
or cfg.is_mistral_derived_model
or model_config.model_type == "mixtral"
or model_config.model_type == "qwen2"
or model_config.model_type in ["mixtral", "qwen2"]
):
model_kwargs["attn_implementation"] = "flash_attention_2"
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
else:
if model_config.model_type == "mixtral":
if model_config.model_type in ["mixtral", "qwen2"]:
model_kwargs["attn_implementation"] = "flash_attention_2"
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
Expand Down

0 comments on commit 9c229c4

Please sign in to comment.