From 342935cff35d70a260c428d1b419a2b5d8989a99 Mon Sep 17 00:00:00 2001 From: Sunny Liu Date: Wed, 13 Nov 2024 15:17:34 -0500 Subject: [PATCH] Update unsloth for torch.cuda.amp deprecation (#2042) * update deprecated unsloth tirch cuda amp decorator * WIP fix torch.cuda.amp deprecation * lint * laxing torch version requirement * remove use of partial * remove use of partial * lint --------- Co-authored-by: sunny --- .../utils/gradient_checkpointing/unsloth.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/gradient_checkpointing/unsloth.py b/src/axolotl/utils/gradient_checkpointing/unsloth.py index fbe8346be2..7a14614b18 100644 --- a/src/axolotl/utils/gradient_checkpointing/unsloth.py +++ b/src/axolotl/utils/gradient_checkpointing/unsloth.py @@ -14,6 +14,16 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch +from packaging import version + +torch_version = version.parse(torch.__version__) + +if torch_version < version.parse("2.4.0"): + torch_cuda_amp_custom_fwd = torch.cuda.amp.custom_fwd + torch_cuda_amp_custom_bwd = torch.cuda.amp.custom_bwd +else: + torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda") + torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda") class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name @@ -25,7 +35,7 @@ class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name """ @staticmethod - @torch.cuda.amp.custom_fwd + @torch_cuda_amp_custom_fwd def forward(ctx, forward_function, hidden_states, *args): saved_hidden_states = hidden_states.to("cpu", non_blocking=True) with torch.no_grad(): @@ -36,7 +46,7 @@ def forward(ctx, forward_function, hidden_states, *args): return output @staticmethod - @torch.cuda.amp.custom_bwd + @torch_cuda_amp_custom_bwd def backward(ctx, dY): (hidden_states,) = ctx.saved_tensors hidden_states = hidden_states.to("cuda", non_blocking=True).detach()