diff --git a/CMakeLists.txt b/CMakeLists.txt index 315e0ff1b..20dd2b45d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,7 +3,7 @@ # For GCC: `cmake -B build . && cmake --build build` # For MSVC: `cmake -B build . && cmake --build build --config Release` # You can also use the following options and variables -# - COMPUTE_BACKEND: Set to `cpu`, `cuda`, `hip` or `mps` to select the backend +# - COMPUTE_BACKEND: Set to `cpu`, `cuda`, `hip`, `mps` or `npu` to select the backend # - NO_CUBLASLT: Default OFF, will skip building/linking CUBLASLT support # - CUDA_VERSION: The expected CUDA version, for sanity checking. The actual version # is whatever CMake finds on your path. @@ -29,11 +29,12 @@ set(CUDA_FILES csrc/ops.cu csrc/kernels.cu) set(HIP_FILES csrc/ops.hip csrc/kernels.hip) set(MPS_FILES csrc/mps_ops.mm) set(METAL_FILES csrc/mps_kernels.metal) +set(NPU_FILES csrc/npu_ops.cpp) # C++ sources are always included list(APPEND SRC_FILES ${CPP_FILES}) -set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps)") -set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps) +set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps, npu)") +set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps npu) option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF) if(APPLE) @@ -69,6 +70,11 @@ elseif(${COMPUTE_BACKEND} STREQUAL "mps") set(BUILD_CUDA OFF) set(BUILD_HIP OFF) set(BUILD_MPS ON) +elseif(${COMPUTE_BACKEND} STREQUAL "npu") + set(BUILD_CUDA OFF) + set(BUILD_HIP OFF) + set(BUILD_MPS OFF) + set(BUILD_NPU ON) else() set(BUILD_CUDA OFF) set(BUILD_HIP OFF) @@ -232,6 +238,33 @@ elseif(BUILD_MPS) COMMENT "Compiling Metal kernels" VERBATIM) add_custom_target(metallib DEPENDS "bitsandbytes/bitsandbytes.metallib") +elseif(BUILD_NPU) + list(APPEND SRC_FILES ${NPU_FILES}) + + set(SOC_VERSION "Ascend910B4" CACHE STRING "system on chip type") + set(ASCEND_CANN_PACKAGE_PATH $ENV{ASCEND_HOME_PATH} CACHE + STRING "ASCEND CAN package installation directory" + ) + + # ${KERNEL_FILES} are used to compile library, push files written by ascendc in ${KERNEL_FILES}. + # ref to cmake/npu.cmake ascendc_library, cmake/cpu.cmake add_library + # file(GLOB KERNEL_FILES ${CMAKE_CURRENT_SOURCE_DIR}/csrc/npu_kernels.cpp) + file(GLOB KERNEL_FILES csrc/npu_kernels.cpp) + + if(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/compiler/tikcpp/ascendc_kernel_cmake) + set(ASCENDC_CMAKE_DIR ${ASCEND_CANN_PACKAGE_PATH}/compiler/tikcpp/ascendc_kernel_cmake) + elseif(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/tools/tikcpp/ascendc_kernel_cmake) + set(ASCENDC_CMAKE_DIR ${ASCEND_CANN_PACKAGE_PATH}/tools/tikcpp/ascendc_kernel_cmake) + else() + message(FATAL_ERROR "ascendc_kernel_cmake does not exist ,please check whether the can package is installed") + endif() + include(${ASCENDC_CMAKE_DIR}/ascendc.cmake) + + # ascendc_library use to add kernel file to generate ascendc library + ascendc_library(ascendc_kernels_npu STATIC ${KERNEL_FILES}) + + string(APPEND BNB_OUTPUT_NAME "_npu") + add_compile_definitions(BUILD_NPU) else() string(APPEND BNB_OUTPUT_NAME "_cpu") set(GPU_SOURCES) @@ -249,7 +282,11 @@ endif() set_source_files_properties(${CPP_FILES} PROPERTIES LANGUAGE CXX) add_library(bitsandbytes SHARED ${SRC_FILES}) -target_compile_features(bitsandbytes PUBLIC cxx_std_14) +if(BUILD_NPU) + target_compile_features(bitsandbytes PUBLIC cxx_std_17) +else() + target_compile_features(bitsandbytes PUBLIC cxx_std_14) +endif() target_include_directories(bitsandbytes PUBLIC csrc include) @@ -306,6 +343,10 @@ if(BUILD_MPS) add_dependencies(bitsandbytes metallib) target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph") endif() +if(BUILD_NPU) + target_compile_options(bitsandbytes PRIVATE -O2 -std=c++17) + target_link_libraries(bitsandbytes PRIVATE $ ascendc_kernels_npu) +endif() if(WIN32) set_target_properties(bitsandbytes PROPERTIES PREFIX "lib") diff --git a/_typos.toml b/_typos.toml index e4e7287fb..ff4c9ae06 100644 --- a/_typos.toml +++ b/_typos.toml @@ -3,12 +3,15 @@ [default] extend-ignore-re = [ "@Ther-nul", # valid Github user + "CANN", # CANN (Compute Architecture for Neural Networks) is a heterogeneous computing architecture for Ascend NPU ] [default.extend-identifiers] [type.py.extend-words] "BA" = "BA" # used as a commented-out variable in tests +"cann" = "cann" # cann (Compute Architecture for Neural Networks) is a heterogeneous computing architecture for Ascend NPU + [type.cuda.extend-words] "subtile" = "subtile" diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index c705137c0..f850140a1 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -25,6 +25,7 @@ features = {"multi_backend"} supported_torch_devices = { "cuda", # includes ROCm + "npu", # Ascend NPU "xpu", # Intel GPU "cpu", } diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index e188479f6..6440ab1b5 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -519,7 +519,12 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState] # 1. Dequantize # 2. MatmulnN - output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias) + if A.device.type == "npu": + output = torch.matmul(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t()) + if bias is not None: + output += bias + else: + output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias) # 3. Save state ctx.state = quant_state @@ -550,7 +555,10 @@ def backward(ctx, grad_output): # not supported by PyTorch. TODO: create work-around # if req_gradB: grad_B = torch.matmul(grad_output.t(), A) if req_gradA: - grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t()) + if grad_output.device.type == "npu": + grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype)) + else: + grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t()) return grad_A, grad_B, None, grad_bias, None @@ -586,7 +594,7 @@ def matmul_4bit( return out else: return MatMul4Bit.apply(A, B, out, bias, quant_state) - elif A.numel() == A.shape[-1] and A.requires_grad == False: + elif A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "npu": if A.shape[-1] % quant_state.blocksize != 0: warn( f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}", diff --git a/bitsandbytes/backends/cpu_xpu_common.py b/bitsandbytes/backends/cpu_xpu_common.py index d2e0c2593..8fdf7569d 100644 --- a/bitsandbytes/backends/cpu_xpu_common.py +++ b/bitsandbytes/backends/cpu_xpu_common.py @@ -23,7 +23,7 @@ gxx_available = False try: - subprocess.run(["g++", "--version"]) + subprocess.run(["g++", "--version"], capture_output=True) # hide terminal output gxx_available = True except BaseException: warnings.warn("g++ not found, torch.compile disabled for CPU/XPU.") diff --git a/bitsandbytes/backends/npu.py b/bitsandbytes/backends/npu.py index 1b3cb57d6..ecbc2f351 100644 --- a/bitsandbytes/backends/npu.py +++ b/bitsandbytes/backends/npu.py @@ -1,17 +1,32 @@ +import ctypes as ct from typing import Literal, Optional, Tuple, Union import torch -from bitsandbytes.utils import QuantState - -from .base import Backend - try: # to support Ascend NPU backend import torch_npu # noqa: F401 except ImportError: pass +from bitsandbytes.cextension import lib +from bitsandbytes.functional import ( + get_4bit_type, + get_ptr, +) +from bitsandbytes.utils import QuantState + +from .base import Backend + + +def assert_on_npu(tensors): + if not all(t.device.type == "npu" for t in tensors if t is not None): + raise TypeError( + "All input tensors to be on NPU, but found some tensors not be on NPU:\n" + f"{[(t.shape, t.device) if isinstance(t, torch.Tensor) else None for t in tensors]}" + ) + return True + class NPUBackend(Backend): def double_quant( @@ -75,12 +90,62 @@ def quantize_4bit( A: torch.Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize=64, + blocksize: Optional[int] = None, compress_statistics=False, - quant_type: Literal["fp4", "nf4"] = "fp4", + quant_type: Literal["fp4", "nf4"] = "nf4", quant_storage=torch.uint8, ) -> Tuple[torch.Tensor, QuantState]: - raise NotImplementedError + if quant_type not in ["nf4"]: + raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.") + if compress_statistics: + raise NotImplementedError("compress_statistics is not implemented.") + if blocksize is None: + blocksize = 128 + + prev_device = torch.npu.current_device() + torch.npu.set_device(A.device) + if A.dtype in [torch.float32, torch.float16, torch.bfloat16]: + data = [ + -1.0, + -0.6961928009986877, + -0.5250730514526367, + -0.39491748809814453, + -0.28444138169288635, + -0.18477343022823334, + -0.09105003625154495, + 0.0, + 0.07958029955625534, + 0.16093020141124725, + 0.24611230194568634, + 0.33791524171829224, + 0.44070982933044434, + 0.5626170039176941, + 0.7229568362236023, + 1.0, + ] + data = torch.tensor(data, device="npu", dtype=torch.float32).view(1, -1) + absmax = A.view(-1, blocksize).abs().max(dim=1, keepdim=True).values + a = A.view(-1, blocksize) / absmax.float() + diff = torch.abs(a.unsqueeze(-1) - data) + out = (torch.argmin(diff, dim=-1) + 8) % 16 + out = out.reshape(-1, 2) + out = (out[:, 0] + out[:, 1] * 16).to(torch.uint8) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + assert_on_npu([A, absmax, out]) + torch.npu.set_device(prev_device) + + code = get_4bit_type(quant_type, device=A.device) + state = QuantState( + absmax=absmax, + shape=A.shape, + dtype=A.dtype, + blocksize=blocksize, + code=code, + quant_type=quant_type, + ) + + return out, state def dequantize_4bit( self, @@ -88,10 +153,77 @@ def dequantize_4bit( quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize: int = 64, - quant_type: Literal["fp4", "nf4"] = "fp4", + blocksize: Optional[int] = None, + quant_type: Literal["fp4", "nf4"] = "nf4", ) -> torch.Tensor: - raise NotImplementedError + if blocksize is None: + blocksize = 128 + supported_blocksizes = [2048, 4096, 1024, 512, 256, 128, 64] + if blocksize not in supported_blocksizes: + raise ValueError( + f"The blockwise of {blocksize} is not supported. Supported values: {supported_blocksizes}" + ) + + if quant_state is None: + assert absmax is not None and out is not None + quant_state = QuantState( + absmax=absmax, shape=out.shape, dtype=out.dtype, blocksize=blocksize, quant_type=quant_type + ) + else: + absmax = quant_state.absmax + + if out is None: + out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device) + + n = out.numel() + + prev_device = torch.npu.current_device() + torch.npu.set_device(A.device) + assert_on_npu([A, absmax, out]) + + if quant_state.quant_type not in ["nf4"]: + raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.") + + if out.dtype == torch.float32: + lib.cdequantize_blockwise_fp32_nf4( + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + torch.npu.current_stream(), + ) + elif out.dtype == torch.float16: + lib.cdequantize_blockwise_fp16_nf4( + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + torch.npu.current_stream(), + ) + elif out.dtype == torch.bfloat16: + # bf16: bf16 -> fp32 -> op -> fp32 -> bf16 + absmax = absmax.to(torch.float32) + out = out.to(torch.float32) + lib.cdequantize_blockwise_fp32_nf4( + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + torch.npu.current_stream(), + ) + out = out.to(torch.bfloat16) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + torch.npu.set_device(prev_device) + is_transposed = True if A.shape[0] == 1 else False + + if is_transposed: + return out.t() + else: + return out def gemv_4bit( self, diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index cc5d8deff..ec329cbb6 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -25,6 +25,7 @@ from bitsandbytes.consts import DYNAMIC_LIBRARY_SUFFIX, PACKAGE_DIR from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs, get_rocm_gpu_arch +from bitsandbytes.npu_specs import get_npu_specs logger = logging.getLogger(__name__) @@ -100,6 +101,10 @@ def get_native_library() -> BNBNativeLibrary: binary_path = cuda_binary_path else: logger.warning("Could not find the bitsandbytes %s binary at %r", BNB_BACKEND, cuda_binary_path) + npu_specs = get_npu_specs() + if npu_specs: + binary_path = PACKAGE_DIR / f"libbitsandbytes_npu{DYNAMIC_LIBRARY_SUFFIX}" + logger.debug(f"Loading bitsandbytes native library from: {binary_path}") dll = ct.cdll.LoadLibrary(str(binary_path)) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 66f14edf7..781e22541 100755 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -314,6 +314,12 @@ def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: b def cpu(self, non_blocking: bool = False): return self.to(device="cpu", non_blocking=non_blocking) + def npu(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False): + # `torch.Tensor.to()` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)). + if isinstance(device, int): + device = f"npu:{device}" + return self.to(device="npu" if device is None else device, non_blocking=non_blocking) + def xpu(self, non_blocking: bool = False): return self.to(device="xpu", non_blocking=non_blocking) @@ -334,7 +340,7 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ... def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) - if device is not None and device.type in ["cuda", "cpu", "xpu"] and not self.bnb_quantized: + if device is not None and device.type in ["cuda", "cpu", "npu", "xpu"] and not self.bnb_quantized: return self._quantize(device) else: if self.quant_state is not None: @@ -497,7 +503,7 @@ def forward(self, x: torch.Tensor): self.weight.quant_state = self.quant_state else: print( - "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.", + "FP4 quantization state not initialized. Please call .cuda(), .npu() or .to(device) on the LinearFP4 layer first.", ) if not self.compute_type_is_set: self.set_compute_type(x) diff --git a/bitsandbytes/npu_specs.py b/bitsandbytes/npu_specs.py new file mode 100644 index 000000000..7c7cd707e --- /dev/null +++ b/bitsandbytes/npu_specs.py @@ -0,0 +1,20 @@ +import dataclasses + +import torch + +try: + import torch_npu # noqa: F401 +except ImportError: + pass + + +@dataclasses.dataclass(frozen=True) +class NPUSpecs: + cann_version_string: str + + +def get_npu_specs(): + if hasattr(torch, "npu") and torch.npu.is_available(): + return NPUSpecs(cann_version_string=torch.version.cann) + else: + return None diff --git a/csrc/npu_kernels.cpp b/csrc/npu_kernels.cpp new file mode 100644 index 000000000..c70e71681 --- /dev/null +++ b/csrc/npu_kernels.cpp @@ -0,0 +1,222 @@ +#include "kernel_operator.h" +#include "npu_ops.h" + +using namespace AscendC; + +constexpr int32_t BUFFER_NUM = 1; + +constexpr half Q_COFF_0 = -0.377685546875; +constexpr half Q_COFF_1 = -3.193359375; +constexpr half Q_COFF_2 = 0.583984375; +constexpr half Q_COFF_3 = 6.02734375; +constexpr half Q_COFF_4 = 1.9560546875; +constexpr half Q_COFF_5 = 7.08984375; + +#define CEIL32(num) (((num) + 32 - 1) / 32 * 32) +#define CEIL_BASE(num, base) (((num) + (base) - 1) / (base) * (base)) + + +template +class KernelDequantizeBlockwiseNf4 { +public: + __aicore__ inline KernelDequantizeBlockwiseNf4() {} + + __aicore__ inline void Init(GM_ADDR A, GM_ADDR absmax, GM_ADDR out, GM_ADDR tilingDevice, TPipe &pipe) + { + ASSERT(GetBlockNum() != 0 && "block dim can not be zero!"); + auto *tiling_data = reinterpret_cast<__gm__ BlockwiseNf4TilingData *>(tilingDevice); + this->blocksize = tiling_data->blocksize; + uint32_t coreNum = tiling_data->coreNum; + uint32_t singleCoreNumel = tiling_data->singleCoreNumel; + uint32_t singleCoreNumelTail = tiling_data->singleCoreNumelTail; + uint32_t numel = tiling_data->numel; + uint32_t ubSize = tiling_data->ubSize; + uint32_t blockIdx = (uint32_t)GetBlockIdx(); + if (coreNum - blockIdx == 1) { + this->CurCoreFP16Num = singleCoreNumelTail; + } else { + this->CurCoreFP16Num = singleCoreNumel; + } + constexpr uint32_t ELEMENT_BYTES = (TypeMode == 1) ? 4 : 2; // FP32: 4bytes, FP16/BF16: 2bytes + uint32_t eachBatchPkgNum = (ubSize - 16 * ELEMENT_BYTES) / + (this->blocksize / 2 * BUFFER_NUM + ELEMENT_BYTES * BUFFER_NUM + this->blocksize * + (ELEMENT_BYTES * BUFFER_NUM + sizeof(half) + sizeof(uint32_t) + ELEMENT_BYTES)); + if (eachBatchPkgNum >= 32 / ELEMENT_BYTES) { + eachBatchPkgNum = (eachBatchPkgNum / (32 / ELEMENT_BYTES)) * (32 / ELEMENT_BYTES); + } else { + eachBatchPkgNum = (eachBatchPkgNum / 2) * 2; + } + this->eachBatchFP16Num = this->blocksize * eachBatchPkgNum; // 64 * 288 + + // gm, 32-byte alignment + uint32_t AOffset = singleCoreNumel / 2 * blockIdx; + uint32_t ABufferSize = singleCoreNumel / 2; + AGm.SetGlobalBuffer((__gm__ int8_t*)A + AOffset, ABufferSize); + uint32_t absmaxOffset = singleCoreNumel / this->blocksize * blockIdx; + uint32_t absmaxBufferSize = singleCoreNumel / this->blocksize; + absmaxGm.SetGlobalBuffer((__gm__ T*)absmax + absmaxOffset, absmaxBufferSize); + uint32_t outOffset = singleCoreNumel * blockIdx; + uint32_t outBufferSize = singleCoreNumel; + outGm.SetGlobalBuffer((__gm__ T*)out + outOffset, outBufferSize); + + // TQue, 32-byte alignment + pipe.InitBuffer(inQueueA, BUFFER_NUM, this->eachBatchFP16Num / 2); + pipe.InitBuffer(inQueueAbsmax, BUFFER_NUM, CEIL32(eachBatchPkgNum * ELEMENT_BYTES)); + pipe.InitBuffer(outQueueOut, BUFFER_NUM, this->eachBatchFP16Num * ELEMENT_BYTES); + + // TBuf, 32-byte alignment + pipe.InitBuffer(calcNf4ToFloat, 16 * ELEMENT_BYTES); + pipe.InitBuffer(calcAFP16, this->eachBatchFP16Num * sizeof(half)); + pipe.InitBuffer(calcAUint32, this->eachBatchFP16Num * sizeof(uint32_t)); + pipe.InitBuffer(calcAbsmaxBuf, this->eachBatchFP16Num * ELEMENT_BYTES); + } + + __aicore__ inline void Process(void) + { + Compute(); + } + +private: + __aicore__ inline void initNf4ToFloat(LocalTensor &nf4ToFloat) + { + if constexpr (TypeMode == 1) { + nf4ToFloat(0) = static_cast(-1.0); + nf4ToFloat(1) = static_cast(-0.6961928009986877); + nf4ToFloat(2) = static_cast(-0.5250730514526367); + nf4ToFloat(3) = static_cast(-0.39491748809814453); + nf4ToFloat(4) = static_cast(-0.28444138169288635); + nf4ToFloat(5) = static_cast(-0.18477343022823334); + nf4ToFloat(6) = static_cast(-0.09105003625154495); + nf4ToFloat(7) = static_cast(0.0); + nf4ToFloat(8) = static_cast(0.07958029955625534); + nf4ToFloat(9) = static_cast(0.16093020141124725); + nf4ToFloat(10) = static_cast(0.24611230194568634); + nf4ToFloat(11) = static_cast(0.33791524171829224); + nf4ToFloat(12) = static_cast(0.44070982933044434); + nf4ToFloat(13) = static_cast(0.5626170039176941); + nf4ToFloat(14) = static_cast(0.7229568362236023); + nf4ToFloat(15) = static_cast(1.0); + } else if constexpr (TypeMode == 2) { + nf4ToFloat(0) = static_cast(-1.0); + nf4ToFloat(1) = static_cast(-0.6962890625); + nf4ToFloat(2) = static_cast(-0.52490234375); + nf4ToFloat(3) = static_cast(-0.39501953125); + nf4ToFloat(4) = static_cast(-0.284423828125); + nf4ToFloat(5) = static_cast(-0.184814453125); + nf4ToFloat(6) = static_cast(-0.091064453125); + nf4ToFloat(7) = static_cast(0.0); + nf4ToFloat(8) = static_cast(0.07958984375); + nf4ToFloat(9) = static_cast(0.160888671875); + nf4ToFloat(10) = static_cast(0.24609375); + nf4ToFloat(11) = static_cast(0.337890625); + nf4ToFloat(12) = static_cast(0.440673828125); + nf4ToFloat(13) = static_cast(0.5625); + nf4ToFloat(14) = static_cast(0.72314453125); + nf4ToFloat(15) = static_cast(1.0); + } + } + + __aicore__ inline void Compute(void) + { + constexpr uint32_t ELEMENT_BYTES = (TypeMode == 1) ? 4 : 2; // FP32: 4bytes, FP16/BF16: 2bytes + LocalTensor ALocal = inQueueA.AllocTensor(); + LocalTensor absmaxLocal = inQueueAbsmax.AllocTensor(); + LocalTensor outLocal = outQueueOut.AllocTensor(); + + LocalTensor AFP16 = calcAFP16.Get(); + LocalTensor AInt32 = calcAUint32.Get(); + LocalTensor absmaxBuf = calcAbsmaxBuf.Get(); + LocalTensor nf4ToFloat = calcNf4ToFloat.Get(); + initNf4ToFloat(nf4ToFloat); + + DataCopyParams dataCopyParams = {1, 0, 0, 0}; + uint32_t curBatchNumel = this->eachBatchFP16Num; + uint32_t curBatchPkgNum = curBatchNumel / this->blocksize; + + uint32_t batchCount = (this->CurCoreFP16Num + this->eachBatchFP16Num - 1) / this->eachBatchFP16Num; + for (uint32_t batchIdx = 0; batchIdx < batchCount; batchIdx++) { + if (batchCount - batchIdx == 1) { + curBatchNumel = this->CurCoreFP16Num - this->eachBatchFP16Num * batchIdx; + curBatchPkgNum = (curBatchNumel + this->blocksize - 1) / this->blocksize; + } + + dataCopyParams.blockLen = curBatchNumel / 2; // Byte + DataCopyPad(ALocal, AGm[this->eachBatchFP16Num / 2 * batchIdx], dataCopyParams, {true, 0, 0, 0}); + dataCopyParams.blockLen = ELEMENT_BYTES * curBatchPkgNum; // Byte + uint32_t gmOffset = this->eachBatchFP16Num / this->blocksize * batchIdx; + DataCopyPad(absmaxLocal, absmaxGm[gmOffset], dataCopyParams, {true, 0, 0, 0}); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + pipe_barrier(PIPE_ALL); + + LocalTensor AInt4 = ALocal.ReinterpretCast(); + Cast(AFP16, AInt4, RoundMode::CAST_NONE, curBatchNumel); + pipe_barrier(PIPE_V); + Adds(AFP16, AFP16, static_cast(8), curBatchNumel); + pipe_barrier(PIPE_V); + if constexpr (TypeMode == 1) { + Muls(AFP16, AFP16, static_cast(4), curBatchNumel); + } else { + Muls(AFP16, AFP16, static_cast(2), curBatchNumel); + } + pipe_barrier(PIPE_V); + Cast(AInt32, AFP16, RoundMode::CAST_ROUND, curBatchNumel); + pipe_barrier(PIPE_V); + LocalTensor AUint32 = AInt32.ReinterpretCast(); + Gather(outLocal, nf4ToFloat, AUint32, 0, curBatchNumel); + pipe_barrier(PIPE_V); + uint32_t dstShape[] = {curBatchPkgNum, this->blocksize}; + uint32_t srcShape[] = {curBatchPkgNum, 1}; + BroadCast(absmaxBuf, absmaxLocal, dstShape, srcShape); + pipe_barrier(PIPE_ALL); + Mul(outLocal, outLocal, absmaxBuf, curBatchNumel); + pipe_barrier(PIPE_ALL); + + dataCopyParams.blockLen = ELEMENT_BYTES * curBatchNumel; // Byte + DataCopyPad(outGm[batchIdx * this->eachBatchFP16Num], outLocal, dataCopyParams); + pipe_barrier(PIPE_MTE3); + } + pipe_barrier(PIPE_ALL); + + inQueueA.FreeTensor(ALocal); + inQueueAbsmax.FreeTensor(absmaxLocal); + outQueueOut.FreeTensor(outLocal); + } + +private: + TQue inQueueA; + TQue inQueueAbsmax; + TQue outQueueOut; + TBuf calcAFP16; + TBuf calcAUint32; + TBuf calcNf4ToFloat; + TBuf calcAbsmaxBuf; + GlobalTensor AGm; + GlobalTensor absmaxGm; + GlobalTensor outGm; + uint32_t blocksize; + uint32_t CurCoreFP16Num; + uint32_t eachBatchFP16Num; +}; + + + +extern "C" { + +__global__ __aicore__ void dequantize_blockwise_fp32_nf4(GM_ADDR A, GM_ADDR absmax, GM_ADDR out, GM_ADDR tiling) +{ + TPipe pipe; + KernelDequantizeBlockwiseNf4 op; + op.Init(A, absmax, out, tiling, pipe); + op.Process(); +} + +__global__ __aicore__ void dequantize_blockwise_fp16_nf4(GM_ADDR A, GM_ADDR absmax, GM_ADDR out, GM_ADDR tiling) +{ + TPipe pipe; + KernelDequantizeBlockwiseNf4 op; + op.Init(A, absmax, out, tiling, pipe); + op.Process(); +} + +} diff --git a/csrc/npu_ops.cpp b/csrc/npu_ops.cpp new file mode 100644 index 000000000..fb5ecef2f --- /dev/null +++ b/csrc/npu_ops.cpp @@ -0,0 +1,51 @@ +#include +#include "acl/acl.h" +#include "tiling/platform/platform_ascendc.h" +#include "npu_ops.h" + +#include "aclrtlaunch_dequantize_blockwise_fp32_nf4.h" +#include "aclrtlaunch_dequantize_blockwise_fp16_nf4.h" + + +extern "C" { + +int32_t get_dequantize_blockwise_nf4_tiling(uint32_t blocksize, uint32_t n, BlockwiseNf4TilingData *tiling) { + tiling->ubSize = 196 * 1024; + uint32_t coreNum = 40; + uint32_t totalPkgNum = (n + blocksize - 1) / blocksize; + uint32_t singleCorePkgNum = (totalPkgNum + coreNum - 1) / coreNum; + coreNum = (totalPkgNum + singleCorePkgNum - 1) / singleCorePkgNum; + uint32_t singleCoreNumel = singleCorePkgNum * blocksize; + uint32_t singleCoreNumelTail = n % singleCoreNumel; + if (singleCoreNumelTail == 0) { + singleCoreNumelTail = singleCoreNumel; + } + tiling->coreNum = coreNum; + tiling->blocksize = blocksize; + tiling->numel = n; + tiling->singleCoreNumel = singleCoreNumel; + tiling->singleCoreNumelTail = singleCoreNumelTail; + return 0; +} + +void dequantizeBlockwiseNf4(uint8_t *A, uint8_t *absmax, uint8_t *out, uint32_t blocksize, uint32_t n, void* stream, const uint32_t type_mode) { + uint32_t blockDim = 40; + size_t tilingSize = sizeof(struct BlockwiseNf4TilingData); + BlockwiseNf4TilingData *tilingHost; + tilingHost = (struct BlockwiseNf4TilingData *)malloc(tilingSize); + uint32_t error = get_dequantize_blockwise_nf4_tiling(blocksize, n, tilingHost); + if (error != 0) { + printf("[!] error\n"); + } + uint8_t *tilingDevice = nullptr; + aclrtMalloc((void **)&tilingDevice, tilingSize, ACL_MEM_MALLOC_NORMAL_ONLY); + aclrtMemcpyAsync((void *)tilingDevice, tilingSize, tilingHost, tilingSize, ACL_MEMCPY_HOST_TO_DEVICE, stream); + if (type_mode == 1) { + ACLRT_LAUNCH_KERNEL(dequantize_blockwise_fp32_nf4)(blockDim, stream, A, absmax, out, tilingDevice); + } else if (type_mode == 2) { + ACLRT_LAUNCH_KERNEL(dequantize_blockwise_fp16_nf4)(blockDim, stream, A, absmax, out, tilingDevice); + } + aclrtFree(tilingDevice); +} + +} diff --git a/csrc/npu_ops.h b/csrc/npu_ops.h new file mode 100644 index 000000000..d7a26cd34 --- /dev/null +++ b/csrc/npu_ops.h @@ -0,0 +1,28 @@ +#ifndef NPU_OPS_H +#define NPU_OPS_H +#include + +#define CHECK_ACL(x) \ + do { \ + aclError __ret = x; \ + if (__ret != ACL_ERROR_NONE) { \ + std::cerr << __FILE__ << ":" << __LINE__ << " aclError:" << __ret << std::endl; \ + } \ + } while (0); + + +struct BlockwiseNf4TilingData { + uint32_t coreNum; + uint32_t blocksize; + uint32_t numel; + uint32_t singleCoreNumel; + uint32_t singleCoreNumelTail; + uint32_t ubSize; +}; + +extern "C" { + +void dequantizeBlockwiseNf4(uint8_t *A, uint8_t *absmax, uint8_t *out, uint32_t blocksize, uint32_t n, void* stream, const uint32_t type_mode); + +} +#endif diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index be6abc070..2d3031936 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -12,6 +12,9 @@ #if BUILD_MPS // #include #endif +#if BUILD_NPU +#include +#endif #include // We cannot call templated code from C, so we wrap the template in a C compatible call here if necessary. @@ -601,6 +604,14 @@ extern "C" #endif +#if BUILD_NPU + void cdequantize_blockwise_fp32_nf4(uint8_t *A, uint8_t *absmax, uint8_t *out, uint32_t blocksize, uint32_t n, void* stream) + { dequantizeBlockwiseNf4(A, absmax, out, blocksize, n, stream, 1); } + + void cdequantize_blockwise_fp16_nf4(uint8_t *A, uint8_t *absmax, uint8_t *out, uint32_t blocksize, uint32_t n, void* stream) + { dequantizeBlockwiseNf4(A, absmax, out, blocksize, n, stream, 2); } +#endif + void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); } } diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index 615dfd95e..79613856f 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -210,6 +210,7 @@ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/YOUR_USERNAME/local/cuda-11.7 | **Apple Silicon (MPS)** | WIP | 3.10+ | M1/M2 chips | Planned | | **Intel CPU** | v2.5.0+ (`ipex`) | 3.10+ | Intel CPU | Alpha | | **Intel GPU** | v2.5.0+ (`ipex`) | 3.10+ | Intel GPU | Experimental | +| **Ascend NPU** | 2.1.0+ (`torch_npu`) | 3.10+ | Ascend NPU | Experimental | For each supported backend, follow the respective instructions below: @@ -251,6 +252,13 @@ Compatible hardware and functioning `import intel_extension_for_pytorch as ipex` Please refer to [the official Intel installations instructions](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=cpu&version=v2.4.0%2bcpu&os=linux%2fwsl2) for guidance on how to pip install the necessary `intel_extension_for_pytorch` dependency. + + + +Compatible hardware and functioning `import torch_npu` capable environment with Python `3.10` as the minimum requirement. + +Please refer to [the official Ascend installations instructions](https://www.hiascend.com/document/detail/zh/Pytorch/60RC3/configandinstg/instg/insg_0001.html) for guidance on how to pip install the necessary `torch_npu` dependency. + @@ -339,6 +347,31 @@ pip install -r requirements-dev.txt pip install -e . # `-e` for "editable" install, when developing BNB (otherwise leave that out) ``` + + + +#### Ascend NPU + +> [!TIP] +> Ascend NPU backend only supports building from source; for now, please follow the instructions below. + + +``` +# Install bitsandbytes from source +# Clone bitsandbytes repo, Ascend NPU backend is currently enabled on multi-backend-refactor branch +git clone -b multi-backend-refactor https://github.com/bitsandbytes-foundation/bitsandbytes.git && cd bitsandbytes/ + +# Install dependencies +pip install -r requirements-dev.txt + +# Compile & install +apt-get install -y build-essential cmake # install build tools dependencies, unless present +cmake -DCOMPUTE_BACKEND=npu -S . +make +pip install -e . # `-e` for "editable" install, when developing BNB (otherwise leave that out) +``` + +