Skip to content

Commit

Permalink
fix: warning of torch.amp.custom_bwd
Browse files Browse the repository at this point in the history
Signed-off-by: Abhishek <[email protected]>
  • Loading branch information
Abhishek-TAMU committed Nov 28, 2024
1 parent a44f96f commit 149eaaa
Showing 1 changed file with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -140,15 +140,15 @@ def quant_matmul_248(

class QuantLinearFunction(torch.autograd.Function):
@staticmethod
@custom_fwd(device_type='cuda')
@custom_fwd(device_type="cuda")
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
output = quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq)
ctx.save_for_backward(qweight, scales, qzeros, g_idx)
ctx.bits, ctx.maxq = bits, maxq
return output

@staticmethod
@custom_bwd(device_type='cuda')
@custom_bwd(device_type="cuda")
def backward(ctx, grad_output):
qweight, scales, qzeros, g_idx = ctx.saved_tensors
bits, maxq = ctx.bits, ctx.maxq
Expand Down

0 comments on commit 149eaaa

Please sign in to comment.