Skip to content

Commit

Permalink
Add ROCm support (#7)
Browse files Browse the repository at this point in the history
* Add ROCm support

Co-authored-by: [  ] <[email protected]>

* Disable half2 by default when using HIP

---------

Co-authored-by: [  ] <[email protected]>
  • Loading branch information
ardfork and BlankParenthesis authored Jun 6, 2023
1 parent 45de2b5 commit 43e3059
Show file tree
Hide file tree
Showing 11 changed files with 85 additions and 4 deletions.
4 changes: 3 additions & 1 deletion cuda_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,10 @@ def find_msvc():
os.path.join(library_dir, "exllama_ext/cuda_func/q4_mlp.cu"),
os.path.join(library_dir, "exllama_ext/cpu_func/rep_penalty.cpp")
],
extra_include_paths = [os.path.join(library_dir, "exllama_ext")],
verbose = verbose,
extra_ldflags = ["cublas.lib"] if windows else []
extra_ldflags = ["cublas.lib"] if windows else [],
extra_cuda_cflags = ["-U__HIP_NO_HALF_CONVERSIONS__"] if torch.version.hip else []
# extra_cflags = ["-ftime-report", "-DTORCH_USE_CUDA_DSA"]
)

Expand Down
4 changes: 2 additions & 2 deletions exllama_ext/cuda_compat.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)

//

#ifdef __CUDA_ARCH__
#if __CUDA_ARCH__ < 700
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)

__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
Expand Down
7 changes: 7 additions & 0 deletions exllama_ext/cuda_func/column_remap.cu
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
#include "column_remap.cuh"
#include "../util.cuh"

// Using 1024 make me crash with "Memory access fault by GPU node-1 (Agent
// handle: 0x012345678912) on address 0x012345678912. Reason: Page not present
// or supervisor privilege."
#if defined(USE_ROCM)
const int SHUF_BLOCKSIZE_X = 256;
#else
const int SHUF_BLOCKSIZE_X = 1024;
#endif
const int SHUF_BLOCKSIZE_Y = 16;

__global__ void column_remap_kernel
Expand Down
3 changes: 3 additions & 0 deletions exllama_ext/cuda_func/half_matmul.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
#include "../util.cuh"
#include "../matrix.cuh"
#include "../cuda_compat.cuh"
#if defined(USE_ROCM)
#include "../hip_compat.cuh"
#endif

// Block size

Expand Down
6 changes: 6 additions & 0 deletions exllama_ext/cuda_func/half_matmul.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
#include <cstdint>
#include <ATen/cuda/CUDAContext.h>

// Workaround for hipify_python using rocblas instead of hipblas.
#if defined(USE_ROCM)
#include <hipblas/hipblas.h>
#define rocblas_handle hipblasHandle_t
#endif

void half_matmul_cuda
(
const half* x,
Expand Down
3 changes: 3 additions & 0 deletions exllama_ext/cuda_func/q4_matmul.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
#include "../matrix.cuh"
#include "../cuda_compat.cuh"
#include "../cuda_buffers.cuh"
#if defined(USE_ROCM)
#include "../hip_compat.cuh"
#endif

const int THREADS_X = 32; // Block size and thread count along columns in w and out
const int THREADS_Y = 1; // Block size and thread count along rows in x and out
Expand Down
6 changes: 6 additions & 0 deletions exllama_ext/cuda_func/q4_matmul.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@
#include "q4_matrix.cuh"
#include "../tuning.h"

// Workaround for hipify_python using rocblas instead of hipblas.
#if defined(USE_ROCM)
#include <hipblas/hipblas.h>
#define rocblas_handle hipblasHandle_t
#endif

void q4_matmul_cuda
(
ExLlamaTuning* tuningParams,
Expand Down
3 changes: 3 additions & 0 deletions exllama_ext/cuda_func/q4_mlp.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
#include "../cuda_buffers.cuh"
#include "../util.cuh"
#include "../matrix.cuh"
#if defined(USE_ROCM)
#include "../hip_compat.cuh"
#endif

const int THREADS_X = 32;
const int THREADS_Y = 4;
Expand Down
45 changes: 45 additions & 0 deletions exllama_ext/hip_compat.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#ifndef _hip_compat_cuh
#define _hip_compat_cuh

// Workaround for a bug in hipamd, backported from upstream.
__device__ __forceinline__ __half __compat_hrcp(__half x) {
return __half_raw{
static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))};
}

__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)),
static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))};
}

#define hrcp __compat_hrcp
#define h2rcp __compat_h2rcp

// Workaround for hipify_python using rocblas instead of hipblas.
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
hipblasOperation_t transA,
hipblasOperation_t transB,
int m,
int n,
int k,
const half* alpha,
const half* AP,
int lda,
const half* BP,
int ldb,
const half* beta,
half* CP,
int ldc) {
return hipblasHgemm(handle, transA, transB, m, n, k,
reinterpret_cast<const hipblasHalf *>(alpha),
reinterpret_cast<const hipblasHalf *>(AP), lda,
reinterpret_cast<const hipblasHalf *>(BP), ldb,
reinterpret_cast<const hipblasHalf *>(beta),
reinterpret_cast<hipblasHalf *>(CP), ldc);
}

#define rocblas_handle hipblasHandle_t
#define rocblas_operation_none HIPBLAS_OP_N
#define rocblas_hgemm __compat_hipblasHgemm

#endif
4 changes: 4 additions & 0 deletions exllama_ext/util.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
#include <cstdint>
#include <cstdio>

#if defined(USE_ROCM)
#define cudaUnspecified hipErrorUnknown
#else
#define cudaUnspecified cudaErrorApiFailureBase
#endif

// React to failure on return code != cudaSuccess

Expand Down
4 changes: 3 additions & 1 deletion model_init.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from model import ExLlama, ExLlamaCache, ExLlamaConfig
from tokenizer import ExLlamaTokenizer
import argparse, sys, os, glob
from torch import version as torch_version

def add_args(parser):

Expand All @@ -23,11 +24,12 @@ def add_args(parser):
parser.add_argument("-mmnh2", "--matmul_no_half2", action = "store_true", help = "Don't use half2 in Q4 matmul kernel")
parser.add_argument("-snh2", "--silu_no_half2", action = "store_true", help = "Don't use half2 in SiLU kernel")
parser.add_argument("-nh2", "--no_half2", action = "store_true", help = "(All of the above) disable half2 in all kernela")
parser.add_argument("-fh2", "--force_half2", action = "store_true", help = "Force enable half2 even if unsupported")


def post_parse(args):

if args.no_half2:
if args.no_half2 or torch_version.hip and not args.force_half2:
args.rmsnorm_no_half2 = True
args.rope_no_half2 = True
args.matmul_no_half2 = True
Expand Down

0 comments on commit 43e3059

Please sign in to comment.