From a71efa8b7660eebff5317d34783ee156074eb519 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 2 Feb 2024 06:17:10 +0000 Subject: [PATCH] test autodoc --- bitsandbytes/nn/modules.py | 42 ++++++++++++++++++++++++++++++++++++ docs/source/quantization.mdx | 10 +++++++-- 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 922feae15..304c1d405 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -397,9 +397,51 @@ def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_k class Linear8bitLt(nn.Linear): + """ + This class is the base module for the [LLM.int8()](https://arxiv.org/abs/2208.07339) algorithm. + To read more about it, have a look at the paper. + + In order to quantize a linear layer one should first load the original fp16 / bf16 weights into + the Linear8bitLt module, then call `int8_module.to("cuda")` to quantize the fp16 weights. + + Example: + + ```python + import torch + import torch.nn as nn + + import bitsandbytes as bnb + from bnb.nn import Linear8bitLt + + fp16_model = nn.Sequential( + nn.Linear(64, 64), + nn.Linear(64, 64) + ) + + int8_model = nn.Sequential( + Linear8bitLt(64, 64, has_fp16_weights=False), + Linear8bitLt(64, 64, has_fp16_weights=False) + ) + + int8_model.load_state_dict(fp16_model.state_dict()) + int8_model = int8_model.to(0) # Quantization happens here + ``` + """ def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0, index=None, device=None): super().__init__(input_features, output_features, bias, device) + """ + Initialize Linear8bitLt class. + + Args: + input_features (`str`): + Number of input features of the linear layer. + output_features (`str`): + Number of output features of the linear layer. + bias (`bool`, defaults to `True`): + Whether the linear class uses the bias term as well. + """ + assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0" self.state = bnb.MatmulLtState() self.index = index diff --git a/docs/source/quantization.mdx b/docs/source/quantization.mdx index a09d90c2d..71f15abac 100644 --- a/docs/source/quantization.mdx +++ b/docs/source/quantization.mdx @@ -1,5 +1,11 @@ -# Linear8bitLt (LLM.int8) -... TODO: to be filled out ... +# Quantization primitives + +Below you will find the docstring of the quantization primitives exposed in bitsandbytes. + +## Linear8bitLt + +[[autodoc]] bitsandbytes.nn.Linear8bitLt + # Linear4bit (QLoRA) ... TODO: to be filled out ...