From 73f02e864838f0f52dfe763d26240aa4a7e73e65 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 25 Nov 2024 12:33:42 -0500 Subject: [PATCH] int8 matmul fallback for inner dims not divisible by 4 --- bitsandbytes/functional.py | 15 ++++++++++++--- tests/test_functional.py | 2 +- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index d79c39c41..11874b739 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2314,9 +2314,6 @@ def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Ten shapeC = (*shapeB[:-1], shapeA[0]) - if out is None: - out = torch.empty(shapeC, device=A.device, dtype=dtype) - k, m = shapeA n = prod(shapeB[:-1]) lda = shapeA[-1] # Weights (outputs, inputs) @@ -2327,6 +2324,18 @@ def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Ten lda == ldb ), f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}" + # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4. + # We'll fall back to a slower fp32 calculation in this circumstance. + # Fortunately, this should not be very common. + if lda % 4 != 0: + result = torch.matmul(B.float(), A.float().t()).to(torch.int32) + if out is not None: + result = out.copy_(result) + return result + + if out is None: + out = torch.empty(shapeC, device=A.device, dtype=dtype) + is_on_gpu([A, B, out]) with _cuda_device_of(A): diff --git a/tests/test_functional.py b/tests/test_functional.py index 3adeb1a96..20375a02e 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -427,7 +427,7 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose): @pytest.mark.parametrize("dim1", [128], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [256], ids=id_formatter("dim2")) -@pytest.mark.parametrize("dim3", [512], ids=id_formatter("dim3")) +@pytest.mark.parametrize("dim3", [499, 512], ids=id_formatter("dim3")) @pytest.mark.parametrize("dim4", [512], ids=id_formatter("dim4")) @pytest.mark.parametrize("dims", (2, 3), ids=id_formatter("dims")) @pytest.mark.parametrize("ldb", (0,), ids=id_formatter("ldb"))