Skip to content

Commit

Permalink
Merge branch 'TimDettmers:main' into setup-pep660
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdouglas authored Feb 28, 2024
2 parents 636c62d + f9eba9c commit f9c2cc5
Show file tree
Hide file tree
Showing 10 changed files with 132 additions and 68 deletions.
10 changes: 7 additions & 3 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@ on:
- 'setup.py'
- 'pyproject.toml'
- 'pytest.ini'
- '**/*.md'
release:
types: [ published ]

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true

jobs:

##
Expand Down Expand Up @@ -53,7 +56,8 @@ jobs:
build_os=${{ matrix.os }}
build_arch=${{ matrix.arch }}
if [ ${build_os:0:6} == ubuntu -a ${build_arch} == aarch64 ]; then
# Allow cross-compile om aarch64
# Allow cross-compile on aarch64
sudo apt-get update
sudo apt-get install -y gcc-aarch64-linux-gnu binutils-aarch64-linux-gnu g++-aarch64-linux-gnu
cmake -DCMAKE_C_COMPILER=aarch64-linux-gnu-gcc -DCMAKE_CXX_COMPILER=aarch64-linux-gnu-g++ -DCOMPUTE_BACKEND=cpu .
elif [ ${build_os:0:5} == macos -a ${build_arch} == aarch64 ]; then
Expand Down Expand Up @@ -125,7 +129,7 @@ jobs:
docker run --platform linux/$build_arch -i -w /src -v $PWD:/src $image sh -c \
"apt-get update \
&& DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \
&& cmake -DCOMPUTE_BACKEND=cuda -DNO_CUBLASLT=${NO_CUBLASLT} . \
&& cmake -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY=\"50;52;60;61;70;75;80;86;89;90\" -DNO_CUBLASLT=${NO_CUBLASLT} . \
&& cmake --build ."
else
cmake -G Ninja -DCOMPUTE_BACKEND=cuda -DNO_CUBLASLT=${NO_CUBLASLT} -DCMAKE_BUILD_TYPE=Release -S .
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ repos:
args:
- --fix=lf
- repo: https://github.com/crate-ci/typos
rev: v1.17.2
rev: v1.18.2
hooks:
- id: typos
52 changes: 42 additions & 10 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ endif()

set(BNB_OUTPUT_NAME "bitsandbytes")

message(STATUS "Building with backend ${COMPUTE_BACKEND}")
message(STATUS "Configuring ${PROJECT_NAME} (Backend: ${COMPUTE_BACKEND})")

if(${COMPUTE_BACKEND} STREQUAL "cuda")
if(APPLE)
Expand Down Expand Up @@ -82,6 +82,31 @@ if(BUILD_CUDA)
message(FATAL_ERROR "CUDA Version > 12 is not supported")
endif()

# CMake < 3.23.0 does not define CMAKE_CUDA_ARCHITECTURES_ALL.
if(CMAKE_VERSION VERSION_LESS "3.23.0")
message(STATUS "CMake < 3.23.0; determining CUDA architectures supported...")

# 11.x and 12.x both support these at a minimum.
set(CMAKE_CUDA_ARCHITECTURES_ALL 50 52 53 60 61 62 70 72 75 80)
set(CMAKE_CUDA_ARCHITECTURES_ALL_MAJOR 50 60 70 80)

# CUDA 11.1 adds Ampere support for GA102-GA107.
if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "11.1")
list(APPEND CMAKE_CUDA_ARCHITECTURES_ALL 86)
endif()

# CUDA 11.4 adds Ampere support for GA10B.
if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "11.4")
list(APPEND CMAKE_CUDA_ARCHITECTURES_ALL 87)
endif()

# CUDA 11.8 adds support for Ada and Hopper.
if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "11.8")
list(APPEND CMAKE_CUDA_ARCHITECTURES_ALL 89 90)
list(APPEND CMAKE_CUDA_ARCHITECTURES_ALL_MAJOR 90)
endif()
endif()

string(APPEND CMAKE_CUDA_FLAGS " --use_fast_math")
if(PTXAS_VERBOSE)
# Verbose? Outputs register usage information, and other things...
Expand All @@ -103,10 +128,18 @@ if(BUILD_CUDA)
message(STATUS "CUDA Capabilities Available: ${POSSIBLE_CAPABILITIES}")
message(STATUS "CUDA Capabilities Selected: ${COMPUTE_CAPABILITY}")

foreach(capability ${COMPUTE_CAPABILITY})
string(APPEND CMAKE_CUDA_FLAGS " -gencode arch=compute_${capability},code=sm_${capability}")
endforeach()

# Use the "real" option to build native cubin for all selections.
# Ensure we build the PTX for the latest version.
# This behavior of adding a PTX (virtual) target for the highest architecture
# is similar to how the "all" and "all-major" options would behave in CMake >= 3.23.
# TODO: Consider bumping CMake requirement and using CMAKE_CUDA_ARCHITECTURES=[all | native] by default
list(REMOVE_DUPLICATES COMPUTE_CAPABILITY)
list(SORT COMPUTE_CAPABILITY COMPARE NATURAL)
list(POP_BACK COMPUTE_CAPABILITY _LATEST_CAPABILITY)
list(TRANSFORM COMPUTE_CAPABILITY APPEND "-real" OUTPUT_VARIABLE CMAKE_CUDA_ARCHITECTURES)
list(APPEND CMAKE_CUDA_ARCHITECTURES ${_LATEST_CAPABILITY})

message(STATUS "CUDA Targets: ${CMAKE_CUDA_ARCHITECTURES}")
message(STATUS "CUDA NVCC Flags: ${CMAKE_CUDA_FLAGS}")

list(APPEND SRC_FILES ${CUDA_FILES})
Expand Down Expand Up @@ -149,7 +182,6 @@ endif()
# Weird MSVC hacks
if(MSVC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2 /fp:fast")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /arch:AVX2 /fp:fast")
endif()

set_source_files_properties(${CPP_FILES} PROPERTIES LANGUAGE CXX)
Expand Down Expand Up @@ -182,10 +214,10 @@ if(WIN32)
endif()
set_target_properties(bitsandbytes PROPERTIES OUTPUT_NAME ${BNB_OUTPUT_NAME})
if(MSVC)
set_target_properties(bitsandbytes PROPERTIES LIBRARY_OUTPUT_DIRECTORY_RELEASE bitsandbytes)
set_target_properties(bitsandbytes PROPERTIES LIBRARY_OUTPUT_DIRECTORY_DEBUG bitsandbytes)
set_target_properties(bitsandbytes PROPERTIES RUNTIME_OUTPUT_DIRECTORY_RELEASE bitsandbytes)
set_target_properties(bitsandbytes PROPERTIES RUNTIME_OUTPUT_DIRECTORY_DEBUG bitsandbytes)
set_target_properties(bitsandbytes PROPERTIES LIBRARY_OUTPUT_DIRECTORY_RELEASE "${PROJECT_SOURCE_DIR}/bitsandbytes")
set_target_properties(bitsandbytes PROPERTIES LIBRARY_OUTPUT_DIRECTORY_DEBUG "${PROJECT_SOURCE_DIR}/bitsandbytes")
set_target_properties(bitsandbytes PROPERTIES RUNTIME_OUTPUT_DIRECTORY_RELEASE "${PROJECT_SOURCE_DIR}/bitsandbytes")
set_target_properties(bitsandbytes PROPERTIES RUNTIME_OUTPUT_DIRECTORY_DEBUG "${PROJECT_SOURCE_DIR}/bitsandbytes")
endif()

set_target_properties(bitsandbytes PROPERTIES LIBRARY_OUTPUT_DIRECTORY bitsandbytes)
6 changes: 3 additions & 3 deletions bitsandbytes/cuda_setup/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

DYNAMIC_LIBRARY_SUFFIX = { "Darwin": ".dylib", "Windows": ".dll", "Linux": ".so"}.get(platform.system(), ".so")
if platform.system() == "Windows": # Windows
CUDA_RUNTIME_LIBS = ["nvcuda.dll"]
CUDA_RUNTIME_LIBS = ["cudart64_110.dll", "cudart64_12.dll"]
else: # Linux or other
# these are the most common libs names
# libcudart.so is missing by default for a conda install with PyTorch 2.0 and instead
Expand Down Expand Up @@ -161,7 +161,7 @@ def run_cuda_setup(self):
self.add_log_entry('3. CUDA not installed')
self.add_log_entry('4. You have multiple conflicting CUDA libraries')
self.add_log_entry('5. Required library not pre-compiled for this bitsandbytes release!')
self.add_log_entry('CUDA SETUP: If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION` for example, `make CUDA_VERSION=113`.')
self.add_log_entry('CUDA SETUP: If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION` for example, `make CUDA_VERSION=118`.')
self.add_log_entry('CUDA SETUP: The CUDA version for the compile might depend on your conda install. Inspect CUDA version via `conda list | grep cuda`.')
self.add_log_entry('='*80)
self.add_log_entry('')
Expand Down Expand Up @@ -268,7 +268,7 @@ def warn_in_case_of_duplicates(results_paths: Set[Path]) -> None:
"BNB_CUDA_VERSION=122 python ..."
"OR set the environmental variable in your .bashrc: export BNB_CUDA_VERSION=122"
"In the case of a manual override, make sure you set the LD_LIBRARY_PATH, e.g."
"export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.2")
"export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-12.2")
CUDASetup.get_instance().add_log_entry(warning_msg, is_warning=True)


Expand Down
2 changes: 1 addition & 1 deletion csrc/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3075,7 +3075,7 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
//// 4. do dequantization from register of B into second pair of registers
//// 5. store (4) into fragment
//// 6. matmul aggregate into fragment C
//// 7. aggreecate files of C into shared memory block C
//// 7. aggregate files of C into shared memory block C
//// 8. sum (7)
//// 9. write outputs to matmul output matrix
//}
Expand Down
10 changes: 6 additions & 4 deletions docs/source/installation.mdx
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Installation

bitsandbytes is only supported on CUDA GPUs for CUDA versions **10.2 - 12.0**. Select your operating system below to see the installation instructions.
bitsandbytes is only supported on CUDA GPUs for CUDA versions **11.0 - 12.3**. Select your operating system below to see the installation instructions.

<hfoptions id="OS system">
<hfoption id="Linux">
Expand All @@ -21,7 +21,9 @@ To install from PyPI.
pip install bitsandbytes
```

To compile from source, you need CMake >= **3.22.1** and Python >= **3.10** installed. Make sure you have a compiler installed to compile C++ (gcc, make, headers, etc.). For example, to install a compiler and CMake on Ubuntu:
## Alternative: Compiling from source

To compile from source, you need CMake >= **3.22.1** and Python >= **3.8** installed. Make sure you have a compiler installed to compile C++ (gcc, make, headers, etc.). For example, to install a compiler and CMake on Ubuntu:

```bash
apt-get install -y build-essential cmake
Expand All @@ -47,7 +49,7 @@ pip install .

Windows systems require Visual Studio with C++ support as well as an installation of the CUDA SDK.

You'll need to build bitsandbytes from source. To compile from source, you need CMake >= **3.22.1** and Python >= **3.10** installed. You should also install CUDA Toolkit by following the [CUDA Installation Guide for Windows](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html) guide from NVIDIA.
You'll need to build bitsandbytes from source. To compile from source, you need CMake >= **3.22.1** and Python >= **3.8** installed. You should also install CUDA Toolkit by following the [CUDA Installation Guide for Windows](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html) guide from NVIDIA.

```bash
git clone https://github.com/TimDettmers/bitsandbytes.git && cd bitsandbytes/
Expand Down Expand Up @@ -81,7 +83,7 @@ Then locally install the CUDA version you need with this script from bitsandbyte
```bash
wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/install_cuda.sh
# Syntax cuda_install CUDA_VERSION INSTALL_PREFIX EXPORT_TO_BASH
# CUDA_VERSION in {110, 111, 112, 113, 114, 115, 116, 117, 118, 120, 121, 122}
# CUDA_VERSION in {110, 111, 112, 113, 114, 115, 116, 117, 118, 120, 121, 122, 123}
# EXPORT_TO_BASH in {0, 1} with 0=False and 1=True

# For example, the following installs CUDA 11.7 to ~/local/cuda-11.7 and exports the path to your .bashrc
Expand Down
61 changes: 54 additions & 7 deletions docs/source/integrations.mdx
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Transformers

With Transformers it's very easy to load any model in 4 or 8-bit, quantizing them on the fly with bitsandbytes primitives.
With Transformers it's very easy to load any model in 4 or 8-bit, quantizing them on the fly with `bitsandbytes` primitives.

Please review the [bitsandbytes section in the Accelerate docs](https://huggingface.co/docs/transformers/v4.37.2/en/quantization#bitsandbytes).
Please review the [`bitsandbytes` section in the Transformers docs](https://huggingface.co/docs/transformers/main/en/quantization#bitsandbytes).

Details about the BitsAndBytesConfig can be found [here](https://huggingface.co/docs/transformers/v4.37.2/en/main_classes/quantization#transformers.BitsAndBytesConfig).

Expand All @@ -21,13 +21,60 @@ quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dty
# PEFT
With `PEFT`, you can use QLoRA out of the box with `LoraConfig` and a 4-bit base model.

Please review the [bitsandbytes section in the Accelerate docs](https://huggingface.co/docs/peft/developer_guides/quantization#quantize-a-model).
Please review the [bitsandbytes section in the PEFT docs](https://huggingface.co/docs/peft/developer_guides/quantization#quantize-a-model).

# Accelerate

Bitsandbytes is also easily usable from within Accelerate.
Bitsandbytes is also easily usable from within Accelerate, where you can quantize any PyTorch model simply by passing a quantization config; e.g:

```py
from accelerate import init_empty_weights
from accelerate.utils import BnbQuantizationConfig, load_and_quantize_model
from mingpt.model import GPT

model_config = GPT.get_default_config()
model_config.model_type = 'gpt2-xl'
model_config.vocab_size = 50257
model_config.block_size = 1024

with init_empty_weights():
empty_model = GPT(model_config)

bnb_quantization_config = BnbQuantizationConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16, # optional
bnb_4bit_use_double_quant=True, # optional
bnb_4bit_quant_type="nf4" # optional
)

quantized_model = load_and_quantize_model(
empty_model,
weights_location=weights_location,
bnb_quantization_config=bnb_quantization_config,
device_map = "auto"
)
```

For further details, e.g. model saving, cpu-offloading andfine-tuning, please review the [`bitsandbytes` section in the Accelerate docs](https://huggingface.co/docs/accelerate/en/usage_guides/quantization).



# PyTorch Lightning and Lightning Fabric

Bitsandbytes is available from within both
- [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/), a deep learning framework for professional AI researchers and machine learning engineers who need maximal flexibility without sacrificing performance at scale;
- and [Lightning Fabric](https://lightning.ai/docs/fabric/stable/), a fast and lightweight way to scale PyTorch models without boilerplate).

Please review the [bitsandbytes section in the PyTorch Lightning docs](https://lightning.ai/docs/pytorch/stable/common/precision_intermediate.html#quantization-via-bitsandbytes).


# Lit-GPT

Bitsandbytes is integrated into [Lit-GPT](https://github.com/Lightning-AI/lit-gpt), a hackable implementation of state-of-the-art open-source large language models, based on Lightning Fabric, where it can be used for quantization during training, finetuning, and inference.

Please review the [bitsandbytes section in the Lit-GPT quantization docs](https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials/quantize.md).


Please review the [bitsandbytes section in the Accelerate docs](https://huggingface.co/docs/accelerate/en/usage_guides/quantization).

# Trainer for the optimizers

Expand All @@ -40,5 +87,5 @@ e.g. for transformers state that you can load any model in 8-bit / 4-bit precisi

# Blog posts

- [Making LLMs even more accessible with bitsandbytes, 4-bit quantization and QLoRA](https://huggingface.co/blog/4bit-transformers-bitsandbytes)
- [A Gentle Introduction to 8-bit Matrix Multiplication for transformers at scale using Hugging Face Transformers, Accelerate and bitsandbytes](https://huggingface.co/blog/hf-bitsandbytes-integration)
- [Making LLMs even more accessible with `bitsandbytes`, 4-bit quantization and QLoRA](https://huggingface.co/blog/4bit-transformers-bitsandbytes)
- [A Gentle Introduction to 8-bit Matrix Multiplication for transformers at scale using Hugging Face Transformers, Accelerate and `bitsandbytes`](https://huggingface.co/blog/hf-bitsandbytes-integration)
10 changes: 5 additions & 5 deletions examples/int8_inference_huggingface.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import LlamaForCausalLM, LlamaTokenizer

MAX_NEW_TOKENS = 128
model_name = 'decapoda-research/llama-7b-hf'
model_name = 'meta-llama/Llama-2-7b-hf'

text = 'Hamburg is in which country?\n'
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer = LlamaTokenizer.from_pretrained(model_name)
input_ids = tokenizer(text, return_tensors="pt").input_ids

free_in_GB = int(torch.cuda.mem_get_info()[0]/1024**3)
max_memory = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB'

n_gpus = torch.cuda.device_count()
max_memory = {i: max_memory for i in range(n_gpus)}

model = AutoModelForCausalLM.from_pretrained(
model = LlamaForCausalLM.from_pretrained(
model_name,
device_map='auto',
load_in_8bit=True,
max_memory=max_memory
)

generated_ids = model.generate(input_ids, max_length=MAX_NEW_TOKENS)
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
15 changes: 5 additions & 10 deletions install_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,21 @@
from urllib.request import urlretrieve

cuda_versions = {
"92": "https://developer.nvidia.com/compute/cuda/9.2/Prod2/local_installers/cuda_9.2.148_396.37_linux",
"100": "https://developer.nvidia.com/compute/cuda/10.0/Prod/local_installers/cuda_10.0.130_410.48_linux",
"101": "https://developer.nvidia.com/compute/cuda/10.1/Prod/local_installers/cuda_10.1.105_418.39_linux.run",
"102": "https://developer.download.nvidia.com/compute/cuda/10.2/Prod/local_installers/cuda_10.2.89_440.33.01_linux.run",
"110": "https://developer.download.nvidia.com/compute/cuda/11.0.3/local_installers/cuda_11.0.3_450.51.06_linux.run",
"111": "https://developer.download.nvidia.com/compute/cuda/11.1.1/local_installers/cuda_11.1.1_455.32.00_linux.run",
"112": "https://developer.download.nvidia.com/compute/cuda/11.2.2/local_installers/cuda_11.2.2_460.32.03_linux.run",
"113": "https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.19.01_linux.run",
"114": "https://developer.download.nvidia.com/compute/cuda/11.4.4/local_installers/cuda_11.4.4_470.82.01_linux.run",
"115": "https://developer.download.nvidia.com/compute/cuda/11.5.2/local_installers/cuda_11.5.2_495.29.05_linux.run",
"116": "https://developer.download.nvidia.com/compute/cuda/11.6.2/local_installers/cuda_11.6.2_510.47.03_linux.run",
"117": "https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_515.43.04_linux.run",
"117": "https://developer.download.nvidia.com/compute/cuda/11.7.1/local_installers/cuda_11.7.1_515.65.01_linux.run",
"118": "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run",
"120": "https://developer.download.nvidia.com/compute/cuda/12.0.0/local_installers/cuda_12.0.0_525.60.13_linux.run",
"121": "https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run",
"122": "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run",
"123": "https://developer.download.nvidia.com/compute/cuda/12.3.1/local_installers/cuda_12.3.1_545.23.08_linux.run",
"120": "https://developer.download.nvidia.com/compute/cuda/12.0.1/local_installers/cuda_12.0.1_525.85.12_linux.run",
"121": "https://developer.download.nvidia.com/compute/cuda/12.1.1/local_installers/cuda_12.1.1_530.30.02_linux.run",
"122": "https://developer.download.nvidia.com/compute/cuda/12.2.2/local_installers/cuda_12.2.2_535.104.05_linux.run",
"123": "https://developer.download.nvidia.com/compute/cuda/12.3.2/local_installers/cuda_12.3.2_545.23.08_linux.run",
}


def install_cuda(version, base_path, download_path):
formatted_version = f"{version[:-1]}.{version[-1]}"
folder = f"cuda-{formatted_version}"
Expand Down
Loading

0 comments on commit f9c2cc5

Please sign in to comment.