Skip to content

Commit

Permalink
refactor base backend registering
Browse files Browse the repository at this point in the history
Co-authored-by: Aarni Koskela <[email protected]>
  • Loading branch information
jianan-gu and akx authored Feb 7, 2024
1 parent 1ab611e commit b933f9f
Showing 1 changed file with 7 additions and 22 deletions.
29 changes: 7 additions & 22 deletions bitsandbytes/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,14 @@
from bitsandbytes.cextension import COMPILED_WITH_CUDA
import typing
import torch

class Backends:
"""
An dict class for device backends that registered with 8bits and 4bits functions.
The values of this device backends are lowercase strings, e.g., ``"cuda"``. They can
be accessed as attributes with key-value, e.g., ``Backends.device["cuda"]``.
"""

devices = {}
from bitsandbytes.cextension import COMPILED_WITH_CUDA
from bitsandbytes.backends.base import Backend

@classmethod
def register_backend(cls, backend_name: str, backend_instance):
assert backend_name.lower() in {
"cpu",
"cuda",
"xpu",
}, "register device backend choices in [cpu, cuda, xpu]"
backends: Dict[str, Backend] = {}

cls.devices[backend_name.lower()] = backend_instance
def register_backend(backend_name: str, backend_instance: Backend):
backends[backend_name.lower()] = backend_instance

if COMPILED_WITH_CUDA:
from .cuda import CUDABackend
cuda_backend = CUDABackend(torch.device("cuda").type)
Backends.register_backend(cuda_backend.get_name(), cuda_backend)
# TODO: register more backends support
register_backend("cuda", CUDABackend())

0 comments on commit b933f9f

Please sign in to comment.