From 2752d5f9582dea722ba61e25f18146ff1dc3167d Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 21 Feb 2024 19:24:21 -0500 Subject: [PATCH] multipack for gemma (#1313) * multipack for gemma * chore: lint * handle cache_position kwarg in updated llama modeling * add position_ids to rotary embed call for updated llama modeling --- examples/gemma/qlora.yml | 22 +++++++++---------- requirements.txt | 2 +- .../monkeypatch/llama_attn_hijack_flash.py | 11 ++++++++-- src/axolotl/monkeypatch/multipack.py | 6 ++++- 4 files changed, 26 insertions(+), 15 deletions(-) diff --git a/examples/gemma/qlora.yml b/examples/gemma/qlora.yml index abbd2f0347..3a9f21da0c 100644 --- a/examples/gemma/qlora.yml +++ b/examples/gemma/qlora.yml @@ -1,49 +1,49 @@ # use google/gemma-7b if you have access -base_model: mhenrichsen/gemma-7b +base_model: mhenrichsen/gemma-7b model_type: AutoModelForCausalLM tokenizer_type: AutoTokenizer - + load_in_8bit: false load_in_4bit: true strict: false - + # huggingface repo datasets: - path: mhenrichsen/alpaca_2k_test type: alpaca val_set_size: 0.1 output_dir: ./out - + adapter: qlora lora_r: 32 lora_alpha: 16 lora_dropout: 0.05 lora_target_linear: true - + sequence_len: 4096 sample_packing: false pad_to_sequence_len: false - + wandb_project: wandb_entity: wandb_watch: wandb_name: wandb_log_model: - - + + gradient_accumulation_steps: 3 micro_batch_size: 2 num_epochs: 4 optimizer: adamw_bnb_8bit lr_scheduler: cosine learning_rate: 0.0002 - + train_on_inputs: false group_by_length: false bf16: auto fp16: tf32: false - + gradient_checkpointing: true early_stopping_patience: resume_from_checkpoint: @@ -51,7 +51,7 @@ local_rank: logging_steps: 1 xformers_attention: flash_attention: true - + warmup_ratio: 0.1 evals_per_epoch: 4 eval_table_size: diff --git a/requirements.txt b/requirements.txt index e20940f649..a5986fa4ff 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ packaging==23.2 peft @ git+https://github.com/huggingface/peft.git -transformers @ git+https://github.com/huggingface/transformers.git@bebeeee01275c32fccec3fa36d8b148d3813a7dc +transformers @ git+https://github.com/huggingface/transformers.git@ae49b218c3d718df90d8e4a109016450fb8f0632 tokenizers==0.15.0 bitsandbytes>=0.41.1 accelerate==0.26.1 diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 4bded9b027..86dde18a6a 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -275,7 +275,9 @@ def flashattn_forward_with_s2attn( kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = self.rotary_emb( + value_states, seq_len=kv_seq_len, position_ids=position_ids + ) query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids ) @@ -425,7 +427,9 @@ def flashattn_forward( if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = self.rotary_emb( + value_states, seq_len=kv_seq_len, position_ids=position_ids + ) query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids ) @@ -688,6 +692,9 @@ def llama_model_forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[ # pylint: disable=unused-argument + torch.LongTensor + ] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = ( output_attentions diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index 640b3b0c3c..65a79a8782 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -6,7 +6,7 @@ from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3 from axolotl.monkeypatch.utils import get_unpad_data -SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi"] +SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi", "gemma"] def patch_for_multipack(model_type): @@ -28,3 +28,7 @@ def patch_for_multipack(model_type): transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access get_unpad_data ) + elif model_type == "gemma": + transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + )