Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add xpu support for 4bit (nf4/fp4) and 8bit quantization #1257

Draft
wants to merge 20 commits into
base: multi-backend-refactor
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 68 additions & 7 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ jobs:
# This job matrix builds the CUDA versions of the libraries for platforms that support CUDA (Linux x64/aarch64 + Windows x64)
##
build-shared-libs-cuda:
if: github.ref_name != 'multi-backend-refactor'
strategy:
matrix:
os: [ubuntu-latest, windows-latest]
Expand Down Expand Up @@ -107,7 +108,7 @@ jobs:
os: [ubuntu-latest]
arch: [x86_64]
rocm_version:
["6.1.2"]
["6.1.2", "6.2"]
runs-on: ${{ matrix.os }} # One day, we could run them on native agents. Azure supports this now but it's planned only for Q3 2023 for hosted agents
steps:
- uses: actions/checkout@v4
Expand All @@ -116,10 +117,23 @@ jobs:
uses: docker/setup-qemu-action@v2
- name: Clean up disk space
run: |
sudo rm -rf /usr/share/dotnet
sudo rm -rf /opt/ghc
sudo rm -rf "/usr/local/share/boost"
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
sudo rm -rf \
/usr/share/dotnet \
/opt/ghc \
"/usr/local/share/boost" \
"$AGENT_TOOLSDIRECTORY" \
/opt/hostedtoolcache \
/opt/google/chrome \
/opt/microsoft/msedge \
/opt/microsoft/powershell \
/opt/pipx \
/usr/lib/mono \
/usr/local/julia* \
/usr/local/lib/android \
/usr/local/lib/node_modules \
/usr/local/share/chromium \
/usr/local/share/powershell \
/usr/share/swift
- name: Build C++
run: bash .github/scripts/build-rocm.sh
env:
Expand All @@ -135,7 +149,7 @@ jobs:
build-wheels:
needs:
- build-shared-libs
- build-shared-libs-cuda
# - build-shared-libs-cuda reduce the pkg size + build times for the preview release
- build-shared-libs-rocm
strategy:
matrix:
Expand All @@ -153,6 +167,13 @@ jobs:
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 1 # shallow clone
- name: Fetch tags for dynamic versioning in setup.py
run: |
git fetch --depth=1 origin --tags
echo "Available Git tags:"
git tag -n
- name: Download build artifact
uses: actions/download-artifact@v4
with:
Expand All @@ -170,7 +191,8 @@ jobs:
python-version: ${{ matrix.python-version }}
cache: pip
- run: pip install build wheel
- run: python -m build .
# for now need to do the below instead of prior `python -m build .`, which didn't allow us to access git tags
- run: python -m build --sdist && python -m build --wheel
- name: Determine and Set Platform Tag, then Tag Wheel
shell: bash
run: |
Expand All @@ -184,6 +206,45 @@ jobs:
path: dist/bitsandbytes-*.whl
retention-days: 7

upload-pre-release-wheels:
name: Create release and upload artifacts
runs-on: ubuntu-latest
if: github.ref_name == 'multi-backend-refactor'
permissions:
contents: write
needs:
- build-wheels
steps:
- name: Download and rename artifacts
uses: actions/download-artifact@v4
with:
path: tmp/
pattern: "bdist_wheel_*"
merge-multiple: true
- name: Inspect tmp directory after downloading artifacts
run: ls -alFR tmp/
- name: Move and rename wheel files with pattern replacement
run: |
mkdir -p wheels/
find tmp/ -type f -name '*.whl' -print0 | while IFS= read -r -d '' wheel; do
wheel_filename=$(basename "$wheel")
# Remove the gith hash, e.g. `+1234567`, for a stable download link on the multi-backend pre-release
cleaned_filename=$(echo "$wheel_filename" | sed -E 's/\+[0-9a-f]{7}-/-/g')
mv "$wheel" "wheels/$cleaned_filename"
done
- name: Inspect wheels directory after renaming files
run: ls -alFR wheels/
- name: Create release and upload artifacts
uses: softprops/[email protected]
with:
files: wheels/*.whl
prerelease: true
name: Multi-Backend Preview
tag_name: continuous-release_multi-backend-refactor
make_latest: false
draft: false
target_commitish: ${{ github.sha }}

audit-wheels:
needs: build-wheels
runs-on: ubuntu-latest
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ dmypy.json
# vim
*.swp

# BNB-specific stuff
dependencies
cuda_build
output/
bitsandbytes/_version.py
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ if(BUILD_CUDA)
# This needs to be added *before* we try to enable the CUDA language so CMake's compiler check passes.
if(MSVC AND MSVC_VERSION VERSION_GREATER_EQUAL 1940)
string(APPEND CMAKE_CUDA_FLAGS " --allow-unsupported-compiler")

# This is needed to build with VS2022 17.11+ and CUDA < 12.4.
if (MSVC_VERSION VERSION_GREATER_EQUAL 1941)
string(APPEND CMAKE_CUDA_FLAGS " -D_ALLOW_COMPILER_AND_STL_VERSION_MISMATCH")
endif()
endif()

enable_language(CUDA) # This will fail if CUDA is not found
Expand Down
5 changes: 3 additions & 2 deletions bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# Import the dynamically generated version from _version.py (see setup.py)
from ._version import __version__ # isort: skip # type: ignore

import torch

from . import research, utils
Expand Down Expand Up @@ -73,5 +76,3 @@
"optim.optimizer.Optimizer8bit": False,
"optim.optimizer.MockArgs": False,
}

__version__ = "0.43.3.dev"
19 changes: 14 additions & 5 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def backward(ctx, grad_output):

def supports_igemmlt(device: torch.device) -> bool:
"""check if this device supports the optimized int8 kernel"""
if device == torch.device("cpu"):
if device == torch.device("cpu") or torch.device("xpu"):
return True
if torch.version.hip:
return False if BNB_HIP_VERSION < 601 else True
Expand Down Expand Up @@ -321,7 +321,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):

# Cast A to fp16
A_dtype = torch.float16
if A.device == torch.device("cpu"):
if A.device == torch.device("cpu") or torch.device("xpu"):
A_dtype = torch.bfloat16
if A.dtype != A_dtype:
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to {A_dtype} during quantization")
Expand Down Expand Up @@ -463,7 +463,9 @@ def backward(ctx, grad_output):
if len(grad_output.shape) == 3:
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()

Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = None, None, None, None, None
if req_gradB or (req_gradA and state.CBt):
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
if req_gradB:
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
Expand Down Expand Up @@ -575,8 +577,15 @@ def matmul_4bit(
bias=None,
):
assert quant_state is not None
if (A.numel() == A.shape[-1] or A.device.type == "cpu") and A.requires_grad == False:
# CPU backend does not require A to be a vector
if A.device.type == "cpu" and A.requires_grad == False:
if getattr(quant_state, "ipex", False):
out = F.gemv_4bit(A, B, out, state=quant_state)
if bias is not None:
out += bias
return out
else:
return MatMul4Bit.apply(A, B, out, bias, quant_state)
elif A.numel() == A.shape[-1] and A.requires_grad == False:
if A.shape[-1] % quant_state.blocksize != 0:
warn(
f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}",
Expand Down
24 changes: 14 additions & 10 deletions bitsandbytes/backends/cpu_xpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

ipex_cpu = ipex if ipex._C._has_cpu() else None
ipex_xpu = ipex if ipex._C._has_xpu() else None
ipex_cpu_only = ipex._C._has_cpu() and (not ipex._C._has_xpu())
except BaseException:
ipex_cpu = None
ipex_xpu = None
Expand Down Expand Up @@ -342,7 +343,7 @@ def quantize_4bit_impl(
scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1)
scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)
# map [-1, 1] to nf4/fp4
out_uint8 = torch.empty(scaled_A.shape, dtype=torch.uint8)
out_uint8 = torch.empty(scaled_A.shape, dtype=torch.uint8, device=A.device)
if quant_type == "nf4":
for i in range(len(NF4_QUANT_TABLE)):
out_uint8[scaled_A > NF4_QUANT_TABLE[i]] = i
Expand Down Expand Up @@ -438,11 +439,11 @@ def dequantize_4bit_impl(
if quant_state.nested:
raise NotImplementedError("bnb_4bit_use_double_quant is not supported yet for CPU/XPU")

if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and hasattr(quant_state, "op_context"):
assert quant_state.op_context is not None
A = quant_state.op_context.to_public(quant_state.op_context.get_weight())
A = A.reshape(-1)
absmax = quant_state.op_context.get_scales().reshape(-1)
if ipex_cpu_only and _ipex_cpu_version_prereq(2, 5) and getattr(quant_state, "ipex", False):
A = torch.ops.ipex_prepack.woq_linear_unpack_weight(
A, "nf4", quant_state.shape, 2
)
quant_state.ipex = False

if out is None:
out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device)
Expand All @@ -452,7 +453,9 @@ def dequantize_4bit_impl(
out_uint8 = torch.empty(A.size(0) * 2, dtype=torch.uint8, device=A.device)
out_uint8[::2] = A.bitwise_and(0xF)
out_uint8[1::2] = A.bitwise_right_shift(4)
out_dq = torch.empty(out_uint8.shape).to(quant_state.dtype)
out_dq = torch.empty(out_uint8.shape).to(quant_state.dtype).to(A.device)
# quant_state.code is fp32, cast to quant_state dtype to avoid the mismatch issue
quant_state.code = quant_state.code.to(quant_state.dtype)
for i in range(len(quant_state.code)):
out_dq[out_uint8 == i] = quant_state.code[i]

Expand Down Expand Up @@ -510,9 +513,10 @@ def gemm_4bit_impl(
torch.Tensor:
GEMM output tensor.
"""
if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and hasattr(state, "op_context"):
assert state.op_context is not None
output = torch.ops.torch_ipex.ipex_woq_linear(A, state.op_context.get_data_handle())
if (ipex_cpu and _ipex_cpu_version_prereq(2, 5)) or (ipex_xpu and _ipex_xpu_version_prereq(2, 5)) and getattr(state, "ipex", False):
output = torch.ops.torch_ipex.woq_linear(A, B, "nf4", state.shape,
state.new_scales, state.new_zeros, None, None, state.blocksize,
ipex_cpu.quantization.WoqLowpMode.BF16, 1, state.compensation)
else:
dqB = dequantize_4bit_impl(B, state, blocksize=state.blocksize).t()
output = torch.matmul(A, dqB.to(A.dtype))
Expand Down
Loading