From 116c91097220da92ab4d892cfc907377de9959b0 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Tue, 6 Feb 2024 01:44:10 +0200 Subject: [PATCH] Rework CUDA setup and diagnostics --- bitsandbytes/__init__.py | 9 +- bitsandbytes/__main__.py | 108 +---- bitsandbytes/cextension.py | 148 +++++-- bitsandbytes/consts.py | 12 + bitsandbytes/cuda_setup/env_vars.py | 53 --- bitsandbytes/cuda_setup/main.py | 393 ------------------ bitsandbytes/cuda_specs.py | 44 ++ .../{cuda_setup => diagnostics}/__init__.py | 0 bitsandbytes/diagnostics/cuda.py | 174 ++++++++ bitsandbytes/diagnostics/main.py | 70 ++++ bitsandbytes/diagnostics/utils.py | 12 + bitsandbytes/functional.py | 4 +- bitsandbytes/optim/__init__.py | 2 - 13 files changed, 434 insertions(+), 595 deletions(-) create mode 100644 bitsandbytes/consts.py delete mode 100644 bitsandbytes/cuda_setup/env_vars.py delete mode 100644 bitsandbytes/cuda_setup/main.py create mode 100644 bitsandbytes/cuda_specs.py rename bitsandbytes/{cuda_setup => diagnostics}/__init__.py (100%) create mode 100644 bitsandbytes/diagnostics/cuda.py create mode 100644 bitsandbytes/diagnostics/main.py create mode 100644 bitsandbytes/diagnostics/utils.py diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index e54e933d9..a64b799e1 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -3,7 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from . import cuda_setup, research, utils +from . import research, utils from .autograd._functions import ( MatmulLtState, bmm_cublas, @@ -12,11 +12,8 @@ matmul_cublas, mm_cublas, ) -from .cextension import COMPILED_WITH_CUDA from .nn import modules - -if COMPILED_WITH_CUDA: - from .optim import adam +from .optim import adam __pdoc__ = { "libbitsandbytes": False, @@ -25,5 +22,3 @@ } __version__ = "0.43.0.dev" - -PACKAGE_GITHUB_URL = "https://github.com/TimDettmers/bitsandbytes" diff --git a/bitsandbytes/__main__.py b/bitsandbytes/__main__.py index 61b42e78f..e716b6f3f 100644 --- a/bitsandbytes/__main__.py +++ b/bitsandbytes/__main__.py @@ -1,108 +1,4 @@ -import glob -import os -import sys -from warnings import warn - -import torch - -HEADER_WIDTH = 60 - - -def find_dynamic_library(folder, filename): - for ext in ("so", "dll", "dylib"): - yield from glob.glob(os.path.join(folder, "**", filename + ext)) - - -def generate_bug_report_information(): - print_header("") - print_header("BUG REPORT INFORMATION") - print_header("") - print('') - - path_sources = [ - ("ANACONDA CUDA PATHS", os.environ.get("CONDA_PREFIX")), - ("/usr/local CUDA PATHS", "/usr/local"), - ("CUDA PATHS", os.environ.get("CUDA_PATH")), - ("WORKING DIRECTORY CUDA PATHS", os.getcwd()), - ] - try: - ld_library_path = os.environ.get("LD_LIBRARY_PATH") - if ld_library_path: - for path in set(ld_library_path.strip().split(os.pathsep)): - path_sources.append((f"LD_LIBRARY_PATH {path} CUDA PATHS", path)) - except Exception as e: - print(f"Could not parse LD_LIBRARY_PATH: {e}") - - for name, path in path_sources: - if path and os.path.isdir(path): - print_header(name) - print(list(find_dynamic_library(path, '*cuda*'))) - print("") - - -def print_header( - txt: str, width: int = HEADER_WIDTH, filler: str = "+" -) -> None: - txt = f" {txt} " if txt else "" - print(txt.center(width, filler)) - - -def print_debug_info() -> None: - from . import PACKAGE_GITHUB_URL - print( - "\nAbove we output some debug information. Please provide this info when " - f"creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose ...\n" - ) - - -def main(): - generate_bug_report_information() - - from . import COMPILED_WITH_CUDA - from .cuda_setup.main import get_compute_capabilities - - print_header("OTHER") - print(f"COMPILED_WITH_CUDA = {COMPILED_WITH_CUDA}") - print(f"COMPUTE_CAPABILITIES_PER_GPU = {get_compute_capabilities()}") - print_header("") - print_header("DEBUG INFO END") - print_header("") - print("Checking that the library is importable and CUDA is callable...") - print("\nWARNING: Please be sure to sanitize sensitive info from any such env vars!\n") - - try: - from bitsandbytes.optim import Adam - - p = torch.nn.Parameter(torch.rand(10, 10).cuda()) - a = torch.rand(10, 10).cuda() - - p1 = p.data.sum().item() - - adam = Adam([p]) - - out = a * p - loss = out.sum() - loss.backward() - adam.step() - - p2 = p.data.sum().item() - - assert p1 != p2 - print("SUCCESS!") - print("Installation was successful!") - except ImportError: - print() - warn( - f"WARNING: {__package__} is currently running as CPU-only!\n" - "Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n" - f"If you think that this is so erroneously,\nplease report an issue!" - ) - print_debug_info() - except Exception as e: - print(e) - print_debug_info() - sys.exit(1) - - if __name__ == "__main__": + from bitsandbytes.diagnostics.main import main + main() diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 858365f02..add3e6528 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -1,39 +1,123 @@ +""" +extract factors the build is dependent on: +[X] compute capability + [ ] TODO: Q - What if we have multiple GPUs of different makes? +- CUDA version +- Software: + - CPU-only: only CPU quantization functions (no optimizer, no matrix multiple) + - CuBLAS-LT: full-build 8-bit optimizer + - no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`) + +evaluation: + - if paths faulty, return meaningful error + - else: + - determine CUDA version + - determine capabilities + - based on that set the default path +""" + import ctypes as ct -from warnings import warn +import logging +import os +from pathlib import Path +from typing import Optional import torch -from bitsandbytes.cuda_setup.main import CUDASetup +from bitsandbytes.consts import DYNAMIC_LIBRARY_SUFFIX, PACKAGE_DIR +from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs + +logger = logging.getLogger(__name__) + + +def compute_override_library_name(library_name: str, override_value: str) -> str: + binary_name_stem, _, binary_name_ext = library_name.rpartition(".") + # `binary_name_stem` will now be e.g. `libbitsandbytes_cuda118`; + # let's remove any trailing numbers: + binary_name_stem = binary_name_stem.rstrip("0123456789") + # `binary_name_stem` will now be e.g. `libbitsandbytes_cuda`; + # let's tack the new version number and the original extension back on. + binary_name = f"{binary_name_stem}{override_value}.{binary_name_ext}" + + logger.warning( + f'WARNING: BNB_CUDA_VERSION={override_value} environment variable detected; loading {binary_name}.\n' + 'This can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n' + 'If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=\n' + 'If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH\n' + 'For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH: Optional[Path]: + library_name = f"libbitsandbytes_cuda{cuda_specs.cuda_version_string}" + if not cuda_specs.has_cublaslt: + # if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt + library_name += "_nocublaslt" + library_name = f"{library_name}{DYNAMIC_LIBRARY_SUFFIX}" + + override_value = os.environ.get("BNB_CUDA_VERSION") + if override_value: + library_name = compute_override_library_name(library_name, override_value) + + cuda_binary_path = PACKAGE_DIR / library_name + if cuda_binary_path.exists(): + return cuda_binary_path + logger.warning("Could not find the bitsandbytes CUDA binary at %r", cuda_binary_path) + + +class BNBNativeLibrary: + _lib: ct.CDLL + compiled_with_cuda = False + + def __init__(self, lib: ct.CDLL): + self._lib = lib + + def __getattr__(self, item): + return getattr(self._lib, item) + + +class CudaBNBNativeLibrary(BNBNativeLibrary): + compiled_with_cuda = True + + def __init__(self, lib: ct.CDLL): + super().__init__(lib) + lib.get_context.restype = ct.c_void_p + lib.get_cusparse.restype = ct.c_void_p + lib.cget_managed_ptr.restype = ct.c_void_p + + +def get_native_library() -> BNBNativeLibrary: + binary_path = PACKAGE_DIR / f"libbitsandbytes_cpu{DYNAMIC_LIBRARY_SUFFIX}" + cuda_specs = get_cuda_specs() + if cuda_specs: + cuda_binary_path = get_cuda_binary_path(cuda_specs) + if cuda_binary_path: + binary_path = cuda_binary_path + logger.debug(f"Loading bitsandbytes native library from: {binary_path}") + dll = ct.cdll.LoadLibrary(str(binary_path)) + + if hasattr(dll, "get_context"): # only a CUDA-built library exposes this + return CudaBNBNativeLibrary(dll) + + logger.warning( + "The installed version of bitsandbytes was compiled without GPU support. " + "8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable." + ) + return BNBNativeLibrary(dll) -setup = CUDASetup.get_instance() -if setup.initialized != True: - setup.run_cuda_setup() -lib = setup.lib try: - if lib is None and torch.cuda.is_available(): - CUDASetup.get_instance().generate_instructions() - CUDASetup.get_instance().print_log_stack() - raise RuntimeError(''' - CUDA Setup failed despite GPU being available. Please run the following command to get more information: - - python -m bitsandbytes - - Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them - to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes - and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues''') - _ = lib.cadam32bit_grad_fp32 # runs on an error if the library could not be found -> COMPILED_WITH_CUDA=False - lib.get_context.restype = ct.c_void_p - lib.get_cusparse.restype = ct.c_void_p - lib.cget_managed_ptr.restype = ct.c_void_p - COMPILED_WITH_CUDA = True -except AttributeError as ex: - warn("The installed version of bitsandbytes was compiled without GPU support. " - "8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.") - COMPILED_WITH_CUDA = False - print(str(ex)) - - -# print the setup details after checking for errors so we do not print twice -#if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0': - #setup.print_log_stack() + lib = get_native_library() +except Exception as e: + lib = None + logger.error(f"Could not load bitsandbytes native library: {e}", exc_info=True) + if torch.cuda.is_available(): + logger.warning(""" +CUDA Setup failed despite CUDA being available. Please run the following command to get more information: + +python -m bitsandbytes + +Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them +to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes +and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues +""") diff --git a/bitsandbytes/consts.py b/bitsandbytes/consts.py new file mode 100644 index 000000000..2ad16de8f --- /dev/null +++ b/bitsandbytes/consts.py @@ -0,0 +1,12 @@ +from pathlib import Path +import platform + +DYNAMIC_LIBRARY_SUFFIX = { + 'Darwin': '.dylib', + 'Linux': '.so', + 'Windows': '.dll', +}.get(platform.system(), '.so') + +PACKAGE_DIR = Path(__file__).parent +PACKAGE_GITHUB_URL = "https://github.com/TimDettmers/bitsandbytes" +NONPYTORCH_DOC_URL = "https://github.com/TimDettmers/bitsandbytes/blob/main/docs/source/nonpytorchcuda.mdx" diff --git a/bitsandbytes/cuda_setup/env_vars.py b/bitsandbytes/cuda_setup/env_vars.py deleted file mode 100644 index 4b2549653..000000000 --- a/bitsandbytes/cuda_setup/env_vars.py +++ /dev/null @@ -1,53 +0,0 @@ -import os -from typing import Dict - - -def to_be_ignored(env_var: str, value: str) -> bool: - ignorable = { - "PWD", # PWD: this is how the shell keeps track of the current working dir - "OLDPWD", - "SSH_AUTH_SOCK", # SSH stuff, therefore unrelated - "SSH_TTY", - "GOOGLE_VM_CONFIG_LOCK_FILE", # GCP: requires elevated permissions, causing problems in VMs and Jupyter notebooks - "HOME", # Linux shell default - "TMUX", # Terminal Multiplexer - "XDG_DATA_DIRS", # XDG: Desktop environment stuff - "XDG_GREETER_DATA_DIR", # XDG: Desktop environment stuff - "XDG_RUNTIME_DIR", - "MAIL", # something related to emails - "SHELL", # binary for currently invoked shell - "DBUS_SESSION_BUS_ADDRESS", # hardware related - "PATH", # this is for finding binaries, not libraries - "LESSOPEN", # related to the `less` command - "LESSCLOSE", - "_", # current Python interpreter - } - return env_var in ignorable - - -def might_contain_a_path(candidate: str) -> bool: - return os.sep in candidate - - -def is_active_conda_env(env_var: str) -> bool: - return "CONDA_PREFIX" == env_var - - -def is_other_conda_env_var(env_var: str) -> bool: - return "CONDA" in env_var - - -def is_relevant_candidate_env_var(env_var: str, value: str) -> bool: - return is_active_conda_env(env_var) or ( - might_contain_a_path(value) and not - is_other_conda_env_var(env_var) and not - to_be_ignored(env_var, value) - ) - - -def get_potentially_lib_path_containing_env_vars() -> Dict[str, str]: - return { - env_var: value - for env_var, value in os.environ.items() - if is_relevant_candidate_env_var(env_var, value) - } diff --git a/bitsandbytes/cuda_setup/main.py b/bitsandbytes/cuda_setup/main.py deleted file mode 100644 index 14c7abbd8..000000000 --- a/bitsandbytes/cuda_setup/main.py +++ /dev/null @@ -1,393 +0,0 @@ -""" -extract factors the build is dependent on: -[X] compute capability - [ ] TODO: Q - What if we have multiple GPUs of different makes? -- CUDA version -- Software: - - CPU-only: only CPU quantization functions (no optimizer, no matrix multiply) - - CuBLAS-LT: full-build 8-bit optimizer - - no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`) - -evaluation: - - if paths faulty, return meaningful error - - else: - - determine CUDA version - - determine capabilities - - based on that set the default path -""" - -import ctypes as ct -import errno -import os -from pathlib import Path -import platform -from typing import Set, Union -from warnings import warn - -import torch - -from .env_vars import get_potentially_lib_path_containing_env_vars - -DYNAMIC_LIBRARY_SUFFIX = { "Darwin": ".dylib", "Windows": ".dll", "Linux": ".so"}.get(platform.system(), ".so") -if platform.system() == "Windows": # Windows - CUDA_RUNTIME_LIBS = ["nvcuda.dll"] -else: # Linux or other - # these are the most common libs names - # libcudart.so is missing by default for a conda install with PyTorch 2.0 and instead - # we have libcudart.so.11.0 which causes a lot of errors before - # not sure if libcudart.so.12.0 exists in pytorch installs, but it does not hurt - CUDA_RUNTIME_LIBS = ["libcudart.so", "libcudart.so.11.0", "libcudart.so.12.0", "libcudart.so.12.1", "libcudart.so.12.2"] - - -class CUDASetup: - _instance = None - - def __init__(self): - raise RuntimeError("Call get_instance() instead") - - def generate_instructions(self): - if getattr(self, 'error', False): return - print(self.error) - self.error = True - if not self.cuda_available: - self.add_log_entry('CUDA SETUP: Problem: The main issue seems to be that the main CUDA library was not detected or CUDA not installed.') - self.add_log_entry('CUDA SETUP: Solution 1): Your paths are probably not up-to-date. You can update them via: sudo ldconfig.') - self.add_log_entry('CUDA SETUP: Solution 2): If you do not have sudo rights, you can do the following:') - self.add_log_entry('CUDA SETUP: Solution 2a): Find the cuda library via: find / -name libcuda.so 2>/dev/null') - self.add_log_entry('CUDA SETUP: Solution 2b): Once the library is found add it to the LD_LIBRARY_PATH: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:FOUND_PATH_FROM_2a') - self.add_log_entry('CUDA SETUP: Solution 2c): For a permanent solution add the export from 2b into your .bashrc file, located at ~/.bashrc') - self.add_log_entry('CUDA SETUP: Solution 3): For a missing CUDA runtime library (libcudart.so), use `find / -name libcudart.so* and follow with step (2b)') - return - - if self.cudart_path is None: - self.add_log_entry('CUDA SETUP: Problem: The main issue seems to be that the main CUDA runtime library was not detected.') - self.add_log_entry('CUDA SETUP: Solution 1: To solve the issue the libcudart.so location needs to be added to the LD_LIBRARY_PATH variable') - self.add_log_entry('CUDA SETUP: Solution 1a): Find the cuda runtime library via: find / -name libcudart.so 2>/dev/null') - self.add_log_entry('CUDA SETUP: Solution 1b): Once the library is found add it to the LD_LIBRARY_PATH: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:FOUND_PATH_FROM_1a') - self.add_log_entry('CUDA SETUP: Solution 1c): For a permanent solution add the export from 1b into your .bashrc file, located at ~/.bashrc') - self.add_log_entry('CUDA SETUP: Solution 2: If no library was found in step 1a) you need to install CUDA.') - self.add_log_entry('CUDA SETUP: Solution 2a): Download CUDA install script: wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/cuda_install.sh') - self.add_log_entry('CUDA SETUP: Solution 2b): Install desired CUDA version to desired location. The syntax is bash cuda_install.sh CUDA_VERSION PATH_TO_INSTALL_INTO.') - self.add_log_entry('CUDA SETUP: Solution 2b): For example, "bash cuda_install.sh 113 ~/local/" will download CUDA 11.3 and install into the folder ~/local') - - return - - make_cmd = f'CUDA_VERSION={self.cuda_version_string}' - if len(self.cuda_version_string) < 3: - make_cmd += ' make cuda92' - elif self.cuda_version_string == '110': - make_cmd += ' make cuda110' - elif self.cuda_version_string[:2] == '11' and int(self.cuda_version_string[2]) > 0: - make_cmd += ' make cuda11x' - elif self.cuda_version_string[:2] == '12' and 1 >= int(self.cuda_version_string[2]) >= 0: - make_cmd += ' make cuda12x' - elif self.cuda_version_string == '100': - self.add_log_entry('CUDA SETUP: CUDA 10.0 not supported. Please use a different CUDA version.') - self.add_log_entry('CUDA SETUP: Before you try again running bitsandbytes, make sure old CUDA 10.0 versions are uninstalled and removed from $LD_LIBRARY_PATH variables.') - return - - - has_cublaslt = is_cublasLt_compatible(self.cc) - if not has_cublaslt: - make_cmd += '_nomatmul' - - self.add_log_entry('CUDA SETUP: Something unexpected happened. Please compile from source:') - self.add_log_entry('git clone https://github.com/TimDettmers/bitsandbytes.git') - self.add_log_entry('cd bitsandbytes') - self.add_log_entry(make_cmd) - self.add_log_entry('python setup.py install') - - def initialize(self): - if not getattr(self, 'initialized', False): - self.has_printed = False - self.lib = None - self.initialized = False - self.error = False - - def manual_override(self): - if not torch.cuda.is_available(): - return - override_value = os.environ.get('BNB_CUDA_VERSION') - if not override_value: - return - - binary_name_stem, _, binary_name_ext = self.binary_name.rpartition(".") - # `binary_name_stem` will now be e.g. `/foo/bar/libbitsandbytes_cuda118`; - # let's remove any trailing numbers: - binary_name_stem = binary_name_stem.rstrip("0123456789") - # `binary_name_stem` will now be e.g. `/foo/bar/libbitsandbytes_cuda`; - # let's tack the new version number and the original extension back on. - self.binary_name = f"{binary_name_stem}{override_value}.{binary_name_ext}" - - warn( - f'\n\n{"=" * 80}\n' - 'WARNING: Manual override via BNB_CUDA_VERSION env variable detected!\n' - 'BNB_CUDA_VERSION=XXX can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n' - 'If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=\n' - 'If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH\n' - 'For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH: Set[Path]: - return {Path(ld_path) for ld_path in paths_list_candidate.split(os.pathsep) if ld_path} - - -def remove_non_existent_dirs(candidate_paths: Set[Path]) -> Set[Path]: - existent_directories: Set[Path] = set() - for path in candidate_paths: - try: - if path.exists(): - existent_directories.add(path) - except PermissionError: - # Handle the PermissionError first as it is a subtype of OSError - # https://docs.python.org/3/library/exceptions.html#exception-hierarchy - pass - except OSError as exc: - if exc.errno != errno.ENAMETOOLONG: - raise exc - - non_existent_directories: Set[Path] = candidate_paths - existent_directories - if non_existent_directories: - CUDASetup.get_instance().add_log_entry( - f"The following directories listed in your path were found to be non-existent: {non_existent_directories}", - is_warning=False, - ) - - return existent_directories - - -def get_cuda_runtime_lib_paths(candidate_paths: Set[Path]) -> Set[Path]: - paths = set() - for libname in CUDA_RUNTIME_LIBS: - for path in candidate_paths: - try: - if (path / libname).is_file(): - paths.add(path / libname) - except PermissionError: - pass - return paths - - -def resolve_paths_list(paths_list_candidate: str) -> Set[Path]: - """ - Searches a given environmental var for the CUDA runtime library, - i.e. `libcudart.so`. - """ - return remove_non_existent_dirs(extract_candidate_paths(paths_list_candidate)) - - -def find_cuda_lib_in(paths_list_candidate: str) -> Set[Path]: - return get_cuda_runtime_lib_paths( - resolve_paths_list(paths_list_candidate) - ) - - -def warn_in_case_of_duplicates(results_paths: Set[Path]) -> None: - if len(results_paths) > 1: - warning_msg = ( - f"Found duplicate {CUDA_RUNTIME_LIBS} files: {results_paths}.. " - "We select the PyTorch default libcudart.so, which is {torch.version.cuda}," - "but this might mismatch with the CUDA version that is needed for bitsandbytes." - "To override this behavior set the BNB_CUDA_VERSION= environmental variable" - "For example, if you want to use the CUDA version 122" - "BNB_CUDA_VERSION=122 python ..." - "OR set the environmental variable in your .bashrc: export BNB_CUDA_VERSION=122" - "In the case of a manual override, make sure you set the LD_LIBRARY_PATH, e.g." - "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.2") - CUDASetup.get_instance().add_log_entry(warning_msg, is_warning=True) - - -def determine_cuda_runtime_lib_path() -> Union[Path, None]: - """ - Searches for a cuda installations, in the following order of priority: - 1. active conda env - 2. LD_LIBRARY_PATH - 3. any other env vars, while ignoring those that - - are known to be unrelated (see `bnb.cuda_setup.env_vars.to_be_ignored`) - - don't contain the path separator `/` - - If multiple libraries are found in part 3, we optimistically try one, - while giving a warning message. - """ - candidate_env_vars = get_potentially_lib_path_containing_env_vars() - - cuda_runtime_libs = set() - if "CONDA_PREFIX" in candidate_env_vars: - conda_libs_path = Path(candidate_env_vars["CONDA_PREFIX"]) / "lib" - - conda_cuda_libs = find_cuda_lib_in(str(conda_libs_path)) - warn_in_case_of_duplicates(conda_cuda_libs) - - if conda_cuda_libs: - cuda_runtime_libs.update(conda_cuda_libs) - - CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["CONDA_PREFIX"]} did not contain ' - f'{CUDA_RUNTIME_LIBS} as expected! Searching further paths...', is_warning=True) - - if "LD_LIBRARY_PATH" in candidate_env_vars: - lib_ld_cuda_libs = find_cuda_lib_in(candidate_env_vars["LD_LIBRARY_PATH"]) - - if lib_ld_cuda_libs: - cuda_runtime_libs.update(lib_ld_cuda_libs) - warn_in_case_of_duplicates(lib_ld_cuda_libs) - - CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["LD_LIBRARY_PATH"]} did not contain ' - f'{CUDA_RUNTIME_LIBS} as expected! Searching further paths...', is_warning=True) - - remaining_candidate_env_vars = { - env_var: value for env_var, value in candidate_env_vars.items() - if env_var not in {"CONDA_PREFIX", "LD_LIBRARY_PATH"} - } - - cuda_runtime_libs = set() - for env_var, value in remaining_candidate_env_vars.items(): - cuda_runtime_libs.update(find_cuda_lib_in(value)) - - if len(cuda_runtime_libs) == 0: - CUDASetup.get_instance().add_log_entry('CUDA_SETUP: WARNING! libcudart.so not found in any environmental path. Searching in backup paths...') - cuda_runtime_libs.update(find_cuda_lib_in('/usr/local/cuda/lib64')) - - warn_in_case_of_duplicates(cuda_runtime_libs) - - cuda_setup = CUDASetup.get_instance() - cuda_setup.add_log_entry(f'DEBUG: Possible options found for libcudart.so: {cuda_runtime_libs}') - - return next(iter(cuda_runtime_libs)) if cuda_runtime_libs else None - - -# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION -def get_cuda_version(): - major, minor = map(int, torch.version.cuda.split(".")) - - if major < 11: - CUDASetup.get_instance().add_log_entry('CUDA SETUP: CUDA version lower than 11 are currently not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!') - - return f'{major}{minor}' - -def get_compute_capabilities(): - ccs = [] - for i in range(torch.cuda.device_count()): - cc_major, cc_minor = torch.cuda.get_device_capability(torch.cuda.device(i)) - ccs.append(f"{cc_major}.{cc_minor}") - - ccs.sort(key=lambda v: tuple(map(int, str(v).split(".")))) - - return ccs - - -def evaluate_cuda_setup(): - cuda_setup = CUDASetup.get_instance() - if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0': - cuda_setup.add_log_entry('') - cuda_setup.add_log_entry('='*35 + 'BUG REPORT' + '='*35) - cuda_setup.add_log_entry(('Welcome to bitsandbytes. For bug reports, please run\n\npython -m bitsandbytes\n\n'), - ('and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')) - cuda_setup.add_log_entry('='*80) - - if not torch.cuda.is_available(): - return f'libbitsandbytes_cpu{DYNAMIC_LIBRARY_SUFFIX}', None, None, None - - cudart_path = determine_cuda_runtime_lib_path() - cc = get_compute_capabilities()[-1] # we take the highest capability - cuda_version_string = get_cuda_version() - - cuda_setup.add_log_entry(f"CUDA SETUP: PyTorch settings found: CUDA_VERSION={cuda_version_string}, Highest Compute Capability: {cc}.") - cuda_setup.add_log_entry( - "CUDA SETUP: To manually override the PyTorch CUDA version please see:" - "https://github.com/TimDettmers/bitsandbytes/blob/main/how_to_use_nonpytorch_cuda.md" - ) - - - # 7.5 is the minimum CC vor cublaslt - has_cublaslt = is_cublasLt_compatible(cc) - - # TODO: - # (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible) - # (2) Multiple CUDA versions installed - - # we use ls -l instead of nvcc to determine the cuda version - # since most installations will have the libcudart.so installed, but not the compiler - - binary_name = f"libbitsandbytes_cuda{cuda_version_string}" - if not has_cublaslt: - # if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt - binary_name += "_nocublaslt" - - binary_name = f"{binary_name}{DYNAMIC_LIBRARY_SUFFIX}" - - return binary_name, cudart_path, cc, cuda_version_string diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py new file mode 100644 index 000000000..78b7fbd90 --- /dev/null +++ b/bitsandbytes/cuda_specs.py @@ -0,0 +1,44 @@ +import dataclasses +from typing import List, Optional, Tuple + +import torch + + +@dataclasses.dataclass(frozen=True) +class CUDASpecs: + highest_compute_capability: Tuple[int, int] + cuda_version_string: str + cuda_version_tuple: Tuple[int, int] + + @property + def has_cublaslt(self) -> bool: + return self.highest_compute_capability >= (7, 5) + + +def get_compute_capabilities() -> List[Tuple[int, int]]: + return sorted( + torch.cuda.get_device_capability(torch.cuda.device(i)) + for i in range(torch.cuda.device_count()) + ) + + +def get_cuda_version_tuple() -> Tuple[int, int]: + # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION + major, minor = map(int, torch.version.cuda.split(".")) + return major, minor + + +def get_cuda_version_string() -> str: + major, minor = get_cuda_version_tuple() + return f'{major}{minor}' + + +def get_cuda_specs() -> Optional[CUDASpecs]: + if not torch.cuda.is_available(): + return None + + return CUDASpecs( + highest_compute_capability=(get_compute_capabilities()[-1]), + cuda_version_string=(get_cuda_version_string()), + cuda_version_tuple=get_cuda_version_tuple(), + ) diff --git a/bitsandbytes/cuda_setup/__init__.py b/bitsandbytes/diagnostics/__init__.py similarity index 100% rename from bitsandbytes/cuda_setup/__init__.py rename to bitsandbytes/diagnostics/__init__.py diff --git a/bitsandbytes/diagnostics/cuda.py b/bitsandbytes/diagnostics/cuda.py new file mode 100644 index 000000000..fa97b8e19 --- /dev/null +++ b/bitsandbytes/diagnostics/cuda.py @@ -0,0 +1,174 @@ +import logging +import os +from pathlib import Path +from typing import Dict, Iterable, Iterator + +import torch + +from bitsandbytes.cextension import get_cuda_binary_path +from bitsandbytes.consts import NONPYTORCH_DOC_URL +from bitsandbytes.cuda_specs import CUDASpecs +from bitsandbytes.diagnostics.utils import print_dedented + +CUDART_PATH_PREFERRED_ENVVARS = ("CONDA_PREFIX", "LD_LIBRARY_PATH") + +CUDART_PATH_IGNORED_ENVVARS = { + "DBUS_SESSION_BUS_ADDRESS", # hardware related + "GOOGLE_VM_CONFIG_LOCK_FILE", # GCP: requires elevated permissions, causing problems in VMs and Jupyter notebooks + "HOME", # Linux shell default + "LESSCLOSE", + "LESSOPEN", # related to the `less` command + "MAIL", # something related to emails + "OLDPWD", + "PATH", # this is for finding binaries, not libraries + "PWD", # PWD: this is how the shell keeps track of the current working dir + "SHELL", # binary for currently invoked shell + "SSH_AUTH_SOCK", # SSH stuff, therefore unrelated + "SSH_TTY", + "TMUX", # Terminal Multiplexer + "XDG_DATA_DIRS", # XDG: Desktop environment stuff + "XDG_GREETER_DATA_DIR", # XDG: Desktop environment stuff + "XDG_RUNTIME_DIR", + "_", # current Python interpreter +} + +CUDA_RUNTIME_LIB_PATTERNS = ( + "nvcuda*.dll", # Windows + "libcudart*.so*", # libcudart.so, libcudart.so.11.0, libcudart.so.12.0, libcudart.so.12.1, libcudart.so.12.2 etc. +) + +logger = logging.getLogger(__name__) + + +def find_cuda_libraries_in_path_list(paths_list_candidate: str) -> Iterable[Path]: + for dir_string in paths_list_candidate.split(os.pathsep): + if not dir_string: + continue + try: + dir = Path(dir_string) + if not dir.exists(): + logger.warning( + f"The directory listed in your path is found to be non-existent: {dir}" + ) + continue + for lib_pattern in CUDA_RUNTIME_LIB_PATTERNS: + for pth in dir.glob(lib_pattern): + if pth.is_file(): + yield pth + except PermissionError: + pass + + +def is_relevant_candidate_env_var(env_var: str, value: str) -> bool: + return ( + env_var in CUDART_PATH_PREFERRED_ENVVARS # is a preferred location + or ( + os.sep in value # might contain a path + and "CONDA" not in env_var # not another conda envvar + and env_var not in CUDART_PATH_IGNORED_ENVVARS # not ignored + ) + ) + + +def get_potentially_lib_path_containing_env_vars() -> Dict[str, str]: + return { + env_var: value + for env_var, value in os.environ.items() + if is_relevant_candidate_env_var(env_var, value) + } + + +def find_cudart_libraries() -> Iterator[Path]: + """ + Searches for a cuda installations, in the following order of priority: + 1. active conda env + 2. LD_LIBRARY_PATH + 3. any other env vars, while ignoring those that + - are known to be unrelated + - don't contain the path separator `/` + + If multiple libraries are found in part 3, we optimistically try one, + while giving a warning message. + """ + candidate_env_vars = get_potentially_lib_path_containing_env_vars() + + for envvar in CUDART_PATH_PREFERRED_ENVVARS: + if envvar in candidate_env_vars: + directory = candidate_env_vars[envvar] + yield from find_cuda_libraries_in_path_list(directory) + candidate_env_vars.pop(envvar) + + for env_var, value in candidate_env_vars.items(): + yield from find_cuda_libraries_in_path_list(value) + + +def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: + print( + f"PyTorch settings found: CUDA_VERSION={cuda_specs.cuda_version_string}, " + f"Highest Compute Capability: {cuda_specs.highest_compute_capability}.", + ) + + binary_path = get_cuda_binary_path(cuda_specs) + if not binary_path.exists(): + print_dedented( + f""" + Library not found: {binary_path}. Maybe you need to compile it from source? + If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION`, + for example, `make CUDA_VERSION=113`. + + The CUDA version for the compile might depend on your conda install, if using conda. + Inspect CUDA version via `conda list | grep cuda`. + """ + ) + + cuda_major, cuda_minor = cuda_specs.cuda_version_tuple + if cuda_major < 11: + print_dedented( + """ + WARNING: CUDA versions lower than 11 are currently not supported for LLM.int8(). + You will be only to use 8-bit optimizers and quantization routines! + """ + ) + + print(f"To manually override the PyTorch CUDA version please see: {NONPYTORCH_DOC_URL}") + + # 7.5 is the minimum CC for cublaslt + if not cuda_specs.has_cublaslt: + print_dedented( + """ + WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU! + If you run into issues with 8-bit matmul, you can try 4-bit quantization: + https://huggingface.co/blog/4bit-transformers-bitsandbytes + """, + ) + + # TODO: + # (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible) + # (2) Multiple CUDA versions installed + + +def print_cuda_runtime_diagnostics() -> None: + cudart_paths = list(find_cudart_libraries()) + if not cudart_paths: + print("CUDA SETUP: WARNING! CUDA runtime files not found in any environmental path.") + elif len(cudart_paths) > 1: + print_dedented( + f""" + Found duplicate CUDA runtime files (see below). + + We select the PyTorch default CUDA runtime, which is {torch.version.cuda}, + but this might mismatch with the CUDA version that is needed for bitsandbytes. + To override this behavior set the `BNB_CUDA_VERSION=` environmental variable. + + For example, if you want to use the CUDA version 122, + BNB_CUDA_VERSION=122 python ... + + OR set the environmental variable in your .bashrc: + export BNB_CUDA_VERSION=122 + + In the case of a manual override, make sure you set LD_LIBRARY_PATH, e.g. + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.2, + """ + ) + for pth in cudart_paths: + print(f"* Found CUDA runtime at: {pth}") diff --git a/bitsandbytes/diagnostics/main.py b/bitsandbytes/diagnostics/main.py new file mode 100644 index 000000000..a7f0c901e --- /dev/null +++ b/bitsandbytes/diagnostics/main.py @@ -0,0 +1,70 @@ +import sys +import traceback + +import torch + +from bitsandbytes.consts import PACKAGE_GITHUB_URL +from bitsandbytes.cuda_specs import get_cuda_specs +from bitsandbytes.diagnostics.cuda import ( + print_cuda_diagnostics, + print_cuda_runtime_diagnostics, +) +from bitsandbytes.diagnostics.utils import print_dedented, print_header + + +def sanity_check(): + from bitsandbytes.optim import Adam + + p = torch.nn.Parameter(torch.rand(10, 10).cuda()) + a = torch.rand(10, 10).cuda() + p1 = p.data.sum().item() + adam = Adam([p]) + out = a * p + loss = out.sum() + loss.backward() + adam.step() + p2 = p.data.sum().item() + assert p1 != p2 + + +def main(): + print_header("") + print_header("BUG REPORT INFORMATION") + print_header("") + + print_header("OTHER") + cuda_specs = get_cuda_specs() + print("CUDA specs:", cuda_specs) + if not torch.cuda.is_available(): + print("Torch says CUDA is not available. Possible reasons:") + print("1. CUDA driver not installed") + print("2. CUDA not installed") + print("3. You have multiple conflicting CUDA libraries") + if cuda_specs: + print_cuda_diagnostics(cuda_specs) + print_cuda_runtime_diagnostics() + print_header("") + print_header("DEBUG INFO END") + print_header("") + print("Checking that the library is importable and CUDA is callable...") + try: + sanity_check() + print("SUCCESS!") + print("Installation was successful!") + return + except ImportError: + print( + f"WARNING: {__package__} is currently running as CPU-only!\n" + "Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n" + f"If you think that this is so erroneously,\nplease report an issue!" + ) + except Exception: + traceback.print_exc() + print_dedented( + f""" + Above we output some debug information. + Please provide this info when creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose + WARNING: Please be sure to sanitize sensitive info from the output before posting it. + """ + ) + sys.exit(1) diff --git a/bitsandbytes/diagnostics/utils.py b/bitsandbytes/diagnostics/utils.py new file mode 100644 index 000000000..770209b9d --- /dev/null +++ b/bitsandbytes/diagnostics/utils.py @@ -0,0 +1,12 @@ +import textwrap + +HEADER_WIDTH = 60 + + +def print_header(txt: str, width: int = HEADER_WIDTH, filler: str = "+") -> None: + txt = f" {txt} " if txt else "" + print(txt.center(width, filler)) + + +def print_dedented(text): + print("\n".join(textwrap.dedent(text).strip().split("\n"))) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 9fc5e08f0..53d54d039 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -14,7 +14,7 @@ from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict -from .cextension import COMPILED_WITH_CUDA, lib +from .cextension import lib # math.prod not compatible with python < 3.8 @@ -23,7 +23,7 @@ def prod(iterable): name2qmap = {} -if COMPILED_WITH_CUDA: +if lib and lib.compiled_with_cuda: """C FUNCTIONS FOR OPTIMIZERS""" str2optimizer32bit = { "adam": ( diff --git a/bitsandbytes/optim/__init__.py b/bitsandbytes/optim/__init__.py index 6796b8e0e..b4c95793a 100644 --- a/bitsandbytes/optim/__init__.py +++ b/bitsandbytes/optim/__init__.py @@ -3,8 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from bitsandbytes.cextension import COMPILED_WITH_CUDA - from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit from .adam import Adam, Adam8bit, Adam32bit, PagedAdam, PagedAdam8bit, PagedAdam32bit from .adamw import (