diff --git a/bitsandbytes/backends/cpu_xpu_common.py b/bitsandbytes/backends/cpu_xpu_common.py index 4bf9c446a..12b8f1506 100644 --- a/bitsandbytes/backends/cpu_xpu_common.py +++ b/bitsandbytes/backends/cpu_xpu_common.py @@ -65,7 +65,7 @@ def _maybe_torch_compile(func): return func -# @_maybe_torch_compile +@_maybe_torch_compile def double_quant_impl(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): """ Find absolute max values of each row/column of a tensor, and symmetrically quantize it to int8. @@ -198,7 +198,7 @@ def igemmlt_impl(A, B, SA=None, SB=None, out=None, Sout=None, dtype=torch.int32) return out, Sout -# @_maybe_torch_compile +@_maybe_torch_compile def mm_dequant_impl( A, quant_state, @@ -278,7 +278,7 @@ def mm_dequant_impl( } -# @_maybe_torch_compile +@_maybe_torch_compile def quantize_4bit_impl( A: Tensor, absmax: Tensor = None, @@ -374,7 +374,7 @@ def quantize_4bit_impl( return out.unsqueeze(0), state -# @_maybe_torch_compile +@_maybe_torch_compile def dequantize_4bit_impl( A: Tensor, quant_state=None, @@ -513,15 +513,11 @@ def gemm_4bit_impl( torch.Tensor: GEMM output tensor. """ - print("~~~~~~~~getattr ipex: ", getattr(state, "ipex", False)) - # if (ipex_cpu and _ipex_cpu_version_prereq(2, 5)) or ipex_xpu: if ipex_cpu and _ipex_cpu_version_prereq(2, 5) and getattr(state, "ipex", False): - print("=======cpu custom op path=========") output = torch.ops.torch_ipex.woq_linear(A, B, "nf4", state.shape, state.new_scales, state.new_zeros, None, None, state.blocksize, ipex_cpu.quantization.WoqLowpMode.BF16, 1, state.compensation) else: - print("======else path=========") dqB = dequantize_4bit_impl(B, state, blocksize=state.blocksize).t() output = torch.matmul(A, dqB.to(A.dtype)) if out is not None: diff --git a/bitsandbytes/backends/xpu.py b/bitsandbytes/backends/xpu.py index 7c8497d48..693b79017 100644 --- a/bitsandbytes/backends/xpu.py +++ b/bitsandbytes/backends/xpu.py @@ -149,12 +149,6 @@ def dequantize_4bit( if blocksize is None: blocksize = 64 assert_on_xpu([A, absmax, out]) - # result = dequantize_4bit_impl(A, quant_state, absmax, out, blocksize, quant_type) - # print("+++++++++result: ", result) - # return dequantize_4bit_impl(A, quant_state, absmax, out, blocksize, quant_type) - print("------A device: ", A.device) - print("------quant_state device: ", quant_state.shape[0]) - print("------absmax device: ", quant_state.absmax.device) output_dq = torch.ops.torch_ipex.dequantize_4bit( A, "nf4", @@ -164,7 +158,6 @@ def dequantize_4bit( blocksize ) output_dq = output_dq.t() - print("=====output_dq: ", output_dq) return output_dq def gemv_4bit(