Skip to content

Commit

Permalink
Transformers version flexibility and FSDP optimizer patch (#2155)
Browse files Browse the repository at this point in the history
* allow flexibility in transformers version for FSDP

* more flexibility with dev versions of 4.47.0.dev0

* add patch for fsdp

* fix typo

* correct fn name

* stray character

* fix patch

* reset Trainer too

* also reset Trainer.training_step

* allow tests/patched to run more than one process on e2e runner

* skip tests/patched in e2e for now since it's run in regular pytest
  • Loading branch information
winglian authored Dec 8, 2024
1 parent be5f554 commit 1302e31
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 20 deletions.
2 changes: 1 addition & 1 deletion cicd/cicd.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
set -e

pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/patched/
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
2 changes: 1 addition & 1 deletion docker/Dockerfile-base
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ ENV PYTHON_VERSION=$PYTHON_VERSION
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST

RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
&& wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.14.0
transformers==4.47.0
transformers>=4.46.3
tokenizers>=0.20.1
bitsandbytes==0.45.0
accelerate==1.2.0
Expand Down
58 changes: 42 additions & 16 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import torch
import transformers
from datasets import Dataset
from packaging import version
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
Expand Down Expand Up @@ -973,7 +974,13 @@ def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> Non
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super().log(logs, start_time)

if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
try:
return super().log(logs, start_time)
except TypeError:
return super().log(logs) # transformers<=4.46
return super().log(logs) # transformers<=4.46

def store_metrics(
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
Expand Down Expand Up @@ -1165,9 +1172,13 @@ def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> Non
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super(DPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)

if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(DPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(DPOTrainer, self).log(logs) # pylint: disable=bad-super-call


class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
Expand All @@ -1185,9 +1196,13 @@ def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> Non
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)

if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(ORPOTrainer, self).log(logs) # pylint: disable=bad-super-call


class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
Expand Down Expand Up @@ -1232,9 +1247,13 @@ def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> Non
for key, metrics in self._stored_metrics[train_eval].items():
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super(KTOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)

if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(KTOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(KTOTrainer, self).log(logs) # pylint: disable=bad-super-call


class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
Expand All @@ -1252,9 +1271,13 @@ def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> Non
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super(CPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)

if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(CPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(CPOTrainer, self).log(logs) # pylint: disable=bad-super-call


class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
Expand All @@ -1266,9 +1289,12 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):

def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
return super(RewardTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(RewardTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(RewardTrainer, self).log(logs) # pylint: disable=bad-super-call


class TrainerBuilderBase(abc.ABC):
Expand Down
80 changes: 80 additions & 0 deletions src/axolotl/monkeypatch/trainer_fsdp_optim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""
fix for FSDP optimizer save in trainer w 4.47.0
"""
import inspect
import logging

from transformers.trainer import Trainer

from axolotl.monkeypatch.unsloth_ import detab_code

LOG = logging.getLogger("axolotl.monkeypatch.trainer_fsdp_save")

ORIGINAL_TRAINER_CODE = """
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled
"""

PATCHED_TRAINER_CODE = """
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
"""


def get_training_loop_code() -> str:
training_loop = inspect.getsource(
Trainer._inner_training_loop # pylint: disable=protected-access
)
return training_loop


def check_training_loop_is_patchable() -> bool:
training_loop = get_training_loop_code()
training_loop, _ = detab_code(training_loop)
return ORIGINAL_TRAINER_CODE in training_loop


def patch_training_loop_for_fsdp():
"""
monkeypatch for fixing the training loop for fsdp with optimizer save
"""

try:
training_loop = get_training_loop_code()
except OSError:
return
Trainer._original_inner_training_loop = ( # pylint: disable=protected-access
training_loop
)
training_loop, _ = detab_code(training_loop)
if ORIGINAL_TRAINER_CODE not in training_loop:
return

training_loop = training_loop.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE)
training_loop = training_loop.replace(
"def _inner_training_loop(",
"def _fixed_inner_training_loop(",
1,
)

# load imports necessary
import transformers.trainer

items_to_import = []
for item in dir(transformers.trainer):
if item in training_loop:
items_to_import.append(item)

exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(training_loop, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching _inner_training_loop for fsdp optimizer save")
Trainer._inner_training_loop = ( # pylint: disable=protected-access
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
)
7 changes: 7 additions & 0 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,13 @@ def apply_patches(self) -> None:
plugin_manager = PluginManager.get_instance()
plugin_manager.pre_model_load(self.cfg)

if self.cfg.fsdp:
from axolotl.monkeypatch.trainer_fsdp_optim import (
patch_training_loop_for_fsdp,
)

patch_training_loop_for_fsdp()

if self.cfg.gradient_checkpointing == "unsloth":
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper

Expand Down
11 changes: 10 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,18 +119,27 @@ def temp_dir():

@pytest.fixture(scope="function", autouse=True)
def cleanup_monkeypatches():
from transformers import Trainer
from transformers.models.llama.modeling_llama import LlamaFlashAttention2

original_fa2_forward = LlamaFlashAttention2.forward
original_trainer_inner_training_loop = (
Trainer._inner_training_loop # pylint: disable=protected-access
)
original_trainer_training_step = Trainer.training_step
# monkey patches can happen inside the tests
yield
# Reset LlamaFlashAttention2 forward
LlamaFlashAttention2.forward = original_fa2_forward
Trainer._inner_training_loop = ( # pylint: disable=protected-access
original_trainer_inner_training_loop
)
Trainer.training_step = original_trainer_training_step

# Reset other known monkeypatches
modules_to_reset: list[tuple[str, list[str]]] = [
("transformers.models.llama.modeling_llama", ["LlamaFlashAttention2"]),
("transformers.trainer",),
("transformers.trainer", ["Trainer"]),
("transformers.loss.loss_utils",),
]
for module_name_tuple in modules_to_reset:
Expand Down

0 comments on commit 1302e31

Please sign in to comment.