Skip to content

Commit

Permalink
Enable XPU and optimize cpu/xpu op (#1418)
Browse files Browse the repository at this point in the history
* enable new ipex API

ipex weight is 4D so we cannot transpose

fix dequant

check require grad

* use ipex op in backward

* enable backward

* Multi backend refactor (#8)

* AMD: Clarify diagnostic messages; free up disk space for CI build

* Add build job for rocm

* Add rocm build script

* Copy shared obj file into output_dir

* upload build artifacts and enable wheels build

* Remove cuda build temporarily

* Add ROCm version to .so filename

* Add rocm_version to whls build

* Revert "Remove cuda build temporarily"

This reverts commit 1413c5f.

* Add rocm_version env var

* Remove thrush header files

* Print node info

* print cuda node info

* Revert "print cuda node info"

This reverts commit cdb209a.

* Revert "Print node info"

This reverts commit 7e9a65c.

* Add rocm arch to compile command

* Rename .so files to rocm

* Update default gpu arch

* Skip cpu based igemmlt int tests on ROCm

* Update Documentation

* Update upstream repo name

* Update docs

* Update string format

Co-authored-by: Aarni Koskela <[email protected]>

* Remove pre-release option for torch install

* Update pytorch install path

Co-authored-by: Titus <[email protected]>

* Add messages for Heuristics error

* Remove toolcache for disk space

* print disk usage

* Clean disk space for linux

* Fix for ubuntu

* Add sudo for apt clean

* Update clean up disk list

* remove disk usage print

* Add BNB_BACKEND variable

* Update diagnostic functions for ROCm

* Fix tuple error

* Fix library detection bug for recursive and symlink cases

* fix pre-commit errors

* Remove recursive path lib search

* Create function for runtime lib patterns

* Update logger format

Co-authored-by: Aarni Koskela <[email protected]>

* Update error reporting

Co-authored-by: Aarni Koskela <[email protected]>

* Remove commented code

Co-authored-by: Aarni Koskela <[email protected]>

* Update error reporting

Co-authored-by: Aarni Koskela <[email protected]>

* Update error reporting

* Create hip diagnostics functions

* Fix Typo

* Fix pre-commit checks

---------

Co-authored-by: Aarni Koskela <[email protected]>
Co-authored-by: Titus <[email protected]>

* check grad before using ipex (#1358)

* Enable packaging for ROCm 6.2 (#1367)

* Enable 6.2 build

* Update documentation for 6.2.0 pip install

* Update for VS2022 17.11 compatibility with CUDA < 12.4 (#1341)

* Update for VS2022 17.11 compatibility with CUDA < 12.4

* Try again

* Enable continuous releases for multi-backend-refactor branch

* Update release workflow

* Publish continuous release for multi-backend

* continuous release: revert wheel renaming due to install err

* Revert "continuous release: revert wheel renaming due to install err"

This reverts commit 0a2b539.

* add dynamic tag-based versioning + git hash for dev vers

* docs: update w/ changes from `main`

* get tags for dynamic versioning

* fine-tune continuous release params

* reduce the pkg size + build times for the preview release

* refine docs for multi-backend alpha release (#1380)

* refine docs for multi-backend alpha release

* docs: further tweaks to multi-backend alpha docs

* docs: further tweaks to multi-backend alpha docs

* docs: further tweaks to multi-backend alpha docs

* docs: add multi-backend feedback links

* docs: add request for contributions

* docs: small fixes

* docs: small fixes

* docs: add info about `main` continuous build

* docs: further tweaks to multi-backend alpha docs

* docs: further tweaks to multi-backend alpha docs

* docs: remove 2 obsolete lines

---------

Co-authored-by: pnunna93 <[email protected]>
Co-authored-by: Aarni Koskela <[email protected]>
Co-authored-by: Titus <[email protected]>
Co-authored-by: Matthew Douglas <[email protected]>

* Revert "enable backward"

This reverts commit cd7bf21.

* Revert "use ipex op in backward"

This reverts commit b8df1aa.

* fix finetune

* check training

* fix gemv check

* reformat

* avoid double quant in backward if not needed

* Zh/xpu support (#9)

* Add xpu support

* Add xpu support for int8

* Add xpu dequant kernel support

* update code

* remove debug comments

* remove redundant comments

* Add xpu integration for woqlinear

* correct the comments

* Update cpu_xpu_common.py

---------

Co-authored-by: zhuhong61 <[email protected]>
Co-authored-by: zhuhong61 <[email protected]>

* avoid import triton if CPU and XPU backend

* fix setup in docker without git config

* xpu do not support compile for now

Signed-off-by: jiqing-feng <[email protected]>

* update xpu

Signed-off-by: jiqing-feng <[email protected]>

* update 4bit compute dtype

* fix xpu int8 path

Signed-off-by: jiqing-feng <[email protected]>

* optimize 4bit dequant

Signed-off-by: jiqing-feng <[email protected]>

* fix xpu dequant

Signed-off-by: jiqing-feng <[email protected]>

* add empty cache in each xpu op

* add nf4 dequant ipex kernel

* fix dequant 4bit op

* empty cache has negative effect on 4bit gemv

* fix xpu save

* fix save

* xpu use float16 default

Signed-off-by: jiqing-feng <[email protected]>

* rm empty cache as it cause slower perf

Signed-off-by: jiqing-feng <[email protected]>

* fix xpu save

Signed-off-by: jiqing-feng <[email protected]>

* fix 8bit int8 param device

Signed-off-by: jiqing-feng <[email protected]>

* fix 8bit int8 param device

Signed-off-by: jiqing-feng <[email protected]>

* fix 8bit int8 param device

Signed-off-by: jiqing-feng <[email protected]>

* fix 8bit int8 param device

Signed-off-by: jiqing-feng <[email protected]>

* fix format

* update readme for Intel CPU and XPU do not need make csrc codes

* fix format

* fix import

---------

Signed-off-by: jiqing-feng <[email protected]>
Co-authored-by: pnunna93 <[email protected]>
Co-authored-by: Aarni Koskela <[email protected]>
Co-authored-by: Titus <[email protected]>
Co-authored-by: Matthew Douglas <[email protected]>
Co-authored-by: zhuhong61 <[email protected]>
Co-authored-by: zhuhong61 <[email protected]>
  • Loading branch information
7 people authored Nov 29, 2024
1 parent cd73601 commit b2ac423
Show file tree
Hide file tree
Showing 10 changed files with 246 additions and 101 deletions.
8 changes: 6 additions & 2 deletions bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@
matmul_cublas,
mm_cublas,
)
from .backends import register_backend
from .backends import backends, register_backend
from .backends.cpu import CPUBackend
from .backends.npu import NPUBackend
from .cextension import lib
from .nn import modules

features = {"multi_backend"}
supported_torch_devices = {
Expand Down Expand Up @@ -64,6 +63,11 @@
if hasattr(torch, "npu") and torch.npu.is_available():
register_backend("npu", NPUBackend())


# import module after decided backends
if backends:
from .nn import modules

# TODO: Other potential backends:
# XLA - Google TPU / PJRT runtime
# HPU - Habana / Intel Gaudi
Expand Down
17 changes: 13 additions & 4 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def backward(ctx, grad_output):

def supports_igemmlt(device: torch.device) -> bool:
"""check if this device supports the optimized int8 kernel"""
if device == torch.device("cpu"):
if device == torch.device("cpu") or torch.device("xpu"):
return True
if torch.version.hip:
return False if BNB_HIP_VERSION < 601 else True
Expand Down Expand Up @@ -463,7 +463,9 @@ def backward(ctx, grad_output):
if len(grad_output.shape) == 3:
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()

Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = None, None, None, None, None
if req_gradB or (req_gradA and state.CBt):
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
if req_gradB:
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
Expand Down Expand Up @@ -575,8 +577,15 @@ def matmul_4bit(
bias=None,
):
assert quant_state is not None
if (A.numel() == A.shape[-1] or A.device.type == "cpu") and A.requires_grad == False:
# CPU backend does not require A to be a vector
if A.device.type in ("cpu", "xpu") and A.requires_grad == False:
if getattr(quant_state, "ipex", False):
out = F.gemv_4bit(A, B.t(), out, state=quant_state)
if bias is not None:
out += bias
return out
else:
return MatMul4Bit.apply(A, B, out, bias, quant_state)
elif A.numel() == A.shape[-1] and A.requires_grad == False:
if A.shape[-1] % quant_state.blocksize != 0:
warn(
f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}",
Expand Down
70 changes: 42 additions & 28 deletions bitsandbytes/backends/cpu_xpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

ipex_cpu = ipex if ipex._C._has_cpu() else None
ipex_xpu = ipex if ipex._C._has_xpu() else None
ipex_cpu_only = ipex._C._has_cpu() and (not ipex._C._has_xpu())
except BaseException:
ipex_cpu = None
ipex_xpu = None
Expand Down Expand Up @@ -55,7 +56,7 @@ def _ipex_xpu_version_prereq(major, minor):

def _maybe_torch_compile(func):
# torch.compile requires g++ and pytorch >= 2.0
if gxx_available and _torch_version_prereq(2, 0):
if gxx_available and _torch_version_prereq(2, 0) and not ipex_xpu:
options = {}
# fx_graph_cache requires pytorch >= 2.2
if _torch_version_prereq(2, 2):
Expand Down Expand Up @@ -181,7 +182,7 @@ def igemmlt_impl(A, B, SA=None, SB=None, out=None, Sout=None, dtype=torch.int32)
A_reshaped = A.reshape(m, k)

# torch._int_mm is available on CPU since torch 2.4
if _torch_version_prereq(2, 4):
if _torch_version_prereq(2, 4) and A.device.type == "cpu":
C = torch._int_mm(A_reshaped, B.T).to(dtype)
else:
C = torch.matmul(A_reshaped.float(), B.t().float()).to(dtype)
Expand Down Expand Up @@ -233,8 +234,10 @@ def mm_dequant_impl(
out_shape = (out_shape[0] * out_shape[1], out_shape[2])

if compute_dtype not in [torch.float32, torch.bfloat16]:
warnings.warn(f"mm_dequant_{A.device}: compute_dtype {compute_dtype} is not supported, will use float instead")
compute_dtype = torch.float32
warnings.warn(
f"mm_dequant_{A.device}: compute_dtype {compute_dtype} is not supported, will use bfloat16 instead"
)
compute_dtype = torch.bfloat16
A_reshaped = A.reshape(out_shape).to(compute_dtype)
row_stats = row_stats.reshape(-1).unsqueeze(-1).to(compute_dtype)
col_stats = col_stats.reshape(-1).unsqueeze(0).to(compute_dtype)
Expand Down Expand Up @@ -342,7 +345,7 @@ def quantize_4bit_impl(
scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1)
scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)
# map [-1, 1] to nf4/fp4
out_uint8 = torch.empty(scaled_A.shape, dtype=torch.uint8)
out_uint8 = torch.empty(scaled_A.shape, dtype=torch.uint8, device=A.device)
if quant_type == "nf4":
for i in range(len(NF4_QUANT_TABLE)):
out_uint8[scaled_A > NF4_QUANT_TABLE[i]] = i
Expand Down Expand Up @@ -408,7 +411,6 @@ def dequantize_4bit_impl(
torch.Tensor:
Dequantized tensor.
"""

if A.shape[0] == 1:
transpose = False
A = A.squeeze(0)
Expand Down Expand Up @@ -438,23 +440,18 @@ def dequantize_4bit_impl(
if quant_state.nested:
raise NotImplementedError("bnb_4bit_use_double_quant is not supported yet for CPU/XPU")

if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and hasattr(quant_state, "op_context"):
assert quant_state.op_context is not None
A = quant_state.op_context.to_public(quant_state.op_context.get_weight())
A = A.reshape(-1)
absmax = quant_state.op_context.get_scales().reshape(-1)

if out is None:
out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device)
if ipex_cpu_only and _ipex_cpu_version_prereq(2, 5) and getattr(quant_state, "ipex", False):
A = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", quant_state.shape, 2)
quant_state.ipex = False

n = out.numel()
# Map nf4 to [-1, 1]
out_uint8 = torch.empty(A.size(0) * 2, dtype=torch.uint8, device=A.device)
out_uint8[::2] = A.bitwise_and(0xF)
out_uint8[1::2] = A.bitwise_right_shift(4)
out_dq = torch.empty(out_uint8.shape).to(quant_state.dtype)
for i in range(len(quant_state.code)):
out_dq[out_uint8 == i] = quant_state.code[i]
out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device)
n = out_dq.numel()
out_dq[::2] = A & 0xF
out_dq[1::2] = A >> 4
# quant_state.code is fp32, cast to quant_state dtype to avoid the mismatch issue
quant_state.code = quant_state.code.to(quant_state.dtype)
out_dq = quant_state.code[out_dq]

# Apply scales
if out_dq.numel() != n:
Expand All @@ -464,12 +461,17 @@ def dequantize_4bit_impl(
blocks += 1 if n % blocksize > 0 else 0
rem = n % blocksize
has_rem = rem > 0
out_reshaped = out.reshape(-1)
out_reshaped[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape(
-1
)

if has_rem:
if out is None:
out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device)
out_reshaped = out.reshape(-1)
out_reshaped[: n - rem] = (
out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)
).reshape(-1)
out_reshaped[n - rem :] = out_dq[n - rem :] * absmax[-1]
else:
out = (out_dq.view(-1, blocksize) * absmax.view(-1, 1)).reshape(quant_state.shape).to(quant_state.dtype)

# take transpose here because weight is transposed (again) for computation
if transpose:
Expand Down Expand Up @@ -510,9 +512,21 @@ def gemm_4bit_impl(
torch.Tensor:
GEMM output tensor.
"""
if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and hasattr(state, "op_context"):
assert state.op_context is not None
output = torch.ops.torch_ipex.ipex_woq_linear(A, state.op_context.get_data_handle())
if getattr(state, "ipex", False):
output = torch.ops.torch_ipex.woq_linear(
A,
B,
"nf4",
state.shape,
state.new_scales,
state.new_zeros,
None,
None,
state.blocksize,
ipex_cpu.quantization.WoqLowpMode.BF16,
1,
state.compensation,
)
else:
dqB = dequantize_4bit_impl(B, state, blocksize=state.blocksize).t()
output = torch.matmul(A, dqB.to(A.dtype))
Expand Down
95 changes: 87 additions & 8 deletions bitsandbytes/backends/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,36 @@
from bitsandbytes.utils import QuantState

from .base import Backend
from .cpu_xpu_common import (
dequantize_4bit_impl,
double_quant_impl,
gemm_4bit_impl,
igemmlt_impl,
mm_dequant_impl,
quantize_4bit_impl,
)

Tensor = torch.Tensor


def assert_on_xpu(tensors):
on_xpu = True
for t in tensors:
if t is None:
continue # NULL pointers are fine
on_xpu &= t.device.type == "xpu"
if not on_xpu:
raise TypeError(
"All input tensors need to be on XPU, but found some tensors to not be on XPU:\n"
f" {[(t.shape, t.device) if isinstance(t, Tensor) else None for t in tensors]}"
)
return on_xpu


class XPUBackend(Backend):
mm_dequant_compute_dtype = torch.bfloat16
mm_dequant_output_dtype = torch.bfloat16

def double_quant(
self,
A: torch.Tensor,
Expand All @@ -17,7 +44,9 @@ def double_quant(
out_row: Optional[torch.Tensor] = None,
threshold=0.0,
):
raise NotImplementedError
assert_on_xpu([A, col_stats, row_stats, out_col, out_row])
output = double_quant_impl(A, col_stats, row_stats, out_col, out_row, threshold)
return output

def transform(
self,
Expand All @@ -29,7 +58,23 @@ def transform(
state: Optional[Tuple[torch.Size, str]] = None,
ld=None,
):
raise NotImplementedError
"""
Transform tensor A to to_order. It is originally designed for CUDA.
For XPU, it returns the original tensor if transpose=False.
Otherwise, it returns the transpose of A
"""
assert_on_xpu([A, out])
if transpose:
if out is not None:
out.copy_(A.T)
else:
out = A.T
else:
if out is not None:
out.copy_(A)
else:
out = A
return out, state

def igemmlt(
self,
Expand All @@ -41,7 +86,9 @@ def igemmlt(
Sout: Optional[Tuple[torch.Size, str]] = None,
dtype=torch.int32,
) -> Union[torch.Tensor, Tuple[Optional[Tuple[torch.Tensor, Tuple[torch.Size, str]]]]]:
raise NotImplementedError
assert_on_xpu([A, B])
output = igemmlt_impl(A, B, SA, SB, out, Sout, dtype)
return output

def mm_dequant(
self,
Expand All @@ -54,15 +101,30 @@ def mm_dequant(
new_col_stats: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
assert_on_xpu([A, row_stats, col_stats, out, bias])
output = mm_dequant_impl(
A,
quant_state,
row_stats,
col_stats,
out,
new_row_stats,
new_col_stats,
bias,
self.mm_dequant_compute_dtype,
self.mm_dequant_output_dtype,
)
return output

def extract_outliers(
self,
A: torch.Tensor,
SA: Tuple[torch.Size, str],
idx: torch.Tensor,
) -> torch.Tensor:
raise NotImplementedError
assert_on_xpu([A])
output = A[:, idx].contiguous()
return output

def quantize_4bit(
self,
Expand All @@ -74,7 +136,12 @@ def quantize_4bit(
quant_type: Literal["fp4", "nf4"] = "fp4",
quant_storage=torch.uint8,
) -> Tuple[torch.Tensor, QuantState]:
raise NotImplementedError
if blocksize is None:
blocksize = 64
assert_on_xpu([A, absmax, out])
assert quant_storage == torch.uint8, "XPU backend only supports uint8 quant_storage"
output = quantize_4bit_impl(A, absmax, out, blocksize, compress_statistics, quant_type)
return output

def dequantize_4bit(
self,
Expand All @@ -85,7 +152,15 @@ def dequantize_4bit(
blocksize: int = 64,
quant_type: Literal["fp4", "nf4"] = "fp4",
) -> torch.Tensor:
raise NotImplementedError
if blocksize is None:
blocksize = 64
assert_on_xpu([A, absmax, out])
if quant_type == "nf4":
output = torch.ops.torch_ipex.dequantize_4bit(A, "nf4", quant_state.shape, absmax, None, blocksize).t()
else:
output = dequantize_4bit_impl(A, quant_state, absmax, out, blocksize, quant_type)

return output

def gemv_4bit(
self,
Expand All @@ -96,7 +171,11 @@ def gemv_4bit(
transposed_B=False,
state: QuantState = None,
) -> torch.Tensor:
raise NotImplementedError
assert_on_xpu([A, B, out])
if state is None:
raise ValueError("state cannot be None. gemv_4bit() requires the state from quantize_4bit()")
output = gemm_4bit_impl(A, B, out, transposed_A, transposed_B, state)
return output

def dequantize_blockwise(
self,
Expand Down
Loading

0 comments on commit b2ac423

Please sign in to comment.