Skip to content

Commit

Permalink
Always link cublasLt, use runtime checks for support
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdouglas committed Mar 4, 2024
1 parent f9eba9c commit 15d3c8f
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 64 deletions.
26 changes: 12 additions & 14 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -122,20 +122,18 @@ jobs:
build_os=${{ matrix.os }}
build_arch=${{ matrix.arch }}
[[ "${{ matrix.os }}" = windows-* ]] && python3 -m pip install ninja
for NO_CUBLASLT in ON OFF; do
if [ ${build_os:0:6} == ubuntu ]; then
image=nvidia/cuda:${{ matrix.cuda_version }}-devel-ubuntu22.04
echo "Using image $image"
docker run --platform linux/$build_arch -i -w /src -v $PWD:/src $image sh -c \
"apt-get update \
&& DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \
&& cmake -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY=\"50;52;60;61;70;75;80;86;89;90\" -DNO_CUBLASLT=${NO_CUBLASLT} . \
&& cmake --build ."
else
cmake -G Ninja -DCOMPUTE_BACKEND=cuda -DNO_CUBLASLT=${NO_CUBLASLT} -DCMAKE_BUILD_TYPE=Release -S .
cmake --build . --config Release
fi
done
if [ ${build_os:0:6} == ubuntu ]; then
image=nvidia/cuda:${{ matrix.cuda_version }}-devel-ubuntu22.04
echo "Using image $image"
docker run --platform linux/$build_arch -i -w /src -v $PWD:/src $image sh -c \
"apt-get update \
&& DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \
&& cmake -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY=\"50;52;60;61;70;75;80;86;89;90\" . \
&& cmake --build ."
else
cmake -G Ninja -DCOMPUTE_BACKEND=cuda -DCMAKE_BUILD_TYPE=Release -S .
cmake --build . --config Release
fi
mkdir -p output/${{ matrix.os }}/${{ matrix.arch }}
( shopt -s nullglob && cp bitsandbytes/*.{so,dylib,dll} output/${{ matrix.os }}/${{ matrix.arch }}/ )
- name: Upload build artifact
Expand Down
22 changes: 7 additions & 15 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# For MSVC: `cmake -B build . && cmake --build build --config Release`
# You can also use the following options and variables
# - COMPUTE_BACKEND: Set to `cpu`, `cuda`, or `mps` to select the backend
# - NO_CUBLASLT: Default OFF, will skip building/linking CUBLASLT support
# - CUDA_VERSION: The expected CUDA version, for sanity checking. The actual version
# is whatever CMake finds on your path.
# - COMPUTE_CAPABILITY: Which GPU Arch/Compute codes to provide to NVCC.
Expand Down Expand Up @@ -39,10 +38,8 @@ if(${COMPUTE_BACKEND} STREQUAL "cuda")
if(APPLE)
message(FATAL_ERROR "CUDA is not supported on macOS" )
endif()
option(NO_CUBLASLT "Disable CUBLAS" OFF)
set(BUILD_CUDA ON)
set(BUILD_MPS OFF)
message(STATUS "NO_CUBLASLT := ${NO_CUBLASLT}")
elseif(${COMPUTE_BACKEND} STREQUAL "mps")
if(NOT APPLE)
message(FATAL_ERROR "MPS is only supported on macOS" )
Expand Down Expand Up @@ -145,9 +142,7 @@ if(BUILD_CUDA)
list(APPEND SRC_FILES ${CUDA_FILES})

string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}")
if(NO_CUBLASLT)
string(APPEND BNB_OUTPUT_NAME "_nocublaslt")
endif()

add_compile_definitions(BUILD_CUDA)
elseif(BUILD_MPS)
if(NOT APPLE)
Expand All @@ -173,13 +168,11 @@ else()
set(GPU_SOURCES)
endif()


if(WIN32)
# Export all symbols
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON)
endif()

# Weird MSVC hacks
if(MSVC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2 /fp:fast")
endif()
Expand All @@ -192,12 +185,11 @@ target_include_directories(bitsandbytes PUBLIC csrc include)

if(BUILD_CUDA)
target_include_directories(bitsandbytes PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_link_libraries(bitsandbytes PUBLIC CUDA::cudart CUDA::cublas CUDA::cusparse)
if(NO_CUBLASLT)
target_compile_definitions(bitsandbytes PUBLIC NO_CUBLASLT)
else()
target_link_libraries(bitsandbytes PUBLIC CUDA::cublasLt)
endif()

# Note: As of CUDA 11.0, cublas depends on cublasLt.
# See: https://gitlab.kitware.com/cmake/cmake/-/merge_requests/6857/diffs
# It is listed here for assurance. In CMake > 3.23.0, it's implicit when linking CUDA::cublas.
target_link_libraries(bitsandbytes PUBLIC CUDA::cudart CUDA::cublas CUDA::cublasLt CUDA::cusparse)

set_target_properties(bitsandbytes
PROPERTIES
Expand All @@ -220,4 +212,4 @@ if(MSVC)
set_target_properties(bitsandbytes PROPERTIES RUNTIME_OUTPUT_DIRECTORY_DEBUG "${PROJECT_SOURCE_DIR}/bitsandbytes")
endif()

set_target_properties(bitsandbytes PROPERTIES LIBRARY_OUTPUT_DIRECTORY bitsandbytes)
set_target_properties(bitsandbytes PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${PROJECT_SOURCE_DIR}/bitsandbytes")
17 changes: 1 addition & 16 deletions bitsandbytes/cuda_setup/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
- Software:
- CPU-only: only CPU quantization functions (no optimizer, no matrix multiply)
- CuBLAS-LT: full-build 8-bit optimizer
- no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`)
evaluation:
- if paths faulty, return meaningful error
Expand Down Expand Up @@ -86,11 +85,6 @@ def generate_instructions(self):
self.add_log_entry('CUDA SETUP: Before you try again running bitsandbytes, make sure old CUDA 10.0 versions are uninstalled and removed from $LD_LIBRARY_PATH variables.')
return


has_cublaslt = is_cublasLt_compatible(self.cc)
if not has_cublaslt:
make_cmd += '_nomatmul'

self.add_log_entry('CUDA SETUP: Something unexpected happened. Please compile from source:')
self.add_log_entry('git clone https://github.com/TimDettmers/bitsandbytes.git')
self.add_log_entry('cd bitsandbytes')
Expand Down Expand Up @@ -372,22 +366,13 @@ def evaluate_cuda_setup():
"https://github.com/TimDettmers/bitsandbytes/blob/main/how_to_use_nonpytorch_cuda.md"
)


# 7.5 is the minimum CC vor cublaslt
has_cublaslt = is_cublasLt_compatible(cc)

# TODO:
# (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible)
# (2) Multiple CUDA versions installed

# we use ls -l instead of nvcc to determine the cuda version
# since most installations will have the libcudart.so installed, but not the compiler

binary_name = f"libbitsandbytes_cuda{cuda_version_string}"
if not has_cublaslt:
# if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt
binary_name += "_nocublaslt"

binary_name = f"{binary_name}{DYNAMIC_LIBRARY_SUFFIX}"
binary_name = f"libbitsandbytes_cuda{cuda_version_string}{DYNAMIC_LIBRARY_SUFFIX}"

return binary_name, cudart_path, cc, cuda_version_string
2 changes: 1 addition & 1 deletion bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1960,7 +1960,7 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
)

if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu`
raise NotImplementedError("igemmlt not available (probably built with NO_CUBLASLT)")
raise NotImplementedError("igemmlt not available (probably CC < 7.5)")

if has_error:
print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}')
Expand Down
45 changes: 30 additions & 15 deletions csrc/ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,6 @@ int roundoff(int v, int d) {
return (v + d - 1) / d * d;
}


#ifdef NO_CUBLASLT
#else
template<int ORDER> cublasLtOrder_t get_order()
{
switch(ORDER)
Expand Down Expand Up @@ -332,8 +329,6 @@ template cublasLtOrder_t get_order<COL>();
template cublasLtOrder_t get_order<COL32>();
template cublasLtOrder_t get_order<COL_TURING>();
template cublasLtOrder_t get_order<COL_AMPERE>();
#endif


template<int ORDER> int get_leading_dim(int dim1, int dim2)
{
Expand Down Expand Up @@ -366,10 +361,33 @@ template int get_leading_dim<ROW>(int dim1, int dim2);
template int get_leading_dim<COL>(int dim1, int dim2);
template int get_leading_dim<COL32>(int dim1, int dim2);

// TODO: Check overhead. Maybe not worth it; just check in Python lib once,
// and avoid calling lib functions w/o support for them.
// TODO: Address GTX 1660, any other 7.5 devices maybe not supported.
inline bool igemmlt_supported() {
int device;
int ccMajor;

CUDA_CHECK_RETURN(cudaGetDevice(&device));
CUDA_CHECK_RETURN(cudaDeviceGetAttribute(&ccMajor, cudaDevAttrComputeCapabilityMajor, device));

if (ccMajor >= 8)
return true;

if (ccMajor < 7)
return false;

int ccMinor;
CUDA_CHECK_RETURN(cudaDeviceGetAttribute(&ccMinor, cudaDevAttrComputeCapabilityMinor, device));

return ccMinor >= 5;
}

template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2)
{
#ifdef NO_CUBLASLT
#else
if (!igemmlt_supported())
return;

cublasLtOrder_t orderA = get_order<SRC>();
cublasLtOrder_t orderOut = get_order<TARGET>();
int ldA = get_leading_dim<SRC>(dim1, dim2);
Expand Down Expand Up @@ -408,7 +426,6 @@ template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void trans
if (A_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(A_desc));
if (out_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(out_desc));
if (A2Out_desc) checkCublasStatus(cublasLtMatrixTransformDescDestroy(A2Out_desc));
#endif
}

template void transform<int8_t, ROW, COL, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
Expand All @@ -422,9 +439,9 @@ template void transform<int32_t, COL32, ROW, false, 32>(cublasLtHandle_t ltHandl

template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{
#ifdef NO_CUBLASLT
return ERR_NOT_IMPLEMENTED;
#else
if (!igemmlt_supported())
return ERR_NOT_IMPLEMENTED;

int has_error = 0;
cublasLtMatmulDesc_t matmulDesc = NULL;
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
Expand Down Expand Up @@ -479,7 +496,6 @@ template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle
printf("error detected");

return has_error;
#endif // NO_CUBLASLT
}

int fill_up_to_nearest_multiple(int value, int multiple)
Expand Down Expand Up @@ -595,8 +611,8 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o
void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B)
{

#ifdef NO_CUBLASLT
#else
if (!igemmlt_supported())
return;

cusparseSpMatDescr_t descA;
cusparseDnMatDescr_t descB, descC;
Expand Down Expand Up @@ -644,7 +660,6 @@ void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_val
CHECK_CUSPARSE( cusparseDestroyDnMat(descB) );
CHECK_CUSPARSE( cusparseDestroyDnMat(descC) );
CUDA_CHECK_RETURN( cudaFree(dBuffer) );
#endif
}

template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
Expand Down
2 changes: 1 addition & 1 deletion docs/source/installation.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ For Linux systems, make sure your hardware meets the following requirements to u
| 8-bit optimizers/quantization | NVIDIA Kepler (GTX 780 or newer) |

> [!WARNING]
> bitsandbytes >= 0.39.1 no longer includes Kepler binaries in pip installations. This requires manual compilation, and you should follow the general steps and use `cuda11x_nomatmul_kepler` for Kepler-targeted compilation.
> bitsandbytes >= 0.39.1 no longer includes Kepler binaries in pip installations. This requires manual compilation, and you should follow the general steps and use CUDA 11.x for Kepler-targeted compilation.
To install from PyPI.

Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ def pytest_runtest_call(item):
try:
item.runtest()
except NotImplementedError as nie:
if "NO_CUBLASLT" in str(nie):
pytest.skip("CUBLASLT not available")
if "CC < 7.5" in str(nie):
pytest.skip("INT8 tensor cores not available")
raise
except AssertionError as ae:
if str(ae) == "Torch not compiled with CUDA enabled":
Expand Down

0 comments on commit 15d3c8f

Please sign in to comment.