Skip to content

Commit

Permalink
[docs] implement API docs (#1075)
Browse files Browse the repository at this point in the history
* optims

* fix path

* fix path

* mdx

* fix path

* toctree

* fix

* optimizer, adagrad

* add init

* add

* more apis

* params

* clarify

* run pre-commit hooks

---------

Co-authored-by: Titus von Koeller <[email protected]>
  • Loading branch information
stevhliu and Titus-von-Koeller authored Mar 7, 2024
1 parent 87e029b commit ac5d6ee
Show file tree
Hide file tree
Showing 25 changed files with 1,389 additions and 44 deletions.
3 changes: 3 additions & 0 deletions .git-blame-ignore-revs
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@ ea7c14f8ef64924f2d0ff80df3cdabf2c7299848

# Remove f-prefix from strings that don't use formatting
7727fa4c8c6c1ef2b109120aff4196a0a6bf3ed6

# format tests/linear_4bit.py
34735ba89de8235ea9da6ef409f814dcea9e2038
83 changes: 61 additions & 22 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,7 @@

class StableEmbedding(torch.nn.Embedding):
"""
Custom embedding layer designed for stable training in NLP tasks. The stable
embedding layer improves stability during optimization for models with word
embeddings, addressing issues related to the non-uniform distribution of input
tokens.
This stable embedding layer is initialized with Xavier uniform initialization,
followed by layer normalization. It is designed to support aggressive quantization,
addressing extreme gradient variations in non-uniform input distributions. The
stability of training is enhanced by using 32-bit optimizer states specifically
for this layer.
Custom embedding layer designed to improve stability during training for NLP tasks by using 32-bit optimizer states. It is designed to reduce gradient variations that can result from quantization. This embedding layer is initialized with Xavier uniform initialization followed by layer normalization.
Example:
Expand All @@ -47,14 +38,11 @@ class StableEmbedding(torch.nn.Embedding):
```
Attributes:
norm (torch.nn.LayerNorm): Layer normalization applied after the embedding.
norm (`torch.nn.LayerNorm`): Layer normalization applied after the embedding.
Methods:
reset_parameters(): Reset embedding parameters using Xavier uniform initialization.
forward(input: Tensor) -> Tensor: Forward pass through the stable embedding layer.
Reference:
- [8-bit optimizer paper](https://arxiv.org/pdf/2110.02861.pdf)
"""
def __init__(
self,
Expand All @@ -71,14 +59,22 @@ def __init__(
) -> None:
"""
Args:
num_embeddings (`int`): The number of unique embeddings (vocabulary size).
embedding_dim (`int`): The dimensionality of the embedding.
padding_idx (`Optional[int]`): If specified, pads the output with zeros at the given index.
max_norm (`Optional[float]`): If given, renormalizes embeddings to have a maximum L2 norm.
norm_type (`float`, defaults to `2.0`): The p-norm to compute for the max_norm option.
scale_grad_by_freq (`bool`): Scale gradient by frequency during backpropagation.
sparse (`bool`): If True, computes sparse gradients; False, computes dense gradients.
_weight (`Optional[Tensor]`): Pre-trained embeddings.
num_embeddings (`int`):
The number of unique embeddings (vocabulary size).
embedding_dim (`int`):
The dimensionality of the embedding.
padding_idx (`Optional[int]`):
Pads the output with zeros at the given index.
max_norm (`Optional[float]`):
Renormalizes embeddings to have a maximum L2 norm.
norm_type (`float`, defaults to `2.0`):
The p-norm to compute for the `max_norm` option.
scale_grad_by_freq (`bool`, defaults to `False`):
Scale gradient by frequency during backpropagation.
sparse (`bool`, defaults to `False`):
Computes dense gradients. Set to `True` to compute sparse gradients instead.
_weight (`Optional[Tensor]`):
Pretrained embeddings.
"""
super().__init__(
num_embeddings,
Expand Down Expand Up @@ -131,6 +127,9 @@ def forward(self, input: Tensor) -> Tensor:


class Embedding(torch.nn.Embedding):
"""
Embedding class to store and retrieve word embeddings from their indices.
"""
def __init__(
self,
num_embeddings: int,
Expand All @@ -143,6 +142,25 @@ def __init__(
_weight: Optional[Tensor] = None,
device: Optional[device] = None,
) -> None:
"""
Args:
num_embeddings (`int`):
The number of unique embeddings (vocabulary size).
embedding_dim (`int`):
The dimensionality of the embedding.
padding_idx (`Optional[int]`):
Pads the output with zeros at the given index.
max_norm (`Optional[float]`):
Renormalizes embeddings to have a maximum L2 norm.
norm_type (`float`, defaults to `2.0`):
The p-norm to compute for the `max_norm` option.
scale_grad_by_freq (`bool`, defaults to `False`):
Scale gradient by frequency during backpropagation.
sparse (`bool`, defaults to `False`):
Computes dense gradients. Set to `True` to compute sparse gradients instead.
_weight (`Optional[Tensor]`):
Pretrained embeddings.
"""
super().__init__(
num_embeddings,
embedding_dim,
Expand Down Expand Up @@ -416,7 +434,19 @@ def forward(self, x: torch.Tensor):


class LinearFP4(Linear4bit):
"""
Implements the FP4 data type.
"""
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None):
"""
Args:
input_features (`str`):
Number of input features of the linear layer.
output_features (`str`):
Number of output features of the linear layer.
bias (`bool`, defaults to `True`):
Whether the linear class uses the bias term as well.
"""
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', quant_storage, device)


Expand All @@ -432,6 +462,15 @@ class LinearNF4(Linear4bit):
the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236.
'''
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None):
"""
Args:
input_features (`str`):
Number of input features of the linear layer.
output_features (`str`):
Number of output features of the linear layer.
bias (`bool`, defaults to `True`):
Whether the linear class uses the bias term as well.
"""
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', quant_storage, device)


Expand Down
81 changes: 81 additions & 0 deletions bitsandbytes/optim/adagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,33 @@ def __init__(
percentile_clipping=100,
block_wise=True,
):
"""
Base Adagrad optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-2):
The learning rate.
lr_decay (`int`, defaults to 0):
The learning rate decay.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
initial_accumulator_value (`int`, defaults to 0):
The initial momemtum values.
eps (`float`, defaults to 1e-10):
The epsilon value prevents division by zero in the optimizer.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay:
Expand Down Expand Up @@ -62,6 +89,33 @@ def __init__(
percentile_clipping=100,
block_wise=True,
):
"""
8-bit Adagrad optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-2):
The learning rate.
lr_decay (`int`, defaults to 0):
The learning rate decay.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
initial_accumulator_value (`int`, defaults to 0):
The initial momemtum values.
eps (`float`, defaults to 1e-10):
The epsilon value prevents division by zero in the optimizer.
optim_bits (`int`, defaults to 8):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay:
Expand Down Expand Up @@ -105,6 +159,33 @@ def __init__(
percentile_clipping=100,
block_wise=True,
):
"""
32-bit Adagrad optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-2):
The learning rate.
lr_decay (`int`, defaults to 0):
The learning rate decay.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
initial_accumulator_value (`int`, defaults to 0):
The initial momemtum values.
eps (`float`, defaults to 1e-10):
The epsilon value prevents division by zero in the optimizer.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay:
Expand Down
Loading

0 comments on commit ac5d6ee

Please sign in to comment.