diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 378b3941f..fc2e6651e 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2445,9 +2445,9 @@ def mm_dequant( def get_colrow_absmax( A: torch.Tensor, - row_stats: torch.Tensor = None, - col_stats: torch.Tensor = None, - nnz_block_ptr: torch.Tensor = None, + row_stats: Optional[torch.Tensor] = None, + col_stats: Optional[torch.Tensor] = None, + nnz_block_ptr: Optional[torch.Tensor] = None, threshold=0.0, ): # Note: prior impl only works with fp16