Skip to content

Commit

Permalink
test autodoc
Browse files Browse the repository at this point in the history
  • Loading branch information
younesbelkada committed Feb 2, 2024
1 parent ab42c5f commit a71efa8
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 2 deletions.
42 changes: 42 additions & 0 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,51 @@ def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_k


class Linear8bitLt(nn.Linear):
"""
This class is the base module for the [LLM.int8()](https://arxiv.org/abs/2208.07339) algorithm.
To read more about it, have a look at the paper.
In order to quantize a linear layer one should first load the original fp16 / bf16 weights into
the Linear8bitLt module, then call `int8_module.to("cuda")` to quantize the fp16 weights.
Example:
```python
import torch
import torch.nn as nn
import bitsandbytes as bnb
from bnb.nn import Linear8bitLt
fp16_model = nn.Sequential(
nn.Linear(64, 64),
nn.Linear(64, 64)
)
int8_model = nn.Sequential(
Linear8bitLt(64, 64, has_fp16_weights=False),
Linear8bitLt(64, 64, has_fp16_weights=False)
)
int8_model.load_state_dict(fp16_model.state_dict())
int8_model = int8_model.to(0) # Quantization happens here
```
"""
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
memory_efficient_backward=False, threshold=0.0, index=None, device=None):
super().__init__(input_features, output_features, bias, device)
"""
Initialize Linear8bitLt class.
Args:
input_features (`str`):
Number of input features of the linear layer.
output_features (`str`):
Number of output features of the linear layer.
bias (`bool`, defaults to `True`):
Whether the linear class uses the bias term as well.
"""

assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
self.state = bnb.MatmulLtState()
self.index = index
Expand Down
10 changes: 8 additions & 2 deletions docs/source/quantization.mdx
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Linear8bitLt (LLM.int8)
... TODO: to be filled out ...
# Quantization primitives

Below you will find the docstring of the quantization primitives exposed in bitsandbytes.

## Linear8bitLt

[[autodoc]] bitsandbytes.nn.Linear8bitLt


# Linear4bit (QLoRA)
... TODO: to be filled out ...

0 comments on commit a71efa8

Please sign in to comment.