Skip to content

Commit

Permalink
multipack for gemma (#1313)
Browse files Browse the repository at this point in the history
* multipack for gemma

* chore: lint

* handle cache_position kwarg in updated llama modeling

* add position_ids to rotary embed call for updated llama modeling
  • Loading branch information
winglian authored Feb 22, 2024
1 parent 9e300ac commit 2752d5f
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 15 deletions.
22 changes: 11 additions & 11 deletions examples/gemma/qlora.yml
Original file line number Diff line number Diff line change
@@ -1,57 +1,57 @@
# use google/gemma-7b if you have access
base_model: mhenrichsen/gemma-7b
base_model: mhenrichsen/gemma-7b
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer

load_in_8bit: false
load_in_4bit: true
strict: false

# huggingface repo
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
val_set_size: 0.1
output_dir: ./out

adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true

sequence_len: 4096
sample_packing: false
pad_to_sequence_len: false

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:


gradient_accumulation_steps: 3
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_ratio: 0.1
evals_per_epoch: 4
eval_table_size:
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft @ git+https://github.com/huggingface/peft.git
transformers @ git+https://github.com/huggingface/transformers.git@bebeeee01275c32fccec3fa36d8b148d3813a7dc
transformers @ git+https://github.com/huggingface/transformers.git@ae49b218c3d718df90d8e4a109016450fb8f0632
tokenizers==0.15.0
bitsandbytes>=0.41.1
accelerate==0.26.1
Expand Down
11 changes: 9 additions & 2 deletions src/axolotl/monkeypatch/llama_attn_hijack_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,9 @@ def flashattn_forward_with_s2attn(
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos, sin = self.rotary_emb(
value_states, seq_len=kv_seq_len, position_ids=position_ids
)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
Expand Down Expand Up @@ -425,7 +427,9 @@ def flashattn_forward(
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]

cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos, sin = self.rotary_emb(
value_states, seq_len=kv_seq_len, position_ids=position_ids
)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
Expand Down Expand Up @@ -688,6 +692,9 @@ def llama_model_forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[ # pylint: disable=unused-argument
torch.LongTensor
] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
Expand Down
6 changes: 5 additions & 1 deletion src/axolotl/monkeypatch/multipack.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
from axolotl.monkeypatch.utils import get_unpad_data

SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi"]
SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi", "gemma"]


def patch_for_multipack(model_type):
Expand All @@ -28,3 +28,7 @@ def patch_for_multipack(model_type):
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "gemma":
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)

0 comments on commit 2752d5f

Please sign in to comment.