Skip to content

Commit

Permalink
int8 matmul fallback for inner dims not divisible by 4
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdouglas committed Nov 25, 2024
1 parent df941ec commit 73f02e8
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
15 changes: 12 additions & 3 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2314,9 +2314,6 @@ def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Ten

shapeC = (*shapeB[:-1], shapeA[0])

if out is None:
out = torch.empty(shapeC, device=A.device, dtype=dtype)

k, m = shapeA
n = prod(shapeB[:-1])
lda = shapeA[-1] # Weights (outputs, inputs)
Expand All @@ -2327,6 +2324,18 @@ def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Ten
lda == ldb
), f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}"

# cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4.
# We'll fall back to a slower fp32 calculation in this circumstance.
# Fortunately, this should not be very common.
if lda % 4 != 0:
result = torch.matmul(B.float(), A.float().t()).to(torch.int32)
if out is not None:
result = out.copy_(result)
return result

if out is None:
out = torch.empty(shapeC, device=A.device, dtype=dtype)

is_on_gpu([A, B, out])

with _cuda_device_of(A):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose):

@pytest.mark.parametrize("dim1", [128], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [256], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [512], ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim3", [499, 512], ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", [512], ids=id_formatter("dim4"))
@pytest.mark.parametrize("dims", (2, 3), ids=id_formatter("dims"))
@pytest.mark.parametrize("ldb", (0,), ids=id_formatter("ldb"))
Expand Down

0 comments on commit 73f02e8

Please sign in to comment.