Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix half2 with HIP #146

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cuda_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def find_msvc():
extra_include_paths = [os.path.join(library_dir, "exllama_ext")],
verbose = verbose,
extra_ldflags = (["cublas.lib"] + ([f"/LIBPATH:{os.path.join(sys.base_prefix, 'libs')}"] if sys.base_prefix != sys.prefix else [])) if windows else [],
extra_cuda_cflags = ["-lineinfo"] + (["-U__HIP_NO_HALF_CONVERSIONS__", "-O3"] if torch.version.hip else []),
extra_cuda_cflags = ["-lineinfo"] + (["-O3"] if torch.version.hip else []),
extra_cflags = ["-O3"]
# extra_cflags = ["-ftime-report", "-DTORCH_USE_CUDA_DSA"]
)
Expand Down
6 changes: 3 additions & 3 deletions exllama_ext/cuda_func/half_matmul.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ __global__ void half_matmul_kernel
for (int k = k0; k < k0 + BLOCKSIZE / 2; k++)
{
half2 x_item = *x_ptr++;
half2 x_item_0 = __half2half2(x_item.x);
half2 x_item_1 = __half2half2(x_item.y);
half2 x_item_0 = __low2half2(x_item);
half2 x_item_1 = __high2half2(x_item);
half2 w_item_0 = *w_ptr; w_ptr += w_.width / 2;
half2 w_item_1 = *w_ptr; w_ptr += w_.width / 2;
acc = __hfma2(x_item_0, w_item_0, acc);
Expand Down Expand Up @@ -184,7 +184,7 @@ __global__ void half_matmul_small_kernel
r = __hfma2(x_23, w_23, r);
}

half rh = __hadd(r.x, r.y);
half rh = __hadd(__low2half(r), __high2half(r));

__shared__ half accum[MAX_DIM_SMALL / S_BLOCKSIZE][S_THREADS_X];
accum[threadIdx.y][threadIdx.x] = rh;
Expand Down
2 changes: 1 addition & 1 deletion exllama_ext/cuda_func/q4_matmul.cu
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ __global__ void q4_matmul_kernel

if constexpr (use_half2)
{
half result = __hadd(acc.x, acc.y);
half result = __hadd(__low2half(acc), __high2half(acc));
atomicAdd(out_.item_ptr(x_row, w_column), result);
}
else
Expand Down
1 change: 1 addition & 0 deletions exllama_ext/cuda_func/q4_mlp.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ const int THREADS_X = 32;
const int THREADS_Y = 4;
// const int MAX_DIMENSION = 8192;


__device__ __forceinline__ half silu(half x)
{
half one = __float2half(1.0f);
Expand Down
4 changes: 2 additions & 2 deletions exllama_ext/cuda_func/rms_norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ __global__ void rms_norm_row_product_kernel
for (int k = 0; k < BLOCKSIZE_X / 2; k++)
{
half2 x2 = *x_ptr++;
float m0 = __half2float(x2.x);
float m1 = __half2float(x2.y);
float m0 = __low2float(x2);
float m1 = __high2float(x2);
acc = fma(m0, m0, acc);
acc = fma(m1, m1, acc);
}
Expand Down
4 changes: 2 additions & 2 deletions exllama_ext/hip_compat.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ __device__ __forceinline__ __half __compat_hrcp(__half x) {
}

__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)),
static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))};
return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half2_raw>(x).data.x)),
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know where I can read about this __builtin_amdgcn_rcph?

static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half2_raw>(x).data.y))};
}

#define hrcp __compat_hrcp
Expand Down
4 changes: 1 addition & 3 deletions model_init.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from model import ExLlama, ExLlamaCache, ExLlamaConfig
from tokenizer import ExLlamaTokenizer
import argparse, sys, os, glob
from torch import version as torch_version

def add_args(parser):

Expand All @@ -28,13 +27,12 @@ def add_args(parser):
parser.add_argument("-mmnh2", "--matmul_no_half2", action = "store_true", help = "Don't use half2 in Q4 matmul kernel")
parser.add_argument("-snh2", "--silu_no_half2", action = "store_true", help = "Don't use half2 in SiLU kernel")
parser.add_argument("-nh2", "--no_half2", action = "store_true", help = "(All of the above) disable half2 in all kernela")
parser.add_argument("-fh2", "--force_half2", action = "store_true", help = "Force enable half2 even if unsupported")
parser.add_argument("-cs", "--concurrent_streams", action = "store_true", help = "Use concurrent CUDA streams")


def post_parse(args):

if args.no_half2 or torch_version.hip and not args.force_half2:
if args.no_half2:
args.rmsnorm_no_half2 = True
args.rope_no_half2 = True
args.matmul_no_half2 = True
Expand Down