diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index f9ccdc2e1..7d7547130 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2429,7 +2429,32 @@ def get_colrow_absmax( nnz_block_ptr: Optional[torch.Tensor] = None, threshold=0.0, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - # Note: prior impl only works with fp16 + """ "Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm. + + The row-wise and column-wise absmax values are determined. + + For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339). + + + This function is useful for training, but for inference it is advised to use [`get_row_absmax`] instead. + The column-wise quantization scales are not typically needed in inference scenarios. + + + Args: + A (`torch.Tensor` with dtype `torch.float16`): Input tensor. + row_stats (`torch.Tensor`, *optional*): If provided, calculation of row statistics is skipped. + col_stats (`torch.Tensor`, *optional*): If provided, calculation of column statistics is skipped. + nnz_block_ptr (`torch.Tensor`, *optional*): Not used. + threshold (`float`, `optional`): + An optional threshold for sparse decomposition of outlier features. + No outliers are held back when 0.0. Defaults to 0.0. + + Returns: + `Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]`: A tuple containing quantization statistics. + - `torch.Tensor` with dtype `torch.float32`: The row-wise quantization statistics. + - `torch.Tensor` with dtype `torch.float32`: The column-wise quantization statistics. + - `torch.Tensor` with dtype `torch.bool`, *optional*: A mask indicating the locations of outliers in the input tensor. + """ assert A.is_floating_point() outlier_mask = None