Skip to content

Commit

Permalink
doc cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdouglas committed Dec 4, 2024
1 parent 582bf22 commit 1ae7c6b
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 96 deletions.
194 changes: 99 additions & 95 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,17 +483,13 @@ def _get_tensor_stream(tensor: Tensor) -> ct.c_void_p:


def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]:
"""
Get the ctypes pointer from a PyTorch Tensor.
"""Gets the memory address of the first element of a tenso
Parameters
----------
A : torch.tensor
The PyTorch tensor.
Args:
A (`Optional[Tensor]`): A PyTorch tensor.
Returns
-------
ctypes.c_void_p
Returns:
`Optional[ct.c_void_p]`: A pointer to the underlying tensor data.
"""
if A is None:
return None
Expand Down Expand Up @@ -863,30 +859,31 @@ def quantize_blockwise(
blocksize=4096,
nested=False,
) -> Tuple[torch.Tensor, QuantState]:
"""
Quantize tensor A in blocks of size 4096 values.
"""Quantize a tensor in blocks of values.
Quantizes tensor A by dividing it into blocks of 4096 values.
Then the absolute maximum value within these blocks is calculated
for the non-linear quantization.
The input tensor is quantized by dividing it into blocks of `blocksize` values.
The the absolute maximum value within these blocks is calculated for scaling
the non-linear quantization.
Parameters
----------
A : torch.Tensor
The input tensor.
code : torch.Tensor
The quantization map.
absmax : torch.Tensor
The absmax values.
out : torch.Tensor
The output tensor (8-bit).
Args:
A (`torch.Tensor`): The input tensor. Supports `float16`, `bfloat16`, or `float32` datatypes.
code (`torch.Tensor`, *optional*):
A mapping describing the low-bit data type. Defaults to a signed 8-bit dynamic type.
For more details, see (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561].
absmax (`torch.Tensor`, *optional*): A tensor to use to store the absmax values.
out (`torch.Tensor`, *optional*): A tensor to use to store the result.
blocksize (`int`, *optional*):
The size of the blocks. Defaults to 4096.
Valid values are 64, 128, 256, 512, 1024, 2048, and 4096.
nested (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False.
Returns
-------
torch.Tensor:
The 8-bit tensor.
tuple(torch.Tensor, torch.Tensor):
The quantization state to undo the quantization.
Raises:
ValueError: Raised when the input data type is not supported.
Returns:
`Tuple[torch.Tensor, QuantState]`: A tuple containing the quantization results.
- `torch.Tensor`: The quantized tensor.
- [`QuantState`]: The state object used to undo the quantization.
"""

if code is None:
Expand Down Expand Up @@ -967,31 +964,38 @@ def dequantize_blockwise(
blocksize: int = 4096,
nested=False,
) -> torch.Tensor:
"""
Dequantizes blockwise quantized values.
"""Dequantize a tensor in blocks of values.
Dequantizes the tensor A with maximum absolute values absmax in
blocks of size 4096.
The input tensor is dequantized by dividing it into blocks of `blocksize` values.
The the absolute maximum value within these blocks is used for scaling
the non-linear dequantization.
Parameters
----------
A : torch.Tensor
The input 8-bit tensor.
quant_state : QuantState
Object with code, absmax and other quantization state components.
absmax : torch.Tensor
The absmax values.
code : torch.Tensor
The quantization map.
out : torch.Tensor
Dequantized output tensor (default: float32)
Args:
A (`torch.Tensor`): The quantized input tensor.
quant_state ([`QuantState`], *optional*):
The quantization state as returned by [`quantize_blockwise`].
Required if `absmax` is not provided.
absmax (`torch.Tensor`, *optional*):
A tensor containing the scaling values.
Required if `quant_state` is not provided and ignored otherwise.
code (`torch.Tensor`, *optional*):
A mapping describing the low-bit data type. Defaults to a signed 8-bit dynamic type.
For more details, see (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561].
Ignored when `quant_state` is provided.
out (`torch.Tensor`, *optional*): A tensor to use to store the result.
blocksize (`int`, *optional*):
The size of the blocks. Defaults to 4096.
Valid values are 64, 128, 256, 512, 1024, 2048, and 4096.
Ignored when `quant_state` is provided.
Raises:
ValueError: Raised when the input data type is not supported.
Returns
-------
torch.Tensor:
Dequantized tensor (default: float32)
Returns:
`torch.Tensor`:
The dequantized tensor. The datatype is indicated by `quant_state.dtype` and defaults to `torch.float32`.
"""

assert quant_state is not None or absmax is not None
if code is None and quant_state is None:
if "dynamic" not in name2qmap:
Expand Down Expand Up @@ -1166,31 +1170,30 @@ def quantize_4bit(
quant_type="fp4",
quant_storage=torch.uint8,
) -> Tuple[torch.Tensor, QuantState]:
"""
Quantize tensor A in blocks of 4-bit values.
"""Quantize tensor A in blocks of 4-bit values.
Quantizes tensor A by dividing it into blocks which are independently quantized to FP4.
Quantizes tensor A by dividing it into blocks which are independently quantized.
Parameters
----------
A : torch.Tensor
The input tensor.
absmax : torch.Tensor
The absmax values.
out : torch.Tensor
The output tensor.
blocksize : int
The blocksize used in quantization.
quant_type : str
The 4-bit quantization data type {fp4, nf4}
Args:
A (`torch.Tensor`): The input tensor. Supports `float16`, `bfloat16`, or `float32` datatypes.
absmax (`torch.Tensor`, *optional*): A tensor to use to store the absmax values.
out (`torch.Tensor`, *optional*): A tensor to use to store the result.
blocksize (`int`, *optional*):
The size of the blocks. Defaults to 64.
Valid values are 64, 128, 256, 512, 1024, 2048, and 4096.
compress_statistics (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False.
quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`.
quant_storage (`torch.dtype`, *optional*): The dtype of the tensor used to store the result. Defaults to `torch.uint8`.
Returns
-------
torch.Tensor:
Tensor with packed 4-bit values.
tuple(torch.Tensor, torch.Size, torch.dtype, int):
The quantization state to undo the quantization.
Raises:
ValueError: Raised when the input data type is not supported.
Returns:
Tuple[`torch.Tensor`, `QuantState`]: A tuple containing the quantization results.
- `torch.Tensor`: The quantized tensor with packed 4-bit values.
- [`QuantState`]: The state object used to undo the quantization.
"""

if A.device.type != "cuda":
raise NotImplementedError(f"Device type not supported for FP4 quantization: {A.device.type}")
if quant_type not in ["fp4", "nf4"]:
Expand Down Expand Up @@ -1297,32 +1300,33 @@ def dequantize_4bit(
blocksize: int = 64,
quant_type="fp4",
) -> torch.Tensor:
"""
Dequantizes FP4 blockwise quantized values.
"""Dequantizes a packed 4-bit quantized tensor.
Dequantizes the tensor A with maximum absolute values absmax in blocks of size blocksize.
The input tensor is dequantized by dividing it into blocks of `blocksize` values.
The the absolute maximum value within these blocks is used for scaling
the non-linear dequantization.
Parameters
----------
A : torch.Tensor
The input tensor (packed 4-bit values).
quant_state : QuantState
object with quantisation stats, incl. absmax values, original tensor shape and original dtype.
absmax : torch.Tensor
The absmax values.
out : torch.Tensor
Dequantized output tensor.
blocksize : int
The blocksize used in quantization.
quant_type : str
The 4-bit quantization data type {fp4, nf4}
Args:
A (`torch.Tensor`): The quantized input tensor.
quant_state ([`QuantState`], *optional*):
The quantization state as returned by [`quantize_4bit`].
Required if `absmax` is not provided.
absmax (`torch.Tensor`, *optional*):
A tensor containing the scaling values.
Required if `quant_state` is not provided and ignored otherwise.
out (`torch.Tensor`, *optional*): A tensor to use to store the result.
blocksize (`int`, *optional*):
The size of the blocks. Defaults to 64.
Valid values are 64, 128, 256, 512, 1024, 2048, and 4096.
quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`.
Raises:
ValueError: Raised when the input data type or blocksize is not supported.
Returns
-------
torch.Tensor:
Dequantized tensor.
Returns:
`torch.Tensor`: The dequantized tensor.
"""

if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
raise ValueError(
f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]",
Expand Down Expand Up @@ -2277,7 +2281,7 @@ def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Ten
Args:
A (`torch.Tensor`): The first matrix operand with the data type `torch.int8`.
B (`torch.Tensor`): The second matrix operand with the data type `torch.int8`.
out (`torch.Tensor, *optional*): A pre-allocated tensor used to store the result.
out (`torch.Tensor`, *optional*): A pre-allocated tensor used to store the result.
dtype (`torch.dtype`, *optional*): The expected data type of the output. Defaults to `torch.int32`.
Raises:
Expand Down Expand Up @@ -2384,7 +2388,7 @@ def int8_mm_dequant(
A (`torch.Tensor` with dtype `torch.int32`): The result of a quantized int8 matrix multiplication.
row_stats (`torch.Tensor`): The row-wise quantization statistics for the lhs operand of the matrix multiplication.
col_stats (`torch.Tensor`): The column-wise quantization statistics for the rhs operand of the matrix multiplication.
out (`torch.Tensor], *optional*): A pre-allocated tensor to store the output of the operation.
out (`torch.Tensor`, *optional*): A pre-allocated tensor to store the output of the operation.
bias (`torch.Tensor`, *optional*): An optional bias vector to add to the result.
Returns:
Expand Down Expand Up @@ -2454,7 +2458,7 @@ def get_colrow_absmax(
row_stats (`torch.Tensor`, *optional*): If provided, calculation of row statistics is skipped.
col_stats (`torch.Tensor`, *optional*): If provided, calculation of column statistics is skipped.
nnz_block_ptr (`torch.Tensor`, *optional*): Not used.
threshold (`float`, `optional`):
threshold (`float`, *optional*):
An optional threshold for sparse decomposition of outlier features.
No outliers are held back when 0.0. Defaults to 0.0.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
- title: k-bit quantizers
sections:
- local: reference/nn/linear8bit
title: 8-bit quantizer
title: LLM.int8()
- local: reference/nn/linear4bit
title: 4-bit quantizer
- local: reference/nn/embeddings
Expand Down

0 comments on commit 1ae7c6b

Please sign in to comment.