Skip to content

Commit

Permalink
transformers 4.47.1 (#2187)
Browse files Browse the repository at this point in the history
* transformers 4.47.1

* drop monkeypatches

* can't remove patches yet

* make flash attention forward ignore the loss kwargs

* patch the flash attention in the modeling arch too

* remove fsdp and deepspeed patches

* cleanup PR

* bump accelerate and torchao, also logically reorder/group requirements

* meant to include torchao

* use official patch release
  • Loading branch information
winglian authored Dec 17, 2024
1 parent f865464 commit 1f623e6
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 26 deletions.
19 changes: 11 additions & 8 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,27 @@ liger-kernel==0.4.2
# END section

packaging==23.2

peft==0.14.0
transformers==4.47.0
transformers==4.47.1
tokenizers>=0.20.1
accelerate==1.2.0
accelerate==1.2.1
datasets==3.1.0
deepspeed==0.16.1
trl==0.12.1

optimum==1.16.2
hf_transfer
sentencepiece
gradio==3.50.2

pydantic==2.6.3
addict
fire
PyYAML>=6.0
requests
sentencepiece
wandb
einops
optimum==1.16.2
hf_transfer
colorama
numba
numpy>=1.24.4,<=2.0.1
Expand All @@ -36,7 +41,6 @@ scipy
scikit-learn==1.4.2
nvidia-ml-py==12.560.30
art
gradio==3.50.2
tensorboard
python-dotenv==1.0.1

Expand All @@ -45,7 +49,6 @@ s3fs>=2024.5.0
gcsfs>=2024.5.0
# adlfs

trl==0.12.1
zstandard==0.22.0
fastcore

Expand All @@ -55,5 +58,5 @@ langdetect==1.0.9
immutabledict==4.2.0
antlr4-python3-runtime==4.13.2

torchao==0.5.0
torchao==0.7.0
schedulefree==1.3.0
2 changes: 1 addition & 1 deletion scripts/unsloth_install.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,5 @@
raise RuntimeError(f"Torch = {v} too new!")
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
print(
f'pip install unsloth-zoo==2024.11.7 && pip install --no-deps "unsloth[{x}]==2024.11.9"'
f'pip install unsloth-zoo==2024.12.1 && pip install --no-deps "unsloth[{x}]==2024.12.4"'
)
26 changes: 22 additions & 4 deletions src/axolotl/monkeypatch/trainer_grad_accum.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,15 @@
import logging

from transformers import LlamaForCausalLM, Trainer
from transformers.modeling_flash_attention_utils import _flash_attention_forward

from axolotl.monkeypatch.unsloth_ import detab_code

LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")

ORIGINAL_CONTEXT_CODE = """
with self.compute_loss_context_manager():
if self.model_accepts_loss_kwargs:
loss = self.compute_loss(model, inputs)
else:
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
"""

PATCHED_CONTEXT_CODE = """
Expand Down Expand Up @@ -288,3 +286,23 @@ def patch_training_loop_for_deepspeed_0_16_x():
Trainer._inner_training_loop = ( # pylint: disable=protected-access
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
)


def patch_flash_attention_forward():
"""
monkeypatch for fixing the forward pass for flash attention to ignore num_items_in_batch
"""

import transformers.modeling_flash_attention_utils

def proxy_flash_attention_forward(*args, **kwargs):
kwargs.pop("num_items_in_batch", None)

return _flash_attention_forward(*args, **kwargs)

transformers.modeling_flash_attention_utils._flash_attention_forward = ( # pylint: disable=protected-access
proxy_flash_attention_forward
)
transformers.models.llama.modeling_llama._flash_attention_forward = ( # pylint: disable=protected-access
proxy_flash_attention_forward
)
15 changes: 2 additions & 13 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,19 +380,6 @@ def apply_patches(self) -> None:
plugin_manager = PluginManager.get_instance()
plugin_manager.pre_model_load(self.cfg)

if self.cfg.fsdp:
from axolotl.monkeypatch.trainer_fsdp_optim import (
patch_training_loop_for_fsdp,
)

patch_training_loop_for_fsdp()
elif self.cfg.deepspeed and self.cfg.gradient_accumulation_steps > 1:
from axolotl.monkeypatch.trainer_grad_accum import (
patch_training_loop_for_deepspeed_0_16_x,
)

patch_training_loop_for_deepspeed_0_16_x()

if self.cfg.gradient_checkpointing == "unsloth":
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper

Expand All @@ -401,10 +388,12 @@ def apply_patches(self) -> None:

if self.cfg.model_config_type == "llama":
from axolotl.monkeypatch.trainer_grad_accum import (
patch_flash_attention_forward,
patch_forward_for_ga,
patch_training_step_for_ga,
)

patch_flash_attention_forward()
patch_forward_for_ga()
patch_training_step_for_ga()

Expand Down

0 comments on commit 1f623e6

Please sign in to comment.