Skip to content

Commit

Permalink
more apis
Browse files Browse the repository at this point in the history
  • Loading branch information
stevhliu committed Feb 29, 2024
1 parent 9f77a71 commit ead726e
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 22 deletions.
62 changes: 40 additions & 22 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,7 @@

class StableEmbedding(torch.nn.Embedding):
"""
Custom embedding layer designed for stable training in NLP tasks. The stable
embedding layer improves stability during optimization for models with word
embeddings, addressing issues related to the non-uniform distribution of input
tokens.
This stable embedding layer is initialized with Xavier uniform initialization,
followed by layer normalization. It is designed to support aggressive quantization,
addressing extreme gradient variations in non-uniform input distributions. The
stability of training is enhanced by using 32-bit optimizer states specifically
for this layer.
Custom embedding layer designed to improve stability during training for NLP tasks by using 32-bit optimizer states. It is designed to reduce gradient variations that can result from quantization. This embedding layer is initialized with Xavier uniform initialization followed by layer normalization.
Example:
Expand All @@ -47,14 +38,11 @@ class StableEmbedding(torch.nn.Embedding):
```
Attributes:
norm (torch.nn.LayerNorm): Layer normalization applied after the embedding.
norm (`torch.nn.LayerNorm`): Layer normalization applied after the embedding.
Methods:
reset_parameters(): Reset embedding parameters using Xavier uniform initialization.
forward(input: Tensor) -> Tensor: Forward pass through the stable embedding layer.
Reference:
- [8-bit optimizer paper](https://arxiv.org/pdf/2110.02861.pdf)
"""
def __init__(
self,
Expand All @@ -71,14 +59,22 @@ def __init__(
) -> None:
"""
Args:
num_embeddings (`int`): The number of unique embeddings (vocabulary size).
embedding_dim (`int`): The dimensionality of the embedding.
padding_idx (`Optional[int]`): If specified, pads the output with zeros at the given index.
max_norm (`Optional[float]`): If given, renormalizes embeddings to have a maximum L2 norm.
norm_type (`float`, defaults to `2.0`): The p-norm to compute for the max_norm option.
scale_grad_by_freq (`bool`): Scale gradient by frequency during backpropagation.
sparse (`bool`): If True, computes sparse gradients; False, computes dense gradients.
_weight (`Optional[Tensor]`): Pre-trained embeddings.
num_embeddings (`int`):
The number of unique embeddings (vocabulary size).
embedding_dim (`int`):
The dimensionality of the embedding.
padding_idx (`Optional[int]`):
Pads the output with zeros at the given index.
max_norm (`Optional[float]`):
Renormalizes embeddings to have a maximum L2 norm.
norm_type (`float`, defaults to `2.0`):
The p-norm to compute for the `max_norm` option.
scale_grad_by_freq (`bool`, defaults to `False`):
Scale gradient by frequency during backpropagation.
sparse (`bool`, defaults to `False`):
Computes dense gradients. Set to `True` to compute sparse gradients instead.
_weight (`Optional[Tensor]`):
Pretrained embeddings.
"""
super().__init__(
num_embeddings,
Expand Down Expand Up @@ -131,6 +127,9 @@ def forward(self, input: Tensor) -> Tensor:


class Embedding(torch.nn.Embedding):
"""
Embedding class to store and retrieve word embeddings from their indices.
"""
def __init__(
self,
num_embeddings: int,
Expand All @@ -143,6 +142,25 @@ def __init__(
_weight: Optional[Tensor] = None,
device: Optional[device] = None,
) -> None:
"""
Args:
num_embeddings (`int`):
The number of unique embeddings (vocabulary size).
embedding_dim (`int`):
The dimensionality of the embedding.
padding_idx (`Optional[int]`):
Pads the output with zeros at the given index.
max_norm (`Optional[float]`):
Renormalizes embeddings to have a maximum L2 norm.
norm_type (`float`, defaults to `2.0`):
The p-norm to compute for the `max_norm` option.
scale_grad_by_freq (`bool`, defaults to `False`):
Scale gradient by frequency during backpropagation.
sparse (`bool`, defaults to `False`):
Computes dense gradients. Set to `True` to compute sparse gradients instead.
_weight (`Optional[Tensor]`):
Pretrained embeddings.
"""
super().__init__(
num_embeddings,
embedding_dim,
Expand Down
8 changes: 8 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,11 @@
title: RMSprop
- local: reference/optim/sgd
title: SGD
- title: k-bit quantizers
sections:
- local: reference/nn/linear8bit
title: 8-bit quantizer
- local: reference/nn/linear4bit
title: 4-bit quantizer
- local: reference/nn/embeddings
title: Embedding
15 changes: 15 additions & 0 deletions docs/source/reference/nn/embeddings.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Embedding

The embedding class is used to store and retrieve word embeddings from their indices. There are two types of embeddings in bitsandbytes, the standard PyTorch [`Embedding`] class and the [`StableEmbedding`] class.

The [`StableEmbedding`] class was introduced in the [8-bit Optimizers via Block-wise Quantization](https://hf.co/papers/2110.02861) paper to reduce gradient variance as a result of the non-uniform distribution of input tokens. This class is designed to support quantization.

## Embedding

[[autodoc]] bitsandbytes.nn.Embedding
- __init__

## StableEmbedding

[[autodoc]] bitsandbytes.nn.StableEmbedding
- __init__
23 changes: 23 additions & 0 deletions docs/source/reference/nn/linear4bit.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# 4-bit quantization

[QLoRA](https://hf.co/papers/2305.14314) is a finetuning method that quantizes a model to 4-bits and adds a set of low-rank adaptation (LoRA) weights to the model and tuning them through the quantized weights. This method also introduces a new data type, 4-bit NormalFloat (`LinearNF4`) in addition to the standard Float4 data type (`LinearFP4`). `LinearNF4` is adapted for weights initialized from a normal distribution and can improve performance.

## Linear4bit

[[autodoc]] bitsandbytes.nn.Linear4bit
- __init__

## LinearFP4

[[autdodoc]] bitsandbytes.nn.LinearFP4
- __init__

## LinearNF4

[[autodoc]] bitsandbytes.nn.LinearNF4
- __init__

## Params4bit

[[autodoc]] bitsandbytes.nn.Params4bit
- __init__
13 changes: 13 additions & 0 deletions docs/source/reference/nn/linear8bit.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# 8-bit quantization

[LLM.int8()](https://hf.co/papers/2208.07339) is a quantization method that doesn't degrade performance which makes large model inference more accessible. The key is to extract the outliers from the inputs and weights and multiply them in 16-bit. All other values are multiplied in 8-bit and quantized to Int8 before being dequantized back to 16-bits. The outputs from the 16-bit and 8-bit multiplication are combined to produce the final output.

## Linear8bitLt

[[autodoc]] bitsandbytes.nn.Linear8bitLt
- __init__

## Int8Params

[[autodoc]] bitsandbytes.nn.Int8Params
- __init__

0 comments on commit ead726e

Please sign in to comment.