From 40907c68877a944cb82fece9e52abd6d3ecee002 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 9 Dec 2024 07:25:10 -0500 Subject: [PATCH] upgrade deepspeed to 0.16.1 (#2157) --- requirements.txt | 21 +++-- src/axolotl/monkeypatch/trainer_grad_accum.py | 84 +++++++++++++++++++ src/axolotl/utils/models.py | 6 ++ 3 files changed, 102 insertions(+), 9 deletions(-) diff --git a/requirements.txt b/requirements.txt index ae1b1838d5..361524e561 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,22 +1,30 @@ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ + +# START section of dependencies that don't install on Darwin/MacOS +bitsandbytes==0.45.0 +triton>=2.3.0 +mamba-ssm==1.2.0.post1 +flash-attn==2.7.0.post2 +xformers>=0.0.23.post1 +autoawq==0.2.7.post3 +liger-kernel==0.4.2 +# END section + packaging==23.2 peft==0.14.0 transformers>=4.46.3 tokenizers>=0.20.1 -bitsandbytes==0.45.0 accelerate==1.2.0 datasets==3.1.0 -deepspeed==0.15.4 +deepspeed==0.16.1 pydantic==2.6.3 addict fire PyYAML>=6.0 requests -flash-attn==2.7.0.post2 sentencepiece wandb einops -xformers>=0.0.23.post1 optimum==1.16.2 hf_transfer colorama @@ -31,11 +39,6 @@ art gradio==3.50.2 tensorboard python-dotenv==1.0.1 -autoawq==0.2.7.post3 -triton>=2.3.0 -liger-kernel==0.4.2 - -mamba-ssm==1.2.0.post1 # remote filesystems s3fs>=2024.5.0 diff --git a/src/axolotl/monkeypatch/trainer_grad_accum.py b/src/axolotl/monkeypatch/trainer_grad_accum.py index 97e6b7f2c5..39435ebac1 100644 --- a/src/axolotl/monkeypatch/trainer_grad_accum.py +++ b/src/axolotl/monkeypatch/trainer_grad_accum.py @@ -205,3 +205,87 @@ def patch_forward_for_ga(): LlamaForCausalLM.forward = ( # pylint: disable=protected-access _fixed_forward # pylint: disable=undefined-variable # noqa: F821 ) + + +ORIGINAL_TRAINER_CODE = """ + context = ( + functools.partial(self.accelerator.no_sync, model=model) + if i != len(batch_samples) - 1 + else contextlib.nullcontext + ) + with context(): + tr_loss_step = self.training_step(model, inputs, num_items_in_batch) +""" + +PATCHED_TRAINER_CODE = """ + disable_deepspeed_no_sync = ( + self.accelerator.distributed_type == DistributedType.DEEPSPEED + and self.accelerator.deepspeed_engine_wrapped.engine.zero_optimization_partition_gradients() + ) + context = ( + functools.partial(self.accelerator.no_sync, model=model) + if i != len(batch_samples) - 1 and not disable_deepspeed_no_sync + else contextlib.nullcontext + ) + with context(): + tr_loss_step = self.training_step(model, inputs, num_items_in_batch) +""" + + +def get_training_loop_code() -> str: + training_loop = inspect.getsource( + Trainer._inner_training_loop # pylint: disable=protected-access + ) + return training_loop + + +def check_training_loop_is_patchable() -> bool: + training_loop = get_training_loop_code() + training_loop, _ = detab_code(training_loop) + return ORIGINAL_TRAINER_CODE in training_loop + + +def patch_training_loop_for_deepspeed_0_16_x(): + """ + monkeypatch for fixing the training loop for deepspeed GA + + see https://github.com/huggingface/transformers/pull/35157 + """ + + try: + training_loop = get_training_loop_code() + except OSError: + return + Trainer._original_inner_training_loop = ( # pylint: disable=protected-access + training_loop + ) + training_loop, _ = detab_code(training_loop) + if ORIGINAL_TRAINER_CODE not in training_loop: + return + + training_loop = training_loop.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE) + training_loop = training_loop.replace( + "def _inner_training_loop(", + "def _fixed_inner_training_loop(", + 1, + ) + + # load imports necessary + import transformers.trainer + + items_to_import = [] + for item in dir(transformers.trainer): + if item in training_loop: + items_to_import.append(item) + + exec( # pylint: disable=exec-used # nosec B102 + "from transformers.trainer import (" + + ", ".join(x for x in items_to_import) + + ")", + globals(), + ) + exec(training_loop, globals()) # pylint: disable=exec-used # nosec B102 + LOG.info("patching _inner_training_loop for fsdp optimizer save") + Trainer._inner_training_loop = ( # pylint: disable=protected-access + _fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821 + ) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 99095c1bfc..a350f24295 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -386,6 +386,12 @@ def apply_patches(self) -> None: ) patch_training_loop_for_fsdp() + elif self.cfg.deepspeed: + from axolotl.monkeypatch.trainer_grad_accum import ( + patch_training_loop_for_deepspeed_0_16_x, + ) + + patch_training_loop_for_deepspeed_0_16_x() if self.cfg.gradient_checkpointing == "unsloth": transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper