Skip to content

Commit

Permalink
int8 sparse decomp: small perf improvement
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdouglas committed Oct 30, 2024
1 parent 32979b4 commit 521da0c
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 21 deletions.
8 changes: 4 additions & 4 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,10 +319,10 @@ def forward(

if ctx.needs_input_grad[1]:
# Slower path
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold)
CA, CAt, SCA, SCAt, outlier_cols = F.double_quant(A.to(torch.float16), threshold=state.threshold)
else:
# Fast path
CA, SCA, coo_tensorA = F.int8_vectorwise_quant(A.to(torch.float16), threshold=state.threshold)
CA, SCA, outlier_cols = F.int8_vectorwise_quant(A.to(torch.float16), threshold=state.threshold)
CAt = SCAt = None

has_grad = False
Expand All @@ -339,8 +339,8 @@ def forward(
# 2. Quantize B
state.CB, state.SCB, _ = F.int8_vectorwise_quant(B.to(torch.float16))

if state.threshold > 0.0 and coo_tensorA is not None:
state.idx = torch.unique(coo_tensorA._indices()[1]).long()
if state.threshold > 0.0 and outlier_cols is not None:
state.idx = outlier_cols

# Zero out the outliers in the transposed 8bit inputs.
if CAt is not None:
Expand Down
14 changes: 5 additions & 9 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2546,7 +2546,7 @@ def double_quant(A: torch.Tensor, col_stats=None, row_stats=None, out_col=None,
# Note: for inference, use the new int8_vectorwise_quant.

# Use CUDA kernel for rowwise and COO tensor
quant_row, row_stats, coo_tensor = int8_vectorwise_quant(A, threshold=threshold)
quant_row, row_stats, outlier_cols = int8_vectorwise_quant(A, threshold=threshold)

# PyTorch impl for colwise
_, col_stats, outlier_mask = get_colrow_absmax(A, threshold=threshold)
Expand All @@ -2559,7 +2559,7 @@ def double_quant(A: torch.Tensor, col_stats=None, row_stats=None, out_col=None,
if out_col is not None:
quant_col = out_col.copy_(quant_col)

return quant_row, quant_col, row_stats, col_stats.flatten().float(), coo_tensor
return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols


def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0):
Expand All @@ -2574,13 +2574,9 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0):

if threshold > 0.0:
# TODO we could improve perf of this

# A.masked_fill(A.abs() < threshold, 0.0).to_sparse_coo()
# coo_tensor = extract_outliers_new(A, threshold)
coo_tensor = torch.masked_fill(A, A.abs() < threshold, 0.0).to_sparse_coo()

outlier_cols = torch.argwhere((A.abs() >= threshold).any(dim=0)).view(-1)
else:
coo_tensor = None
outlier_cols = None

with torch.cuda.device_of(A):
lib.cint8_vector_quant(
Expand All @@ -2593,7 +2589,7 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0):
get_tensor_stream(A),
)

return out_row, row_stats, coo_tensor
return out_row, row_stats, outlier_cols


@deprecated(
Expand Down
15 changes: 7 additions & 8 deletions bitsandbytes/research/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,12 +217,11 @@ def forward(ctx, A, B, out=None, bias=None, state: MatmulLtState = None):
# 1. Quantize A
if len(A.shape) == 3:
A = A.view(-1, A.shape[-1]).contiguous()
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold)
CA, CAt, SCA, SCAt, outlier_cols = F.double_quant(A.to(torch.float16), threshold=state.threshold)

if state.threshold > 0.0 and coo_tensorA is not None:
if state.threshold > 0.0 and outlier_cols is not None:
if state.has_fp16_weights:
# idx = torch.unique(coo_tensorA.colidx).long()
idx = torch.unique(coo_tensorA._indices()[1]).long()
idx = outlier_cols
CA[:, idx] = 0
# CAt[:, idx] = 0
subA = A[:, idx]
Expand Down Expand Up @@ -257,9 +256,9 @@ def forward(ctx, A, B, out=None, bias=None, state: MatmulLtState = None):
else:
has_grad = False

if coo_tensorA is not None and not state.has_fp16_weights:
if outlier_cols is not None and not state.has_fp16_weights:
# extract outliers
state.idx = torch.unique(coo_tensorA._indices()[1]).long()
state.idx = outlier_cols

# outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
outliers = state.CB[:, state.idx.long()].clone()
Expand Down Expand Up @@ -287,7 +286,7 @@ def forward(ctx, A, B, out=None, bias=None, state: MatmulLtState = None):
output = output.to(A.dtype).add_(bias)

# 4. Mixed-precision decomposition matmul
if coo_tensorA is not None and subA is not None:
if outlier_cols is not None and subA is not None:
output += torch.matmul(subA, state.subB)

# 5. Save state
Expand Down Expand Up @@ -327,7 +326,7 @@ def backward(ctx, grad_output):
if len(grad_output.shape) == 3:
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()

Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
Cgrad, Cgradt, SCgrad, SCgradt, outlier_cols = F.double_quant(grad_output.to(torch.float16))

if req_gradB:
# print('back A shape', A.shape)
Expand Down

0 comments on commit 521da0c

Please sign in to comment.