Skip to content

Commit

Permalink
Unsloth gradient checkpointing offload (#1528)
Browse files Browse the repository at this point in the history
* unsloth gradient checkpointing

* fix validation too

* fixes to make it work with mistral

* monkeypatch the checkpoint fn earlier
  • Loading branch information
winglian authored Apr 16, 2024
1 parent 132eb74 commit 6319da1
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 19 deletions.
30 changes: 12 additions & 18 deletions src/axolotl/monkeypatch/mistral_attn_hijack_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,24 +516,18 @@ def mistral_model_forward(
past_key_value = past_key_values[idx] if past_key_values is not None else None

if self.gradient_checkpointing and self.training:

def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs)

return custom_forward

layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
None,
cu_seqlens,
max_seqlen,
layer_outputs = (
self._gradient_checkpointing_func( # pylint: disable=protected-access
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
None,
cu_seqlens,
max_seqlen,
)
)
else:
layer_outputs = decoder_layer(
Expand Down
5 changes: 4 additions & 1 deletion src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,7 @@ class Config:
eval_causal_lm_metrics: Optional[List[str]] = None
do_bench_eval: Optional[bool] = None
bench_dataset: Optional[str] = None
bench_split: Optional[str] = None
metric_for_best_model: Optional[str] = None
greater_is_better: Optional[bool] = None

Expand All @@ -494,7 +495,9 @@ class Config:

# torch_dtype: Optional[torch.dtype]

gradient_checkpointing: Optional[bool] = Field(default=False)
gradient_checkpointing: Optional[Union[Literal["unsloth"], bool]] = Field(
default=False
)
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None

unfrozen_parameters: Optional[List[str]] = None
Expand Down
13 changes: 13 additions & 0 deletions src/axolotl/utils/gradient_checkpointing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""custom checkpointing utils"""
from axolotl.utils.gradient_checkpointing.unsloth import (
Unsloth_Offloaded_Gradient_Checkpointer,
)


def hf_grad_checkpoint_unsloth_wrapper(
decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
decoder_layer.__self__,
*args,
)
52 changes: 52 additions & 0 deletions src/axolotl/utils/gradient_checkpointing/unsloth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Unsloth checkpointing"""

# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch


class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
torch.autograd.Function
):
"""
Saves VRAM by smartly offloading to RAM.
Tiny hit to performance, since we mask the movement via non blocking calls.
"""

@staticmethod
@torch.cuda.amp.custom_fwd
def forward(ctx, forward_function, hidden_states, *args):
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
with torch.no_grad():
output = forward_function(hidden_states, *args)
ctx.save_for_backward(saved_hidden_states)
ctx.forward_function = forward_function
ctx.args = args
return output

@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, dY):
(hidden_states,) = ctx.saved_tensors
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
hidden_states.requires_grad = True
with torch.enable_grad():
(output,) = ctx.forward_function(hidden_states, *ctx.args)
torch.autograd.backward(output, dY)
return (
None,
hidden_states.grad,
) + (
None,
) * len(ctx.args)
5 changes: 5 additions & 0 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import bitsandbytes as bnb
import torch
import transformers
import transformers.modeling_utils
from accelerate import init_empty_weights
from bitsandbytes.nn import Params4bit
from peft import (
Expand Down Expand Up @@ -44,6 +45,7 @@
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import zero_only
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant

Expand Down Expand Up @@ -310,6 +312,9 @@ def load_model(
# TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit

if cfg.gradient_checkpointing == "unsloth":
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper

if hasattr(model_config, "model_type") and model_config.model_type == "btlm":
if cfg.flash_attention:
from axolotl.monkeypatch.btlm_attn_hijack_flash import (
Expand Down

0 comments on commit 6319da1

Please sign in to comment.