diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index feb6c766e..039139b95 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.2.0 + rev: v0.1.15 hooks: - id: ruff args: diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 304c1d405..8597e9503 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -222,9 +222,50 @@ def to(self, *args, **kwargs): class Linear4bit(nn.Linear): + """ + This class is the base module for the 4-bit quantization algorithm presented in [QLoRA](https://arxiv.org/abs/2305.14314). + QLoRA 4-bit linear layers uses blockwise k-bit quantization under the hood, with the possibility of selecting various + compute datatypes such as FP4 and NF4. + + In order to quantize a linear layer one should first load the original fp16 / bf16 weights into + the Linear8bitLt module, then call `quantized_module.to("cuda")` to quantize the fp16 / bf16 weights. + + Example: + + ```python + import torch + import torch.nn as nn + + import bitsandbytes as bnb + from bnb.nn import Linear4bit + + fp16_model = nn.Sequential( + nn.Linear(64, 64), + nn.Linear(64, 64) + ) + quantized_model = nn.Sequential( + Linear4bit(64, 64), + Linear4bit(64, 64) + ) + + quantized_model.load_state_dict(fp16_model.state_dict()) + quantized_model = quantized_model.to(0) # Quantization happens here + ``` + """ def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4', quant_storage=torch.uint8, device=None): super().__init__(input_features, output_features, bias, device) + """ + Initialize Linear4bit class. + + Args: + input_features (`str`): + Number of input features of the linear layer. + output_features (`str`): + Number of output features of the linear layer. + bias (`bool`, defaults to `True`): + Whether the linear class uses the bias term as well. + """ self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type, quant_storage=quant_storage, module=self) # self.persistent_buffers = [] # TODO consider as way to save quant state self.compute_dtype = compute_dtype @@ -440,8 +481,7 @@ def __init__(self, input_features, output_features, bias=True, has_fp16_weights= Number of output features of the linear layer. bias (`bool`, defaults to `True`): Whether the linear class uses the bias term as well. - """ - + """ assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0" self.state = bnb.MatmulLtState() self.index = index diff --git a/docs/source/quantization.mdx b/docs/source/quantization.mdx index 71f15abac..287b2b87a 100644 --- a/docs/source/quantization.mdx +++ b/docs/source/quantization.mdx @@ -8,4 +8,5 @@ Below you will find the docstring of the quantization primitives exposed in bits # Linear4bit (QLoRA) -... TODO: to be filled out ... + +[[autodoc]] bitsandbytes.nn.Linear4bit \ No newline at end of file