diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 1045070cd..512fd2455 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -3,7 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from . import device_setup, utils, research +from . import device_setup, research, utils from .autograd._functions import ( MatmulLtState, bmm_cublas, diff --git a/bitsandbytes/backends/__init__.py b/bitsandbytes/backends/__init__.py index 496e7d671..bf8a76cba 100644 --- a/bitsandbytes/backends/__init__.py +++ b/bitsandbytes/backends/__init__.py @@ -1,5 +1,6 @@ from bitsandbytes.cextension import COMPILED_WITH_CUDA + class Backends: """ An dict class for device backends that registered with 8bits and 4bits functions. @@ -25,4 +26,4 @@ def register_backend(cls, backend_name: str, backend_instance): from .cuda import CUDABackend cuda_backend = CUDABackend(torch.device("cuda").type) Backends.register_backend(cuda_backend.get_name(), cuda_backend) -# TODO: register more backends support \ No newline at end of file +# TODO: register more backends support diff --git a/bitsandbytes/backends/basic_backend.py b/bitsandbytes/backends/basic_backend.py index 8565c5f73..b97723d81 100644 --- a/bitsandbytes/backends/basic_backend.py +++ b/bitsandbytes/backends/basic_backend.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod -import torch from typing import Optional, Tuple + +import torch + from bitsandbytes.functional import QuantState diff --git a/bitsandbytes/backends/cuda.py b/bitsandbytes/backends/cuda.py index 7680bf2a1..965138a69 100644 --- a/bitsandbytes/backends/cuda.py +++ b/bitsandbytes/backends/cuda.py @@ -1,25 +1,30 @@ -import torch -from torch import Tensor import ctypes as ct from typing import Optional, Tuple + +import torch +from torch import Tensor + +from bitsandbytes.cextension import lib from bitsandbytes.functional import ( - pre_call, - post_call, + CUBLAS_Context, + QuantState, + coo_zeros, + dequantize_blockwise, + dtype2bytes, + get_4bit_type, get_colrow_absmax, get_ptr, - is_on_gpu, - coo_zeros, get_transform_buffer, + is_on_gpu, + post_call, + pre_call, prod, - get_4bit_type, quantize_blockwise, - dequantize_blockwise, - dtype2bytes, ) -from bitsandbytes.functional import CUBLAS_Context, QuantState -from bitsandbytes.cextension import lib + from .basic_backend import DeviceBackends + class CUDABackend(DeviceBackends): def __init__(self, backend_name: str): self.backend_name = backend_name diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 0848784c0..dab34982e 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -2,6 +2,7 @@ from warnings import warn import torch + from bitsandbytes.device_setup.cuda.main import CUDASetup setup = CUDASetup.get_instance() diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index c2fb491dd..f8a9723cb 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2081,6 +2081,7 @@ def pipeline_test(A, batch_size): from bitsandbytes.backends import Backends + # 8 bits common functions def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): assert A.device.type in Backends.devices, f"Device backend for {A.device.type} is not supported" @@ -2127,4 +2128,3 @@ def quantize_4bit( def dequantize_4bit(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64, quant_type='fp4') -> Tensor: assert A.device.type in Backends.devices, f"Device backend for {A.device.type} is not supported" return Backends.devices[A.device.type].dequantize_4bit(A, quant_state=quant_state, absmax=absmax, out=out, blocksize=blocksize, quant_type=quant_type) -