Skip to content

Commit

Permalink
update code
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuhong61 committed Oct 28, 2024
1 parent f0e22ec commit d30c026
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
4 changes: 2 additions & 2 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState]
# 1. Dequantize
# 2. MatmulnN
print("*******quant_state absmax: ", quant_state.absmax)
output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype), bias)
output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)

# 3. Save state
ctx.state = quant_state
Expand Down Expand Up @@ -549,7 +549,7 @@ def backward(ctx, grad_output):
# not supported by PyTorch. TODO: create work-around
# if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
if req_gradA:
grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype))
grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t())

return grad_A, grad_B, None, grad_bias, None

Expand Down
4 changes: 4 additions & 0 deletions bitsandbytes/backends/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ def dequantize_4bit(
if blocksize is None:
blocksize = 64
assert_on_xpu([A, absmax, out])
# result = dequantize_4bit_impl(A, quant_state, absmax, out, blocksize, quant_type)
# print("+++++++++result: ", result)
# return dequantize_4bit_impl(A, quant_state, absmax, out, blocksize, quant_type)
print("------A device: ", A.device)
print("------quant_state device: ", quant_state.shape[0])
Expand All @@ -161,6 +163,8 @@ def dequantize_4bit(
None,
blocksize
)
output_dq = output_dq.t()
print("=====output_dq: ", output_dq)
return output_dq

def gemv_4bit(
Expand Down

0 comments on commit d30c026

Please sign in to comment.