Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Ascend NPU support for nf4 quant #1422

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 45 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# For GCC: `cmake -B build . && cmake --build build`
# For MSVC: `cmake -B build . && cmake --build build --config Release`
# You can also use the following options and variables
# - COMPUTE_BACKEND: Set to `cpu`, `cuda`, `hip` or `mps` to select the backend
# - COMPUTE_BACKEND: Set to `cpu`, `cuda`, `hip`, `mps` or `npu` to select the backend
# - NO_CUBLASLT: Default OFF, will skip building/linking CUBLASLT support
# - CUDA_VERSION: The expected CUDA version, for sanity checking. The actual version
# is whatever CMake finds on your path.
Expand All @@ -29,11 +29,12 @@ set(CUDA_FILES csrc/ops.cu csrc/kernels.cu)
set(HIP_FILES csrc/ops.hip csrc/kernels.hip)
set(MPS_FILES csrc/mps_ops.mm)
set(METAL_FILES csrc/mps_kernels.metal)
set(NPU_FILES csrc/npu_ops.cpp)
# C++ sources are always included
list(APPEND SRC_FILES ${CPP_FILES})

set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps)")
set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps)
set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps, npu)")
set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps npu)
option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF)

if(APPLE)
Expand Down Expand Up @@ -69,6 +70,11 @@ elseif(${COMPUTE_BACKEND} STREQUAL "mps")
set(BUILD_CUDA OFF)
set(BUILD_HIP OFF)
set(BUILD_MPS ON)
elseif(${COMPUTE_BACKEND} STREQUAL "npu")
set(BUILD_CUDA OFF)
set(BUILD_HIP OFF)
set(BUILD_MPS OFF)
set(BUILD_NPU ON)
else()
set(BUILD_CUDA OFF)
set(BUILD_HIP OFF)
Expand Down Expand Up @@ -232,6 +238,33 @@ elseif(BUILD_MPS)
COMMENT "Compiling Metal kernels"
VERBATIM)
add_custom_target(metallib DEPENDS "bitsandbytes/bitsandbytes.metallib")
elseif(BUILD_NPU)
list(APPEND SRC_FILES ${NPU_FILES})

set(SOC_VERSION "Ascend910B4" CACHE STRING "system on chip type")
set(ASCEND_CANN_PACKAGE_PATH $ENV{ASCEND_HOME_PATH} CACHE
STRING "ASCEND CAN package installation directory"
)

# ${KERNEL_FILES} are used to compile library, push files written by ascendc in ${KERNEL_FILES}.
# ref to cmake/npu.cmake ascendc_library, cmake/cpu.cmake add_library
# file(GLOB KERNEL_FILES ${CMAKE_CURRENT_SOURCE_DIR}/csrc/npu_kernels.cpp)
file(GLOB KERNEL_FILES csrc/npu_kernels.cpp)

if(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/compiler/tikcpp/ascendc_kernel_cmake)
set(ASCENDC_CMAKE_DIR ${ASCEND_CANN_PACKAGE_PATH}/compiler/tikcpp/ascendc_kernel_cmake)
elseif(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/tools/tikcpp/ascendc_kernel_cmake)
set(ASCENDC_CMAKE_DIR ${ASCEND_CANN_PACKAGE_PATH}/tools/tikcpp/ascendc_kernel_cmake)
else()
message(FATAL_ERROR "ascendc_kernel_cmake does not exist ,please check whether the can package is installed")
endif()
include(${ASCENDC_CMAKE_DIR}/ascendc.cmake)

# ascendc_library use to add kernel file to generate ascendc library
ascendc_library(ascendc_kernels_npu STATIC ${KERNEL_FILES})

string(APPEND BNB_OUTPUT_NAME "_npu")
add_compile_definitions(BUILD_NPU)
else()
string(APPEND BNB_OUTPUT_NAME "_cpu")
set(GPU_SOURCES)
Expand All @@ -249,7 +282,11 @@ endif()

set_source_files_properties(${CPP_FILES} PROPERTIES LANGUAGE CXX)
add_library(bitsandbytes SHARED ${SRC_FILES})
target_compile_features(bitsandbytes PUBLIC cxx_std_14)
if(BUILD_NPU)
target_compile_features(bitsandbytes PUBLIC cxx_std_17)
else()
target_compile_features(bitsandbytes PUBLIC cxx_std_14)
endif()
target_include_directories(bitsandbytes PUBLIC csrc include)


Expand Down Expand Up @@ -306,6 +343,10 @@ if(BUILD_MPS)
add_dependencies(bitsandbytes metallib)
target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph")
endif()
if(BUILD_NPU)
target_compile_options(bitsandbytes PRIVATE -O2 -std=c++17)
target_link_libraries(bitsandbytes PRIVATE $<BUILD_INTERFACE:host_intf_pub> ascendc_kernels_npu)
endif()

if(WIN32)
set_target_properties(bitsandbytes PROPERTIES PREFIX "lib")
Expand Down
3 changes: 3 additions & 0 deletions _typos.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
[default]
extend-ignore-re = [
"@Ther-nul", # valid Github user
"CANN", # CANN (Compute Architecture for Neural Networks) is a heterogeneous computing architecture for Ascend NPU
]

[default.extend-identifiers]

[type.py.extend-words]
"BA" = "BA" # used as a commented-out variable in tests
"cann" = "cann" # cann (Compute Architecture for Neural Networks) is a heterogeneous computing architecture for Ascend NPU


[type.cuda.extend-words]
"subtile" = "subtile"
Expand Down
1 change: 1 addition & 0 deletions bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
features = {"multi_backend"}
supported_torch_devices = {
"cuda", # includes ROCm
"npu", # Ascend NPU
"xpu", # Intel GPU
"cpu",
}
Expand Down
14 changes: 11 additions & 3 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,12 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState]

# 1. Dequantize
# 2. MatmulnN
output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)
if A.device.type == "npu":
output = torch.matmul(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t())
if bias is not None:
output += bias
else:
output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)

# 3. Save state
ctx.state = quant_state
Expand Down Expand Up @@ -550,7 +555,10 @@ def backward(ctx, grad_output):
# not supported by PyTorch. TODO: create work-around
# if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
if req_gradA:
grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t())
if grad_output.device.type == "npu":
grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype))
else:
grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t())

return grad_A, grad_B, None, grad_bias, None

Expand Down Expand Up @@ -586,7 +594,7 @@ def matmul_4bit(
return out
else:
return MatMul4Bit.apply(A, B, out, bias, quant_state)
elif A.numel() == A.shape[-1] and A.requires_grad == False:
elif A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "npu":
if A.shape[-1] % quant_state.blocksize != 0:
warn(
f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}",
Expand Down
2 changes: 1 addition & 1 deletion bitsandbytes/backends/cpu_xpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

gxx_available = False
try:
subprocess.run(["g++", "--version"])
subprocess.run(["g++", "--version"], capture_output=True) # hide terminal output
gxx_available = True
except BaseException:
warnings.warn("g++ not found, torch.compile disabled for CPU/XPU.")
Expand Down
152 changes: 142 additions & 10 deletions bitsandbytes/backends/npu.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,32 @@
import ctypes as ct
from typing import Literal, Optional, Tuple, Union

import torch

from bitsandbytes.utils import QuantState

from .base import Backend

try:
# to support Ascend NPU backend
import torch_npu # noqa: F401
except ImportError:
pass

from bitsandbytes.cextension import lib
from bitsandbytes.functional import (
get_4bit_type,
get_ptr,
)
from bitsandbytes.utils import QuantState

from .base import Backend


def assert_on_npu(tensors):
if not all(t.device.type == "npu" for t in tensors if t is not None):
raise TypeError(
"All input tensors to be on NPU, but found some tensors not be on NPU:\n"
f"{[(t.shape, t.device) if isinstance(t, torch.Tensor) else None for t in tensors]}"
)
return True


class NPUBackend(Backend):
def double_quant(
Expand Down Expand Up @@ -75,23 +90,140 @@ def quantize_4bit(
A: torch.Tensor,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize=64,
blocksize: Optional[int] = None,
compress_statistics=False,
quant_type: Literal["fp4", "nf4"] = "fp4",
quant_type: Literal["fp4", "nf4"] = "nf4",
quant_storage=torch.uint8,
) -> Tuple[torch.Tensor, QuantState]:
raise NotImplementedError
if quant_type not in ["nf4"]:
raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.")
if compress_statistics:
raise NotImplementedError("compress_statistics is not implemented.")
if blocksize is None:
blocksize = 128

prev_device = torch.npu.current_device()
torch.npu.set_device(A.device)
if A.dtype in [torch.float32, torch.float16, torch.bfloat16]:
data = [
-1.0,
-0.6961928009986877,
-0.5250730514526367,
-0.39491748809814453,
-0.28444138169288635,
-0.18477343022823334,
-0.09105003625154495,
0.0,
0.07958029955625534,
0.16093020141124725,
0.24611230194568634,
0.33791524171829224,
0.44070982933044434,
0.5626170039176941,
0.7229568362236023,
1.0,
]
data = torch.tensor(data, device="npu", dtype=torch.float32).view(1, -1)
absmax = A.view(-1, blocksize).abs().max(dim=1, keepdim=True).values
a = A.view(-1, blocksize) / absmax.float()
diff = torch.abs(a.unsqueeze(-1) - data)
out = (torch.argmin(diff, dim=-1) + 8) % 16
out = out.reshape(-1, 2)
out = (out[:, 0] + out[:, 1] * 16).to(torch.uint8)
else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
assert_on_npu([A, absmax, out])
torch.npu.set_device(prev_device)

code = get_4bit_type(quant_type, device=A.device)
state = QuantState(
absmax=absmax,
shape=A.shape,
dtype=A.dtype,
blocksize=blocksize,
code=code,
quant_type=quant_type,
)

return out, state

def dequantize_4bit(
self,
A: torch.Tensor,
quant_state: Optional[QuantState] = None,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize: int = 64,
quant_type: Literal["fp4", "nf4"] = "fp4",
blocksize: Optional[int] = None,
quant_type: Literal["fp4", "nf4"] = "nf4",
) -> torch.Tensor:
raise NotImplementedError
if blocksize is None:
blocksize = 128
supported_blocksizes = [2048, 4096, 1024, 512, 256, 128, 64]
if blocksize not in supported_blocksizes:
raise ValueError(
f"The blockwise of {blocksize} is not supported. Supported values: {supported_blocksizes}"
)

if quant_state is None:
assert absmax is not None and out is not None
quant_state = QuantState(
absmax=absmax, shape=out.shape, dtype=out.dtype, blocksize=blocksize, quant_type=quant_type
)
else:
absmax = quant_state.absmax

if out is None:
out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device)

n = out.numel()

prev_device = torch.npu.current_device()
torch.npu.set_device(A.device)
assert_on_npu([A, absmax, out])

if quant_state.quant_type not in ["nf4"]:
raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.")

if out.dtype == torch.float32:
lib.cdequantize_blockwise_fp32_nf4(
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(n),
torch.npu.current_stream(),
)
elif out.dtype == torch.float16:
lib.cdequantize_blockwise_fp16_nf4(
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(n),
torch.npu.current_stream(),
)
elif out.dtype == torch.bfloat16:
# bf16: bf16 -> fp32 -> op -> fp32 -> bf16
absmax = absmax.to(torch.float32)
out = out.to(torch.float32)
lib.cdequantize_blockwise_fp32_nf4(
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(n),
torch.npu.current_stream(),
)
out = out.to(torch.bfloat16)
else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
torch.npu.set_device(prev_device)
is_transposed = True if A.shape[0] == 1 else False

if is_transposed:
return out.t()
else:
return out

def gemv_4bit(
self,
Expand Down
5 changes: 5 additions & 0 deletions bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from bitsandbytes.consts import DYNAMIC_LIBRARY_SUFFIX, PACKAGE_DIR
from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs, get_rocm_gpu_arch
from bitsandbytes.npu_specs import get_npu_specs

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -100,6 +101,10 @@ def get_native_library() -> BNBNativeLibrary:
binary_path = cuda_binary_path
else:
logger.warning("Could not find the bitsandbytes %s binary at %r", BNB_BACKEND, cuda_binary_path)
npu_specs = get_npu_specs()
if npu_specs:
binary_path = PACKAGE_DIR / f"libbitsandbytes_npu{DYNAMIC_LIBRARY_SUFFIX}"

logger.debug(f"Loading bitsandbytes native library from: {binary_path}")
dll = ct.cdll.LoadLibrary(str(binary_path))

Expand Down
10 changes: 8 additions & 2 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,12 @@ def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: b
def cpu(self, non_blocking: bool = False):
return self.to(device="cpu", non_blocking=non_blocking)

def npu(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
# `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
if isinstance(device, int):
device = f"npu:{device}"
return self.to(device="npu" if device is None else device, non_blocking=non_blocking)

def xpu(self, non_blocking: bool = False):
return self.to(device="xpu", non_blocking=non_blocking)

Expand All @@ -334,7 +340,7 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)

if device is not None and device.type in ["cuda", "cpu", "xpu"] and not self.bnb_quantized:
if device is not None and device.type in ["cuda", "cpu", "npu", "xpu"] and not self.bnb_quantized:
return self._quantize(device)
else:
if self.quant_state is not None:
Expand Down Expand Up @@ -497,7 +503,7 @@ def forward(self, x: torch.Tensor):
self.weight.quant_state = self.quant_state
else:
print(
"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.",
"FP4 quantization state not initialized. Please call .cuda(), .npu() or .to(device) on the LinearFP4 layer first.",
)
if not self.compute_type_is_set:
self.set_compute_type(x)
Expand Down
Loading
Loading