Skip to content

Commit

Permalink
[Fix](mluOpXgetrf):resolve bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Chuancysun committed Jan 3, 2025
1 parent 73ccad1 commit 7cc42eb
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 124 deletions.
2 changes: 2 additions & 0 deletions kernels/xgetrf/cnrtMemcpy2D_union1.mlu
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ __mlu_entry__ void MLUKernelcnrtMemcpy2D(int batch, int m, int n, float *dA,
__memcpy(dB, dA, batch * 2 * m * n * sizeof(float), GDRAM2GDRAM);
} else if (mode == 3) {
__gdramset(dA, 1, 0);
} else if (mode == 4) {
__memcpy(dB, dA, 1 * m * n * sizeof(int), GDRAM2GDRAM);
}
}
}
Expand Down
92 changes: 51 additions & 41 deletions kernels/xgetrf/scal_ger_union1.mlu
Original file line number Diff line number Diff line change
Expand Up @@ -477,14 +477,13 @@ __mlu_func__ void MLUKernelCPivot(int m, int n, int step, int batch, int M_size,
r_src_nram = r_orig + start;
i_src_nram = i_orig + start;

float absmax =
r_src_nram[0] * r_src_nram[0] + i_src_nram[0] * i_src_nram[0];
float absmax = FABS(r_src_nram[0]) + FABS(i_src_nram[0]);
int max_gb_pos =
0 + (m_per_core - len) + tx * mp_ + k * MAX_M_SIZE_COMPLEX_PIVOT;
int max_pos = 0;

__bang_square(r_src_nram, r_src_nram, len);
__bang_square(i_src_nram, i_src_nram, len);
__bang_abs(r_src_nram, r_src_nram, len);
__bang_abs(i_src_nram, i_src_nram, len);
__bang_add(r_src_nram, r_src_nram, i_src_nram, len);
__bang_argmax(r_src_nram, r_src_nram, len);
absmax = r_src_nram[0];
Expand All @@ -509,12 +508,15 @@ __mlu_func__ void MLUKernelCPivot(int m, int n, int step, int batch, int M_size,
__sync_cluster();
}
// Swap rows based on the pivoting results to maintain correctness for complex
__mlu_func__ void MLUKernelCKSwap(
int m, int n, int step, int batch, int M_size, int N_size, float *d_rA,
float *d_iA, int lda, int stride_a, float *r_orig, float *i_orig,
int ld_orig, int m_per_core, float *k_max, int *k_max_idx, int *k_max_gbidx,
float *r_k_max_arr, float *i_k_max_arr, int tx, int taskdim, int batch_id,
int mp, int mp_, int *dipiv, int *dipiv2, int *info, int gbstep) {
__mlu_func__ void MLUKernelCKSwap(int m, int n, int step, int batch, int M_size,
int N_size, float *d_rA, float *d_iA, int lda,
int stride_a, float *r_orig, float *i_orig,
int ld_orig, int m_per_core, float *k_max,
int *k_max_idx, int *k_max_gbidx,
float *r_k_max_arr, float *i_k_max_arr,
int tx, int taskdim, int batch_id, int mp,
int mp_, int *dipiv, int *dipiv2, int *pivot,
int *info, int J, int gbstep) {
float *rA = d_rA;
float *iA = d_iA;

Expand All @@ -538,6 +540,7 @@ __mlu_func__ void MLUKernelCKSwap(
int temp = dipiv[step];
dipiv[step] = dipiv[k_max_gbidx[0]];
dipiv[k_max_gbidx[0]] = temp;
pivot[step] = k_max_gbidx[0] + 1 + J + gbstep; // 1-indexing

int temp2 = dipiv2[step];
dipiv2[step] = dipiv2[k_max_gbidx[0]];
Expand All @@ -553,7 +556,8 @@ __mlu_func__ void MLUKernelKSwap(int m, int n, int step, int batch, int M_size,
float *k_max, int *k_max_idx, int *k_max_gbidx,
float *k_max_arr, int tx, int taskdim,
int batch_id, int mp, int mp_, int *dipiv,
int *dipiv2, int *info, int gbstep) {
int *dipiv2, int *pivot, int *info, int J,
int gbstep) {
float *A = dA;

if (m_per_core > 0) {
Expand All @@ -568,6 +572,7 @@ __mlu_func__ void MLUKernelKSwap(int m, int n, int step, int batch, int M_size,
int temp = dipiv[step];
dipiv[step] = dipiv[k_max_gbidx[0]];
dipiv[k_max_gbidx[0]] = temp;
pivot[step] = k_max_gbidx[0] + 1 + J + gbstep;

int temp2 = dipiv2[step];
dipiv2[step] = dipiv2[k_max_gbidx[0]];
Expand Down Expand Up @@ -653,7 +658,7 @@ __mlu_func__ void MLUKernelPivotSwap2(
int ld_orig, int len_extra, int m_e, float *max_arr, float *max_arr_vec,
int *max_idx, int *max_gbidx, float *shared_y, int tx, int taskdim,
int batch_id, int m_per_core, int m_per_core_, int *dipiv, int *dipiv2,
int *info, int gbstep) {
int *pivot, int *info, int J, int gbstep) {
float *src_nram = L1;
float *src_nram1 = orig;
float *sram_buffer =
Expand Down Expand Up @@ -761,6 +766,7 @@ __mlu_func__ void MLUKernelPivotSwap2(
int temp = dipiv[cur_row];
dipiv[cur_row] = dipiv[max_row];
dipiv[max_row] = temp;
pivot[cur_row] = max_row + 1 + J + gbstep;

int temp2 = dipiv2[cur_row];
dipiv2[cur_row] = dipiv2[max_row];
Expand All @@ -775,8 +781,8 @@ __mlu_func__ void MLUKernelPivotSwap(
int lda, int stride_a, float *workspace, float *L1, int ldL1, float *orig,
int ld_orig, int len_extra, int m_e, float *max_arr, float *max_arr_vec,
int *max_idx, float *shared_y, int tx, int taskdim, int batch_id,
int m_per_core, int m_per_core_, int *dipiv, int *dipiv2, int *info,
int gbstep) {
int m_per_core, int m_per_core_, int *dipiv, int *dipiv2, int *pivot,
int *info, int J, int gbstep) {
float *src_nram = L1;
float *src_nram1 = orig;
float *sram_buffer =
Expand Down Expand Up @@ -873,6 +879,7 @@ __mlu_func__ void MLUKernelPivotSwap(
int temp = dipiv[cur_row];
dipiv[cur_row] = dipiv[max_row];
dipiv[max_row] = temp;
pivot[cur_row] = max_row + 1 + J + gbstep;

int temp2 = dipiv2[cur_row];
dipiv2[cur_row] = dipiv2[max_row];
Expand All @@ -888,7 +895,8 @@ __mlu_func__ void MLUKernelCPivotSwap2(
int len_extra, int m_e, float *max_arr, float *r_max_arr_vec,
float *i_max_arr_vec, int *max_idx, int *max_gbidx, float *r_shared_y,
float *i_shared_y, int tx, int taskdim, int batch_id, int m_per_core,
int m_per_core_, int *dipiv, int *dipiv2, int *info, int gbstep) {
int m_per_core_, int *dipiv, int *dipiv2, int *pivot, int *info, int J,
int gbstep) {
float *r_src_nram = r_L1;
float *i_src_nram = i_L1;
float *r_src_nram1 = r_orig;
Expand Down Expand Up @@ -936,13 +944,12 @@ __mlu_func__ void MLUKernelCPivotSwap2(
r_src_nram = r_L1 + start + step + step * ldL1;
i_src_nram = i_L1 + start + step + step * ldL1;

float absmax =
r_src_nram[0] * r_src_nram[0] + i_src_nram[0] * i_src_nram[0];
float absmax = FABS(r_src_nram[0]) + FABS(i_src_nram[0]);
int max_gb_pos = 0 + (m_per_core - len) + tx * m_per_core_;
int max_pos = 0;

__bang_square(r_src_nram1, r_src_nram, len);
__bang_square(i_src_nram1, i_src_nram, len);
__bang_abs(r_src_nram1, r_src_nram, len);
__bang_abs(i_src_nram1, i_src_nram, len);
__bang_add(r_src_nram1, r_src_nram1, i_src_nram1, len);
__bang_argmax(r_src_nram1, r_src_nram1, len);
absmax = r_src_nram1[0];
Expand Down Expand Up @@ -1021,6 +1028,7 @@ __mlu_func__ void MLUKernelCPivotSwap2(
int temp = dipiv[cur_row];
dipiv[cur_row] = dipiv[max_row];
dipiv[max_row] = temp;
pivot[cur_row] = max_row + 1 + J + gbstep;

int temp2 = dipiv2[cur_row];
dipiv2[cur_row] = dipiv2[max_row];
Expand All @@ -1036,7 +1044,7 @@ __mlu_func__ void MLUKernelCPivotSwap(
int len_extra, int m_e, float *max_arr, float *r_max_arr_vec,
float *i_max_arr_vec, int *max_idx, float *r_shared_y, float *i_shared_y,
int tx, int taskdim, int batch_id, int m_per_core, int m_per_core_,
int *dipiv, int *dipiv2, int *info, int gbstep) {
int *dipiv, int *dipiv2, int *pivot, int *info, int J, int gbstep) {
float *r_src_nram = r_L1;
float *i_src_nram = i_L1;
float *r_src_nram1 = r_orig;
Expand Down Expand Up @@ -1082,13 +1090,12 @@ __mlu_func__ void MLUKernelCPivotSwap(
r_src_nram = r_L1 + start + step + step * ldL1;
i_src_nram = i_L1 + start + step + step * ldL1;

float absmax =
r_src_nram[0] * r_src_nram[0] + i_src_nram[0] * i_src_nram[0];
float absmax = FABS(r_src_nram[0]) + FABS(i_src_nram[0]);
int max_gb_pos = 0 + (m_per_core - len) + tx * m_per_core_;
int max_pos = 0;

__bang_square(r_src_nram1, r_src_nram, len);
__bang_square(i_src_nram1, i_src_nram, len);
__bang_abs(r_src_nram1, r_src_nram, len);
__bang_abs(i_src_nram1, i_src_nram, len);
__bang_add(r_src_nram1, r_src_nram1, i_src_nram1, len);
__bang_argmax(r_src_nram1, r_src_nram1, len);
absmax = r_src_nram1[0];
Expand Down Expand Up @@ -1158,6 +1165,7 @@ __mlu_func__ void MLUKernelCPivotSwap(
int temp = dipiv[cur_row];
dipiv[cur_row] = dipiv[max_row];
dipiv[max_row] = temp;
pivot[cur_row] = max_row + 1 + J + gbstep;

int temp2 = dipiv2[cur_row];
dipiv2[cur_row] = dipiv2[max_row];
Expand Down Expand Up @@ -1212,8 +1220,8 @@ __mlu_global__ void MLUKernelScal_ger_pivot(int batch, int M_size, int N_size,
int ib, int J, int m, int n,
int step, float *dA, int lda,
int stride_a, float *workspace,
int *dipiv, int *dipiv2, int *info,
int gbstep, int mode) {
int *dipiv, int *dipiv2, int *pivot,
int *info, int gbstep, int mode) {
int id, batch_id, tx, taskdim;
if (batch > 1) {
id = taskId;
Expand Down Expand Up @@ -1319,7 +1327,7 @@ __mlu_global__ void MLUKernelScal_ger_pivot(int batch, int M_size, int N_size,
MLUKernelKSwap(m - gbj, ib - STEP, STEP, batch, M_size, N_size, A, lda,
stride_a, orig, 1, m_e, k_max, k_max_idx, k_max_gbidx,
k_max_arr, tx, taskdim, batch_id, mp, mp_, dipiv, dipiv2,
info, gbstep);
pivot, info, J, gbstep);

__memcpy(shared_y, A + STEP + STEP * lda, (ib - STEP) * sizeof(float),
GDRAM2NRAM);
Expand Down Expand Up @@ -1419,14 +1427,15 @@ __mlu_global__ void MLUKernelScal_ger_pivot(int batch, int M_size, int N_size,
A, lda, stride_a, workspace, L1, ld_L1, orig, n,
len_extra, m_e, max_arr, max_arr_vec, max_idx,
max_gbidx, shared_y, tx, taskdim, batch_id,
m_per_core, mp_, dipiv, dipiv2, info, gbstep);
m_per_core, mp_, dipiv, dipiv2, pivot, info, J,
gbstep);
} else {
// Handle pivoting and swapping for smaller-scale matrices
MLUKernelPivotSwap(m - gbj, ib - STEP, STEP, batch, M_size, N_size,
A, lda, stride_a, workspace, L1, ld_L1, orig, n,
len_extra, m_e, max_arr, max_arr_vec, max_idx,
shared_y, tx, taskdim, batch_id, m_per_core, mp_,
dipiv, dipiv2, info, gbstep);
dipiv, dipiv2, pivot, info, J, gbstep);
}

MLUSubKernelScal_ger_pivot(m - gbj, ib - STEP, STEP, J, taskdim,
Expand Down Expand Up @@ -1539,7 +1548,7 @@ __mlu_func__ void MLUSubKernelCcal_ger_pivot(
__mlu_global__ void MLUKernelCcal_ger_pivot(
int batch, int M_size, int N_size, int ib, int J, int m, int n, int step,
float *d_rA, float *d_iA, int lda, int stride_a, float *workspace,
int *dipiv, int *dipiv2, int *info, int gbstep, int mode) {
int *dipiv, int *dipiv2, int *pivot, int *info, int gbstep, int mode) {
int id, batch_id, tx, taskdim;
if (batch > 1) {
id = taskId;
Expand Down Expand Up @@ -1673,7 +1682,7 @@ __mlu_global__ void MLUKernelCcal_ger_pivot(
MLUKernelCKSwap(m - gbj, ib - STEP, STEP, batch, M_size, N_size, rA, iA,
lda, stride_a, r_orig, i_orig, 1, m_e, k_max, k_max_idx,
k_max_gbidx, r_k_max_arr, i_k_max_arr, tx, taskdim,
batch_id, mp, mp_, dipiv, dipiv2, info, gbstep);
batch_id, mp, mp_, dipiv, dipiv2, pivot, info, J, gbstep);

__memcpy(r_shared_y, rA + STEP + STEP * lda, (ib - STEP) * sizeof(float),
GDRAM2NRAM);
Expand Down Expand Up @@ -1803,18 +1812,19 @@ __mlu_global__ void MLUKernelCcal_ger_pivot(
if (gbj < M_size) {
if (m > MAX_M_SIZE_COMPLEX_PIVOT * TaskUnion1) {
// Handle pivoting and swapping for large-scale complex matrices
MLUKernelCPivotSwap2(
m - gbj, ib - STEP, STEP, batch, M_size, N_size, r_L1, i_L1,
ld_L1, r_orig, i_orig, n, len_extra, m_e, max_arr, r_max_arr_vec,
i_max_arr_vec, max_idx, max_gbidx, r_shared_y, i_shared_y, tx,
taskdim, batch_id, m_per_core, mp_, dipiv, dipiv2, info, gbstep);
MLUKernelCPivotSwap2(m - gbj, ib - STEP, STEP, batch, M_size, N_size,
r_L1, i_L1, ld_L1, r_orig, i_orig, n, len_extra,
m_e, max_arr, r_max_arr_vec, i_max_arr_vec,
max_idx, max_gbidx, r_shared_y, i_shared_y, tx,
taskdim, batch_id, m_per_core, mp_, dipiv,
dipiv2, pivot, info, J, gbstep);
} else {
// Handle pivoting and swapping for smaller-scale complex matrices
MLUKernelCPivotSwap(
m - gbj, ib - STEP, STEP, batch, M_size, N_size, r_L1, i_L1,
ld_L1, r_orig, i_orig, n, len_extra, m_e, max_arr, r_max_arr_vec,
i_max_arr_vec, max_idx, r_shared_y, i_shared_y, tx, taskdim,
batch_id, m_per_core, mp_, dipiv, dipiv2, info, gbstep);
batch_id, m_per_core, mp_, dipiv, dipiv2, pivot, info, J, gbstep);
}

MLUSubKernelCcal_ger_pivot(
Expand Down Expand Up @@ -1847,15 +1857,15 @@ mluOpStatus_t MLUOP_WIN_API KernelScal_ger(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
mluOpDataType_t d_type, int batch, int M_size, int N_size, int ib, int J,
int m, int n, int step, float *dA, int lda, int stride_a, float *workspace,
int *dipiv, int *dipiv2, int *info, int gbstep, int mode) {
int *dipiv, int *dipiv2, int *pivot, int *info, int gbstep, int mode) {
if (mode == 0) {
KERNEL_CHECK(MLUKernelScal_ger<<<k_dim, k_type, queue>>>(
batch, M_size, N_size, ib, J, m, n, step, dA, lda, stride_a, dipiv,
info, gbstep, mode));
} else {
KERNEL_CHECK(MLUKernelScal_ger_pivot<<<k_dim, k_type, queue>>>(
batch, M_size, N_size, ib, J, m, n, step, dA, lda, stride_a, workspace,
dipiv, dipiv2, info, gbstep, mode));
dipiv, dipiv2, pivot, info, gbstep, mode));
}
return MLUOP_STATUS_SUCCESS;
}
Expand All @@ -1865,15 +1875,15 @@ KernelCcal_ger(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
mluOpDataType_t d_type, int batch, int M_size, int N_size,
int ib, int J, int m, int n, int step, float *d_rA, float *d_iA,
int lda, int stride_a, float *workspace, int *dipiv, int *dipiv2,
int *info, int gbstep, int mode) {
int *pivot, int *info, int gbstep, int mode) {
if (mode == 0) {
KERNEL_CHECK(MLUKernelCcal_ger<<<k_dim, k_type, queue>>>(
batch, M_size, N_size, ib, J, m, n, step, d_rA, d_iA, lda, stride_a,
dipiv, info, gbstep, mode));
} else {
KERNEL_CHECK(MLUKernelCcal_ger_pivot<<<k_dim, k_type, queue>>>(
batch, M_size, N_size, ib, J, m, n, step, d_rA, d_iA, lda, stride_a,
workspace, dipiv, dipiv2, info, gbstep, mode));
workspace, dipiv, dipiv2, pivot, info, gbstep, mode));
}
return MLUOP_STATUS_SUCCESS;
}
23 changes: 13 additions & 10 deletions kernels/xgetrf/xgetrf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,20 +75,22 @@ mluOpStatus_t MLUOP_WIN_API mluOpGetXgetrfWorkspaceSize(
return MLUOP_STATUS_SUCCESS;
}

mluOpStatus_t MLUOP_WIN_API mluOpXgetrf(
mluOpHandle_t handle, const mluOpTensorDescriptor_t x_desc, const void *x,
const mluOpTensorDescriptor_t y_desc, void *y, void *workspace,
const mluOpTensorDescriptor_t dipiv_desc, int *dipiv, int *info, int mode) {
mluOpStatus_t MLUOP_WIN_API
mluOpXgetrf(mluOpHandle_t handle, const mluOpTensorDescriptor_t x_desc,
const void *x, const mluOpTensorDescriptor_t y_desc, void *y,
void *workspace, const mluOpTensorDescriptor_t pivots_desc,
int *pivots, int *info, int mode) {
/* parameter check*/
size_t m, n;

int batch;
mluOpDataType_t dtype = x_desc->dtype;
PARAM_CHECK("mluOpXgetrf", x_desc != NULL);
PARAM_CHECK("mluOpXgetrf", y_desc != NULL);
PARAM_CHECK("mluOpXgetrf", dipiv_desc != NULL);
PARAM_CHECK("mluOpXgetrf", pivots_desc != NULL);
PARAM_CHECK("mluOpXgetrf", x != NULL);
PARAM_CHECK("mluOpXgetrf", y != NULL);
PARAM_CHECK("mluOpXgetrf", dipiv != NULL);
PARAM_CHECK("mluOpXgetrf", pivots != NULL);
PARAM_CHECK("mluOpXgetrf",
x_desc->dim == 2 || x_desc->dim == 3 || x_desc->dim == 4);

Expand Down Expand Up @@ -128,6 +130,7 @@ mluOpStatus_t MLUOP_WIN_API mluOpXgetrf(
mluOpGetQueue(handle, &(handle->queue));

size_t ldda = n;
size_t minmn = MIN(m, n);
if (dtype == MLUOP_DTYPE_COMPLEX_FLOAT) {
transpose(handle, MLUOP_DTYPE_COMPLEX_FLOAT, batch, m, n, (float *)x,
(float *)y, handle->queue);
Expand All @@ -140,22 +143,22 @@ mluOpStatus_t MLUOP_WIN_API mluOpXgetrf(
if (mode == 0) {
if (dtype == MLUOP_DTYPE_COMPLEX_FLOAT)
xgetrf_mlu(handle, dtype, batch, m, n, (float *)y, (float *)y,
(float *)y + batch * m * ldda, ldda, dipiv, info, mode,
(float *)y + batch * m * ldda, ldda, pivots, info, mode,
workspace);
else if (dtype == MLUOP_DTYPE_FLOAT)
xgetrf_mlu(handle, dtype, batch, m, n, (float *)y, NULL, NULL, ldda,
dipiv, info, mode, workspace);
pivots, info, mode, workspace);
} else {
if (dtype == MLUOP_DTYPE_COMPLEX_FLOAT) {
for (int b = 0; b < batch; b++) {
xgetrf_mlu(handle, dtype, 1, m, n, NULL, (float *)y + b * m * n,
(float *)y + batch * m * ldda + b * m * n, ldda,
dipiv + b * m, info, mode, workspace);
pivots + b * minmn, info, mode, workspace);
}
} else if (dtype == MLUOP_DTYPE_FLOAT) {
for (int b = 0; b < batch; b++) {
xgetrf_mlu(handle, dtype, 1, m, n, (float *)y + b * m * n, NULL, NULL,
ldda, dipiv + b * m, info, mode, workspace);
ldda, pivots + b * minmn, info, mode, workspace);
}
}
}
Expand Down
Loading

0 comments on commit 7cc42eb

Please sign in to comment.