Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transformers 4.47.1 #2187

Merged
merged 10 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,27 @@ liger-kernel==0.4.2
# END section

packaging==23.2

peft==0.14.0
transformers==4.47.0
transformers==4.47.1
tokenizers>=0.20.1
accelerate==1.2.0
accelerate==1.2.1
datasets==3.1.0
deepspeed==0.16.1
trl==0.12.1

optimum==1.16.2
hf_transfer
sentencepiece
gradio==3.50.2

pydantic==2.6.3
addict
fire
PyYAML>=6.0
requests
sentencepiece
wandb
einops
optimum==1.16.2
hf_transfer
colorama
numba
numpy>=1.24.4,<=2.0.1
Expand All @@ -36,7 +41,6 @@ scipy
scikit-learn==1.4.2
nvidia-ml-py==12.560.30
art
gradio==3.50.2
tensorboard
python-dotenv==1.0.1

Expand All @@ -45,7 +49,6 @@ s3fs>=2024.5.0
gcsfs>=2024.5.0
# adlfs

trl==0.12.1
zstandard==0.22.0
fastcore

Expand All @@ -55,5 +58,5 @@ langdetect==1.0.9
immutabledict==4.2.0
antlr4-python3-runtime==4.13.2

torchao==0.5.0
torchao==0.7.0
schedulefree==1.3.0
2 changes: 1 addition & 1 deletion scripts/unsloth_install.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,5 @@
raise RuntimeError(f"Torch = {v} too new!")
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
print(
f'pip install unsloth-zoo==2024.11.7 && pip install --no-deps "unsloth[{x}]==2024.11.9"'
f'pip install unsloth-zoo==2024.12.1 && pip install --no-deps "unsloth[{x}]==2024.12.4"'
)
26 changes: 22 additions & 4 deletions src/axolotl/monkeypatch/trainer_grad_accum.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,15 @@
import logging

from transformers import LlamaForCausalLM, Trainer
from transformers.modeling_flash_attention_utils import _flash_attention_forward

from axolotl.monkeypatch.unsloth_ import detab_code

LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")

ORIGINAL_CONTEXT_CODE = """
with self.compute_loss_context_manager():
if self.model_accepts_loss_kwargs:
loss = self.compute_loss(model, inputs)
else:
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
"""

PATCHED_CONTEXT_CODE = """
Expand Down Expand Up @@ -288,3 +286,23 @@ def patch_training_loop_for_deepspeed_0_16_x():
Trainer._inner_training_loop = ( # pylint: disable=protected-access
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
)


def patch_flash_attention_forward():
"""
monkeypatch for fixing the forward pass for flash attention to ignore num_items_in_batch
"""

import transformers.modeling_flash_attention_utils

def proxy_flash_attention_forward(*args, **kwargs):
kwargs.pop("num_items_in_batch", None)

return _flash_attention_forward(*args, **kwargs)

transformers.modeling_flash_attention_utils._flash_attention_forward = ( # pylint: disable=protected-access
proxy_flash_attention_forward
)
transformers.models.llama.modeling_llama._flash_attention_forward = ( # pylint: disable=protected-access
proxy_flash_attention_forward
)
15 changes: 2 additions & 13 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,19 +380,6 @@ def apply_patches(self) -> None:
plugin_manager = PluginManager.get_instance()
plugin_manager.pre_model_load(self.cfg)

if self.cfg.fsdp:
from axolotl.monkeypatch.trainer_fsdp_optim import (
patch_training_loop_for_fsdp,
)

patch_training_loop_for_fsdp()
elif self.cfg.deepspeed and self.cfg.gradient_accumulation_steps > 1:
from axolotl.monkeypatch.trainer_grad_accum import (
patch_training_loop_for_deepspeed_0_16_x,
)

patch_training_loop_for_deepspeed_0_16_x()

if self.cfg.gradient_checkpointing == "unsloth":
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper

Expand All @@ -401,10 +388,12 @@ def apply_patches(self) -> None:

if self.cfg.model_config_type == "llama":
from axolotl.monkeypatch.trainer_grad_accum import (
patch_flash_attention_forward,
patch_forward_for_ga,
patch_training_step_for_ga,
)

patch_flash_attention_forward()
patch_forward_for_ga()
patch_training_step_for_ga()

Expand Down