Skip to content

Commit

Permalink
Add xpu support
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuhong61 committed Jun 25, 2024
1 parent c79b1e9 commit 7f43430
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 13 deletions.
2 changes: 1 addition & 1 deletion bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ def matmul_4bit(
bias=None,
):
assert quant_state is not None
if (A.numel() == A.shape[-1] or A.device.type == "cpu") and A.requires_grad == False:
if (A.numel() == A.shape[-1] or A.device.type == "cpu" or A.device.type == "xpu") and A.requires_grad == False:
# CPU backend does not require A to be a vector
if A.shape[-1] % quant_state.blocksize != 0:
warn(
Expand Down
9 changes: 6 additions & 3 deletions bitsandbytes/backends/cpu_xpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

ipex_cpu = ipex if ipex._C._has_cpu() else None
ipex_xpu = ipex if ipex._C._has_xpu() else None
ipex_cpu_only = ipex._C._has_cpu() and (not ipex._C._has_xpu())
except BaseException:
ipex_cpu = None
ipex_xpu = None
Expand Down Expand Up @@ -333,7 +334,7 @@ def quantize_4bit_impl(
scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1)
scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)
# map [-1, 1] to nf4/fp4
out_uint8 = torch.empty(scaled_A.shape, dtype=torch.uint8)
out_uint8 = torch.empty(scaled_A.shape, dtype=torch.uint8, device=A.device)
if quant_type == "nf4":
for i in range(len(NF4_QUANT_TABLE)):
out_uint8[scaled_A > NF4_QUANT_TABLE[i]] = i
Expand Down Expand Up @@ -361,7 +362,7 @@ def quantize_4bit_impl(
quant_type=quant_type,
)

if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and input_shape[1] % blocksize == 0 and quant_type == "nf4":
if ipex_cpu_only and _ipex_cpu_version_prereq(2, 3) and input_shape[1] % blocksize == 0 and quant_type == "nf4":
# lowp_mode: lowest precision for computation
lowp_mode = ipex_cpu.quantization.WoqLowpMode.BF16
state.op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack(
Expand Down Expand Up @@ -455,8 +456,10 @@ def dequantize_4bit_impl(
out_uint8 = torch.empty(A.size(0) * 2, dtype=torch.uint8, device=A.device)
out_uint8[::2] = A.bitwise_and(0xF)
out_uint8[1::2] = A.bitwise_right_shift(4)
out_dq = torch.empty(out_uint8.shape).to(quant_state.dtype)
out_dq = torch.empty(out_uint8.shape).to(quant_state.dtype).to(A.device)
for i in range(len(quant_state.code)):
# quant_state.code is fp32, cast to quant_state dtype to avoid the mismatch issue
quant_state.code = quant_state.code.to(quant_state.dtype)
out_dq[out_uint8 == i] = quant_state.code[i]

# Apply scales
Expand Down
80 changes: 72 additions & 8 deletions bitsandbytes/backends/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,28 @@
from bitsandbytes.utils import QuantState

from .base import Backend
from .cpu_xpu_common import (
dequantize_4bit_impl,
double_quant_impl,
gemm_4bit_impl,
igemmlt_impl,
mm_dequant_impl,
quantize_4bit_impl,
)

Tensor = torch.Tensor
def assert_on_xpu(tensors):
on_xpu = True
for t in tensors:
if t is None:
continue # NULL pointers are fine
on_xpu &= t.device.type == "xpu"
if not on_xpu:
raise TypeError(
"All input tensors need to be on CPU, but found some tensors to not be on XPU:\n"
f" {[(t.shape, t.device) if isinstance(t, Tensor) else None for t in tensors]}"
)
return on_xpu


class XPUBackend(Backend):
Expand All @@ -17,7 +39,8 @@ def double_quant(
out_row: Optional[torch.Tensor] = None,
threshold=0.0,
):
raise NotImplementedError
assert_on_xpu([A, col_stats, row_stats, out_col, out_row])
return double_quant_impl(A, col_stats, row_stats, out_col, out_row, threshold)

def transform(
self,
Expand All @@ -29,7 +52,23 @@ def transform(
state: Optional[Tuple[torch.Size, str]] = None,
ld=None,
):
raise NotImplementedError
"""
Transform tensor A to to_order. It is originally designed for CUDA.
For CPU, it returns the original tensor if transpose=False.
Otherwise, it returns the transpose of A
"""
assert_on_xpu([A, out])
if transpose:
if out is not None:
out.copy_(A.T)
else:
out = A.T
else:
if out is not None:
out.copy_(A)
else:
out = A
return out, state

def igemmlt(
self,
Expand All @@ -41,7 +80,8 @@ def igemmlt(
Sout: Optional[Tuple[torch.Size, str]] = None,
dtype=torch.int32,
) -> Union[torch.Tensor, Tuple[Optional[Tuple[torch.Tensor, Tuple[torch.Size, str]]]]]:
raise NotImplementedError
assert_on_xpu([A, B])
return igemmlt_impl(A, B, SA, SB, out, Sout, dtype)

def mm_dequant(
self,
Expand All @@ -54,15 +94,29 @@ def mm_dequant(
new_col_stats: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
assert_on_xpu([A, row_stats, col_stats, out, bias])
return mm_dequant_impl(
A,
quant_state,
row_stats,
col_stats,
out,
new_row_stats,
new_col_stats,
bias,
self.mm_dequant_compute_dtype,
self.mm_dequant_output_dtype,
)

def extract_outliers(
self,
A: torch.Tensor,
SA: Tuple[torch.Size, str],
idx: torch.Tensor,
) -> torch.Tensor:
raise NotImplementedError
assert_on_xpu([A])
return A[:, idx].contiguous()


def quantize_4bit(
self,
Expand All @@ -74,7 +128,11 @@ def quantize_4bit(
quant_type: Literal["fp4", "nf4"] = "fp4",
quant_storage=torch.uint8,
) -> Tuple[torch.Tensor, QuantState]:
raise NotImplementedError
if blocksize is None:
blocksize = 64
assert_on_xpu([A, absmax, out])
assert quant_storage == torch.uint8, "CPU backend only supports uint8 quant_storage"
return quantize_4bit_impl(A, absmax, out, blocksize, compress_statistics, quant_type)

def dequantize_4bit(
self,
Expand All @@ -85,7 +143,10 @@ def dequantize_4bit(
blocksize: int = 64,
quant_type: Literal["fp4", "nf4"] = "fp4",
) -> torch.Tensor:
raise NotImplementedError
if blocksize is None:
blocksize = 64
assert_on_xpu([A, absmax, out])
return dequantize_4bit_impl(A, quant_state, absmax, out, blocksize, quant_type)

def gemv_4bit(
self,
Expand All @@ -96,7 +157,10 @@ def gemv_4bit(
transposed_B=False,
state: QuantState = None,
) -> torch.Tensor:
raise NotImplementedError
assert_on_xpu([A, B, out])
if state is None:
raise ValueError("state cannot be None. gemv_4bit() requires the state from quantize_4bit()")
return gemm_4bit_impl(A, B, out, transposed_A, transposed_B, state)

def dequantize_blockwise(
self,
Expand Down
18 changes: 17 additions & 1 deletion bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,9 @@ def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: b
def cpu(self, non_blocking: bool = False):
return self.to(device="cpu", non_blocking=non_blocking)

def xpu(self, non_blocking: bool = False):
return self.to(device="xpu", non_blocking=non_blocking)

@overload
def to(
self: T,
Expand All @@ -327,7 +330,7 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)

if device is not None and device.type in ["cuda", "cpu"] and not self.bnb_quantized:
if device is not None and device.type in ["cuda", "cpu", "xpu"] and not self.bnb_quantized:
return self._quantize(device)
else:
if self.quant_state is not None:
Expand Down Expand Up @@ -605,6 +608,19 @@ def cpu(self):
self.SCB = SCB
return self

def xpu(self):
# we store the 8-bit rows-major weight
B = self.data.contiguous().bfloat16().cpu()
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
if CBt is not None:
del CBt
if SCBt is not None:
del SCBt
self.data = CB
self.CB = CB
self.SCB = SCB
return self

@overload
def to(
self: T,
Expand Down

0 comments on commit 7f43430

Please sign in to comment.