diff --git a/src/axolotl/utils/gradient_checkpointing/unsloth.py b/src/axolotl/utils/gradient_checkpointing/unsloth.py index fbe8346be2..7a14614b18 100644 --- a/src/axolotl/utils/gradient_checkpointing/unsloth.py +++ b/src/axolotl/utils/gradient_checkpointing/unsloth.py @@ -14,6 +14,16 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch +from packaging import version + +torch_version = version.parse(torch.__version__) + +if torch_version < version.parse("2.4.0"): + torch_cuda_amp_custom_fwd = torch.cuda.amp.custom_fwd + torch_cuda_amp_custom_bwd = torch.cuda.amp.custom_bwd +else: + torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda") + torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda") class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name @@ -25,7 +35,7 @@ class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name """ @staticmethod - @torch.cuda.amp.custom_fwd + @torch_cuda_amp_custom_fwd def forward(ctx, forward_function, hidden_states, *args): saved_hidden_states = hidden_states.to("cpu", non_blocking=True) with torch.no_grad(): @@ -36,7 +46,7 @@ def forward(ctx, forward_function, hidden_states, *args): return output @staticmethod - @torch.cuda.amp.custom_bwd + @torch_cuda_amp_custom_bwd def backward(ctx, dY): (hidden_states,) = ctx.saved_tensors hidden_states = hidden_states.to("cuda", non_blocking=True).detach()