Skip to content

Commit

Permalink
new additions
Browse files Browse the repository at this point in the history
  • Loading branch information
younesbelkada committed Feb 2, 2024
1 parent a71efa8 commit c1ec5f8
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.0
rev: v0.1.15
hooks:
- id: ruff
args:
Expand Down
44 changes: 42 additions & 2 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,50 @@ def to(self, *args, **kwargs):


class Linear4bit(nn.Linear):
"""
This class is the base module for the 4-bit quantization algorithm presented in [QLoRA](https://arxiv.org/abs/2305.14314).
QLoRA 4-bit linear layers uses blockwise k-bit quantization under the hood, with the possibility of selecting various
compute datatypes such as FP4 and NF4.
In order to quantize a linear layer one should first load the original fp16 / bf16 weights into
the Linear8bitLt module, then call `quantized_module.to("cuda")` to quantize the fp16 / bf16 weights.
Example:
```python
import torch
import torch.nn as nn
import bitsandbytes as bnb
from bnb.nn import Linear4bit
fp16_model = nn.Sequential(
nn.Linear(64, 64),
nn.Linear(64, 64)
)
quantized_model = nn.Sequential(
Linear4bit(64, 64),
Linear4bit(64, 64)
)
quantized_model.load_state_dict(fp16_model.state_dict())
quantized_model = quantized_model.to(0) # Quantization happens here
```
"""
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4', quant_storage=torch.uint8, device=None):
super().__init__(input_features, output_features, bias, device)
"""
Initialize Linear4bit class.
Args:
input_features (`str`):
Number of input features of the linear layer.
output_features (`str`):
Number of output features of the linear layer.
bias (`bool`, defaults to `True`):
Whether the linear class uses the bias term as well.
"""
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type, quant_storage=quant_storage, module=self)
# self.persistent_buffers = [] # TODO consider as way to save quant state
self.compute_dtype = compute_dtype
Expand Down Expand Up @@ -440,8 +481,7 @@ def __init__(self, input_features, output_features, bias=True, has_fp16_weights=
Number of output features of the linear layer.
bias (`bool`, defaults to `True`):
Whether the linear class uses the bias term as well.
"""

"""
assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
self.state = bnb.MatmulLtState()
self.index = index
Expand Down
3 changes: 2 additions & 1 deletion docs/source/quantization.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ Below you will find the docstring of the quantization primitives exposed in bits


# Linear4bit (QLoRA)
... TODO: to be filled out ...

[[autodoc]] bitsandbytes.nn.Linear4bit

0 comments on commit c1ec5f8

Please sign in to comment.