Skip to content

Commit

Permalink
refine structures of backends
Browse files Browse the repository at this point in the history
  • Loading branch information
jianan-gu committed Feb 7, 2024
1 parent b933f9f commit 8b4baaa
Show file tree
Hide file tree
Showing 6 changed files with 372 additions and 340 deletions.
2 changes: 1 addition & 1 deletion bitsandbytes/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import typing
from typing import Dict
import torch

from bitsandbytes.cextension import COMPILED_WITH_CUDA
Expand Down
133 changes: 133 additions & 0 deletions bitsandbytes/backends/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from abc import ABC, abstractmethod
from typing import Optional, Tuple

import torch

from bitsandbytes.utils import QuantState


class Backend(ABC):
"""Base class for devices backends that will implement their own 8bits and 4bits functions."""

@abstractmethod
def double_quant(
self,
A,
col_stats=None,
row_stats=None,
out_col=None,
out_row=None,
threshold=0.0,
):
raise NotImplementedError

@abstractmethod
def transform(
self,
A,
to_order,
from_order="row",
out=None,
transpose=False,
state=None,
ld=None,
):
raise NotImplementedError

@abstractmethod
def igemmlt(self, A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
raise NotImplementedError

@abstractmethod
def mm_dequant(
self,
A,
quant_state,
row_stats,
col_stats,
out=None,
new_row_stats=None,
new_col_stats=None,
bias=None,
):
raise NotImplementedError

@abstractmethod
def extract_outliers(self, A, SA, idx):
raise NotImplementedError

@abstractmethod
def quantize_4bit(
self,
A: torch.Tensor,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize=64,
compress_statistics=False,
quant_type="fp4",
quant_storage=torch.uint8,
) -> Tuple[torch.Tensor, QuantState]:
"""
Quantize tensor A in blocks of 4-bit values.
Quantizes tensor A by dividing it into blocks which are independently quantized to FP4.
Parameters
----------
A : torch.Tensor
The input tensor.
absmax : torch.Tensor
The absmax values.
out : torch.Tensor
The output tensor.
blocksize : int
The blocksize used in quantization.
quant_type : str
The 4-bit quantization data type {fp4, nf4}
Returns
-------
torch.Tensor:
Tensor with packed 4-bit values.
tuple(torch.Tensor, torch.Size, torch.dtype, int):
The quantization state to undo the quantization.
"""
raise NotImplementedError

@abstractmethod
def dequantize_4bit(
self,
A: torch.Tensor,
quant_state: Optional[QuantState] = None,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize: int = 64,
quant_type="fp4",
) -> torch.Tensor:
"""
Dequantizes FP4 blockwise quantized values.
Dequantizes the tensor A with maximum absolute values absmax in blocks of size blocksize.
Parameters
----------
A : torch.Tensor
The input tensor (packed 4-bit values).
quant_state : QuantState
object with quantisation stats, incl. absmax values, original tensor shape and original dtype.
absmax : torch.Tensor
The absmax values.
out : torch.Tensor
Dequantized output tensor.
blocksize : int
The blocksize used in quantization.
quant_type : str
The 4-bit quantization data type {fp4, nf4}
Returns
-------
torch.Tensor:
Dequantized tensor.
"""
raise NotImplementedError
94 changes: 0 additions & 94 deletions bitsandbytes/backends/basic_backend.py

This file was deleted.

88 changes: 12 additions & 76 deletions bitsandbytes/backends/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@
from typing import Optional, Tuple

import torch
from torch import Tensor

from bitsandbytes.cextension import lib
from bitsandbytes.functional import (
CUBLAS_Context,
QuantState,
coo_zeros,
dequantize_blockwise,
dtype2bytes,
Expand All @@ -22,19 +20,14 @@
quantize_blockwise,
)

from .basic_backend import DeviceBackends
from bitsandbytes.utils import QuantState

from .base import Backend

class CUDABackend(DeviceBackends):
def __init__(self, backend_name: str):
self.backend_name = backend_name

def get_name(self) -> str:
return self.backend_name

@classmethod
class CUDABackend(Backend):
def double_quant(
cls, A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0
self, A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0
):
device = A.device
assert A.dtype == torch.half
Expand Down Expand Up @@ -128,8 +121,7 @@ def double_quant(

return out_row, out_col, row_stats, col_stats, coo_tensor

@classmethod
def transform(cls, A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None):
def transform(self, A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None):
prev_device = pre_call(A.device)
if state is None: state = (A.shape, from_order)
else: from_order = state[1]
Expand Down Expand Up @@ -172,8 +164,7 @@ def transform(cls, A, to_order, from_order='row', out=None, transpose=False, sta

return out, new_state

@classmethod
def igemmlt(cls, A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
def igemmlt(self, A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
shapeA = SA[0]
shapeB = SB[0]
dimsA = len(shapeA)
Expand Down Expand Up @@ -272,9 +263,8 @@ def igemmlt(cls, A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):

return out, Sout

@classmethod
def mm_dequant(
cls,
self,
A,
quant_state,
row_stats,
Expand Down Expand Up @@ -324,8 +314,7 @@ def mm_dequant(

return out

@classmethod
def extract_outliers(cls, A, SA, idx):
def extract_outliers(self, A, SA, idx):
shapeA = SA[0]
formatA = SA[1]
assert formatA in ["col_turing", "col_ampere"]
Expand All @@ -351,42 +340,16 @@ def extract_outliers(cls, A, SA, idx):

return out

@classmethod
def quantize_4bit(
cls,
A: Tensor,
self,
A: torch.Tensor,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize=64,
compress_statistics=False,
quant_type='fp4',
quant_storage=torch.uint8,
) -> Tuple[Tensor, QuantState]:
"""
Quantize tensor A in blocks of 4-bit values.
Quantizes tensor A by dividing it into blocks which are independently quantized to FP4.
Parameters
----------
A : torch.Tensor
The input tensor.
absmax : torch.Tensor
The absmax values.
out : torch.Tensor
The output tensor.
blocksize : int
The blocksize used in quantization.
quant_type : str
The 4-bit quantization data type {fp4, nf4}
Returns
-------
torch.Tensor:
Tensor with packed 4-bit values.
tuple(torch.Tensor, torch.Size, torch.dtype, int):
The quantization state to undo the quantization.
"""
) -> Tuple[torch.Tensor, QuantState]:
if A.device.type != 'cuda':
raise NotImplementedError(f'Device type not supported for FP4 quantization: {A.device.type}')
if quant_type not in ['fp4', 'nf4']:
Expand Down Expand Up @@ -442,34 +405,7 @@ def quantize_4bit(

return out, state

@classmethod
def dequantize_4bit(cls, A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64, quant_type='fp4') -> Tensor:
"""
Dequantizes FP4 blockwise quantized values.
Dequantizes the tensor A with maximum absolute values absmax in blocks of size blocksize.
Parameters
----------
A : torch.Tensor
The input tensor (packed 4-bit values).
quant_state : QuantState
object with quantisation stats, incl. absmax values, original tensor shape and original dtype.
absmax : torch.Tensor
The absmax values.
out : torch.Tensor
Dequantized output tensor.
blocksize : int
The blocksize used in quantization.
quant_type : str
The 4-bit quantization data type {fp4, nf4}
Returns
-------
torch.Tensor:
Dequantized tensor.
"""
def dequantize_4bit(self, A: torch.Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64, quant_type='fp4') -> torch.Tensor:
if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]")
if quant_type not in ['fp4', 'nf4']:
Expand Down
Loading

0 comments on commit 8b4baaa

Please sign in to comment.