Skip to content

Commit

Permalink
Type annotations, cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdouglas committed Nov 5, 2024
1 parent b5d6135 commit b1c4adc
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 49 deletions.
37 changes: 20 additions & 17 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,25 +244,26 @@ def get_tile_inds(format, device):
@dataclass
class MatmulLtState:
_tile_indices: Optional[torch.Tensor] = None

force_no_igemmlt: bool = False
CB = None
CxB = None # TODO: Deprecate/remove
SB = None
SCB = None

CxBt = None # TODO: Deprecate/remove
SBt = None
CBt = None
CB: Optional[torch.Tensor] = None
CxB: Optional[torch.Tensor] = None # TODO: Deprecate/remove
SB: Optional[torch.Tensor] = None
SCB: Optional[torch.Tensor] = None

CxBt: Optional[torch.Tensor] = None # TODO: Deprecate/remove
SBt: Optional[torch.Tensor] = None
CBt: Optional[torch.Tensor] = None

subB = None
subB: Optional[torch.Tensor] = None

outlier_pool = None
outlier_pool: Optional[GlobalOutlierPooler] = None
has_accumulated_gradients = False
threshold = 0.0
idx = None
idx: Optional[torch.Tensor] = None
is_training = True
has_fp16_weights = True
memory_efficient_backward = False
use_pool = False
formatB = "row" # TODO: Deprecate/remove

Expand Down Expand Up @@ -313,10 +314,10 @@ def forward(
if A.dtype != torch.float16:
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")

# 1. Quantize A. Note that as a side-effect, outliers are suppressed.
if len(A.shape) == 3:
A = A.reshape(-1, A.shape[-1])

# 1. Quantize A. Note that as a side-effect, outliers are suppressed in CA/CAt.
if ctx.needs_input_grad[1]:
# Slower path
CA, CAt, SCA, SCAt, outlier_cols = F.double_quant(A.to(torch.float16), threshold=state.threshold)
Expand Down Expand Up @@ -366,6 +367,8 @@ def forward(

# 3. Int8 Matmul
out32 = F.int8_linear_matmul(CA, state.CB)

# Dequantize matmul result
if bias is None or bias.dtype == torch.float16:
# we apply the fused bias here
output = F.int8_mm_dequant(out32, SCA, state.SCB, bias=bias).to(A.dtype)
Expand All @@ -375,7 +378,7 @@ def forward(

# 4. Mixed-precision decomposition matmul
if subA is not None and state.subB is not None:
output += torch.matmul(subA, state.subB.to(subA.dtype))
output += torch.matmul(subA, state.subB)

# 5. Save state
ctx.state = state
Expand All @@ -399,15 +402,15 @@ def forward(
return output

@staticmethod
def backward(ctx, grad_output):
def backward(ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor):
if ctx.is_empty:
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None

req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
CAt, subA, A = ctx.tensors
SCAt, idx = ctx.tensor_states
state = ctx.state
state: MatmulLtState = ctx.state
grad_A = grad_B = grad_bias = None

if req_gradBias:
Expand Down Expand Up @@ -499,7 +502,7 @@ def matmul(
out: Optional[torch.Tensor] = None,
state: Optional[MatmulLtState] = None,
threshold=0.0,
bias=None,
bias: Optional[torch.Tensor] = None,
):
state = state or MatmulLtState()
if threshold > 0.0:
Expand All @@ -512,7 +515,7 @@ def matmul_4bit(
B: torch.Tensor,
quant_state: F.QuantState,
out: Optional[torch.Tensor] = None,
bias=None,
bias: Optional[torch.Tensor] = None,
):
assert quant_state is not None

Expand Down
18 changes: 0 additions & 18 deletions bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,3 @@
"""
extract factors the build is dependent on:
[X] compute capability
[ ] TODO: Q - What if we have multiple GPUs of different makes?
- CUDA version
- Software:
- CPU-only: only CPU quantization functions (no optimizer, no matrix multiple)
- CuBLAS-LT: full-build 8-bit optimizer
- no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`)
evaluation:
- if paths faulty, return meaningful error
- else:
- determine CUDA version
- determine capabilities
- based on that set the default path
"""

import ctypes as ct
import logging
import os
Expand Down
12 changes: 7 additions & 5 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2279,7 +2279,9 @@ def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Ten
ldb = shapeB[-1] # Activations (batch, tokens, inputs)
ldc = shapeC[-1] # Output (batch, tokens, outputs)

assert lda == ldb, f"igemmlt only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}"
assert (
lda == ldb
), f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}"

is_on_gpu([A, B, out])

Expand All @@ -2288,7 +2290,7 @@ def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Ten
ptrA = get_ptr(A)
ptrB = get_ptr(B)
ptrC = get_ptr(out)
ptrRowScale = get_ptr(None)
ptrRowScale = None
m = ct.c_int32(m)
n = ct.c_int32(n)
k = ct.c_int32(k)
Expand All @@ -2303,7 +2305,7 @@ def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Ten
has_error = lib.cigemmlt_8(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream)

if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu`
raise NotImplementedError("igemmlt not implemented!")
raise NotImplementedError("int8_linear_matmul not implemented!")

if has_error:
raise RuntimeError(
Expand Down Expand Up @@ -2369,7 +2371,7 @@ def get_colrow_absmax(
col_stats: Optional[torch.Tensor] = None,
nnz_block_ptr: Optional[torch.Tensor] = None,
threshold=0.0,
):
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
# Note: prior impl only works with fp16
assert A.is_floating_point()

Expand All @@ -2395,7 +2397,7 @@ def get_colrow_absmax(
return row_stats, col_stats, outlier_mask


def get_row_absmax(A, threshold=0.0):
def get_row_absmax(A: torch.Tensor, threshold=0.0):
assert A.dtype == torch.float16

rows = prod(A.shape[:-1])
Expand Down
10 changes: 4 additions & 6 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,11 +566,11 @@ def __init__(
class Int8Params(torch.nn.Parameter):
def __new__(
cls,
data=None,
data: Optional[torch.Tensor] = None,
requires_grad=True,
has_fp16_weights=False,
CB=None,
SCB=None,
CB: Optional[torch.Tensor] = None,
SCB: Optional[torch.Tensor] = None,
):
if data is None:
data = torch.empty(0)
Expand Down Expand Up @@ -881,7 +881,6 @@ def __init__(
output_features: int,
bias=True,
has_fp16_weights=True,
memory_efficient_backward=False,
threshold=0.0,
index=None,
device=None,
Expand All @@ -898,13 +897,12 @@ def __init__(
Whether the linear class uses the bias term as well.
"""
super().__init__(input_features, output_features, bias, device)
assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
self.state = bnb.MatmulLtState()
self.index = index

self.state.threshold = threshold
self.state.has_fp16_weights = has_fp16_weights
self.state.memory_efficient_backward = memory_efficient_backward

if threshold > 0.0 and not has_fp16_weights:
self.state.use_pool = True

Expand Down
3 changes: 0 additions & 3 deletions bitsandbytes/research/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,9 +328,6 @@ def backward(ctx, grad_output):
grad_B = torch.matmul(grad_output.t(), A)

if req_gradA:
# if state.CBt is not None:
# gradA32, SgradA32 = F.igemmlt(Cgrad, state.CBt.t())
# grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
if state.CB is not None:
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
Expand Down

0 comments on commit b1c4adc

Please sign in to comment.