Skip to content

Commit

Permalink
Merge branch 'main' into dockerfile
Browse files Browse the repository at this point in the history
  • Loading branch information
rickardp authored Feb 27, 2024
2 parents efcd78f + 0488566 commit 530676c
Show file tree
Hide file tree
Showing 22 changed files with 364 additions and 262 deletions.
60 changes: 25 additions & 35 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@ on:
- 'setup.py'
- 'pyproject.toml'
- 'pytest.ini'
- '**/*.md'
release:
types: [ published ]

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true

jobs:

##
Expand All @@ -43,16 +46,10 @@ jobs:
uses: jwlawson/[email protected]
with:
cmake-version: '3.26.x'
- name: Add msbuild to PATH
uses: microsoft/[email protected]
if: ${{ startsWith(matrix.os, 'windows') }}
# Check out dependencies code
- uses: actions/checkout@v4
name: Check out NVidia cub
with:
repository: nvidia/cub
ref: 1.11.0
path: dependencies/cub
- name: Setup MSVC
if: startsWith(matrix.os, 'windows')
#uses: microsoft/[email protected] # to use msbuild
uses: ilammy/[email protected] # to use cl
# Compile C++ code
- name: Build C++
shell: bash
Expand All @@ -62,18 +59,14 @@ jobs:
build_arch=${{ matrix.arch }}
if [ ${build_os:0:6} == ubuntu -a ${build_arch} == aarch64 ]; then
# Allow cross-compile om aarch64
sudo apt-get install -y gcc-aarch64-linux-gnu binutils-aarch64-linux-gnu
fi
if [ ${build_os:0:5} == macos -a ${build_arch} == aarch64 ]; then
sudo apt-get install -y gcc-aarch64-linux-gnu binutils-aarch64-linux-gnu g++-aarch64-linux-gnu
cmake -DCMAKE_C_COMPILER=aarch64-linux-gnu-gcc -DCMAKE_CXX_COMPILER=aarch64-linux-gnu-g++ -DCOMPUTE_BACKEND=cpu .
elif [ ${build_os:0:5} == macos -a ${build_arch} == aarch64 ]; then
cmake -DCMAKE_OSX_ARCHITECTURES=arm64 -DCOMPUTE_BACKEND=cpu .
else
cmake -DCOMPUTE_BACKEND=cpu .
fi
if [ ${build_os:0:7} == windows ]; then
pwsh -Command "msbuild bitsandbytes.vcxproj /property:Configuration=Release"
else
make
fi
cmake --build . --config Release
mkdir -p output/${{ matrix.os }}/${{ matrix.arch }}
( shopt -s nullglob && cp bitsandbytes/*.{so,dylib,dll} output/${{ matrix.os }}/${{ matrix.arch }}/ )
- name: Upload build artifact
Expand Down Expand Up @@ -121,37 +114,34 @@ jobs:
id: cuda-toolkit
with:
cuda: ${{ env.CUDA_VERSION }}
method: 'local'
# sub-packages: '["nvcc","cudart","nvrtc_dev","cublas_dev","cusparse_dev","visual_studio_integration"]'
- name: Add msbuild to PATH
uses: microsoft/[email protected]
if: ${{ startsWith(matrix.os, 'windows') }}
# Check out dependencies code
- uses: actions/checkout@v4
name: Check out NVidia cub
with:
repository: nvidia/cub
ref: 1.11.0
path: dependencies/cub
method: 'network'
sub-packages: '["nvcc","cudart","cusparse","cublas","thrust","nvrtc_dev","cublas_dev","cusparse_dev"]'
linux-local-args: '["--toolkit"]'
use-github-cache: false
- name: Setup MSVC
if: startsWith(matrix.os, 'windows')
#uses: microsoft/[email protected] # to use msbuild
uses: ilammy/[email protected] # to use cl
# Compile C++ code
- name: Build C++
shell: bash
run: |
set -ex
build_os=${{ matrix.os }}
build_arch=${{ matrix.arch }}
[[ "${{ matrix.os }}" = windows-* ]] && python3 -m pip install ninja
for NO_CUBLASLT in ON OFF; do
if [ ${build_os:0:6} == ubuntu ]; then
image=nvidia/cuda:${{ env.CUDA_VERSION }}-devel-ubuntu22.04
echo "Using image $image"
docker run --platform linux/$build_arch -i -w /src -v $PWD:/src $image sh -c \
"apt-get update \
&& DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \
&& cmake -DCOMPUTE_BACKEND=cuda -DNO_CUBLASLT=${NO_CUBLASLT} . \
&& make"
&& cmake -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY=\"50;52;60;61;70;75;80;86;89;90\" -DNO_CUBLASLT=${NO_CUBLASLT} . \
&& cmake --build ."
else
cmake -DCOMPUTE_BACKEND=cuda -DNO_CUBLASLT=${NO_CUBLASLT} .
pwsh -Command "msbuild bitsandbytes.vcxproj /property:Configuration=Release"
cmake -G Ninja -DCOMPUTE_BACKEND=cuda -DNO_CUBLASLT=${NO_CUBLASLT} -DCMAKE_BUILD_TYPE=Release -S .
cmake --build . --config Release
fi
done
mkdir -p output/${{ matrix.os }}/${{ matrix.arch }}
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ repos:
args:
- --fix=lf
- repo: https://github.com/crate-ci/typos
rev: v1.17.2
rev: v1.18.2
hooks:
- id: typos
52 changes: 42 additions & 10 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ endif()

set(BNB_OUTPUT_NAME "bitsandbytes")

message(STATUS "Building with backend ${COMPUTE_BACKEND}")
message(STATUS "Configuring ${PROJECT_NAME} (Backend: ${COMPUTE_BACKEND})")

if(${COMPUTE_BACKEND} STREQUAL "cuda")
if(APPLE)
Expand Down Expand Up @@ -82,6 +82,31 @@ if(BUILD_CUDA)
message(FATAL_ERROR "CUDA Version > 12 is not supported")
endif()

# CMake < 3.23.0 does not define CMAKE_CUDA_ARCHITECTURES_ALL.
if(CMAKE_VERSION VERSION_LESS "3.23.0")
message(STATUS "CMake < 3.23.0; determining CUDA architectures supported...")

# 11.x and 12.x both support these at a minimum.
set(CMAKE_CUDA_ARCHITECTURES_ALL 50 52 53 60 61 62 70 72 75 80)
set(CMAKE_CUDA_ARCHITECTURES_ALL_MAJOR 50 60 70 80)

# CUDA 11.1 adds Ampere support for GA102-GA107.
if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "11.1")
list(APPEND CMAKE_CUDA_ARCHITECTURES_ALL 86)
endif()

# CUDA 11.4 adds Ampere support for GA10B.
if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "11.4")
list(APPEND CMAKE_CUDA_ARCHITECTURES_ALL 87)
endif()

# CUDA 11.8 adds support for Ada and Hopper.
if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "11.8")
list(APPEND CMAKE_CUDA_ARCHITECTURES_ALL 89 90)
list(APPEND CMAKE_CUDA_ARCHITECTURES_ALL_MAJOR 90)
endif()
endif()

string(APPEND CMAKE_CUDA_FLAGS " --use_fast_math")
if(PTXAS_VERBOSE)
# Verbose? Outputs register usage information, and other things...
Expand All @@ -103,10 +128,18 @@ if(BUILD_CUDA)
message(STATUS "CUDA Capabilities Available: ${POSSIBLE_CAPABILITIES}")
message(STATUS "CUDA Capabilities Selected: ${COMPUTE_CAPABILITY}")

foreach(capability ${COMPUTE_CAPABILITY})
string(APPEND CMAKE_CUDA_FLAGS " -gencode arch=compute_${capability},code=sm_${capability}")
endforeach()

# Use the "real" option to build native cubin for all selections.
# Ensure we build the PTX for the latest version.
# This behavior of adding a PTX (virtual) target for the highest architecture
# is similar to how the "all" and "all-major" options would behave in CMake >= 3.23.
# TODO: Consider bumping CMake requirement and using CMAKE_CUDA_ARCHITECTURES=[all | native] by default
list(REMOVE_DUPLICATES COMPUTE_CAPABILITY)
list(SORT COMPUTE_CAPABILITY COMPARE NATURAL)
list(POP_BACK COMPUTE_CAPABILITY _LATEST_CAPABILITY)
list(TRANSFORM COMPUTE_CAPABILITY APPEND "-real" OUTPUT_VARIABLE CMAKE_CUDA_ARCHITECTURES)
list(APPEND CMAKE_CUDA_ARCHITECTURES ${_LATEST_CAPABILITY})

message(STATUS "CUDA Targets: ${CMAKE_CUDA_ARCHITECTURES}")
message(STATUS "CUDA NVCC Flags: ${CMAKE_CUDA_FLAGS}")

list(APPEND SRC_FILES ${CUDA_FILES})
Expand Down Expand Up @@ -149,7 +182,6 @@ endif()
# Weird MSVC hacks
if(MSVC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2 /fp:fast")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /arch:AVX2 /fp:fast")
endif()

set_source_files_properties(${CPP_FILES} PROPERTIES LANGUAGE CXX)
Expand Down Expand Up @@ -182,10 +214,10 @@ if(WIN32)
endif()
set_target_properties(bitsandbytes PROPERTIES OUTPUT_NAME ${BNB_OUTPUT_NAME})
if(MSVC)
set_target_properties(bitsandbytes PROPERTIES LIBRARY_OUTPUT_DIRECTORY_RELEASE bitsandbytes)
set_target_properties(bitsandbytes PROPERTIES LIBRARY_OUTPUT_DIRECTORY_DEBUG bitsandbytes)
set_target_properties(bitsandbytes PROPERTIES RUNTIME_OUTPUT_DIRECTORY_RELEASE bitsandbytes)
set_target_properties(bitsandbytes PROPERTIES RUNTIME_OUTPUT_DIRECTORY_DEBUG bitsandbytes)
set_target_properties(bitsandbytes PROPERTIES LIBRARY_OUTPUT_DIRECTORY_RELEASE "${PROJECT_SOURCE_DIR}/bitsandbytes")
set_target_properties(bitsandbytes PROPERTIES LIBRARY_OUTPUT_DIRECTORY_DEBUG "${PROJECT_SOURCE_DIR}/bitsandbytes")
set_target_properties(bitsandbytes PROPERTIES RUNTIME_OUTPUT_DIRECTORY_RELEASE "${PROJECT_SOURCE_DIR}/bitsandbytes")
set_target_properties(bitsandbytes PROPERTIES RUNTIME_OUTPUT_DIRECTORY_DEBUG "${PROJECT_SOURCE_DIR}/bitsandbytes")
endif()

set_target_properties(bitsandbytes PROPERTIES LIBRARY_OUTPUT_DIRECTORY bitsandbytes)
6 changes: 3 additions & 3 deletions bitsandbytes/cuda_setup/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

DYNAMIC_LIBRARY_SUFFIX = { "Darwin": ".dylib", "Windows": ".dll", "Linux": ".so"}.get(platform.system(), ".so")
if platform.system() == "Windows": # Windows
CUDA_RUNTIME_LIBS = ["nvcuda.dll"]
CUDA_RUNTIME_LIBS = ["cudart64_110.dll", "cudart64_12.dll"]
else: # Linux or other
# these are the most common libs names
# libcudart.so is missing by default for a conda install with PyTorch 2.0 and instead
Expand Down Expand Up @@ -161,7 +161,7 @@ def run_cuda_setup(self):
self.add_log_entry('3. CUDA not installed')
self.add_log_entry('4. You have multiple conflicting CUDA libraries')
self.add_log_entry('5. Required library not pre-compiled for this bitsandbytes release!')
self.add_log_entry('CUDA SETUP: If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION` for example, `make CUDA_VERSION=113`.')
self.add_log_entry('CUDA SETUP: If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION` for example, `make CUDA_VERSION=118`.')
self.add_log_entry('CUDA SETUP: The CUDA version for the compile might depend on your conda install. Inspect CUDA version via `conda list | grep cuda`.')
self.add_log_entry('='*80)
self.add_log_entry('')
Expand Down Expand Up @@ -268,7 +268,7 @@ def warn_in_case_of_duplicates(results_paths: Set[Path]) -> None:
"BNB_CUDA_VERSION=122 python ..."
"OR set the environmental variable in your .bashrc: export BNB_CUDA_VERSION=122"
"In the case of a manual override, make sure you set the LD_LIBRARY_PATH, e.g."
"export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.2")
"export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-12.2")
CUDASetup.get_instance().add_log_entry(warning_msg, is_warning=True)


Expand Down
15 changes: 15 additions & 0 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,21 @@ def to(self, device):
self.state2.absmax = self.state2.absmax.to(device)
self.state2.code = self.state2.code.to(device)

def __eq__(self, other):
if not isinstance(other, QuantState):
return False

return (
torch.allclose(self.absmax, other.absmax, atol=1e-6) and
self.shape == other.shape and
torch.allclose(self.code, other.code, atol=1e-6) and
self.dtype == other.dtype and
self.blocksize == other.blocksize and
self.quant_type == other.quant_type and
(self.offset == other.offset if self.offset is not None and other.offset is not None else self.offset is other.offset) and
(self.state2 == other.state2 if self.state2 is not None and other.state2 is not None else self.state2 is other.state2)
)


def quantize_blockwise(
A: Tensor,
Expand Down
45 changes: 41 additions & 4 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import copy
from typing import Any, Dict, Optional, TypeVar, Union, overload
import warnings

Expand Down Expand Up @@ -191,7 +192,7 @@ class Params4bit(torch.nn.Parameter):
def __new__(
cls,
data: Optional[torch.Tensor] = None,
requires_grad=True,
requires_grad=False, # quantized weights should be frozen by default
quant_state: Optional[QuantState] = None,
blocksize: int = 64,
compress_statistics: bool = True,
Expand All @@ -214,6 +215,37 @@ def __new__(
self.module = module
return self

def __getstate__(self):
state = self.__dict__
state["data"] = self.data
state["requires_grad"] = self.requires_grad
return state

def __setstate__(self, state):
self.requires_grad = state["requires_grad"]
self.blocksize = state["blocksize"]
self.compress_statistics = state["compress_statistics"]
self.quant_type = state["quant_type"]
self.quant_state = state["quant_state"]
self.data = state["data"]
self.quant_storage = state["quant_storage"]
self.bnb_quantized = state["bnb_quantized"]
self.module = state["module"]

def __deepcopy__(self,memo):
new_instance = type(self).__new__(type(self))
state = self.__getstate__()
new_instance.__setstate__(state)
new_instance.quant_state = copy.deepcopy(state["quant_state"])
new_instance.data = copy.deepcopy(state["data"])
return new_instance

def __copy__(self):
new_instance = type(self).__new__(type(self))
state = self.__getstate__()
new_instance.__setstate__(state)
return new_instance

@classmethod
def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], requires_grad: bool = False, device='cuda', **kwargs) -> "Params4bit":
self = torch.Tensor._make_subclass(cls, data.to(device))
Expand All @@ -227,8 +259,13 @@ def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any],

def _quantize(self, device):
w = self.data.contiguous().cuda(device)
w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics,
quant_type=self.quant_type, quant_storage=self.quant_storage)
w_4bit, quant_state = bnb.functional.quantize_4bit(
w,
blocksize=self.blocksize,
compress_statistics=self.compress_statistics,
quant_type=self.quant_type,
quant_storage=self.quant_storage,
)
self.data = w_4bit
self.quant_state = quant_state
if self.module is not None:
Expand Down Expand Up @@ -275,7 +312,7 @@ class Linear4bit(nn.Linear):
compute datatypes such as FP4 and NF4.
In order to quantize a linear layer one should first load the original fp16 / bf16 weights into
the Linear8bitLt module, then call `quantized_module.to("cuda")` to quantize the fp16 / bf16 weights.
the Linear4bit module, then call `quantized_module.to("cuda")` to quantize the fp16 / bf16 weights.
Example:
Expand Down
4 changes: 3 additions & 1 deletion csrc/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,8 @@ __global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const f
for(int j = threadIdx.x; j < BLOCK_ESTIMATE; j+=blockDim.x)
temp_storage.smem_qidx[j] = -1;

__syncthreads();

if(threadIdx.x < 256)
{
float q_interval = (1.0f-(2.0f*offset))/255.0f;
Expand Down Expand Up @@ -3073,7 +3075,7 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
//// 4. do dequantization from register of B into second pair of registers
//// 5. store (4) into fragment
//// 6. matmul aggregate into fragment C
//// 7. aggreecate files of C into shared memory block C
//// 7. aggregate files of C into shared memory block C
//// 8. sum (7)
//// 9. write outputs to matmul output matrix
//}
Expand Down
Loading

0 comments on commit 530676c

Please sign in to comment.