From 1302e310491a9f6a5911eb1763a75302b74db285 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 8 Dec 2024 14:50:40 -0500 Subject: [PATCH] Transformers version flexibility and FSDP optimizer patch (#2155) * allow flexibility in transformers version for FSDP * more flexibility with dev versions of 4.47.0.dev0 * add patch for fsdp * fix typo * correct fn name * stray character * fix patch * reset Trainer too * also reset Trainer.training_step * allow tests/patched to run more than one process on e2e runner * skip tests/patched in e2e for now since it's run in regular pytest --- cicd/cicd.sh | 2 +- docker/Dockerfile-base | 2 +- requirements.txt | 2 +- src/axolotl/core/trainer_builder.py | 58 ++++++++++---- src/axolotl/monkeypatch/trainer_fsdp_optim.py | 80 +++++++++++++++++++ src/axolotl/utils/models.py | 7 ++ tests/conftest.py | 11 ++- 7 files changed, 142 insertions(+), 20 deletions(-) create mode 100644 src/axolotl/monkeypatch/trainer_fsdp_optim.py diff --git a/cicd/cicd.sh b/cicd/cicd.sh index 79b3cc95e0..c3e46920d8 100755 --- a/cicd/cicd.sh +++ b/cicd/cicd.sh @@ -2,6 +2,6 @@ set -e pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/ -pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/patched/ +# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/ pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/ pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/ diff --git a/docker/Dockerfile-base b/docker/Dockerfile-base index 7eab3b3e43..4b24bfc3ae 100644 --- a/docker/Dockerfile-base +++ b/docker/Dockerfile-base @@ -16,7 +16,7 @@ ENV PYTHON_VERSION=$PYTHON_VERSION ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST RUN apt-get update \ - && apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \ + && apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \ && wget \ https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ && mkdir /root/.conda \ diff --git a/requirements.txt b/requirements.txt index 1d21cb354c..ae1b1838d5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ packaging==23.2 peft==0.14.0 -transformers==4.47.0 +transformers>=4.46.3 tokenizers>=0.20.1 bitsandbytes==0.45.0 accelerate==1.2.0 diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index baac94da80..691437bc65 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -22,6 +22,7 @@ import torch import transformers from datasets import Dataset +from packaging import version from peft.optimizers import create_loraplus_optimizer from torch import nn from torch.optim.lr_scheduler import OneCycleLR @@ -973,7 +974,13 @@ def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> Non for key, metrics in self._stored_metrics[train_eval].items(): logs[key] = torch.tensor(metrics).mean().item() del self._stored_metrics[train_eval] - return super().log(logs, start_time) + + if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): + try: + return super().log(logs, start_time) + except TypeError: + return super().log(logs) # transformers<=4.46 + return super().log(logs) # transformers<=4.46 def store_metrics( self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train" @@ -1165,9 +1172,13 @@ def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> Non for key, metrics in self._stored_metrics[train_eval].items(): logs[key] = torch.tensor(metrics).mean().item() del self._stored_metrics[train_eval] - return super(DPOTrainer, self).log( # pylint: disable=bad-super-call - logs, start_time - ) + + if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): + return super(DPOTrainer, self).log( # pylint: disable=bad-super-call + logs, start_time + ) + # transformers<=4.46 + return super(DPOTrainer, self).log(logs) # pylint: disable=bad-super-call class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): @@ -1185,9 +1196,13 @@ def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> Non for key, metrics in self._stored_metrics[train_eval].items(): logs[key] = torch.tensor(metrics).mean().item() del self._stored_metrics[train_eval] - return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call - logs, start_time - ) + + if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): + return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call + logs, start_time + ) + # transformers<=4.46 + return super(ORPOTrainer, self).log(logs) # pylint: disable=bad-super-call class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer): @@ -1232,9 +1247,13 @@ def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> Non for key, metrics in self._stored_metrics[train_eval].items(): logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item() del self._stored_metrics[train_eval] - return super(KTOTrainer, self).log( # pylint: disable=bad-super-call - logs, start_time - ) + + if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): + return super(KTOTrainer, self).log( # pylint: disable=bad-super-call + logs, start_time + ) + # transformers<=4.46 + return super(KTOTrainer, self).log(logs) # pylint: disable=bad-super-call class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer): @@ -1252,9 +1271,13 @@ def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> Non for key, metrics in self._stored_metrics[train_eval].items(): logs[key] = torch.tensor(metrics).mean().item() del self._stored_metrics[train_eval] - return super(CPOTrainer, self).log( # pylint: disable=bad-super-call - logs, start_time - ) + + if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): + return super(CPOTrainer, self).log( # pylint: disable=bad-super-call + logs, start_time + ) + # transformers<=4.46 + return super(CPOTrainer, self).log(logs) # pylint: disable=bad-super-call class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer): @@ -1266,9 +1289,12 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer): def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None: # TODO remove once trl supports the updated to the Trainer.log method - return super(RewardTrainer, self).log( # pylint: disable=bad-super-call - logs, start_time - ) + if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): + return super(RewardTrainer, self).log( # pylint: disable=bad-super-call + logs, start_time + ) + # transformers<=4.46 + return super(RewardTrainer, self).log(logs) # pylint: disable=bad-super-call class TrainerBuilderBase(abc.ABC): diff --git a/src/axolotl/monkeypatch/trainer_fsdp_optim.py b/src/axolotl/monkeypatch/trainer_fsdp_optim.py new file mode 100644 index 0000000000..835dea69b5 --- /dev/null +++ b/src/axolotl/monkeypatch/trainer_fsdp_optim.py @@ -0,0 +1,80 @@ +""" +fix for FSDP optimizer save in trainer w 4.47.0 +""" +import inspect +import logging + +from transformers.trainer import Trainer + +from axolotl.monkeypatch.unsloth_ import detab_code + +LOG = logging.getLogger("axolotl.monkeypatch.trainer_fsdp_save") + +ORIGINAL_TRAINER_CODE = """ + + delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled + +""" + +PATCHED_TRAINER_CODE = """ + + delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled + +""" + + +def get_training_loop_code() -> str: + training_loop = inspect.getsource( + Trainer._inner_training_loop # pylint: disable=protected-access + ) + return training_loop + + +def check_training_loop_is_patchable() -> bool: + training_loop = get_training_loop_code() + training_loop, _ = detab_code(training_loop) + return ORIGINAL_TRAINER_CODE in training_loop + + +def patch_training_loop_for_fsdp(): + """ + monkeypatch for fixing the training loop for fsdp with optimizer save + """ + + try: + training_loop = get_training_loop_code() + except OSError: + return + Trainer._original_inner_training_loop = ( # pylint: disable=protected-access + training_loop + ) + training_loop, _ = detab_code(training_loop) + if ORIGINAL_TRAINER_CODE not in training_loop: + return + + training_loop = training_loop.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE) + training_loop = training_loop.replace( + "def _inner_training_loop(", + "def _fixed_inner_training_loop(", + 1, + ) + + # load imports necessary + import transformers.trainer + + items_to_import = [] + for item in dir(transformers.trainer): + if item in training_loop: + items_to_import.append(item) + + exec( # pylint: disable=exec-used # nosec B102 + "from transformers.trainer import (" + + ", ".join(x for x in items_to_import) + + ")", + globals(), + ) + exec(training_loop, globals()) # pylint: disable=exec-used # nosec B102 + LOG.info("patching _inner_training_loop for fsdp optimizer save") + Trainer._inner_training_loop = ( # pylint: disable=protected-access + _fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821 + ) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 88a8aa581f..99095c1bfc 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -380,6 +380,13 @@ def apply_patches(self) -> None: plugin_manager = PluginManager.get_instance() plugin_manager.pre_model_load(self.cfg) + if self.cfg.fsdp: + from axolotl.monkeypatch.trainer_fsdp_optim import ( + patch_training_loop_for_fsdp, + ) + + patch_training_loop_for_fsdp() + if self.cfg.gradient_checkpointing == "unsloth": transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper diff --git a/tests/conftest.py b/tests/conftest.py index 1295d34b64..a9dde9dd88 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -119,18 +119,27 @@ def temp_dir(): @pytest.fixture(scope="function", autouse=True) def cleanup_monkeypatches(): + from transformers import Trainer from transformers.models.llama.modeling_llama import LlamaFlashAttention2 original_fa2_forward = LlamaFlashAttention2.forward + original_trainer_inner_training_loop = ( + Trainer._inner_training_loop # pylint: disable=protected-access + ) + original_trainer_training_step = Trainer.training_step # monkey patches can happen inside the tests yield # Reset LlamaFlashAttention2 forward LlamaFlashAttention2.forward = original_fa2_forward + Trainer._inner_training_loop = ( # pylint: disable=protected-access + original_trainer_inner_training_loop + ) + Trainer.training_step = original_trainer_training_step # Reset other known monkeypatches modules_to_reset: list[tuple[str, list[str]]] = [ ("transformers.models.llama.modeling_llama", ["LlamaFlashAttention2"]), - ("transformers.trainer",), + ("transformers.trainer", ["Trainer"]), ("transformers.loss.loss_utils",), ] for module_name_tuple in modules_to_reset: