From 6319da1f9bdf47c27f9d8a516de0f8d93f93b767 Mon Sep 17 00:00:00 2001
From: Wing Lian <wing.lian@gmail.com>
Date: Tue, 16 Apr 2024 14:53:57 -0400
Subject: [PATCH] Unsloth gradient checkpointing offload (#1528)

* unsloth gradient checkpointing

* fix validation too

* fixes to make it work with mistral

* monkeypatch the checkpoint fn earlier
---
 .../monkeypatch/mistral_attn_hijack_flash.py  | 30 +++++------
 .../config/models/input/v0_4_1/__init__.py    |  5 +-
 .../utils/gradient_checkpointing/__init__.py  | 13 +++++
 .../utils/gradient_checkpointing/unsloth.py   | 52 +++++++++++++++++++
 src/axolotl/utils/models.py                   |  5 ++
 5 files changed, 86 insertions(+), 19 deletions(-)
 create mode 100644 src/axolotl/utils/gradient_checkpointing/__init__.py
 create mode 100644 src/axolotl/utils/gradient_checkpointing/unsloth.py

diff --git a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py
index 8e43da1110..6ae2e75fa2 100644
--- a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py
+++ b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py
@@ -516,24 +516,18 @@ def mistral_model_forward(
         past_key_value = past_key_values[idx] if past_key_values is not None else None
 
         if self.gradient_checkpointing and self.training:
-
-            def create_custom_forward(module):
-                def custom_forward(*inputs):
-                    # None for past_key_value
-                    return module(*inputs)
-
-                return custom_forward
-
-            layer_outputs = torch.utils.checkpoint.checkpoint(
-                create_custom_forward(decoder_layer),
-                hidden_states,
-                attention_mask,
-                position_ids,
-                past_key_value,
-                output_attentions,
-                None,
-                cu_seqlens,
-                max_seqlen,
+            layer_outputs = (
+                self._gradient_checkpointing_func(  # pylint: disable=protected-access
+                    decoder_layer.__call__,
+                    hidden_states,
+                    attention_mask,
+                    position_ids,
+                    past_key_value,
+                    output_attentions,
+                    None,
+                    cu_seqlens,
+                    max_seqlen,
+                )
             )
         else:
             layer_outputs = decoder_layer(
diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py
index 0fbed08ca3..d99155ac25 100644
--- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py
+++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py
@@ -479,6 +479,7 @@ class Config:
     eval_causal_lm_metrics: Optional[List[str]] = None
     do_bench_eval: Optional[bool] = None
     bench_dataset: Optional[str] = None
+    bench_split: Optional[str] = None
     metric_for_best_model: Optional[str] = None
     greater_is_better: Optional[bool] = None
 
@@ -494,7 +495,9 @@ class Config:
 
     # torch_dtype: Optional[torch.dtype]
 
-    gradient_checkpointing: Optional[bool] = Field(default=False)
+    gradient_checkpointing: Optional[Union[Literal["unsloth"], bool]] = Field(
+        default=False
+    )
     gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
 
     unfrozen_parameters: Optional[List[str]] = None
diff --git a/src/axolotl/utils/gradient_checkpointing/__init__.py b/src/axolotl/utils/gradient_checkpointing/__init__.py
new file mode 100644
index 0000000000..4639fc266c
--- /dev/null
+++ b/src/axolotl/utils/gradient_checkpointing/__init__.py
@@ -0,0 +1,13 @@
+"""custom checkpointing utils"""
+from axolotl.utils.gradient_checkpointing.unsloth import (
+    Unsloth_Offloaded_Gradient_Checkpointer,
+)
+
+
+def hf_grad_checkpoint_unsloth_wrapper(
+    decoder_layer, *args, use_reentrant=None
+):  # pylint: disable=unused-argument
+    return Unsloth_Offloaded_Gradient_Checkpointer.apply(
+        decoder_layer.__self__,
+        *args,
+    )
diff --git a/src/axolotl/utils/gradient_checkpointing/unsloth.py b/src/axolotl/utils/gradient_checkpointing/unsloth.py
new file mode 100644
index 0000000000..fbe8346be2
--- /dev/null
+++ b/src/axolotl/utils/gradient_checkpointing/unsloth.py
@@ -0,0 +1,52 @@
+"""Unsloth checkpointing"""
+
+# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+
+
+class Unsloth_Offloaded_Gradient_Checkpointer(  # pylint: disable=invalid-name
+    torch.autograd.Function
+):
+    """
+    Saves VRAM by smartly offloading to RAM.
+    Tiny hit to performance, since we mask the movement via non blocking calls.
+    """
+
+    @staticmethod
+    @torch.cuda.amp.custom_fwd
+    def forward(ctx, forward_function, hidden_states, *args):
+        saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
+        with torch.no_grad():
+            output = forward_function(hidden_states, *args)
+        ctx.save_for_backward(saved_hidden_states)
+        ctx.forward_function = forward_function
+        ctx.args = args
+        return output
+
+    @staticmethod
+    @torch.cuda.amp.custom_bwd
+    def backward(ctx, dY):
+        (hidden_states,) = ctx.saved_tensors
+        hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
+        hidden_states.requires_grad = True
+        with torch.enable_grad():
+            (output,) = ctx.forward_function(hidden_states, *ctx.args)
+        torch.autograd.backward(output, dY)
+        return (
+            None,
+            hidden_states.grad,
+        ) + (
+            None,
+        ) * len(ctx.args)
diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py
index 0b15850518..52d8db047f 100644
--- a/src/axolotl/utils/models.py
+++ b/src/axolotl/utils/models.py
@@ -11,6 +11,7 @@
 import bitsandbytes as bnb
 import torch
 import transformers
+import transformers.modeling_utils
 from accelerate import init_empty_weights
 from bitsandbytes.nn import Params4bit
 from peft import (
@@ -44,6 +45,7 @@
 from axolotl.utils.chat_templates import chat_templates
 from axolotl.utils.dict import DictDefault
 from axolotl.utils.distributed import zero_only
+from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
 from axolotl.utils.lora_embeddings import get_linear_embedding_layers
 from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
 
@@ -310,6 +312,9 @@ def load_model(
     # TODO refactor as a kwarg
     load_in_8bit = cfg.load_in_8bit
 
+    if cfg.gradient_checkpointing == "unsloth":
+        transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
+
     if hasattr(model_config, "model_type") and model_config.model_type == "btlm":
         if cfg.flash_attention:
             from axolotl.monkeypatch.btlm_attn_hijack_flash import (