diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index e99526ea6..79ff48f44 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -518,7 +518,7 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState] # 1. Dequantize # 2. MatmulnN print("*******quant_state absmax: ", quant_state.absmax) - output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype), bias) + output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias) # 3. Save state ctx.state = quant_state @@ -549,7 +549,7 @@ def backward(ctx, grad_output): # not supported by PyTorch. TODO: create work-around # if req_gradB: grad_B = torch.matmul(grad_output.t(), A) if req_gradA: - grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype)) + grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t()) return grad_A, grad_B, None, grad_bias, None diff --git a/bitsandbytes/backends/xpu.py b/bitsandbytes/backends/xpu.py index 0305964d3..7c8497d48 100644 --- a/bitsandbytes/backends/xpu.py +++ b/bitsandbytes/backends/xpu.py @@ -149,6 +149,8 @@ def dequantize_4bit( if blocksize is None: blocksize = 64 assert_on_xpu([A, absmax, out]) + # result = dequantize_4bit_impl(A, quant_state, absmax, out, blocksize, quant_type) + # print("+++++++++result: ", result) # return dequantize_4bit_impl(A, quant_state, absmax, out, blocksize, quant_type) print("------A device: ", A.device) print("------quant_state device: ", quant_state.shape[0]) @@ -161,6 +163,8 @@ def dequantize_4bit( None, blocksize ) + output_dq = output_dq.t() + print("=====output_dq: ", output_dq) return output_dq def gemv_4bit(