Skip to content

Commit

Permalink
Add xpu integration for woqlinear
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuhong61 committed Nov 10, 2024
1 parent dc01ef9 commit c88901a
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 10 deletions.
4 changes: 2 additions & 2 deletions bitsandbytes/backends/cpu_xpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def dequantize_4bit_impl(
if quant_state.nested:
raise NotImplementedError("bnb_4bit_use_double_quant is not supported yet for CPU/XPU")

if ipex_cpu and _ipex_cpu_version_prereq(2, 5) and getattr(quant_state, "ipex", False):
if ipex_cpu_only and _ipex_cpu_version_prereq(2, 5) and getattr(quant_state, "ipex", False):
A = torch.ops.ipex_prepack.woq_linear_unpack_weight(
A, "nf4", quant_state.shape, 2
)
Expand Down Expand Up @@ -513,7 +513,7 @@ def gemm_4bit_impl(
torch.Tensor:
GEMM output tensor.
"""
if ipex_cpu and _ipex_cpu_version_prereq(2, 5) and getattr(state, "ipex", False):
if (ipex_cpu and _ipex_cpu_version_prereq(2, 5)) or (ipex_xpu and _ipex_xpu_version_prereq(2, 5)) and getattr(state, "ipex", False):
output = torch.ops.torch_ipex.woq_linear(A, B, "nf4", state.shape,
state.new_scales, state.new_zeros, None, None, state.blocksize,
ipex_cpu.quantization.WoqLowpMode.BF16, 1, state.compensation)
Expand Down
2 changes: 1 addition & 1 deletion bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):

def set_ipex_linear(self, x: torch.Tensor):
if (
x.device.type == "cpu"
(x.device.type == "cpu" or x.device.type == "xpu")
and not getattr(self.weight.quant_state, "ipex", False)
and self.weight.quant_state.shape[1] % self.weight.quant_state.blocksize == 0
and self.weight.quant_state.quant_type == "nf4"
Expand Down
22 changes: 15 additions & 7 deletions bitsandbytes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,10 @@ def unpack_tensor_to_dict(tensor_data):


def enable_ipex_fusion(linear):
from bitsandbytes.backends.cpu_xpu_common import _ipex_cpu_version_prereq
from bitsandbytes.backends.cpu_xpu_common import _ipex_cpu_version_prereq, _ipex_xpu_version_prereq
from bitsandbytes.backends.cpu_xpu_common import ipex_cpu_only, ipex_xpu

if _ipex_cpu_version_prereq(2, 5):
if ipex_cpu_only and _ipex_cpu_version_prereq(2, 5):
quant_state = linear.weight.quant_state
new_weight, new_scales, new_zeros, _, compensation = \
torch.ops.ipex_prepack.woq_linear_pack_weight(
Expand All @@ -217,11 +218,18 @@ def enable_ipex_fusion(linear):
quant_state.blocksize,
2,
)
linear.weight.data = new_weight.data
setattr(linear.weight.quant_state, "ipex", True)
setattr(linear.weight.quant_state, "new_scales", new_scales)
setattr(linear.weight.quant_state, "new_zeros", new_zeros)
setattr(linear.weight.quant_state, "compensation", compensation)
elif ipex_xpu and _ipex_xpu_version_prereq(2, 5):
quant_state = linear.weight.quant_state
new_weight = linear.weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2])

new_scales = quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize)
new_zeros = None
compensation = None
linear.weight.data = new_weight.data
setattr(linear.weight.quant_state, "ipex", True)
setattr(linear.weight.quant_state, "new_scales", new_scales)
setattr(linear.weight.quant_state, "new_zeros", new_zeros)
setattr(linear.weight.quant_state, "compensation", compensation)


class QuantState:
Expand Down

0 comments on commit c88901a

Please sign in to comment.