Skip to content

Commit

Permalink
fix: Deprecation Warnings in AutoCast API (#113)
Browse files Browse the repository at this point in the history
* fix: warning of torch.amp.custom_bwd

Signed-off-by: Abhishek <[email protected]>

* fix: warning of torch.amp.custom_bwd

Signed-off-by: Abhishek <[email protected]>

* fix: warning of torch.amp.custom_bwd

Signed-off-by: Abhishek <[email protected]>

* fix: warning of torch.amp.custom_bwd

Signed-off-by: Abhishek <[email protected]>

* fix: warning of torch.amp.custom_bwd

Signed-off-by: Abhishek <[email protected]>

* fix: fmt, lint

Signed-off-by: Abhishek <[email protected]>

* fix: warning of torch.amp.custom_bwd

Signed-off-by: Abhishek <[email protected]>

---------

Signed-off-by: Abhishek <[email protected]>
  • Loading branch information
Abhishek-TAMU authored Dec 2, 2024
1 parent e7a0e2f commit c70ffe0
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import itertools

# Third Party
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.amp import custom_bwd, custom_fwd
import torch
import triton
import triton.language as tl
Expand Down Expand Up @@ -140,15 +140,15 @@ def quant_matmul_248(

class QuantLinearFunction(torch.autograd.Function):
@staticmethod
@custom_fwd
@custom_fwd(device_type="cuda")
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
output = quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq)
ctx.save_for_backward(qweight, scales, qzeros, g_idx)
ctx.bits, ctx.maxq = bits, maxq
return output

@staticmethod
@custom_bwd
@custom_bwd(device_type="cuda")
def backward(ctx, grad_output):
qweight, scales, qzeros, g_idx = ctx.saved_tensors
bits, maxq = ctx.bits, ctx.maxq
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
import torch

# Local
from .utils import lora_adapters_switch_ddp_from_fsdp
from .models.utils import filter_mp_rules
from .utils import lora_adapters_switch_ddp_from_fsdp


# consider rewriting register_foak_model_patch_rules into something
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class LoRA_MLP(torch.autograd.Function):
Don't forget to see our blog post for more details!
"""
@staticmethod
@torch.cuda.amp.custom_fwd
@torch.amp.custom_fwd(device_type='cuda')
def forward(ctx, X : torch.Tensor,
gateW, gateW_quant, gate_bias, gateA, gateB, gateS,
upW, upW_quant, up_bias, upA, upB, upS,
Expand Down Expand Up @@ -104,7 +104,7 @@ def forward(ctx, X : torch.Tensor,


@staticmethod
@torch.cuda.amp.custom_bwd
@torch.amp.custom_bwd(device_type='cuda')
def backward(ctx, dY : torch.Tensor):
gateW, gateW_quant, gateS, upW, upW_quant, upS, downW, downW_quant, downS, \
_backward_function = ctx.custom_saved_tensors
Expand Down Expand Up @@ -251,7 +251,7 @@ class LoRA_QKV(torch.autograd.Function):
dC/dBv = A.T @ X.T @ D(Wv)
"""
@staticmethod
@torch.cuda.amp.custom_fwd
@torch.amp.custom_fwd(device_type='cuda')
def forward(ctx, X : torch.Tensor,
QW, QW_quant, Q_bias, QA, QB, QS,
KW, KW_quant, K_bias, KA, KB, KS,
Expand Down Expand Up @@ -294,7 +294,7 @@ def forward(ctx, X : torch.Tensor,
pass

@staticmethod
@torch.cuda.amp.custom_bwd
@torch.amp.custom_bwd(device_type='cuda')
def backward(ctx, dQ, dK, dV):
QW, QW_quant, QS, KW, KW_quant, KS, VW, VW_quant, VS = \
ctx.custom_saved_tensors
Expand Down Expand Up @@ -404,7 +404,7 @@ class LoRA_W(torch.autograd.Function):
dC/dBv = A.T @ X.T @ D(Wv)
"""
@staticmethod
@torch.cuda.amp.custom_fwd
@torch.amp.custom_fwd(device_type='cuda')
def forward(ctx, X : torch.Tensor,
W, W_quant, bias, A, B, S, dropout_O):
dtype = X.dtype
Expand All @@ -423,7 +423,7 @@ def forward(ctx, X : torch.Tensor,
pass

@staticmethod
@torch.cuda.amp.custom_bwd
@torch.amp.custom_bwd(device_type='cuda')
def backward(ctx, dY : torch.Tensor):
W, W_quant, S = ctx.custom_saved_tensors
A, B, X, OX = ctx.saved_tensors
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@

# with modifications from The IBM Tuning Team

import math
from dataclasses import dataclass
from logging import getLogger
from typing import Optional

import torch
from torch.cuda.amp import custom_bwd, custom_fwd

from .triton.kernels import dequant248
from ..swiglu import swiglu_DWf_DW_dfg_kernel, swiglu_fg_kernel
Expand Down Expand Up @@ -213,7 +211,7 @@ class LoRA_MLP(torch.autograd.Function):
"""

@staticmethod
@torch.cuda.amp.custom_fwd
@torch.amp.custom_fwd(device_type='cuda')
def forward(
ctx,
X: torch.Tensor,
Expand Down Expand Up @@ -309,7 +307,7 @@ def forward(
return i

@staticmethod
@torch.cuda.amp.custom_bwd
@torch.amp.custom_bwd(device_type='cuda')
def backward(ctx, dY: torch.Tensor):
(
gate_qweight,
Expand Down Expand Up @@ -497,7 +495,7 @@ class LoRA_QKV(torch.autograd.Function):
"""

@staticmethod
@torch.cuda.amp.custom_fwd
@torch.amp.custom_fwd(device_type='cuda')
def forward(
ctx,
X: torch.Tensor,
Expand Down Expand Up @@ -591,7 +589,7 @@ def forward(
return Q, K, V

@staticmethod
@torch.cuda.amp.custom_bwd
@torch.amp.custom_bwd(device_type='cuda')
def backward(ctx, dQ, dK, dV):
(
Q_qweight,
Expand Down Expand Up @@ -770,7 +768,7 @@ class LoRA_W(torch.autograd.Function):
"""

@staticmethod
@torch.cuda.amp.custom_fwd
@torch.amp.custom_fwd(device_type='cuda')
def forward(
ctx,
X: torch.Tensor,
Expand Down Expand Up @@ -807,7 +805,7 @@ def forward(
return XW

@staticmethod
@torch.cuda.amp.custom_bwd
@torch.amp.custom_bwd(device_type='cuda')
def backward(ctx, dY: torch.Tensor):
O_qweight, O_scales, O_qzeros, O_g_idx, O_bits, S = ctx.custom_saved_tensors
A, B, X, OX = ctx.saved_tensors
Expand Down

0 comments on commit c70ffe0

Please sign in to comment.