Skip to content

Commit

Permalink
format in CI
Browse files Browse the repository at this point in the history
  • Loading branch information
jianan-gu committed Feb 7, 2024
1 parent 9f23308 commit b41c1c4
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 15 deletions.
2 changes: 1 addition & 1 deletion bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from . import device_setup, utils, research
from . import device_setup, research, utils
from .autograd._functions import (
MatmulLtState,
bmm_cublas,
Expand Down
3 changes: 2 additions & 1 deletion bitsandbytes/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from bitsandbytes.cextension import COMPILED_WITH_CUDA


class Backends:
"""
An dict class for device backends that registered with 8bits and 4bits functions.
Expand All @@ -25,4 +26,4 @@ def register_backend(cls, backend_name: str, backend_instance):
from .cuda import CUDABackend
cuda_backend = CUDABackend(torch.device("cuda").type)

Check failure on line 27 in bitsandbytes/backends/__init__.py

View workflow job for this annotation

GitHub Actions / Lint

Ruff (F821)

bitsandbytes/backends/__init__.py:27:32: F821 Undefined name `torch`
Backends.register_backend(cuda_backend.get_name(), cuda_backend)
# TODO: register more backends support
# TODO: register more backends support
4 changes: 3 additions & 1 deletion bitsandbytes/backends/basic_backend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from abc import ABC, abstractmethod
import torch
from typing import Optional, Tuple

import torch

from bitsandbytes.functional import QuantState


Expand Down
27 changes: 16 additions & 11 deletions bitsandbytes/backends/cuda.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,30 @@
import torch
from torch import Tensor
import ctypes as ct
from typing import Optional, Tuple

import torch
from torch import Tensor

from bitsandbytes.cextension import lib
from bitsandbytes.functional import (
pre_call,
post_call,
CUBLAS_Context,
QuantState,
coo_zeros,
dequantize_blockwise,
dtype2bytes,
get_4bit_type,
get_colrow_absmax,
get_ptr,
is_on_gpu,
coo_zeros,
get_transform_buffer,
is_on_gpu,
post_call,
pre_call,
prod,
get_4bit_type,
quantize_blockwise,
dequantize_blockwise,
dtype2bytes,
)
from bitsandbytes.functional import CUBLAS_Context, QuantState
from bitsandbytes.cextension import lib

from .basic_backend import DeviceBackends


class CUDABackend(DeviceBackends):
def __init__(self, backend_name: str):
self.backend_name = backend_name
Expand Down
1 change: 1 addition & 0 deletions bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from warnings import warn

import torch

from bitsandbytes.device_setup.cuda.main import CUDASetup

setup = CUDASetup.get_instance()
Expand Down
2 changes: 1 addition & 1 deletion bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2081,6 +2081,7 @@ def pipeline_test(A, batch_size):

from bitsandbytes.backends import Backends

Check failure on line 2082 in bitsandbytes/functional.py

View workflow job for this annotation

GitHub Actions / Lint

Ruff (E402)

bitsandbytes/functional.py:2082:1: E402 Module level import not at top of file


# 8 bits common functions
def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0):
assert A.device.type in Backends.devices, f"Device backend for {A.device.type} is not supported"
Expand Down Expand Up @@ -2127,4 +2128,3 @@ def quantize_4bit(
def dequantize_4bit(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64, quant_type='fp4') -> Tensor:
assert A.device.type in Backends.devices, f"Device backend for {A.device.type} is not supported"
return Backends.devices[A.device.type].dequantize_4bit(A, quant_state=quant_state, absmax=absmax, out=out, blocksize=blocksize, quant_type=quant_type)

0 comments on commit b41c1c4

Please sign in to comment.