Skip to content

Commit

Permalink
Merge pull request #35 from zhuangqh/support-sd
Browse files Browse the repository at this point in the history
add more cublasLt, cublas implement
  • Loading branch information
n-eiling authored Dec 4, 2023
2 parents 72570fb + d3c545d commit b1d0b24
Show file tree
Hide file tree
Showing 16 changed files with 942 additions and 24 deletions.
8 changes: 5 additions & 3 deletions cpu/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ SRC_SERVER = $(RPC_XDR) \
mt-memcpy.c \
cpu-elf2.c \
cpu-server-nvml.c \
cpu-server-cudnn.c
cpu-server-cudnn.c \
cpu-server-cublaslt.c

SRC_SERVER_LIB = server-library.c
SRC_SERVER_EXE = server-exe.c
Expand All @@ -62,7 +63,8 @@ SRC_CLIENT = $(RPC_XDR) \
cpu-elf2.c \
cpu-client-nvml.c \
cpu-client-cudnn.c \
cpu-client-cublas.c
cpu-client-cublas.c \
cpu-client-cublaslt.c

# cpu-client-driver-hidden.c \
Expand Down Expand Up @@ -110,7 +112,7 @@ ifdef WITH_IB
CC_FLAGS += -DWITH_IB=$(WITH_IB)
endif

SERVER_LD_FLAGS = $(LD_FLAGS) -lcudart -lcusolver -lcuda -lcublas -lrt -lpthread -lnvidia-ml -lcudnn
SERVER_LD_FLAGS = $(LD_FLAGS) -lcudart -lcusolver -lcuda -lcublas -lrt -lpthread -lnvidia-ml -lcudnn -lcublasLt
SERVER_BIN_LD_FLAGS = $(SERVER_LD_FLAGS) -Wl,--unresolved-symbols=ignore-in-object-files
CLIENT_LD_FLAGS = $(LD_FLAGS)

Expand Down
177 changes: 172 additions & 5 deletions cpu/cpu-client-cublas.c
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ cublasStatus_t cublasCreate_v2(cublasHandle_t* handle)
clnt_perror (clnt, "call failed");
}
if (result.err == 0) {
*handle = (void*)result.ptr_result_u.ptr;
*handle = (cublasHandle_t)result.ptr_result_u.ptr;
}
return result.err;
}
Expand Down Expand Up @@ -93,7 +93,28 @@ DEF_FN(cublasStatus_t, cublasGetPointerMode_v2, cublasHandle_t, handle, cublasPo
DEF_FN(cublasStatus_t, cublasSetPointerMode_v2, cublasHandle_t, handle, cublasPointerMode_t, mode);
DEF_FN(cublasStatus_t, cublasGetAtomicsMode, cublasHandle_t, handle, cublasAtomicsMode_t*, mode);
DEF_FN(cublasStatus_t, cublasSetAtomicsMode, cublasHandle_t, handle, cublasAtomicsMode_t, mode);
DEF_FN(cublasStatus_t, cublasGetMathMode, cublasHandle_t, handle, cublasMath_t*, mode);

cublasStatus_t cublasGetMathMode(cublasHandle_t handle, cublasMath_t *mode)
{
#ifdef WITH_API_CNT
api_call_cnt++;
#endif //WITH_API_CNT

int_result result;
enum clnt_stat retval_1;
retval_1 = rpc_cublasgetmathmode_1(
(ptr)handle,
&result, clnt
);
if (retval_1 != RPC_SUCCESS) {
clnt_perror (clnt, "call failed");
}
if (result.err == 0) {
*mode = result.int_result_u.data;
}
return result.err;
}

cublasStatus_t cublasSetMathMode(cublasHandle_t handle, cublasMath_t mode)
{
#ifdef WITH_API_CNT
Expand Down Expand Up @@ -605,7 +626,6 @@ cublasStatus_t cublasSgemmEx(cublasHandle_t handle,


DEF_FN(cublasStatus_t, cublasSgemmEx_64, cublasHandle_t, handle, cublasOperation_t, transa, cublasOperation_t, transb, int64_t, m, int64_t, n, int64_t, k, const float*, alpha, const void*, A, cudaDataType, Atype, int64_t, lda, const void*, B, cudaDataType, Btype, int64_t, ldb, const float*, beta, void*, C, cudaDataType, Ctype, int64_t, ldc);
DEF_FN(cublasStatus_t, cublasGemmEx, cublasHandle_t, handle, cublasOperation_t, transa, cublasOperation_t, transb, int, m, int, n, int, k, const void*, alpha, const void*, A, cudaDataType, Atype, int, lda, const void*, B, cudaDataType, Btype, int, ldb, const void*, beta, void*, C, cudaDataType, Ctype, int, ldc, cublasComputeType_t, computeType, cublasGemmAlgo_t, algo);
DEF_FN(cublasStatus_t, cublasGemmEx_64, cublasHandle_t, handle, cublasOperation_t, transa, cublasOperation_t, transb, int64_t, m, int64_t, n, int64_t, k, const void*, alpha, const void*, A, cudaDataType, Atype, int64_t, lda, const void*, B, cudaDataType, Btype, int64_t, ldb, const void*, beta, void*, C, cudaDataType, Ctype, int64_t, ldc, cublasComputeType_t, computeType, cublasGemmAlgo_t, algo);
DEF_FN(cublasStatus_t, cublasCgemmEx, cublasHandle_t, handle, cublasOperation_t, transa, cublasOperation_t, transb, int, m, int, n, int, k, const cuComplex*, alpha, const void*, A, cudaDataType, Atype, int, lda, const void*, B, cudaDataType, Btype, int, ldb, const cuComplex*, beta, void*, C, cudaDataType, Ctype, int, ldc);
DEF_FN(cublasStatus_t, cublasCgemmEx_64, cublasHandle_t, handle, cublasOperation_t, transa, cublasOperation_t, transb, int64_t, m, int64_t, n, int64_t, k, const cuComplex*, alpha, const void*, A, cudaDataType, Atype, int64_t, lda, const void*, B, cudaDataType, Btype, int64_t, ldb, const cuComplex*, beta, void*, C, cudaDataType, Ctype, int64_t, ldc);
Expand Down Expand Up @@ -691,7 +711,6 @@ DEF_FN(cublasStatus_t, cublasCgemm3mBatched, cublasHandle_t, handle, cublasOpera
DEF_FN(cublasStatus_t, cublasCgemm3mBatched_64, cublasHandle_t, handle, cublasOperation_t, transa, cublasOperation_t, transb, int64_t, m, int64_t, n, int64_t, k, const cuComplex*, alpha, const cuComplex* const*, Aarray, int64_t, lda, const cuComplex* const*, Barray, int64_t, ldb, const cuComplex*, beta, cuComplex* const*, Carray, int64_t, ldc, int64_t, batchCount);
DEF_FN(cublasStatus_t, cublasZgemmBatched, cublasHandle_t, handle, cublasOperation_t, transa, cublasOperation_t, transb, int, m, int, n, int, k, const cuDoubleComplex*, alpha, const cuDoubleComplex* const*, Aarray, int, lda, const cuDoubleComplex* const*, Barray, int, ldb, const cuDoubleComplex*, beta, cuDoubleComplex* const*, Carray, int, ldc, int, batchCount);
DEF_FN(cublasStatus_t, cublasZgemmBatched_64, cublasHandle_t, handle, cublasOperation_t, transa, cublasOperation_t, transb, int64_t, m, int64_t, n, int64_t, k, const cuDoubleComplex*, alpha, const cuDoubleComplex* const*, Aarray, int64_t, lda, const cuDoubleComplex* const*, Barray, int64_t, ldb, const cuDoubleComplex*, beta, cuDoubleComplex* const*, Carray, int64_t, ldc, int64_t, batchCount);
DEF_FN(cublasStatus_t, cublasSgemmStridedBatched, cublasHandle_t, handle, cublasOperation_t, transa, cublasOperation_t, transb, int, m, int, n, int, k, const float*, alpha, const float*, A, int, lda, long long int, strideA, const float*, B, int, ldb, long long int, strideB, const float*, beta, float*, C, int, ldc, long long int, strideC, int, batchCount);
DEF_FN(cublasStatus_t, cublasSgemmStridedBatched_64, cublasHandle_t, handle, cublasOperation_t, transa, cublasOperation_t, transb, int64_t, m, int64_t, n, int64_t, k, const float*, alpha, const float*, A, int64_t, lda, long long int, strideA, const float*, B, int64_t, ldb, long long int, strideB, const float*, beta, float*, C, int64_t, ldc, long long int, strideC, int64_t, batchCount);
DEF_FN(cublasStatus_t, cublasDgemmStridedBatched, cublasHandle_t, handle, cublasOperation_t, transa, cublasOperation_t, transb, int, m, int, n, int, k, const double*, alpha, const double*, A, int, lda, long long int, strideA, const double*, B, int, ldb, long long int, strideB, const double*, beta, double*, C, int, ldc, long long int, strideC, int, batchCount);
DEF_FN(cublasStatus_t, cublasDgemmStridedBatched_64, cublasHandle_t, handle, cublasOperation_t, transa, cublasOperation_t, transb, int64_t, m, int64_t, n, int64_t, k, const double*, alpha, const double*, A, int64_t, lda, long long int, strideA, const double*, B, int64_t, ldb, long long int, strideB, const double*, beta, double*, C, int64_t, ldc, long long int, strideC, int64_t, batchCount);
Expand All @@ -703,7 +722,6 @@ DEF_FN(cublasStatus_t, cublasZgemmStridedBatched, cublasHandle_t, handle, cublas
DEF_FN(cublasStatus_t, cublasZgemmStridedBatched_64, cublasHandle_t, handle, cublasOperation_t, transa, cublasOperation_t, transb, int64_t, m, int64_t, n, int64_t, k, const cuDoubleComplex*, alpha, const cuDoubleComplex*, A, int64_t, lda, long long int, strideA, const cuDoubleComplex*, B, int64_t, ldb, long long int, strideB, const cuDoubleComplex*, beta, cuDoubleComplex*, C, int64_t, ldc, long long int, strideC, int64_t, batchCount);
DEF_FN(cublasStatus_t, cublasGemmBatchedEx, cublasHandle_t, handle, cublasOperation_t, transa, cublasOperation_t, transb, int, m, int, n, int, k, const void*, alpha, const void* const*, Aarray, cudaDataType, Atype, int, lda, const void* const*, Barray, cudaDataType, Btype, int, ldb, const void*, beta, void* const*, Carray, cudaDataType, Ctype, int, ldc, int, batchCount, cublasComputeType_t, computeType, cublasGemmAlgo_t, algo);
DEF_FN(cublasStatus_t, cublasGemmBatchedEx_64, cublasHandle_t, handle, cublasOperation_t, transa, cublasOperation_t, transb, int64_t, m, int64_t, n, int64_t, k, const void*, alpha, const void* const*, Aarray, cudaDataType, Atype, int64_t, lda, const void* const*, Barray, cudaDataType, Btype, int64_t, ldb, const void*, beta, void* const*, Carray, cudaDataType, Ctype, int64_t, ldc, int64_t, batchCount, cublasComputeType_t, computeType, cublasGemmAlgo_t, algo);
DEF_FN(cublasStatus_t, cublasGemmStridedBatchedEx, cublasHandle_t, handle, cublasOperation_t, transa, cublasOperation_t, transb, int, m, int, n, int, k, const void*, alpha, const void*, A, cudaDataType, Atype, int, lda, long long int, strideA, const void*, B, cudaDataType, Btype, int, ldb, long long int, strideB, const void*, beta, void*, C, cudaDataType, Ctype, int, ldc, long long int, strideC, int, batchCount, cublasComputeType_t, computeType, cublasGemmAlgo_t, algo);
DEF_FN(cublasStatus_t, cublasGemmStridedBatchedEx_64, cublasHandle_t, handle, cublasOperation_t, transa, cublasOperation_t, transb, int64_t, m, int64_t, n, int64_t, k, const void*, alpha, const void*, A, cudaDataType, Atype, int64_t, lda, long long int, strideA, const void*, B, cudaDataType, Btype, int64_t, ldb, long long int, strideB, const void*, beta, void*, C, cudaDataType, Ctype, int64_t, ldc, long long int, strideC, int64_t, batchCount, cublasComputeType_t, computeType, cublasGemmAlgo_t, algo);
DEF_FN(cublasStatus_t, cublasSgeam, cublasHandle_t, handle, cublasOperation_t, transa, cublasOperation_t, transb, int, m, int, n, const float*, alpha, const float*, A, int, lda, const float*, beta, const float*, B, int, ldb, float*, C, int, ldc);
DEF_FN(cublasStatus_t, cublasSgeam_64, cublasHandle_t, handle, cublasOperation_t, transa, cublasOperation_t, transb, int64_t, m, int64_t, n, const float*, alpha, const float*, A, int64_t, lda, const float*, beta, const float*, B, int64_t, ldb, float*, C, int64_t, ldc);
Expand Down Expand Up @@ -761,3 +779,152 @@ DEF_FN(cublasStatus_t, cublasSgetrsBatched, cublasHandle_t, handle, cublasOperat
DEF_FN(cublasStatus_t, cublasDgetrsBatched, cublasHandle_t, handle, cublasOperation_t, trans, int, n, int, nrhs, const double* const*, Aarray, int, lda, const int*, devIpiv, double* const*, Barray, int, ldb, int*, info, int, batchSize);
DEF_FN(cublasStatus_t, cublasCgetrsBatched, cublasHandle_t, handle, cublasOperation_t, trans, int, n, int, nrhs, const cuComplex* const*, Aarray, int, lda, const int*, devIpiv, cuComplex* const*, Barray, int, ldb, int*, info, int, batchSize);
DEF_FN(cublasStatus_t, cublasZgetrsBatched, cublasHandle_t, handle, cublasOperation_t, trans, int, n, int, nrhs, const cuDoubleComplex* const*, Aarray, int, lda, const int*, devIpiv, cuDoubleComplex* const*, Barray, int, ldb, int*, info, int, batchSize);

cublasStatus_t cublasGemmStridedBatchedEx(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const void *alpha,
const void *A,
cudaDataType_t Atype,
int lda,
long long int strideA,
const void *B,
cudaDataType_t Btype,
int ldb,
long long int strideB,
const void *beta,
void *C,
cudaDataType_t Ctype,
int ldc,
long long int strideC,
int batchCount,
cublasComputeType_t computeType,
cublasGemmAlgo_t algo)
{
#ifdef WITH_API_CNT
api_call_cnt++;
#endif //WITH_API_CNT
int result;
enum clnt_stat retval_1;
retval_1 = rpc_cublasgemmstridedbatchedex_1(
(ptr)handle,
(int)transa,
(int)transb,
m, n, k,
*((float*)alpha),
(ptr)A,
(int)Atype,
lda,
strideA,
(ptr)B,
(int)Btype,
ldb,
strideB,
*((float*)beta),
(ptr)C,
(int)Ctype,
ldc,
strideC,
batchCount,
(int)computeType,
(int)algo,
&result, clnt
);
if (retval_1 != RPC_SUCCESS) {
clnt_perror (clnt, "call failed");
}
return result;
}


cublasStatus_t cublasGemmEx(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const void *alpha,
const void *A,
cudaDataType_t Atype,
int lda,
const void *B,
cudaDataType_t Btype,
int ldb,
const void *beta,
void *C,
cudaDataType_t Ctype,
int ldc,
cublasComputeType_t computeType,
cublasGemmAlgo_t algo)
{
#ifdef WITH_API_CNT
api_call_cnt++;
#endif //WITH_API_CNT
int result;
enum clnt_stat retval_1;
retval_1 = rpc_cublasgemmex_1(
(ptr)handle,
(int)transa,
(int)transb,
m, n, k,
*((float*)alpha),
(ptr)A, (int)Atype, lda,
(ptr)B, (int)Btype, ldb,
*((float*)beta),
(ptr)C, (int)Ctype, ldc,
computeType, algo,
&result, clnt);
if (retval_1 != RPC_SUCCESS) {
clnt_perror (clnt, "call failed");
}
return result;
}


cublasStatus_t cublasSgemmStridedBatched(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m, int n, int k,
const float *alpha,
const float *A, int lda,
long long int strideA,
const float *B, int ldb,
long long int strideB,
const float *beta,
float *C, int ldc,
long long int strideC,
int batchCount)
{
#ifdef WITH_API_CNT
api_call_cnt++;
#endif //WITH_API_CNT

int result;
enum clnt_stat retval_1;
retval_1 = rpc_cublasgemmstridedbatched_1(
(ptr)handle,
(int)transa,
(int)transb,
m, n, k,
*alpha,
(ptr)A,
lda,
strideA,
(ptr)B,
ldb,
strideB,
*beta,
(ptr)C,
ldc,
strideC,
batchCount,
&result, clnt
);
if (retval_1 != RPC_SUCCESS) {
clnt_perror (clnt, "call failed");
}
return result;
}
Loading

0 comments on commit b1d0b24

Please sign in to comment.