Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

documentation release v1 #1012

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
fbc0385
add optional dependency for preview to environment.yml
Titus-von-Koeller Feb 1, 2024
84b5fc0
Add additional sections, first optimizers, MacOS WIP
Titus-von-Koeller Feb 1, 2024
725d29a
drafting + refactoring new docs
Titus-von-Koeller Feb 1, 2024
58566e2
some changes
younesbelkada Feb 2, 2024
47cc3e9
run pre-commit hooks
Titus-von-Koeller Feb 2, 2024
c26645b
add mention of pre-commit to contributing
Titus-von-Koeller Feb 2, 2024
ab42c5f
fix
younesbelkada Feb 2, 2024
a71efa8
test autodoc
younesbelkada Feb 2, 2024
c1ec5f8
new additions
younesbelkada Feb 2, 2024
544114d
add subtilte
younesbelkada Feb 2, 2024
f735b35
add some content
younesbelkada Feb 2, 2024
daff94c
add more methods
younesbelkada Feb 2, 2024
301ee80
fix
younesbelkada Feb 2, 2024
683a72b
further docs updates
Titus-von-Koeller Feb 2, 2024
60a7699
Update _toctree.yml
younesbelkada Feb 2, 2024
543a7b1
fix link
Titus-von-Koeller Feb 3, 2024
2d73f4d
run pre-commit hooks
Titus-von-Koeller Feb 3, 2024
8f0fd8a
refactor + further docs
Titus-von-Koeller Feb 4, 2024
a3c45d3
Update README.md with new docs link
Titus-von-Koeller Feb 4, 2024
b370cee
list of blog posts
Titus-von-Koeller Feb 4, 2024
fd64f21
list of blog posts
Titus-von-Koeller Feb 4, 2024
38d323a
accept change suggestion
Titus-von-Koeller Feb 4, 2024
82485d0
accept suggestion
Titus-von-Koeller Feb 4, 2024
75cfb1c
accept suggestion
Titus-von-Koeller Feb 4, 2024
7a71390
Update docs/source/integrations.mdx
Titus-von-Koeller Feb 4, 2024
a84afcf
index instead of intro
Titus-von-Koeller Feb 4, 2024
d3709f4
fixup README, add docs link
Titus-von-Koeller Feb 4, 2024
e00cbc9
add instructions for creating docstrings
Titus-von-Koeller Feb 4, 2024
8a67759
final polish (except integrations)
Titus-von-Koeller Feb 4, 2024
d632531
fill out integrations section
Titus-von-Koeller Feb 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.15
rev: v0.2.0
Titus-von-Koeller marked this conversation as resolved.
Show resolved Hide resolved
hooks:
- id: ruff
args:
Expand Down
192 changes: 7 additions & 185 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,195 +1,17 @@
# bitsandbytes
# `bitsandbytes`

The bitsandbytes is a lightweight wrapper around CUDA custom functions, in particular 8-bit optimizers, matrix multiplication (LLM.int8()), and quantization functions.
The `bitsandbytes` library is a lightweight Python wrapper around CUDA custom functions, in particular 8-bit optimizers, matrix multiplication (LLM.int8()), and 8 & 4-bit quantization functions.

The library includes quantization primitives for 8-bit & 4-bit operations, through `bitsandbytes.nn.Linear8bitLt` and `bitsandbytes.nn.Linear4bit` and 8-bit optimizers through `bitsandbytes.optim` module.

There are ongoing efforts to support further hardware backends, i.e. Intel CPU + GPU, AMD GPU, Apple Silicon. Windows support is quite far along and is on its way as well.

Resources:
Titus-von-Koeller marked this conversation as resolved.
Show resolved Hide resolved
- [8-bit Optimizer Paper](https://arxiv.org/abs/2110.02861) -- [Video](https://www.youtube.com/watch?v=IxrlHAJtqKE) -- [Docs](https://bitsandbytes.readthedocs.io/en/latest/)
**Please head to the official documentation page:**

- [LLM.int8() Paper](https://arxiv.org/abs/2208.07339) -- [LLM.int8() Software Blog Post](https://huggingface.co/blog/hf-bitsandbytes-integration) -- [LLM.int8() Emergent Features Blog Post](https://timdettmers.com/2022/08/17/llm-int8-and-emergent-features/)

## TL;DR
**Requirements**
Python >=3.8. Linux distribution (Ubuntu, MacOS, etc.) + CUDA > 10.0.

(Deprecated: CUDA 10.0 is deprecated and only CUDA >= 11.0) will be supported with release 0.39.0)

**Installation**:

``pip install bitsandbytes``

In some cases it can happen that you need to compile from source. If this happens please consider submitting a bug report with `python -m bitsandbytes` information. What now follows is some short instructions which might work out of the box if `nvcc` is installed. If these do not work see further below.

Compilation quickstart:
```bash
git clone https://github.com/timdettmers/bitsandbytes.git
cd bitsandbytes

# CUDA_VERSIONS in {110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122}
# make argument in {cuda110, cuda11x, cuda12x}
# if you do not know what CUDA you have, try looking at the output of: python -m bitsandbytes
CUDA_VERSION=117 make cuda11x
python setup.py install
```

**Using Int8 inference with HuggingFace Transformers**

```python
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
'decapoda-research/llama-7b-hf',
device_map='auto',
load_in_8bit=True,
max_memory={
i: f'{int(torch.cuda.mem_get_info(i)[0]/1024**3)-2}GB'
for i in range(torch.cuda.device_count())
}
)
```

A more detailed example, can be found in [examples/int8_inference_huggingface.py](examples/int8_inference_huggingface.py).

**Using 8-bit optimizer**:
1. Comment out optimizer: ``#torch.optim.Adam(....)``
2. Add 8-bit optimizer of your choice ``bnb.optim.Adam8bit(....)`` (arguments stay the same)
3. Replace embedding layer if necessary: ``torch.nn.Embedding(..) -> bnb.nn.Embedding(..)``


**Using 8-bit Inference**:
1. Comment out torch.nn.Linear: ``#linear = torch.nn.Linear(...)``
2. Add bnb 8-bit linear light module: ``linear = bnb.nn.Linear8bitLt(...)`` (base arguments stay the same)
3. There are two modes:
- Mixed 8-bit training with 16-bit main weights. Pass the argument ``has_fp16_weights=True`` (default)
- Int8 inference. Pass the argument ``has_fp16_weights=False``
4. To use the full LLM.int8() method, use the ``threshold=k`` argument. We recommend ``k=6.0``.
```python
# LLM.int8()
linear = bnb.nn.Linear8bitLt(dim1, dim2, bias=True, has_fp16_weights=False, threshold=6.0)
# inputs need to be fp16
out = linear(x.to(torch.float16))
```


## Features
- 8-bit Matrix multiplication with mixed precision decomposition
- LLM.int8() inference
- 8-bit Optimizers: Adam, AdamW, RMSProp, LARS, LAMB, Lion (saves 75% memory)
- Stable Embedding Layer: Improved stability through better initialization, and normalization
- 8-bit quantization: Quantile, Linear, and Dynamic quantization
- Fast quantile estimation: Up to 100x faster than other algorithms

## Requirements & Installation

Requirements: anaconda, cudatoolkit, pytorch

Hardware requirements:
- LLM.int8(): NVIDIA Turing (RTX 20xx; T4) or Ampere GPU (RTX 30xx; A4-A100); (a GPU from 2018 or newer).
- 8-bit optimizers and quantization: NVIDIA Kepler GPU or newer (>=GTX 78X).

Supported CUDA versions: 10.2 - 12.2

The bitsandbytes library is currently only supported on Linux distributions. Windows is not supported at the moment.

The requirements can best be fulfilled by installing pytorch via anaconda. You can install PyTorch by following the ["Get Started"](https://pytorch.org/get-started/locally/) instructions on the official website.

To install run:

``pip install bitsandbytes``

## Using bitsandbytes

### Using Int8 Matrix Multiplication

For straight Int8 matrix multiplication with mixed precision decomposition you can use ``bnb.matmul(...)``. To enable mixed precision decomposition, use the threshold parameter:
```python
bnb.matmul(..., threshold=6.0)
```

For instructions how to use LLM.int8() inference layers in your own code, see the TL;DR above or for extended instruction see [this blog post](https://huggingface.co/blog/hf-bitsandbytes-integration).

### Using the 8-bit Optimizers

With bitsandbytes 8-bit optimizers can be used by changing a single line of code in your codebase. For NLP models we recommend also to use the StableEmbedding layers (see below) which improves results and helps with stable 8-bit optimization. To get started with 8-bit optimizers, it is sufficient to replace your old optimizer with the 8-bit optimizer in the following way:
```python
import bitsandbytes as bnb

# adam = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.995)) # comment out old optimizer
adam = bnb.optim.Adam8bit(model.parameters(), lr=0.001, betas=(0.9, 0.995)) # add bnb optimizer
adam = bnb.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.995), optim_bits=8) # equivalent


torch.nn.Embedding(...) -> bnb.nn.StableEmbedding(...) # recommended for NLP models
```

Note that by default all parameter tensors with less than 4096 elements are kept at 32-bit even if you initialize those parameters with 8-bit optimizers. This is done since such small tensors do not save much memory and often contain highly variable parameters (biases) or parameters that require high precision (batch norm, layer norm). You can change this behavior like so:
```python
# parameter tensors with less than 16384 values are optimized in 32-bit
# it is recommended to use multiplies of 4096
adam = bnb.optim.Adam8bit(model.parameters(), min_8bit_size=16384)
```

### Change Bits and other Hyperparameters for Individual Parameters

If you want to optimize some unstable parameters with 32-bit Adam and others with 8-bit Adam, you can use the `GlobalOptimManager`. With this, we can also configure specific hyperparameters for particular layers, such as embedding layers. To do that, we need two things: (1) register the parameter while they are still on the CPU, (2) override the config with the new desired hyperparameters (anytime, anywhere). See our [guide](howto_config_override.md) for more details

### Fairseq Users

To use the Stable Embedding Layer, override the respective `build_embedding(...)` function of your model. Make sure to also use the `--no-scale-embedding` flag to disable scaling of the word embedding layer (nor replaced with layer norm). You can use the optimizers by replacing the optimizer in the respective file (`adam.py` etc.).

## Release and Feature History

For upcoming features and changes and full history see [Patch Notes](CHANGELOG.md).

## Errors

1. RuntimeError: CUDA error: no kernel image is available for execution on the device. [Solution](errors_and_solutions.md#No-kernel-image-available)
2. __fatbinwrap_.. [Solution](errors_and_solutions.md#fatbinwrap_)

## Compile from source
To compile from source, you need an installation of CUDA. If `nvcc` is not installed, you can install the CUDA Toolkit with nvcc through the following commands.

```bash
wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/install_cuda.sh
# Syntax cuda_install CUDA_VERSION INSTALL_PREFIX EXPORT_TO_BASH
# CUDA_VERSION in {110, 111, 112, 113, 114, 115, 116, 117, 118, 120, 121, 122}
# EXPORT_TO_BASH in {0, 1} with 0=False and 1=True

# For example, the following installs CUDA 11.7 to ~/local/cuda-11.7 and exports the path to your .bashrc
bash install_cuda.sh 117 ~/local 1
```

To use a specific CUDA version just for a single compile run, you can set the variable `CUDA_HOME`, for example the following command compiles `libbitsandbytes_cuda117.so` using compiler flags for cuda11x with the cuda version at `~/local/cuda-11.7`:

``CUDA_HOME=~/local/cuda-11.7 CUDA_VERSION=117 make cuda11x``

For more detailed instruction, please follow the [compile_from_source.md](compile_from_source.md) instructions.
**[https://huggingface.co/docs/bitsandbytes/main](https://huggingface.co/docs/bitsandbytes/main)**

## License

The majority of bitsandbytes is licensed under MIT, however portions of the project are available under separate license terms: Pytorch is licensed under the BSD license.
The majority of bitsandbytes is licensed under MIT, however small portions of the project are available under separate license terms, as the parts adapted from Pytorch are licensed under the BSD license.

We thank Fabio Cannizzo for his work on [FastBinarySearch](https://github.com/fabiocannizzo/FastBinarySearch) which we use for CPU quantization.

## How to cite us
If you found this library and found LLM.int8() useful, please consider citing our work:

```bibtex
@article{dettmers2022llmint8,
title={LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale},
author={Dettmers, Tim and Lewis, Mike and Belkada, Younes and Zettlemoyer, Luke},
journal={arXiv preprint arXiv:2208.07339},
year={2022}
}
```

For 8-bit optimizers or quantization routines, please consider citing the following work:

```bibtex
@article{dettmers2022optimizers,
title={8-bit Optimizers via Block-wise Quantization},
author={Dettmers, Tim and Lewis, Mike and Shleifer, Sam and Zettlemoyer, Luke},
journal={9th International Conference on Learning Representations, ICLR},
year={2022}
}
```
129 changes: 129 additions & 0 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,42 @@


class StableEmbedding(torch.nn.Embedding):
"""
Titus-von-Koeller marked this conversation as resolved.
Show resolved Hide resolved
Custom embedding layer designed for stable training in NLP tasks. The stable
embedding layer improves stability during optimization for models with word
embeddings, addressing issues related to the non-uniform distribution of input
tokens.

This stable embedding layer is initialized with Xavier uniform initialization,
followed by layer normalization. It is designed to support aggressive quantization,
addressing extreme gradient variations in non-uniform input distributions. The
stability of training is enhanced by using 32-bit optimizer states specifically
for this layer.

Example:

```
# Initialize StableEmbedding layer with vocabulary size 1000, embedding dimension 300
embedding_layer = StableEmbedding(num_embeddings=1000, embedding_dim=300)

# Reset embedding parameters
embedding_layer.reset_parameters()

# Perform a forward pass with input tensor
input_tensor = torch.tensor([1, 2, 3])
output_embedding = embedding_layer(input_tensor)
```

Attributes:
norm (torch.nn.LayerNorm): Layer normalization applied after the embedding.

Methods:
reset_parameters(): Reset embedding parameters using Xavier uniform initialization.
forward(input: Tensor) -> Tensor: Forward pass through the stable embedding layer.

Reference:
- [8-bit optimizer paper](https://arxiv.org/pdf/2110.02861.pdf)
"""
def __init__(
self,
num_embeddings: int,
Expand All @@ -32,6 +68,17 @@ def __init__(
device=None,
dtype=None,
) -> None:
"""
Args:
num_embeddings (`int`): The number of unique embeddings (vocabulary size).
embedding_dim (`int`): The dimensionality of the embedding.
padding_idx (`Optional[int]`): If specified, pads the output with zeros at the given index.
max_norm (`Optional[float]`): If given, renormalizes embeddings to have a maximum L2 norm.
norm_type (`float`, defaults to `2.0`): The p-norm to compute for the max_norm option.
scale_grad_by_freq (`bool`): Scale gradient by frequency during backpropagation.
sparse (`bool`): If True, computes sparse gradients; False, computes dense gradients.
_weight (`Optional[Tensor]`): Pre-trained embeddings.
"""
super().__init__(
num_embeddings,
embedding_dim,
Expand Down Expand Up @@ -222,8 +269,49 @@ def to(self, *args, **kwargs):


class Linear4bit(nn.Linear):
"""
This class is the base module for the 4-bit quantization algorithm presented in [QLoRA](https://arxiv.org/abs/2305.14314).
QLoRA 4-bit linear layers uses blockwise k-bit quantization under the hood, with the possibility of selecting various
compute datatypes such as FP4 and NF4.

In order to quantize a linear layer one should first load the original fp16 / bf16 weights into
the Linear8bitLt module, then call `quantized_module.to("cuda")` to quantize the fp16 / bf16 weights.

Example:

```python
import torch
import torch.nn as nn

import bitsandbytes as bnb
from bnb.nn import Linear4bit

fp16_model = nn.Sequential(
nn.Linear(64, 64),
nn.Linear(64, 64)
)

quantized_model = nn.Sequential(
Linear4bit(64, 64),
Linear4bit(64, 64)
)

quantized_model.load_state_dict(fp16_model.state_dict())
quantized_model = quantized_model.to(0) # Quantization happens here
```
"""
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4', quant_storage=torch.uint8, device=None):
"""
Initialize Linear4bit class.

Args:
input_features (`str`):
Number of input features of the linear layer.
output_features (`str`):
Number of output features of the linear layer.
bias (`bool`, defaults to `True`):
Whether the linear class uses the bias term as well.
"""
super().__init__(input_features, output_features, bias, device)
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type, quant_storage=quant_storage, module=self)
# self.persistent_buffers = [] # TODO consider as way to save quant state
Expand Down Expand Up @@ -397,8 +485,49 @@ def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_k


class Linear8bitLt(nn.Linear):
"""
This class is the base module for the [LLM.int8()](https://arxiv.org/abs/2208.07339) algorithm.
To read more about it, have a look at the paper.

In order to quantize a linear layer one should first load the original fp16 / bf16 weights into
the Linear8bitLt module, then call `int8_module.to("cuda")` to quantize the fp16 weights.

Example:

```python
import torch
import torch.nn as nn

import bitsandbytes as bnb
from bnb.nn import Linear8bitLt

fp16_model = nn.Sequential(
nn.Linear(64, 64),
nn.Linear(64, 64)
)

int8_model = nn.Sequential(
Linear8bitLt(64, 64, has_fp16_weights=False),
Linear8bitLt(64, 64, has_fp16_weights=False)
)

int8_model.load_state_dict(fp16_model.state_dict())
int8_model = int8_model.to(0) # Quantization happens here
```
"""
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
memory_efficient_backward=False, threshold=0.0, index=None, device=None):
"""
Initialize Linear8bitLt class.

Args:
input_features (`str`):
Number of input features of the linear layer.
output_features (`str`):
Number of output features of the linear layer.
bias (`bool`, defaults to `True`):
Whether the linear class uses the bias term as well.
Titus-von-Koeller marked this conversation as resolved.
Show resolved Hide resolved
"""
super().__init__(input_features, output_features, bias, device)
assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
self.state = bnb.MatmulLtState()
Expand Down
Loading
Loading