diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index c8ae7358d..45573538e 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -20,6 +20,7 @@ import logging import os from pathlib import Path +import re import torch @@ -44,13 +45,7 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path: override_value = os.environ.get("BNB_CUDA_VERSION") if override_value: - library_name_stem, _, library_name_ext = library_name.rpartition(".") - # `library_name_stem` will now be e.g. `libbitsandbytes_cuda118`; - # let's remove any trailing numbers: - library_name_stem = library_name_stem.rstrip("0123456789") - # `library_name_stem` will now be e.g. `libbitsandbytes_cuda`; - # let's tack the new version number and the original extension back on. - library_name = f"{library_name_stem}{override_value}.{library_name_ext}" + library_name = re.sub("cuda\d+", f"cuda{override_value}", library_name, count=1) logger.warning( f"WARNING: BNB_CUDA_VERSION={override_value} environment variable detected; loading {library_name}.\n" "This can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n" diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py index fc79a54b0..b13f8b6c6 100644 --- a/tests/test_cuda_setup_evaluator.py +++ b/tests/test_cuda_setup_evaluator.py @@ -33,6 +33,12 @@ def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog): assert "BNB_CUDA_VERSION" in caplog.text # did we get the warning? +def test_get_cuda_bnb_library_path_override_nocublaslt(monkeypatch, cuda111_noblas_spec, caplog): + monkeypatch.setenv("BNB_CUDA_VERSION", "125") + assert get_cuda_bnb_library_path(cuda111_noblas_spec).stem == "libbitsandbytes_cuda125_nocublaslt" + assert "BNB_CUDA_VERSION" in caplog.text # did we get the warning? + + def test_get_cuda_bnb_library_path_nocublaslt(monkeypatch, cuda111_noblas_spec): monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) assert get_cuda_bnb_library_path(cuda111_noblas_spec).stem == "libbitsandbytes_cuda111_nocublaslt"