diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index bd2bd5832..ac7bb8e7b 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -21,16 +21,7 @@ class StableEmbedding(torch.nn.Embedding): """ - Custom embedding layer designed for stable training in NLP tasks. The stable - embedding layer improves stability during optimization for models with word - embeddings, addressing issues related to the non-uniform distribution of input - tokens. - - This stable embedding layer is initialized with Xavier uniform initialization, - followed by layer normalization. It is designed to support aggressive quantization, - addressing extreme gradient variations in non-uniform input distributions. The - stability of training is enhanced by using 32-bit optimizer states specifically - for this layer. + Custom embedding layer designed to improve stability during training for NLP tasks by using 32-bit optimizer states. It is designed to reduce gradient variations that can result from quantization. This embedding layer is initialized with Xavier uniform initialization followed by layer normalization. Example: @@ -47,14 +38,11 @@ class StableEmbedding(torch.nn.Embedding): ``` Attributes: - norm (torch.nn.LayerNorm): Layer normalization applied after the embedding. + norm (`torch.nn.LayerNorm`): Layer normalization applied after the embedding. Methods: reset_parameters(): Reset embedding parameters using Xavier uniform initialization. forward(input: Tensor) -> Tensor: Forward pass through the stable embedding layer. - - Reference: - - [8-bit optimizer paper](https://arxiv.org/pdf/2110.02861.pdf) """ def __init__( self, @@ -71,14 +59,22 @@ def __init__( ) -> None: """ Args: - num_embeddings (`int`): The number of unique embeddings (vocabulary size). - embedding_dim (`int`): The dimensionality of the embedding. - padding_idx (`Optional[int]`): If specified, pads the output with zeros at the given index. - max_norm (`Optional[float]`): If given, renormalizes embeddings to have a maximum L2 norm. - norm_type (`float`, defaults to `2.0`): The p-norm to compute for the max_norm option. - scale_grad_by_freq (`bool`): Scale gradient by frequency during backpropagation. - sparse (`bool`): If True, computes sparse gradients; False, computes dense gradients. - _weight (`Optional[Tensor]`): Pre-trained embeddings. + num_embeddings (`int`): + The number of unique embeddings (vocabulary size). + embedding_dim (`int`): + The dimensionality of the embedding. + padding_idx (`Optional[int]`): + Pads the output with zeros at the given index. + max_norm (`Optional[float]`): + Renormalizes embeddings to have a maximum L2 norm. + norm_type (`float`, defaults to `2.0`): + The p-norm to compute for the `max_norm` option. + scale_grad_by_freq (`bool`, defaults to `False`): + Scale gradient by frequency during backpropagation. + sparse (`bool`, defaults to `False`): + Computes dense gradients. Set to `True` to compute sparse gradients instead. + _weight (`Optional[Tensor]`): + Pretrained embeddings. """ super().__init__( num_embeddings, @@ -131,6 +127,9 @@ def forward(self, input: Tensor) -> Tensor: class Embedding(torch.nn.Embedding): + """ + Embedding class to store and retrieve word embeddings from their indices. + """ def __init__( self, num_embeddings: int, @@ -143,6 +142,25 @@ def __init__( _weight: Optional[Tensor] = None, device: Optional[device] = None, ) -> None: + """ + Args: + num_embeddings (`int`): + The number of unique embeddings (vocabulary size). + embedding_dim (`int`): + The dimensionality of the embedding. + padding_idx (`Optional[int]`): + Pads the output with zeros at the given index. + max_norm (`Optional[float]`): + Renormalizes embeddings to have a maximum L2 norm. + norm_type (`float`, defaults to `2.0`): + The p-norm to compute for the `max_norm` option. + scale_grad_by_freq (`bool`, defaults to `False`): + Scale gradient by frequency during backpropagation. + sparse (`bool`, defaults to `False`): + Computes dense gradients. Set to `True` to compute sparse gradients instead. + _weight (`Optional[Tensor]`): + Pretrained embeddings. + """ super().__init__( num_embeddings, embedding_dim, diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 6db060286..87c4242de 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -48,3 +48,11 @@ title: RMSprop - local: reference/optim/sgd title: SGD + - title: k-bit quantizers + sections: + - local: reference/nn/linear8bit + title: 8-bit quantizer + - local: reference/nn/linear4bit + title: 4-bit quantizer + - local: reference/nn/embeddings + title: Embedding diff --git a/docs/source/reference/nn/embeddings.mdx b/docs/source/reference/nn/embeddings.mdx new file mode 100644 index 000000000..e725ecb17 --- /dev/null +++ b/docs/source/reference/nn/embeddings.mdx @@ -0,0 +1,15 @@ +# Embedding + +The embedding class is used to store and retrieve word embeddings from their indices. There are two types of embeddings in bitsandbytes, the standard PyTorch [`Embedding`] class and the [`StableEmbedding`] class. + +The [`StableEmbedding`] class was introduced in the [8-bit Optimizers via Block-wise Quantization](https://hf.co/papers/2110.02861) paper to reduce gradient variance as a result of the non-uniform distribution of input tokens. This class is designed to support quantization. + +## Embedding + +[[autodoc]] bitsandbytes.nn.Embedding + - __init__ + +## StableEmbedding + +[[autodoc]] bitsandbytes.nn.StableEmbedding + - __init__ diff --git a/docs/source/reference/nn/linear4bit.mdx b/docs/source/reference/nn/linear4bit.mdx new file mode 100644 index 000000000..88aec707d --- /dev/null +++ b/docs/source/reference/nn/linear4bit.mdx @@ -0,0 +1,23 @@ +# 4-bit quantization + +[QLoRA](https://hf.co/papers/2305.14314) is a finetuning method that quantizes a model to 4-bits and adds a set of low-rank adaptation (LoRA) weights to the model and tuning them through the quantized weights. This method also introduces a new data type, 4-bit NormalFloat (`LinearNF4`) in addition to the standard Float4 data type (`LinearFP4`). `LinearNF4` is adapted for weights initialized from a normal distribution and can improve performance. + +## Linear4bit + +[[autodoc]] bitsandbytes.nn.Linear4bit + - __init__ + +## LinearFP4 + +[[autdodoc]] bitsandbytes.nn.LinearFP4 + - __init__ + +## LinearNF4 + +[[autodoc]] bitsandbytes.nn.LinearNF4 + - __init__ + +## Params4bit + +[[autodoc]] bitsandbytes.nn.Params4bit + - __init__ diff --git a/docs/source/reference/nn/linear8bit.mdx b/docs/source/reference/nn/linear8bit.mdx new file mode 100644 index 000000000..73254fe67 --- /dev/null +++ b/docs/source/reference/nn/linear8bit.mdx @@ -0,0 +1,13 @@ +# 8-bit quantization + +[LLM.int8()](https://hf.co/papers/2208.07339) is a quantization method that doesn't degrade performance which makes large model inference more accessible. The key is to extract the outliers from the inputs and weights and multiply them in 16-bit. All other values are multiplied in 8-bit and quantized to Int8 before being dequantized back to 16-bits. The outputs from the 16-bit and 8-bit multiplication are combined to produce the final output. + +## Linear8bitLt + +[[autodoc]] bitsandbytes.nn.Linear8bitLt + - __init__ + +## Int8Params + +[[autodoc]] bitsandbytes.nn.Int8Params + - __init__