diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index e524e0203..a1c9e9b28 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -431,7 +431,7 @@ def backward(ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor if req_gradA: if state.CB is not None: CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) - grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) + grad_A = torch.matmul(grad_output.to(ctx.dtype_A), CB).view(ctx.grad_shape) else: raise Exception("State must contain CB matrix for backward")