Skip to content

Commit

Permalink
update docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdouglas committed Nov 18, 2024
1 parent 4bced86 commit f61d8bc
Showing 1 changed file with 112 additions and 3 deletions.
115 changes: 112 additions & 3 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,20 @@ def get_special_format_str():


def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]):
"""Verifies that the input tensors are all on the same device.
An input tensor may also be marked as `paged`, in which case the device placement is ignored.
Args:
tensors (Iterable[Optional[torch.Tensor]]): A list of tensors to verify.
Raises:
`RuntimeError`: Raised when the verification fails.
Returns:
`Literal[True]`
"""

on_gpu = True
gpu_ids = set()

Expand Down Expand Up @@ -1199,7 +1213,7 @@ def quantize_4bit(

with _cuda_device_of(A):
args = (
get_ptr(None),
None,
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
Expand Down Expand Up @@ -1346,7 +1360,7 @@ def dequantize_4bit(

with _cuda_device_of(A):
args = (
get_ptr(None),
None,
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
Expand Down Expand Up @@ -2255,6 +2269,25 @@ def igemmlt(


def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtype=torch.int32):
"""Performs an 8-bit integer matrix multiplication.
A linear transformation is applied such that `out = A @ B.T`. When possible, integer tensor core hardware is
utilized to accelerate the operation.
Args:
A (`torch.Tensor`): The first matrix operand with the data type `torch.int8`.
B (`torch.Tensor`): The second matrix operand with the data type `torch.int8`.
out (`torch.Tensor, *optional*): A pre-allocated tensor used to store the result.
dtype (`torch.dtype`, *optional*): The expected data type of the output. Defaults to `torch.int32`.
Raises:
`NotImplementedError`: The operation is not supported in the current environment.
`RuntimeError`: Raised when the cannot be completed for any other reason.
Returns:
`torch.Tensor`: The result of the operation.
"""

#
# To use the IMMA tensor core kernels without special Turing/Ampere layouts,
# cublasLt has some rules, namely: A must be transposed, B must not be transposed.
Expand Down Expand Up @@ -2336,6 +2369,19 @@ def int8_mm_dequant(
out: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
):
"""Performs dequantization on the result of a quantized int8 matrix multiplication.
Args:
A (`torch.Tensor` with dtype `torch.int32`): The result of a quantized int8 matrix multiplication.
row_stats (`torch.Tensor`): The row-wise quantization statistics for the lhs operand of the matrix multiplication.
col_stats (`torch.Tensor`): The column-wise quantization statistics for the rhs operand of the matrix multiplication.
out (`torch.Tensor], *optional*): A pre-allocated tensor to store the output of the operation.
bias (`torch.Tensor`, *optional*): An optional bias vector to add to the result.
Returns:
`torch.Tensor`: The dequantized result with an optional bias, with dtype `torch.float16`.
"""

assert A.dtype == torch.int32

if bias is not None:
Expand Down Expand Up @@ -2409,6 +2455,20 @@ def get_colrow_absmax(


def get_row_absmax(A: torch.Tensor, threshold=0.0):
"""Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm.
For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339).
Args:
A (`torch.Tensor` with dtype `torch.float16`): The input matrix.
threshold (`float`, *optional*):
An optional threshold for sparse decomposition of outlier features.
No outliers are held back when 0.0. Defaults to 0.0.
Returns:
`torch.Tensor` with dtype `torch.float32`: The absolute maximum value for each row, with outliers ignored.
"""

assert A.dtype == torch.float16

rows = prod(A.shape[:-1])
Expand Down Expand Up @@ -2520,6 +2580,37 @@ def double_quant(
out_row: Optional[torch.Tensor] = None,
threshold=0.0,
):
"""Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm.
The statistics are determined both row-wise and column-wise (transposed).
For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339).
<Tip>
This function is useful for training, but for inference it is advised to use [`int8_vectorwise_quant`] instead.
This implementation performs additional column-wise transposed calculations which are not optimized.
</Tip>
Args:
A (`torch.Tensor` with dtype `torch.float16`): The input matrix.
col_stats (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the column-wise quantization scales.
row_stats (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the row-wise quantization scales.
out_col (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the column-wise quantized data.
out_row (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the row-wise quantized data.
threshold (`float`, *optional*):
An optional threshold for sparse decomposition of outlier features.
No outliers are held back when 0.0. Defaults to 0.0.
Returns:
`Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]`: A tuple containing the quantized tensor and relevant statistics.
- `torch.Tensor` with dtype `torch.int8`: The row-wise quantized data.
- `torch.Tensor` with dtype `torch.int8`: The column-wise quantized data.
- `torch.Tensor` with dtype `torch.float32`: The row-wise quantization scales.
- `torch.Tensor` with dtype `torch.float32`: The column-wise quantization scales.
- `torch.Tensor` with dtype `torch.int32`, *optional*: A list of column indices which contain outlier features.
"""

# TODO: Optimize/write CUDA kernel for this?
# Note: for inference, use the new int8_vectorwise_quant.

Expand All @@ -2541,6 +2632,24 @@ def double_quant(


def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0):
"""Quantizes a tensor with dtype `torch.float16` to `torch.int8` in accordance to the `LLM.int8()` algorithm.
For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339).
Args:
A (`torch.Tensor` with dtype `torch.float16`): The input tensor.
threshold (`float`, *optional*):
An optional threshold for sparse decomposition of outlier features.
No outliers are held back when 0.0. Defaults to 0.0.
Returns:
`Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]`: A tuple containing the quantized tensor and relevant statistics.
- `torch.Tensor` with dtype `torch.int8`: The quantized data.
- `torch.Tensor` with dtype `torch.float32`: The quantization scales.
- `torch.Tensor` with dtype `torch.int32`, *optional*: A list of column indices which contain outlier features.
"""

assert A.dtype == torch.half
is_on_gpu([A])

Expand Down Expand Up @@ -2838,7 +2947,7 @@ def vectorwise_dequant(xq, max1, quant_type="vector"):


@deprecated(
"This function is deprecated and will be removed in a future release. Consider using `mm_dequant` instead.",
"This function is deprecated and will be removed in a future release. Consider using `int8_mm_dequant` instead.",
category=FutureWarning,
)
def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"):
Expand Down

0 comments on commit f61d8bc

Please sign in to comment.