Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] add new operator LU factorization #1019

Open
wants to merge 20 commits into
base: master
Choose a base branch
from

Conversation

Chuancysun
Copy link
Collaborator

@Chuancysun Chuancysun commented Apr 30, 2024

Thanks for your contribution and we appreciate it a lot. 🚀🚀

1. Motivation

add floating point operator LU factorization

2. Modification

add implementation of floating point LU factorization

3. Test Report

3.1 Modification Details

3.1.1 Accuracy Acceptance Standard

For static threshold standard details, see: MLU-OPS™ Accuracy Acceptance Standard.

  • static threshold
    • diff1
      • float32 mlu diff1 <= 1e-5
      • [*] float32 mlu diff1 <= 3e-3
      • float16 mlu diff1 <= 3e-3
    • diff2
      • float32 mlu diff2 <= 1e-5
      • [* ] float32 mlu diff2 <= 3e-3
      • float16 mlu diff2 <= 3e-3
    • diff3
      • mlu diff3 == 0
      • mlu diff3_1 == 0
      • mlu diff3_2 == 0
  • dynamic threshold
    • diff1: mlu diff1 <= max(baseline diff1 * 10, static threshold)
    • diff2: mlu diff2 <= max(baseline diff2 * 10, static threshold)
    • diff3: mlu diff3 <= max(baseline diff3 * 10, static threshold)
      • float32, threshold = 1e-5
      • float16, threshold = 1e-3

3.1.2 Operator Scheme checklist

  • Supported hardware
    • [* ] MLU370
    • MLU590
  • Job types
    • BLOCK
    • UNION1
    • UNION2
    • UNION4
    • [* ] The operator will dynamically select the most suitable task type, for example, UNION8

3.2 Accuracy Test

3.2.1 Accuracy Test

If you have checked the following items, please tick the relevant box.

  • Data type test (e.g. float32/int8)
  • Multi-dimensional tensor test
  • Layout test
  • Different size/integer remainder end segment/alignment misalignment test
  • Zero dimensional tensor test/zero element test
  • stability test
  • Multiple platform test
  • Gen_case module test, see: Gencase-User-Guide-zh
  • Nan/INF tests
  • Bug fix tests
  • For memory leak check details, see: GTest-User-Guide-zh
  • For code coverage check details, see: GTest-User-Guide-zh
  • For I/O calculation efficiency check details, see: MLU-OPS™-Performance-Acceptance-Standard

3.3 Performance Test

Platform:MLU370

----------- case0 -----------
case0
[Op name ]: sgetrf
[Shape ]: input.shape=[256,256], output.shape=[256,256]
[Data type] ]: float32
[MLU Hardware Time ]: 6460 (us)
[MLU Interface Time ]: 15336.7 (us)
[MLU IO Efficiency ]: 0.00026419
[MLU Compute Efficiency ]: 9.90712e-06
[MLU Workspace Size ]: -1 (Bytes)
[MLU Kernel Name(s) ]: {}
[MLU TheoryOps ]: 65536 (Ops)
[MLU TheoryIOs ]: 524288 (Bytes)
[MLU ComputeForce ]: 1.024e+12 (op/s)
[MLU IoBandWidth ]: 307.2 (GB/s)
[GPU Hardware Time ]: -1 (us)
[GPU IO Efficiency ]: -1
[GPU Compute Efficiency ]: -1
[GPU Workspace Size ]: -1 (Bytes)
[Diffs]:
[output]
DIFF1: 1.798500e-04
DIFF2: 7.016698e-04
[^ OK ] ../../test/mlu_op_gtest/pb_gtest/src/zoo/sgetrf/test_case/case0.prototxt
[ OK ] sgetrf/TestSuite.mluOp/0 (36 ms)
[----------] 1 test from sgetrf/TestSuite (36 ms total)

[----------] Global test environment tear-down
[ SUMMARY ] Total 1 cases of 1 op(s).
ALL PASSED.
[==========] 1 test case from 1 test suite ran. (3727 ms total)
[ PASSED ] 1 test case.

3.4 Summary Analysis

Please give a brief overview here, if you want to note and summarize the content.

* permit persons to whom the Software is furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

补充 proto,docs/bangc-docs/user_guide/9_operators/index.rst 算子说明
可参考 https://github.com/Cambricon/mlu-ops/pull/662/files#diff-7f0a558d8f985a4ebd89cd6674a4bf1a91549ddcc6e708a897f351cb2006f0e8

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已添加算子实现方案

mlu_op.h Outdated Show resolved Hide resolved
nb = get_sgetrf_native_nb(m, n);

float *workspace;
cnrtMalloc((void **)&workspace, nb * nb * sizeof(float));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

一般workspace是用户自己传进来的,算子是提供getWorkspace的接口来让用户去分配对应的空间,可以参考https://github.com/Cambricon/mlu-ops/blob/master/mlu_op.h#L3716

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

mlu_op.h Outdated
typedef enum
{
MLUOP_STATUS_SUCCESS = 0, /*!< The operation is successfully completed. */
MLUOP_STATUS_NOT_INITIALIZED = 1,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议头文件回退其他修改,只增加自己这次提交的部分。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

mlu_op.h Outdated Show resolved Hide resolved
mlu_op.h Outdated Show resolved Hide resolved
mlu_op.h Outdated Show resolved Hide resolved
mlu_op.h Outdated Show resolved Hide resolved
mlu_op.h Outdated Show resolved Hide resolved
mlu_op.h Outdated Show resolved Hide resolved
mlu_op.h Outdated Show resolved Hide resolved
kernels/sgetrf2/sgetrf2.h Outdated Show resolved Hide resolved
{
temp += mul_result[k];
}
temp = temp * -1.0 * diag_element;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里标量计算太慢了,不能向量化调用bangc接口吗

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是将一个向量中的所有元素中的值相加,文档没有找到能实现此功能的bang函数

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reduce_sum或者sumpool可以试试

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

查看文档后发现这两个函数的功能跟算子的逻辑不符

mluOpDataType_t dtype = x_desc->dtype;
PARAM_CHECK("mluOpSgetrf2", x_desc != NULL);

PARAM_CHECK("mluOpSgetrf2",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参数检查缺少0元素和large tensor 处理,建议参考下

mlu-ops/kernels/abs/abs.cpp

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

mlu_op.h Outdated
*
* @par Scale Limitation
* - The dimension of input tensor must be either 2, 3 or 4.
* Considering the size of the GDRAM, the space occupied by the input matrix should not exceed 7GB.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果只是gdram空间导致的问题,不是算子本身的限制,可以删除
换句话说,如果gdram空间无限制的话,算子本身是不是没有规模限制?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

PARAM_CHECK("mluOpSgetrf2", x_desc->dims[1] > 0);

/* sgetrf参数转换*/
int m, n, batch = 1;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里用了int,对于单个维度超过int max,这里会出现overflow异常

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单维度规模如果超过int max会在调用算子时就会报空间不足的错误

mlu_op.h Outdated
*
* @par Data Layout
* - The supported combinations of data types are shown below:
* - size_t(size)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里data layout也标记成None吧,workspace接口不做额外说明

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

mlu_op.h Outdated
* - size_t(size)
*
* @par Scale Limitation
* - The dimension of input tensor must be either 2, 3 or 4.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的限制也标记成None吧

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

*
* @par Data Type
* - The supported combinations of data types are shown below:
* - float(x) - float(y)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不也支持complex吗?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

mlu_op.h Outdated
* Considering the size of the GDRAM, the space occupied by the input matrix should not exceed 7GB.
*
* @par API Dependency
* - The allocated extra workspace should be passed to ::mluOpSgetrf2 to perform the LU operation.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before calling this function to perform ::mluOpRoiPointPool3d, you need to

  • get the size of workspace by ::mluOpGetRoiPointPool3dWorkspaceSize.
    这里的API依赖描述类似这样,这里和workspace接口的一样了!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

mlu_op.h Outdated Show resolved Hide resolved
mlu_op.h Outdated Show resolved Hide resolved
test/mlu_op_gtest/pb_gtest/src/zoo/sgetrf2/sgetrf.cpp Outdated Show resolved Hide resolved
}
}
}
k_dim.x = dim_x;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是这部分内容来抽取到policyFunc,参考下面的实现
void policyFuncBallQuery(const mluOpHandle_t &handle,
const mluOpTensorDescriptor_t &desc, cnrtDim3_t *k_dim,
cnrtFunctionType_t *k_type) {
size_t cluster_num = mluop::runtime::getClusterLimitCapability(handle);
VLOG(5) << "In current device, cluster_num:" << cluster_num;
size_t core_in_cluster = handle->core_num_per_cluster;
VLOG(5) << "In current device, core_in_cluster:" << core_in_cluster;

size_t total_data_num = desc->total_element_num;

// On a core, a lot of new_xyz data element can be stored; but only one data
// element can be processed at a time. So a cluster can only process four data
// element.
size_t needed_cluster_num =
(total_data_num + core_in_cluster - 1) / core_in_cluster;
*k_type = cnrtFuncTypeUnion1;
k_dim->x = core_in_cluster;
k_dim->y =
needed_cluster_num > cluster_num ? cluster_num : needed_cluster_num;
k_dim->z = 1;
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

@@ -793,3 +793,10 @@ mluOpLgamma

- ``x`` 为输入张量。

.. Sgetrf2::
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.. Sgetrf2::>>>.. _sgetrf2:

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

@@ -793,3 +793,10 @@ mluOpLgamma

- ``x`` 为输入张量。

.. Sgetrf2::
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个新增内容在手册的更新历史章节也补充下哈

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个能具体解释一下吗


mluOpSgetrf2
---------------
执行 LU 分解,将一个矩阵分解为一个下三角矩阵(L)和一个上三角矩阵(U),参数``mode``用来指定是否进行选主元操作。
Copy link
Collaborator

@AndyQiao0828 AndyQiao0828 Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参数mode用来》》参数 mode 用来

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是什么意思?

---------------
执行 LU 分解,将一个矩阵分解为一个下三角矩阵(L)和一个上三角矩阵(U),参数``mode``用来指定是否进行选主元操作。

该算子包含7个输入:handle 为操作句柄,x_desc 与 x 分别描述并提供输入矩阵的信息;两个输出:y_desc 与 y 分别描述并存储输出矩阵的信息;此外,还包含一个参数 mode,用于指定是否进行选主元,值为0表示选择非主元模式,ipiv表示置换矩阵,以及一个 workspace 用于临时存储计算过程中的数据。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有公式的话,需要补充下公式。

Suggested change
该算子包含7个输入:handle 为操作句柄,x_desc 与 x 分别描述并提供输入矩阵的信息;两个输出:y_desc 与 y 分别描述并存储输出矩阵的信息;此外,还包含一个参数 mode,用于指定是否进行选主元,值为0表示选择非主元模式,ipiv表示置换矩阵,以及一个 workspace 用于临时存储计算过程中的数据。
该算子包含7个输入:其中,``handle`` 为操作句柄,``x_desc````x`` 分别描述并提供输入矩阵的信息;两个输出:``y_desc````y`` 分别描述并存储输出矩阵的信息;此外,还包含一个参数 ``mode`` 用于指定是否进行选主元操作,值为0时,表示选择非主元模式,``ipiv`` 表示置换矩阵,以及 ``workspace`` 用于临时存储计算过程中的数据。

》》7个输入中提到了三个(handle,x_desc和x)?mode、ipiv和workspace也是输入吗?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

@@ -14523,6 +14523,132 @@ mluOpLgamma(mluOpHandle_t handle,
const mluOpTensorDescriptor_t y_desc,
void *y);

/*!
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这几个接口属于哪个group呢?参照其他补充下哈:
// Group:Lgamma

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

mlu_op.h Outdated
Comment on lines 14598 to 14599
* INTEGER array, dimension (m);
* The pivot indices; row i of the matrix was interchanged with row IPIV(i)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* INTEGER array, dimension (m);
* The pivot indices; row i of the matrix was interchanged with row IPIV(i)
* An integer array, dimension (m);
* The pivot indices; row i of the matrix was interchanged with row IPIV(i).

需要优化,说明ipiv的含义。写成完整的句子:这里dimension跟pivot indices是啥关系。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

* INTEGER array, dimension (m);
* The pivot indices; row i of the matrix was interchanged with row IPIV(i)
*
* @param[out] info
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在参数解释最前面先说明infor的含义

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

mlu_op.h Outdated
* to solve a system of equations.
*
* @param[in] mode
* option to perform operation with pivoting/no pivoting versions
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* option to perform operation with pivoting/no pivoting versions
* Option to perform the operation with pivoting/no pivoting versions

都有什么mode呢?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已在代码中修改

mlu_op.h Outdated
* - The data layout of y should be MLUOP_LAYOUT_ARRAY.
*
* @par Scale Limitation
* - The dimension of input tensor must be either 2, 3 or 4.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dimension of input tensor must be 2, 3 or 4.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

mlu_op.h Outdated
* @param[out] info
* - = 0: successful exit
* - < 0: if INFO = -i, the i-th argument had an illegal value
* or another error occured, such as memory allocation failed.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

occurred

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是什么意思呢?

@@ -0,0 +1,41 @@
op_name: "test_sgetrf2"
input {
id: "input"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在590上测试精度超出阈值,且generator产生的case,590精度也会超出阈值

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在590上测试精度超出阈值,且generator产生的case,590精度也会超出阈值

能否给出矩阵输入规模等信息及输出?

Copy link
Collaborator

@shunshen93 shunshen93 Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[22,29,206],DTYPE_COMPLEX_FLOAT的case会超时

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[22,29,206],DTYPE_COMPLEX_FLOAT的case会超时

请问[22,29,206]对应矩阵规模的[batch,m,n]/[m,n,batch]吗?我分别测试了这两种情况都没有问题

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是跑的complex类型吗?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已复现并修改代码

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

op_name: "test_xgetrf"
input {
  id: "input"
  shape: {
    dims: 1
    dims: 22
    dims: 129
    dims: 206
 
  }
  layout: LAYOUT_ARRAY
  dtype: DTYPE_COMPLEX_FLOAT
  random_data: {
    seed: 25
    upper_bound: 10.0
    lower_bound: -10.0
    distribution: UNIFORM
  }
}
output {
  id: "output"
  shape: {
    dims: 1
    dims: 22
    dims: 129
    dims: 206
   
  }
  layout: LAYOUT_ARRAY
  dtype: DTYPE_COMPLEX_FLOAT
}
output {
  id: "output2"
  shape {
    dims: 1
    dims: 22
    dims: 129
  }
  layout: LAYOUT_ARRAY
  dtype: DTYPE_INT32
}
xgetrf_param{
    mode: 0
}
test_param: {
  error_func: DIFF1
  error_func: DIFF2
  error_threshold: 0.003
  error_threshold: 0.003
  baseline_device: CPU
}

测试下这个case

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个case已复现并修复,请问mlu端怎么关闭nan/inf的检查呢?包含nan/inf的case的测试标准是什么呢?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mlu需要支持nan/inf的检查。nan/inf的结果需要和竞品一致。如果结果不一样,需要解释原因,比如算法不一致或者其它原因

kernels/sgetrf2/scal_ger_union1.mlu Outdated Show resolved Hide resolved
kernels/sgetrf2/scal_ger_union1.mlu Outdated Show resolved Hide resolved
@ArtIntAI
Copy link
Collaborator

另外针对用户感知到的一些tensor信息,如下所列,支持的做下测试,不支持的可以参考下其他算子做好参数拦截
1.large tensor(tensor单个维度超过2G num, tensor的所有维度乘积超过2G num)
2. inplace,输入和输出tensor地址一致
3. stride,如果不支持做好参数检查报错
4. 广播,如果不支持做好参数检查报错
5. 输入和输出包含nan/inf时精度是否和GPU精度对齐
6. 输入tensor是0元素,某个维度是0


if (dtype == MLUOP_DTYPE_COMPLEX_FLOAT) {
if (batch > 1) {
k_type = CNRT_FUNC_TYPE_UNION8;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

板卡上不一定有这个类型,建议参考这里进行设置:

*k_type = mluop::runtime::getJobLimitCapabilityCnrtFuncType(handle);

int task_type = mluop::runtime::getJobLimitCapability(handle);

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

transpose(handle, MLUOP_DTYPE_COMPLEX_FLOAT, batch, m, n, (float *)x,
(float *)y, handle->queue);
} else {
cnrtMemcpy((float *)y, (float *)x, batch * m * n * sizeof(float),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不建议使用cnrtMemcpy和cnrtMemset,cnrtQueueSync,会对上层使用mlu_graph有问题
建议cnrtMemcpy使用片上的__memcpy来替换
cnrtMemset使用片上设置数据来替换
cnrtQueueSync可以去掉,对于同一个queue来说,queue内的kernel调用(使用<<<>>>)是串行的

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

output {
id: "output2"
shape {
dims: 1024
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

output2的shape为[1,1,1024]

cpu_fp32_output_[1][i] = pivots[i];
}

if (tensor_desc_[0].tensor->dtype == MLUOP_DTYPE_FLOAT) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

添加cpu测试逻辑

}

void XgetrfExecutor::cpuCompute() {
auto count = parser_->input(0)->shape_count;
Copy link
Collaborator

@shunshen93 shunshen93 Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lu_facto如下:
P, L, U = torch.linalg.lu(A)
LU, P = torch.linalg.lu_factor(A)
其中L, U 是LU的下三角和上三角

参考svd的测试方法,需要测试:

  1. 验证还原性: 验证MLU的P * L * U 与 GPU 的P * L * U 的误差满足动态阈值(L, U由LU得到)
  2. 验证 MLU 的 LU 矩阵与 baseline 的 LU 矩阵误差满足动态阈值
  3. 验证MLU的 P 矩阵与baseline的 P 矩阵误差满足动态阈值

所以output需要存储P, LU, P*LU共3个矩阵
其中,动态阈值对比diff1,diff2,diff4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants