diff --git a/docs/design_docs/fft/fft.md b/docs/design_docs/fft/fft.md
new file mode 100644
index 000000000..6d2e242ee
--- /dev/null
+++ b/docs/design_docs/fft/fft.md
@@ -0,0 +1,648 @@
+# FFT1d 算子开发设计方案
+
+
+* #### 文档基本信息
+| 算子名称 |
+| ------------------------------------------ |
+| FFT(RFFT1d, IRFFT1d, FFT1d, IFFT1d) |
+
+* #### 内容描述
+
+本文档为`FFT`算子的设计文档,包括需求分析、接口设计、方案设计、性能优化记录和方案实施部分。
+
+* #### 算子需求checklist
+
+算子需求提出者需要`提供`的信息如下:
+
+- 算子接口描述:实现RFFT1d, IRFFT1d, FFT1d, IFFT1d的DFT算法
+- 功能描述:实现RFFT1d, IRFFT1d, FFT1d, IFFT1d的傅里叶变换算法
+- 是否需要支持原位:需要*
+- 是否需要支持stride机制:需要
+
+*当支持原位的时候,需要保证input的sequence长度补齐到output的sequence的长度,否则会出现未定义的结果。
+
+算子需求提出者需要`check`的部分如下:
+
+- 1.1 算子需求分析
+- 1.2 算子功能和应用场景描述
+- 1.3 算子输入输出参数要求
+- 1.4 算子限制
+- 1.5 验收标准
+- 2.2 接口设计
+- 3.5 测试用例(需求提出者check算子需求表中所给规模是否列出)
+
+## 1 需求分析
+
+### 1.1 算子需求分析
+
+| 算子功能简介| 对数列进行傅里叶变换操作,详细描述在1.2中进行说明 |
+|-------------|--------------------------------------------------------------|
+| 需求来源 | PyTorch/Tensorflow |
+| 应用网络 | Conformer |
+| 输入数据类型| half, float, complex_half, complex_float |
+| 输入Shape | [batches, array] |
+| 输入Layout | ARRAY |
+| 输出数据类型| half, float, complex32, complex64 |
+| 输出Shape | [batches, array] |
+| 输出Layout | ARRAY |
+| 模式(可选) | |
+| 是否含有dim/axis等类似语义的参数且该参数支持负数/其他特殊处理 | 通过stride语义来支持dim |
+| 是否含有labels/index等类似语义的参数且该参数支持负数/界外情况/其他特殊处理 | 无 |
+| 是否需要支持原位 | 是 |
+| 是否需要支持stride机制 | 是 |
+| 是否需要支持广播 | 否 |
+| 0元素检查是否直接返回 | 否(array=0时不支持,支持以下两种情况:1.batch等于0;2.输入或者输出的dim等于0,但是补齐到array) |
+| 其他特殊需求(在线量化,融合,转数提前等,可选)| 无 |
+| 本次开发优先支持的规模/模式| 支持rfft,irfft,fft,ifft |
+
+### 1.2 算子功能和应用场景描述
+
+RFFT: 对一个长度为N的实数数列进行傅里叶变换,输出长度为 N/2+1的复数数列。因为后半部分结果和前半部分是复共轭的关系,所以该算子仅输出前半部分结果。
+
+IRFFT: RFFT的反向,对一个长度为N/2+1的复数数列进行傅里叶变换,输出长度为 N的实数数列。因为后半部分输入和前半部分是复共轭的关系,所以该算子仅提供前半部分输入,需自行补齐。
+
+FFT: 对一个长度为N的复数数列进行傅里叶变换,输出长度为 N的复数数列。
+
+IFFT: FFT的反向,对一个长度为N的复数数列进行傅里叶变换,输出长度为 N的复数数列。
+
+公式:
+```math
+X[k]=\Sigma_n W_N^{kn}x(k) \\
+
+W_N^{kn}=e^{sign*ikn\frac{2\pi}{N}},
+sign= \left \{
+\begin{array}{ll}
+-1, fft\ or \ rfft \\
+1, ifft\ or \ irfft
+\end{array}
+\right.
+```
+备注:
+
+1、需要说明对nan/inf的特殊处理,输入存在nan/inf时,输出为按照IEEE754标准根据FFT计算公式得到,其中任何数乘nan为nan,0乘inf为nan,非0乘正负inf根据符号为正负inf,任何数加nan为nan,任何非nan数加正负inf根据符号为正负inf。
+在部分情况下,pytorch的CPU结果和GPU结果并不是严格按照IEEE754标准和FFT计算公式得到,因此无法与MLU计算结果对齐。例如,某个batch的输入数据包含一个inf,且inf所在位置对应的系数不为0,此时GPU计算结果全为nan,但是MLU计算结果全为inf。
+再例如,某个batch的输入部分包含一个nan,且nan所在位置对应的系数为0,此时GPU计算结果为0,但是MLU计算结果仍为nan。
+
+2、N为大质数的情况,可能来不及支持。
+
+
+
+### 1.3 算子输入输出参数要求
+
+| 参数 | 语义 | 类型(输入/输出) | 支持类型 | 物理布局 | 规模限制 |
+| ----------- | ---- | ----------------- | ----------- | -------- | ---------------- |
+| handle | | 输入 | | / | 无 |
+| fft1d_plan | | 输入 | | / | 暂时不支持大质数 |
+| input | | 输入 | half, float, complex_half, complex_float | ARRAY | 无 |
+| output | | 输出 | half, float, complex_half, complex_float | ARRAY | 无 |
+
+## 2 算子接口设计
+
+### 2.1 参考接口
+
+- TensorFlow
+
+```python
+#Computes the 1-dimensional discrete Fourier Transform of a real-valued signal over the inner-most dimension of input.
+# Since the DFT of a real signal is Hermitian-symmetric, RFFT only returns the `fft_length / 2 + 1` unique components of the FFT: The zero-frequency term, followed by the `fft_length / 2` positive-frequency terms.
+# Along the axis RFFT is computed on, if `fft_length` is smaller than the corresponding dimension of input, the dimension is cropped. If it is larger, the dimension is padded with zeros
+tensorflow.signal.rfft(input_tensor, fft_length=None, name=None)
+
+tensorflow.signal.irfft(input_tensor, fft_length=None, name=None)
+
+tensorflow.signal.fft(input_tensor, name=None)
+
+tensorflow.signal.ifft(input_tensor, name=None)
+```
+
+- PyTorch
+```python
+#Computes the one dimensional Fourier transform of real-valued input.
+# The FFT of a real signal is Hermitian-symmetric, X[i] = conj(X[-i]) so the ouput contains only the positive frequencies below the Nyquist frequency. To compute the full output, use fft()
+# input(Tensor) - the real input tensor
+# s(Tuple[int], optional) - Signal length. If given, the input will either be zero-padded or trimmed to this length before computing the real FFT
+# dim(Tuple[int], optional) - The dimension along which to take the one dimensional real FFT.
+# norm(str, optional) - Normalization mode. For the forward transform(rfft()), these correspond to:
+# "forward" - normalized by `1\n`
+# "backward" - no normalization
+# "ortho" - normalize by `1/sqrt(n)` (making the FFT orthonormal)
+# Calling the backward transform (irfft()) with the same normalization mode will apply an overal normalization of 1/n between the two transforms. This is required to make irfft() the exact inverse
+# Default is "backward" (no normalization).
+torch.fft.rfftn(input, s=None, dim=None, norm=None, *, out=None)->Tensor
+
+torch.fft.irfftn(input, s=None, dim=None, norm=None, *, out=None)->Tensor
+
+torch.fft.fftn(input, s=None, dim=None, norm=None, *, out=None)->Tensor
+
+torch.fft.ifftn(input, s=None, dim=None, norm=None, *, out=None)->Tensor
+```
+
+- Mathematica
+
+```mathematica
+Fourier[list]
+(*finds the discrete Fourier transform of a list of complex numbers*)
+Fourier[list, {p1, p2, ...}]
+(*returns the specified positions of the discrete Fourier transform*)
+```
+
+- Matlab
+
+```matlab
+# computes the discrete Fourier transform of X using a fast Fourier transform algorithm
+# if X is a vector, then fft(X) returns the Fourier transform of the vector
+Y = fft(X)
+# returns the n-point DFT. If no value is specified, Y is the same size as X.
+# If X is a vector, and the length of X is greater than n, then X is truncated to length n.
+Y = fft(X, n)
+# returns the Fourier transform along the dimension dim. For example, if X is a matrix, then fft(X, n, 2) returns the n-point Fourier transform of each row
+Y = fft(X, n, dim)
+```
+
+- OTFFT
+
+```c
+OTFFT::RFFT::fwd(const_double_vector x, complex_vector y);
+```
+
+- FFTW
+
+```c
+fftw_plan fftw_plan_dft_r2c_1d(int n0, double *in, fftw_complex *out, unsigned flags);
+
+fftw_plan fftw_plan_many_dft_r2c(int rank, const int *n, int howmany, double *in, const int *inembed, int istride, int idist, fftw_complex *out, const int *onembed, int ostride, int odist, unsigned flags);
+
+void fftw_execute_dft_r2c(const fftw_plan p, double *in, fftw_complex *out);
+
+void fftw_execute_dft_c2r(const fftw_plan p, fftw_complex *in, double *out);
+
+void fftw_execute_dft(const fftw_plan p, fftw_complex *in, fftw_complex *out);
+
+void fftw_destroy_plan(fftw_plan plan);
+```
+
+
+
+### 2.2 接口设计
+
+```c++
+/*! The descriptor of FFT (Fast Fourier Transform) operator that holds FFT information including
+ * the tensor descriptor of input tensor and output tensor, the rank of FFT, the FFT size on each
+ * dimension, the size of reserved space and the size of workspace.
+ *
+ * You need to call the ::mluOpCreateFFTPlan function to create a descriptor for the FFT operator, and call
+ * the ::mluOpMakeFFTPlanMany function to set the information of the FFT operator to the descriptor.
+ * Then, you need to allocate the reserved space and set the space to the fft descriptor by ::mluOpSetReserveArea.
+ * Also, you need to destroy the MluOp context at the end with the ::mluOpDestroyFFTPlan.
+ */
+typedef struct mluOpFFTStruct *mluOpFFTPlan_t;
+
+/*!
+ * @brief Creates a descriptor pointed by \b fft_plan for the FFT operator, and allocates memory
+ * for holding the information about the FFT operation. The information is defined in ::mluOpFFTPlan_t.
+ */
+mluOpStatus_t mluOpCreateFFTPlan(mluOpFFTPlan_t *fft_plan);
+
+/*!
+ * @brief Initializes the FFT descriptor pointed by \b fft_plan that is previously created
+ * with the ::mluOpCreateFFTPlan function, and sets the information about the
+ * tensor descriptors of input tensor and output tensor, the rank of FFT, and the FFT size on each
+ * dimension.
+ *
+ * This function also gets the size of MLU memory buffers for FFT execution, including \b reservespace_size and
+ * \b workspace_size. The size of extra workspace is based on the given information of the
+ * \b fft_plan.
+ */
+mluOpStatus_t mluOpMakeFFTPlanMany(mluOpHandle_t handle,
+ mluOpFFTPlan_t fft_plan,
+ const mluOpTensorDescriptor_t input_desc,
+ const mluOpTensorDescriptor_t output_desc,
+ const int rank,
+ const int n[],
+ size_t *reservespace_size,
+ size_t *workspace_size);
+/*!
+ * @brief Bond the reserve space to the \b fft_plan. The size of reserved space can be derived through ::mluOpMakeFFTPlanMany.
+ */
+mluOpStatus_t mluOpSetFFTReserveArea(mluOpHandle_t handle,
+ mluOpFFTPlan_t fft_plan,
+ void *reservespace);
+/*!
+ * @brief Executes any FFT. In case of complex-to-real and real-to-complex
+ * transforms \b direction parameter is ignored. This function stores the Fourier coefficients
+ * in the output array. If the address of input and output are the same, an in-place FFT
+ * is adopted.
+ */
+mluOpStatus_t mluOpExecFFT(mluOpHandle_t handle,
+ const mluOpFFTPlan_t fft_plan,
+ const void *input,
+ const float scale_factor,
+ void *workspace,
+ void *output,
+ int direction);
+
+/*!
+ * @brief Destroys a FFT plan \b fft_plan that is created with the
+ * ::mluOpCreateFFTPlan function.
+ */
+mluOpStatus_t mluOpDestroyFFTPlan(mluOpFFTPlan_t fft_plan);
+```
+
+框架使用场景,下面假设一个一维rfft,batch为2000,n=400的rfft:
+
+1. 建立fft描述符
+
+ ```c
+ mluOpFFTPlan_t fft_plan;
+ mluOpCreateFFTPlan(&fft_plan);
+ ```
+
+2. 给fft描述符设定参数,并获取reserve_size,workspace_size大小
+
+ ```c
+ mluOpTensorDescriptor_t input_desc, output_desc;
+ mluOpDataType_t input_data_type = MLUOP_DTYPE_FLOAT;
+ mluOpDataType_t output_data_type = MLUOP_DTYPE_COMPLEX_FLOAT;
+ mluOpDataType_t execution_dtype = MLUOP_DTYPE_FLOAT;
+ const int rank = 1;
+ const int batch = 2000;
+ const int n[rank] = {400};
+ const int ndim = rank + 1;
+ const int input_dim_size[ndim] = {batch, n[0]};
+ const int input_dim_stride[ndim] = {n[0], 1};
+
+ const int output_dim_size[ndim] = {batch, n[0] / 2 + 1};
+ const int output_dim_stride[ndim] = {n[0] / 2 + 1, 1};
+
+ mluOpCreateTensorDescriptor(&input_desc);
+ mluOpCreateTensorDescriptor(&output_desc);
+ mluOpSetTensorDescriptorEx(input_desc, MLUOP_LAYOUT_ARRAY, input_data_type, ndim, input_dim_size, input_dim_stride);
+ mluOpSetTensorDescriptorOnchipDataType(execution_dtype);
+ mluOpSetTensorDescriptorEx(output_desc, MLUOP_LAYOUT_ARRAY, output_data_type, ndim,
+ output_dim_size, output_dim_stride);
+ size_t reservespace_size;
+ size_t workspace_size;
+ mluOpMakeFFTPlanMany(handle, fft_plan, input_desc, output_desc, rank, n, &reservespace_size, &workspace_size);
+ mluOpDestroyTensorDescriptor(input_desc);
+ mluOpDestroyTensorDescriptor(output_desc);
+ ```
+
+3. 给plan绑定reservespace指针
+
+ ```c
+ void *reservespace;
+ cnrtMalloc(&reservespace, reservespace_size);
+ mluOpSetReserveArea(handle, fft_plan, reservespace);
+ ```
+
+4. 执行FFT,plan创建好以后可以执行多次
+
+ ```c
+ void *workspace;
+ cnrtMalloc(&workspace, workspace_size);
+ const float scale = 1.0;
+ mluOpStatus_t mluOpExecFFT(handle, fft_plan, input, scale, workspace, output, 0);
+ cnrtFree(workspace);
+ ```pull/902/files#diff-7274399dd2d36c9d582d793971e1ecb6a43564088f8c10c16d32d6297520bd5b
+
+5. 算子运行完以后释放plan,释放reservespace。
+
+ ```c
+ mluOpDestroyFFTPlan(fft_plan);
+ cnrtFree(reservespace);
+ ```
+
+## 3 实现方案设计
+
+参考文档:
+
+www.pikara.ne.jp/okojisan/otfft-en/
+
+https://users.ece.cmu.edu/~franzf/papers/fft-enc11.pdf
+
+https://zh.wikipedia.org/wiki/克罗内克积
+
+https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/tr-2008-62.pdf
+
+www.fftw.org
+
+https://docs.nvidia.com/cuda/cufft/index.html
+
+### 3.1 实现方案
+
+参照franz、cufft、OTFFT、和Microsoft的"Fast Computation of General Fourier Transforms on GPUs"的方案进行设计。
+
+针对2的整数次幂采用StockhamFFT算法进行计算,质因数分解采用Six-Step FFT,对于非上述两种情况采用Blustein z-chirp方法进行计算。对于WRAM上面能直接放下Fourier Matrix的情境,可以考虑直接采用O(n^2)的暴力方法进行计算。
+
+#### 3.1.1 Iterative FFT:Stockham FFT Algorithm
+
+用来处理2的整数次幂的部分,Stockham FFT属于迭代式Cooley-tukey公式的变种,可以由Cooley Tukey FFT变换得到。长度为n的Cooley-tukey计算公式如下:
+```math
+\text{DFT}_n=(\text{DFT}_k \otimes I_m) T_m^n(I_k \otimes \text{DFT}_m) L_k^n,n=km,\tag{1}
+```
+其中$`\otimes`$表示为克罗内克乘积,$`(\text{DFT}_k \otimes I_m)`$为向量并行矩阵;$`T_m^n`$为旋转因子,为对角矩阵;$`(I_k \otimes \text{DFT}_m)`$为块内并行矩阵;$`L_k^n`$形式为$`L_m^{mn}:={in+j \rightarrow jm+i, 0\le i<m, 0 \le j < n}`$ 的置换矩阵。下面展示了 $`k=4,m=2`$下的矩阵:
+```math
+I_4 \otimes \text{DFT}_2=\left[
+\begin{array}{cccccccc}
+1&1&&&&&&\\
+1&1&&&&&&\\
+&&1&1&&&&\\
+&&1&-1&&&&\\
+&&&&1&1&&\\
+&&&&1&-1&&\\
+&&&&&&1&1\\
+&&&&&&1&-1\\
+\end{array}
+\right]
+
+I_4 \otimes \text{DFT}_2=\left[
+\begin{array}{cccccccc}
+1&&&&1&&&\\
+&1&&&&1&&\\
+&&1&&&&1&\\
+&&&1&&&&1\\
+1&&&&-1&&&\\
+&1&&&&-1&&\\
+&&1&&&&-1&\\
+&&&1&&&&-1\\
+\end{array}
+\right]
+```
+从上式可见,在满足$n=r^l$的条件下,$(1)$ 式可以不断展开写为迭代形式,得到迭代形式的 radix-r Cooley-tukey表达式:
+```math
+\text{DFT}_{r^l}=\left(\prod_{i=0}^{l-1}(I_{r^i}\otimes \text{DFT}_r\otimes I_{r^{l-i-1}})D_i^{r^l}\right)R_r^l,\tag{2}
+```
+其中,$D_i^{r^l}$ 为第$i$阶段的旋转因子矩阵,
+```math
+D_i^{r^l}=DiagMatrix[exp(\frac{2 \pi i}{r^l}r^i \alpha \beta), 0 \le \alpha < r, 0 \le \beta < r^{l-i}], \tag{3}
+```
+$R_r^{r^l}$为bit-reverse矩阵,即按位逆序置换矩阵。下面给出长度为8,基为2的置换矩阵表达形式:
+```math
+(x_0, x_4, x_2, x_6, x_1, x_5, x_3, x_7)^T=
+\left(
+\begin{array}{cccccccc}
+1&&&&&&&\\
+&&&&1&&&\\
+&&1&&&&&\\
+&&&&&&1&\\
+&1&&&&&&\\
+&&&&&1&&\\
+&&&1&&&&\\
+&&&&&&&1\\
+\end{array}
+\right) (x_0, x_1, x_2, x_3, x_4, x_5, x_6, x_7)^T.\tag{4}
+```
+可以看到$`R_r^l`$的离散取数逻辑对于硬件并不友好,经过对$`(2)`$式进行变换,尝试将$`R_r^l`$矩阵变换,可以得到Stockham FFT的形式:
+```math
+\text{DFT}_r^l=\prod_{i=0}^{l-1}(\text{DFT}_r\otimes I_{r^{l-1}})D_i^{r^l}(L_r^{r^{l-i}}\otimes I_{r^i})\tag{5}.
+```
+其中旋转因子矩阵为
+```math
+\text{DiagMatrix}\left[ W_{r^{j+2}}^{(\alpha r^j +\xi) \beta}, 0 \leq \alpha, \beta < r, 0 \leq \gamma < r^{l-j-2}, 0 \leq \xi < r^j \right] \tag{6}
+```
+
+
+注意这里的$`D_i^{r^l}$和$(2)`$中的旋转因子矩阵虽然具有相同的符号,但是其公式并不相同。可以看到$`(4)`$式具有固定的变换格式,并且左矩阵$`(\text{DFT}_r\otimes I_{r^{l-1}})`$具有向量并行的形式。适用于向量编程。我们使用Stockham FFT作为我们的迭代FFT的设计方案。
+
+#### 3.1.2 Recursive FFT:Four-step FFT Algorithm
+
+循环FFT是将$`n`$拆分为$`km`$的形式,式$`(1)`$就是一种Recursive FFT的计算形式,根据不同的硬件平台,我们可以将向量并行矩阵和块内并行矩阵进行转化来适应各种硬件平台。其中将块内并行矩阵转化的表达形式为4步法FFT:
+```math
+\text{DFT}_{n}=(\text{DFT}_{k}\otimes I_{m})T_{m}^{n} L_{k}^{n} (\text{DFT}_{m} \times I_{k}), n=km, \tag{7}
+```
+类似的,还有6步法FFT:
+```math
+\text{DFT}_n=L_k^n(I_m \otimes \text{DFT}_k) L_m^n T_m^n (I_k \otimes \text{DFT}_m) L_k^n, n=km, \tag{8}
+```
+以6步法FFT为例,OTFFT里面给出了很好的解释:
+
+6步法FFT $`F_n`$可以看做是一个双重 $`G_n`$ ,代数表达方式如下:
+```math
+[k_1+k_2 m]=G_n(x[p_1+p_2 k])=\frac{1}{k}\sum_{p_1=0}^{k-1}\left(\left(\frac{1}{m}\sum_{p_2=0}^{m-1}x[p_1+p_2 k]W_m^{k_1 p_2}\right)W_n^{k_1 p_1}\right)W_k^{k_2 p_1} \tag{9}
+```
+可以看到,上式是FFT的组合,计算$`X`$ 的计算步骤如下:
+
+ Step 1. 转置$`x`$
+```math
+x[p_1+p_2 m]->a[p_2+p_1k] \tag{10}
+```
+ Step 2. 对$`a`$ 中所有的 $`p_1`$ 分别做 $`F_m`$ 的FFT:
+```math
+a[p_2+p_1m]->b[k_1+p_1 m] \tag{11}
+```
+ Step 3. 乘以旋转因子 $`W_n^{k_1 p_1}`$ :
+```math
+b[k_1+p_1m]->b[k_1+p_1 m]W_n^{k_1 p_1}=c[k_1+p_1 m]\tag{12}
+```
+ Step 4. 对 $`c`$ 进行转置:
+```math
+c[k_1+p_1 m]->d[p_1+k_1 k] \tag{13}
+```
+ Step 5. 对所有的 $`k_1`$ 做 $`F_k`$ 的FFT:
+```math
+d[p_1+k_1 k]->e[k_2+k_1 k] \tag{14}
+```
+ Step 6. 转置 $`e`$ :
+```math
+e[k_2+k_1 k]->X[k_1+k_2 m] \tag{15}
+```
+其中,第2步和第6步的FFT可以使用3.1.1节中的迭代附列变换完成。
+
+6步法FFT可以将长序列的FFT转化为一系列短序列的FFT,所以该方法对于长序列表现较好。
+
+#### 3.1.3 General FFT:Bluestein chirp-z Algorithm
+
+Bluestein z-chirp算法并不要求数据的长度具有合数的性质,将$`km`$写为$`km=(k^2+m^2-(k-m)^2)/2`$。
+```math
+\large X[k]=e^{-\frac{\pi i}{n} k^2}\sum_{j=0}^{n-1}\left(x_j e^{-\frac{\pi i}{n}j^2}\right)e^{\frac{\pi i}{n}(k-j)^2},0\le j < n \tag{16}
+```
+可见上式可以看做是带有 Scale 的两个向量的卷积形式:
+```math
+\large
+\begin{align}
+a_j&=x_j e^{-\frac{-\pi i}{n}j^2}\\
+b_j&=e^{\frac{\pi i}{n}j^2}
+\end{align} \tag{17}
+```
+
+```math
+X_j=b_j^{*}\left(\sum_{j=0}^{n-1}a_j b_{k-j}\right), \tag{18}
+```
+
+$`a_j, b_j`$的分别卷积可以使用 zero-padding 到 $`r^l`$ 次方然后使用快速傅里叶变换进行实现,可以看到,该算法也具有 $`O(nlogn)`$ 的计算复杂度。但是由于3次傅里叶变换,该算法比 Cooley-Tukey 算法要慢。
+
+
+#### 3.1.4 暴力解法:Direct FFT Algorithm
+
+如果不考虑上述FFT的优化形式,我们可以直接将$`\text{DFT}_n`$直接写为Hermitian矩阵的形式:
+```math
+\text{DFT}_n= [W_n^{ij}, i\in[0,n),j\in[0,n)], \tag{19}
+```
+上述算法的计算复杂度为$O(n^2)$,但是对于一些具有矩阵计算单元的处理器来说,在小规模下性能不差于cooley-tukey FFT算法。
+
+下面以 $`\text{DFT}_8`$ 为例进行展示:
+```math
+\text{DFT}_8= \left[
+\begin{array}{cccccccc}
+W_8^0&W_8^0&W_8^0&W_8^0&W_8^0&W_8^0&W_8^0&W_8^0&\\
+W_8^0&W_8^1&W_8^2&W_8^3&W_8^4&W_8^5&W_8^6&W_8^7&\\
+W_8^0&W_8^2&W_8^4&W_8^6&W_8^8&W_8^{10}&W_8^{12}&W_8^{14}&\\
+W_8^0&W_8^3&W_8^6&W_8^9&W_8^{12}&W_8^{15}&W_8^{18}&W_8^{21}&\\
+W_8^0&W_8^4&W_8^8&W_8^{12}&W_8^{16}&W_8^{20}&W_8^{24}&W_8^{28}&\\
+W_8^0&W_8^5&W_8^{10}&W_8^{15}&W_8^{20}&W_8^{25}&W_8^{30}&W_8^{35}&\\
+W_8^0&W_8^6&W_8^{12}&W_8^{18}&W_8^{24}&W_8^{30}&W_8^{36}&W_8^{42}&\\
+W_8^0&W_8^7&W_8^{14}&W_8^{21}&W_8^{28}&W_8^{35}&W_8^{42}&W_8^{49}&\\
+\end{array}
+\right], \tag{20}
+```
+再次将 $\text{DFT}_8$矩阵的实部和虚部进行分行,再利用实数傅里叶分解的对称共轭特性,可以将 FFT 写为实数矩阵形式:
+```math
+\left[
+\begin{array}{c}
+ReX_0\\
+ImX_0\\
+ReX_1\\
+ImX_1\\
+ReX_2\\
+ImX_2\\
+ReX_3\\
+ImX_3\\
+ReX_4\\
+ImX_4\\
+\end{array}
+\right]
+=
+\left[
+\begin{array}{cccccccc}
+Re(W_8^0&W_8^0&W_8^0&W_8^0&W_8^0&W_8^0&W_8^0&W_8^0)&\\
+Im(W_8^0&W_8^0&W_8^0&W_8^0&W_8^0&W_8^0&W_8^0&W_8^0)&\\
+Re(W_8^0&W_8^1&W_8^2&W_8^3&W_8^4&W_8^5&W_8^6&W_8^7)&\\
+Im(W_8^0&W_8^1&W_8^2&W_8^3&W_8^4&W_8^5&W_8^6&W_8^7)&\\
+Re(W_8^0&W_8^2&W_8^4&W_8^6&W_8^8&W_8^{10}&W_8^{12}&W_8^{14})&\\
+Im(W_8^0&W_8^2&W_8^4&W_8^6&W_8^8&W_8^{10}&W_8^{12}&W_8^{14})&\\
+Re(W_8^0&W_8^3&W_8^6&W_8^9&W_8^{12}&W_8^{15}&W_8^{18}&W_8^{21})&\\
+Im(W_8^0&W_8^3&W_8^6&W_8^9&W_8^{12}&W_8^{15}&W_8^{18}&W_8^{21})&\\
+Re(W_8^0&W_8^4&W_8^8&W_8^{12}&W_8^{16}&W_8^{20}&W_8^{24}&W_8^{28})&\\
+Im(W_8^0&W_8^4&W_8^8&W_8^{12}&W_8^{16}&W_8^{20}&W_8^{24}&W_8^{28})&
+
+\end{array}
+\right]
+\left[
+\begin{array}{c}
+x_0\\
+x_1\\
+x_2\\
+x_3\\
+x_4\\
+x_5\\
+x_6\\
+x_7\\
+\end{array}
+\right], \tag{21}
+```
+之后可以根据矩阵乘设计方案进行设计,该方案一条指令即可完成一个短序列的FFT操作,并且对于长度大小没有限制。经过实测,X4计算卡处理 float32类型的 $`(2659, 400) * (402, 400)^T`$ 的耗时为 70us 左右,而T4下该规模的cufft 性能为 200us 左右。
+
+在短时傅里叶变换(short-time fourier transformation, stft)的应用场景中,$n$ 的长度一般较短,该方案可以满足绝大多数的短时傅里叶变换的场景。
+
+### 3.2调用方案设计
+
+目前实现的方案如下:
+
+- 对于小规模,使用暴力解法,算法逻辑参考matmul(仅支持到n[0] <= 4096)
+ - 拼接流程如下:
+ - 生成DFT-matrix ,公式(21)
+ - 300系列生成DFT-matrix,使用surpass生成sin-cos查找表,然后使用vaa指令生成DFT-matrix,每个taskId处理一行DFT-matrix。
+ - 生成DFT-matrix的量化参数(可选)
+ - input tensor 连续化(可选)
+ - input tensor补pad/切crop(可选)
+ - input tensor量化参数(可选)
+ - matmul 生成 output
+ - 处理output的stride特性(可选)
+
+- 对于满足 $n=2^m*l$ 次幂的规模,采用Cooley-turkey FFT 解法(参考论文:“FFT algorithms for vector computers (Parallel Computing 1 (1984) 45-63) ”),具体如下:
+
+ - step0:findLimit,求出m, l, s系数;其中m,l为上面表述的array分解系数,s表示NRAM片上一次性能放下的子图大小,即:(2^s) * l;
+ - step1:生成DFT-matrix:[l, l],同上;
+ - step2:input tensor转数:[batch, n, c] -> [batch, l, 2^m, c] -> [c, batch, 2^m, l];其中,batch表示输入序列的batch_size,n表示输入序列的长度,c表示输入序列的类型,输入为实数时,c=1,输入为实数时,c=2;
+ - step3:input tensor和matrix在l维度进行matmul;
+ - step4:从matmul结果中permute出子图数据到NRAM,并进行实部、虚部结合等预处理操作;
+ - step5:进行子图合并,从1个l开始,前s逐层合并(为了满足精度要求,在输入数据类型half或complex_half时,合并的计算和累加操作位宽提升为float),前一层的输出作为后一层的输入,一直到限制的子图大小,即2*s个l;合并完的数据一次性写回到workspace;计算时,子图均分给每个核处理;
+ - step6:继续合并剩余的m-s层,因为无法放下一个完整的子图,需要每次处理一部分子图数据;计算时,子图划分的部分均分给每个核处理;
+ - step7:将计算结果转置为layout要求格式:c维度从最高维变到最低维,写回到output;
+
+ step1通过在host端 findLimit函数实现;step0,step2,step3为子序列拆分步骤,通过调用mluOpTranspose、mluOpMatmul等kernel实现;step4-step7为子序列合并步骤,为了提高效率,减少重复IO,通过一个新的kernel实现。
+
+- 对于片上可以放下一个2^(m-1) * align_size * 29的规模,采用stockham算法进行优化
+
+### 3.3 拆分(任务拆分,多核拆分)
+
+1. 对于小规模,调用matmul方案,多核拆分先交给matmul处理。
+2. 对于大于4096的规模,step0,step2,step3多核拆分参考调用kernel方案;step4-step7,根据NRAM能放下的子图大小,均分给各个MLU core并行处理。
+
+### 3.5 方案理论性能
+
+完成上述3.1,3.2,3.3,3.4几个步骤之后,基本可以给出一个理论性能,不需要每一个算子都有过于复杂的公式,但是一定要对自己的算子有一个心理的预期,最终实现之后的效率值是多少。
+
+### 3.6 可维护性设计
+
+1、bangc代码中加入必要的 log信息,比如输入的规模、数据类型、layout这些,以及如果出错会导致程序core dump的变量,比如IO指令的data_size、dim xyz的值等,这些信息都是有利于快速定位问题。
+
+2、对每一个函数命名变量命名都有充分的注释
+
+3、避免魔鬼数字,对于确定的数字尽量使用公共宏来替代 (待提供公共宏文档)
+
+### 3.7 测试用例设计
+
+- 框架在需求列表中给出的算子在网络中用到的规模:
+ [2495, 400]
+
+- 边界case:
+ [1, 4096]
+
+ [300, 1]
+
+ [5, 0], n=[100]
+
+ input stride
+
+ output stride
+
+ input stride + output stride
+
+ input stride +inplace
+
+ output stride + inplace
+
+ input stride + output stride + inplace
+
+其他可根据需要进行补充。算子开发完毕后,补充测试报告链接。
+
+### 3.8 算子防呆检查
+
+
+- 列出算子需要做的防呆,比如
+
+ 1、handle, plan, desc指针为空防呆;
+
+ 2、rank 为 1, 2, 3;
+
+ 3、输入输出维度防呆;
+
+ 4、输入输出数据类型防呆,针对r2c, c2r, c2c分别防呆;
+
+ 5、batch 大小防呆;
+
+ 6、execution dtype 数据类型防呆。
+
+ 7、输入输出stride防呆;
+
+ 8、signal length防呆,rfft:output = n / 2 + 1,irfft:output = n,fft: output = n;
+
+ 9、输入输出空指针防呆,如果输入元素不为0,防空指针;输出必不为空指针。
+
+ 9、 2-d,3-d fft,尚未支持。
+
+ 10、c2c, c2r fft防呆,尚未支持。
+
+ 11、r2c,n[0] > 4096,尚未支持。
+
diff --git a/docs/user_guide/9_operators/index.rst b/docs/user_guide/9_operators/index.rst
index 2d122b3ca..03fe2b9a2 100755
--- a/docs/user_guide/9_operators/index.rst
+++ b/docs/user_guide/9_operators/index.rst
@@ -729,3 +729,22 @@ mluOpSyncBatchnormBackwardReduce
mluOpSyncBatchNormBackwardElemt
---------------------------------
该算子用来计算输入的梯度,与 :ref:`sync_batchnorm_backward_reduce` 共同实现了sync_batchnorm_backward。
+
+.. _execFFT:
+
+mluOpExecFFT
+-----------
+对一个长度为N的实数数列进行傅里叶变换。
+
+计算公式如下:
+
+.. math::
+
+ y = DFT_{N} x
+
+其中:
+
+- ``x`` 为输入信号。
+- ``y`` 为输出信号。
+- :math:`DFT_{N}` 为长度为N傅里叶变换的变换矩阵。
+
diff --git a/kernels/fft/README.md b/kernels/fft/README.md
new file mode 100644
index 000000000..8504c06dd
--- /dev/null
+++ b/kernels/fft/README.md
@@ -0,0 +1,52 @@
+# Introducation
+
+在这个目录中,我们将整合所有与FFT算子相关的代码,代码支持了FFT算子四种模式:rfft(r2c), irfft(c2r), fft(c2c), ifft(c2c);同时,出于性能考虑,在具体实现中,不同规模会调用不同的算法,共有三种:DFT、FFT_cooley-tukey和FFT_stockham,具体的代码组织方式如下:
+
+## 代码目录以及说明
+
+1.目录的树状图如下:
+ ├── c2c_fft
+ │ ├── c2c_fft.h
+ │ └── c2c_fft_host.cpp
+ ├── common
+ │ ├── fft_basic_ops.cpp
+ │ ├── fft_basic_ops.h
+ │ ├── fft_common_kernels.h
+ │ └── fft_common_kernels.mlu
+ ├── fft.h
+ ├── fft.cpp
+ ├── fft_optm_device
+ │ ├── fft_cooley-tukey_ux_device.mlu
+ │ └── fft_stockham_u1_device.mlu
+ ├── irfft
+ │ ├── irfft.h
+ │ └── irfft_host.cpp
+ └── rfft
+ ├── rfft.h
+ └── rfft_host.cpp
+
+2.fft.h和fft.mlu:
+ * fft.h:文件中定义了一些基本的结构体,如:不同模式、策略、地址等;进行了golbal函数的声明;
+ * fft.mlu:文件中定义了用户调用的公共接口,如:策略初始化、workspace初始化、host函数选择、基本防呆操作等;每一种模式都会先进入到这个文件,然后根据判断结果,调用对应模式的host代码;
+
+3.common文件夹:
+ * fft_basic_ops.h:在进行FFT调用时,也会使用到别的接口,如转置、量化、矩阵乘等,这些接口的函数调用封装的声明均放置在这个文件;还有一些封装的基本公共函数也放在这里:如findLimit函数;
+ * fft_basic_ops.cpp:给出fft_basic_ops.h中声明接口的实现;
+ * fft_common_kernels.h:生成W矩阵通常是一个耗时的操作,在网络训练中,只有第一次迭代时会生成一次,这里进行了预生成W矩阵接口函数的声明;
+ * fft_common_kernels.mlu:给出fft_common_kernels.h中声明接口的实现;
+
+4.rfft文件夹:
+ * rfft.h:给出rfft的策略函数、workspace空间申请函数和执行函数等的声明;
+ * rfft_host.cpp:rfft host函数的具体实现;会根据策略函数的结果选择:DFT、FFT_cooley-tukey或FFT_stockham算法;
+
+5.irfft文件夹:
+ * 文件夹结构同rfft,只是针对irfft的声明和实现;
+
+6.c2c_fft文件夹:
+ * 文件夹结构同rfft,只是针对fft和ifft的声明和实现,因为两者差别只有一个常数因子,所以放在了同一个文件夹中。
+
+7.fft_optm_device文件夹:
+ * fft_cooley-tukey_ux_device.mlu:优化kernel device代码,基于cooley-tukey算法思想实现;
+ * fft_stockham_u1_device.mlu:优化kernel device代码,基于stockham算法思想实现;
+ * 备注:DFT调用cnnlTranspose, cnnlMatmul等kernel实现,调用fft_basic_ops.cpp中封装好的函数即可,未单独实现kernel device代码。
+
diff --git a/kernels/fft/c2c_fft/c2c_fft.h b/kernels/fft/c2c_fft/c2c_fft.h
new file mode 100644
index 000000000..b44cb9c42
--- /dev/null
+++ b/kernels/fft/c2c_fft/c2c_fft.h
@@ -0,0 +1,38 @@
+/*************************************************************************
+ * Copyright (C) [2024] by Cambricon, Inc.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a
+ * copy of this software and associated documentation files (the
+ * "Software"), to deal in the Software without restriction, including
+ * without limitation the rights to use, copy, modify, merge, publish,
+ * distribute, sublicense, and/or sell copies of the Software, and to
+ * permit persons to whom the Software is furnished to do so, subject to
+ * the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included
+ * in all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+ * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+ * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+ * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
+ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
+ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+ *************************************************************************/
+#ifndef KERNELS_FFT_C2C_FFT_C2C_FFT_H_
+#define KERNELS_FFT_C2C_FFT_C2C_FFT_H_
+
+#include
+#include "kernels/fft/fft.h"
+
+mluOpStatus_t makeFFT1dPolicy(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan);
+
+mluOpStatus_t setFFT1dReserveArea(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan,
+ const std::string api);
+
+mluOpStatus_t execFFT1d(mluOpHandle_t handle, const mluOpFFTPlan_t fft_plan,
+ const void *input, const float scale_factor,
+ void *workspace, void *output, int direction);
+
+#endif // KERNELS_FFT_C2C_FFT_C2C_FFT_H_
diff --git a/kernels/fft/c2c_fft/c2c_fft_host.cpp b/kernels/fft/c2c_fft/c2c_fft_host.cpp
new file mode 100644
index 000000000..f523f64bb
--- /dev/null
+++ b/kernels/fft/c2c_fft/c2c_fft_host.cpp
@@ -0,0 +1,1261 @@
+/*************************************************************************
+ * Copyright (C) [2024] by Cambricon, Inc.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a
+ * copy of this software and associated documentation files (the
+ * "Software"), to deal in the Software without restriction, including
+ * without limitation the rights to use, copy, modify, merge, publish,
+ * distribute, sublicense, and/or sell copies of the Software, and to
+ * permit persons to whom the Software is furnished to do so, subject to
+ * the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included
+ * in all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+ * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+ * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+ * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
+ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
+ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+ *************************************************************************/
+
+#include "kernels/fft/c2c_fft/c2c_fft.h"
+#include
+#include
+
+#define DIRECTION 2 // FORWARD and BACKWARD
+
+static mluOpStatus_t selectFFT1dStrategy(mluOpHandle_t handle,
+ mluOpFFTPlan_t fft_plan) {
+ const std::string make_plan_api = "[selectFFT1dStrategy]";
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+ /* there are plenty of algorithms for FFT, depending on the fft length.
+ * Iterative FFT:
+ * Stockham FFT, Cooley-Tukey FFT, peaseFFT, Kron-Lambiotte FFT
+ * Recursive FFT:
+ * Recursive Cooley-Tukey FFT, Four-step FFT, Six-step FFT, Multicore FFT,
+ * SIMD short vector FFT. General FFT: chirp-Z Bluestein FFT.
+ */
+ // select Stockham FFT, Cooley-Tukey FFT or MATMUL strategy logic
+ fft_plan->fft_strategy = CNFFT_FUNC_MATMUL;
+ status = selectFFTStrategy(handle, fft_plan, make_plan_api);
+ return status;
+}
+
+/*
+ * Make the policy of FFT1d.
+ */
+mluOpStatus_t makeFFT1dPolicy(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan) {
+ std::string api = "[mluOpMakeFFTPlanMany]";
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+ INTERNAL_CHECK(api,
+ selectFFT1dStrategy(handle, fft_plan) == MLUOP_STATUS_SUCCESS);
+
+ mluOpDataType_t in_c_dtype = fft_plan->input_dtype;
+ mluOpDataType_t in_r_dtype = (in_c_dtype == MLUOP_DTYPE_COMPLEX_HALF)
+ ? MLUOP_DTYPE_HALF
+ : MLUOP_DTYPE_FLOAT;
+ mluOpDataType_t in_e_dtype = fft_plan->execution_dtype;
+ size_t in_c_dtype_size = mluOpDataTypeBytes(in_c_dtype);
+ size_t in_r_dtype_size = mluOpDataTypeBytes(in_r_dtype);
+ int batch = fft_plan->batch;
+ int n = fft_plan->n[0];
+
+ switch (fft_plan->fft_strategy) {
+ case CNFFT_FUNC_MATMUL: {
+ if (n > FFT_L_LIMIT) {
+ LOG(ERROR) << "[mluOpMakeFFTPlanMany]: FFT1d CNFFT_FUNC_MATMUL "
+ << "length > 4096 is not supported currently.";
+ return MLUOP_STATUS_NOT_SUPPORTED;
+ }
+
+ // Matmul Input : 2 * [batch, n]
+ // Matmul Matrix : 2 * 2 * [n, n] (forward and backward)
+ // Matmul Result : 4 * [batch, n]
+ int dft_mat_times = COMPLEX;
+ int dim0 = n;
+ int dim1 = n;
+ int dft_mat_num = DIRECTION * dft_mat_times * dim0 * dim1;
+
+ // reservespace size allocation
+ fft_plan->reservespace_size = 0;
+ fft_plan->reservespace_size +=
+ dft_mat_num * mluOpDataTypeBytes(in_r_dtype);
+ if (fftIsIntDtype(in_e_dtype)) {
+ fft_plan->reservespace_size += sizeof(int32_t) + sizeof(float);
+ size_t required_size = 0;
+ status = fftGetQuantizeParamWorkspaceSize(
+ handle, required_size, dft_mat_num, in_r_dtype, in_e_dtype, api);
+ fft_plan->reservespace_size += required_size;
+ }
+
+ /* CNFFT_FUNC_MATMUL :
+ -------------------------
+ | input |
+ -------------------------
+ |
+ | input contiguous
+ \|/
+ -------------------------
+ | input_contiguous |
+ -------------------------
+ |
+ | input pad
+ \|/
+ -------------------------
+ | input_pad |
+ -------------------------
+ |
+ | input trans: batch * n * 2 --> 2 * batch * n
+ \|/
+ -------------------------
+ | input_re |
+ | input_im |
+ -------------------------
+ |
+ | matmul
+ | optensor(re_mul_re - im_mul_im, re_mul_im + im_mul_re)
+ \|/
+ -------------------------
+ | matmul_re_mul_re | (matmul_re)
+ | matmul_re_mul_im | (matmul_im)
+ | matmul_im_mul_re |
+ | matmul_im_mul_im |
+ -------------------------
+ |
+ | output trans: 2 * batch * n --> batch * n * 2
+ \|/
+ -------------------------
+ | output_contiguous |
+ -------------------------
+ |
+ | output contiguous
+ \|/
+ -------------------------
+ | output |
+ -------------------------
+ */
+ // worksapce size allocation
+ fft_plan->matmul_addrs.internal_workspace_size = 0;
+ fft_plan->workspace_size = 0;
+
+ // input contiguous
+ size_t input_size = in_c_dtype_size * fft_plan->inum;
+ fft_plan->workspace_size +=
+ fft_plan->is_input_contiguous ? 0 : input_size;
+
+ // input pad
+ bool need_pad = (fft_plan->inembed[0] != n);
+ int padded_input_num = batch * n;
+ size_t padded_input_size = in_c_dtype_size * padded_input_num;
+ fft_plan->workspace_size += need_pad ? padded_input_size : 0;
+
+ // input trans and workspace
+ size_t transed_input_size = padded_input_size;
+ fft_plan->workspace_size += transed_input_size;
+ // input trans workspace: batch * n * 2 --> 2 * batch * n
+ const int in_trans_dim_num = 2;
+ int in_trans_input_dims[in_trans_dim_num] = {padded_input_num, COMPLEX};
+ int in_trans_permute[in_trans_dim_num] = {1, 0};
+ size_t in_trans_workspace_size = 0;
+ status = fftGetTransposeWorkspaceSize(
+ handle, in_trans_workspace_size, in_trans_dim_num,
+ in_trans_input_dims, in_trans_permute, in_r_dtype, api);
+ fft_plan->matmul_addrs.internal_workspace_size =
+ std::max(fft_plan->matmul_addrs.internal_workspace_size,
+ in_trans_workspace_size);
+
+ // input quantize param and workspace
+ if (fftIsIntDtype(in_e_dtype)) {
+ fft_plan->workspace_size += sizeof(int32_t) + sizeof(float);
+ size_t input_quant_workspace_size = 0;
+ status = fftGetQuantizeParamWorkspaceSize(
+ handle, input_quant_workspace_size, COMPLEX * padded_input_num,
+ in_r_dtype, in_e_dtype, api);
+ fft_plan->matmul_addrs.internal_workspace_size =
+ std::max(fft_plan->matmul_addrs.internal_workspace_size,
+ input_quant_workspace_size);
+ }
+
+ // matmul output
+ const int matmul_times =
+ 4; // real mul real, real mul imag, imag mul real, imag mul imag
+ int per_matmul_output_num = batch * n;
+ size_t per_matmul_output_size = in_r_dtype_size * per_matmul_output_num;
+ size_t matmul_output_size = matmul_times * per_matmul_output_size;
+ fft_plan->workspace_size += matmul_output_size;
+ // matmul workspace
+ size_t matmul_workspace_size = 0;
+ status = fftGetQuantizeMatMulWorkspaceSize(
+ handle, matmul_workspace_size, batch, dim1, dim0, false, true,
+ in_e_dtype, in_e_dtype, in_r_dtype, api);
+ fft_plan->matmul_addrs.internal_workspace_size =
+ std::max(fft_plan->matmul_addrs.internal_workspace_size,
+ matmul_workspace_size);
+ // optensor workspace
+ size_t optensor_workspace_size = 0;
+ status =
+ fftGetOptensorWorkspaceSize(handle, optensor_workspace_size,
+ per_matmul_output_num, in_r_dtype, api);
+ fft_plan->matmul_addrs.internal_workspace_size =
+ std::max(fft_plan->matmul_addrs.internal_workspace_size,
+ optensor_workspace_size);
+
+ // output trans workspace: 2 * batch * n --> batch * n * 2
+ const int out_trans_dim_num = 2;
+ int out_trans_input_dims[out_trans_dim_num] = {COMPLEX,
+ per_matmul_output_num};
+ int out_trans_permute[out_trans_dim_num] = {1, 0};
+ size_t out_trans_workspace_size = 0;
+ status = fftGetTransposeWorkspaceSize(
+ handle, out_trans_workspace_size, out_trans_dim_num,
+ out_trans_input_dims, out_trans_permute, in_r_dtype, api);
+ fft_plan->matmul_addrs.internal_workspace_size =
+ std::max(fft_plan->matmul_addrs.internal_workspace_size,
+ out_trans_workspace_size);
+
+ // output contiguous
+ size_t output_size =
+ mluOpDataTypeBytes(fft_plan->output_dtype) * fft_plan->onum;
+ fft_plan->workspace_size +=
+ fft_plan->is_output_contiguous ? 0 : output_size;
+
+ // internal_workspace
+ fft_plan->workspace_size +=
+ fft_plan->matmul_addrs.internal_workspace_size;
+ VLOG(5) << "internal workspace size: "
+ << fft_plan->matmul_addrs.internal_workspace_size;
+ VLOG(5) << "total workspace size: " << fft_plan->workspace_size;
+ }; break;
+ case CNFFT_FUNC_COOLEY_TUKEY:
+ case CNFFT_FUNC_STOCKHAM: {
+ int L = fft_plan->L;
+ int m = (1 << fft_plan->m);
+ if (L > FFT_L_LIMIT) {
+ LOG(ERROR) << "[mluOpMakeFFTPlanMany]: FFT1d CNFFT_FUNC_COOLEY_TUKEY "
+ << "n = L * 2^m and L > 4096 is not supported currently.";
+ return MLUOP_STATUS_NOT_SUPPORTED;
+ }
+
+ // Matmul Input : 2 * [batch, 2^m, L]
+ // Matmul Matrix : 2 * 2 * [L, L] (forward and backward)
+ // Matmul Result : 4 * [batch, 2^m, L]
+ int dft_mat_times = COMPLEX;
+ int dim0 = L;
+ int dim1 = L;
+ int dft_mat_num = DIRECTION * dft_mat_times * dim0 * dim1;
+
+ // reservespace size allocation
+ fft_plan->reservespace_size = 0;
+ fft_plan->reservespace_size += dft_mat_num * in_r_dtype_size;
+ if (fftIsIntDtype(in_e_dtype)) {
+ fft_plan->reservespace_size += sizeof(int32_t) + sizeof(float);
+ size_t required_size = 0;
+ status = fftGetQuantizeParamWorkspaceSize(
+ handle, required_size, dft_mat_num, in_r_dtype, in_e_dtype, api);
+ fft_plan->reservespace_size += required_size;
+ }
+
+ /* CNFFT_FUNC_COOLEY_TUKEY :
+ -------------------------
+ | input |
+ -------------------------
+ |
+ | input contiguous
+ \|/
+ -------------------------
+ | input_contiguous |
+ -------------------------
+ |
+ | input pad
+ \|/
+ -------------------------
+ | input_pad |
+ -------------------------
+ |
+ | input trans: batch * n * 2 --> 2 * batch * n
+ \|/
+ -------------------------
+ | input_transed |
+ -------------------------
+ |
+ | input trans: 2 * batch * L * 2^m --> 2 * batch * 2^m * L
+ \|/
+ -------------------------
+ | input_re |
+ | input_im |
+ -------------------------
+ |
+ | matmul
+ | optensor(re_mul_re - im_mul_im, re_mul_im + im_mul_re)
+ \|/
+ -------------------------
+ | matmul_re_mul_re | (matmul_re)
+ | matmul_re_mul_im | (matmul_im)
+ | matmul_im_mul_re |
+ | matmul_im_mul_im |
+ -------------------------
+ |
+ | output merge
+ \|/
+ -------------------------
+ | output_contiguous |
+ -------------------------
+ |
+ | output contiguous
+ \|/
+ -------------------------
+ | output |
+ -------------------------
+ */
+ // worksapce size allocation
+ fft_plan->matmul_addrs.internal_workspace_size = 0;
+ fft_plan->workspace_size = 0;
+
+ // input contiguous
+ size_t input_size = in_c_dtype_size * fft_plan->inum;
+ fft_plan->workspace_size +=
+ fft_plan->is_input_contiguous ? 0 : input_size;
+
+ // input pad
+ bool need_pad = (fft_plan->inembed[0] != n);
+ int padded_input_num = fft_plan->batch * n;
+ size_t padded_input_size = in_c_dtype_size * padded_input_num;
+ fft_plan->workspace_size += need_pad ? padded_input_size : 0;
+
+ // input trans and workspace
+ const int trans_times = 2;
+ size_t transed_input_size = trans_times * padded_input_size;
+ fft_plan->workspace_size += transed_input_size;
+ // input trans workspace
+ // 1st transpose: batch * n * 2 --> 2 * batch * n
+ const int in_trans_1st_dim_num = 2;
+ int in_trans_1st_input_dims[in_trans_1st_dim_num] = {padded_input_num,
+ COMPLEX};
+ int in_trans_1st_permute[in_trans_1st_dim_num] = {1, 0};
+ size_t in_trans_1st_workspace_size = 0;
+ status = fftGetTransposeWorkspaceSize(
+ handle, in_trans_1st_workspace_size, in_trans_1st_dim_num,
+ in_trans_1st_input_dims, in_trans_1st_permute, in_r_dtype, api);
+ fft_plan->matmul_addrs.internal_workspace_size =
+ std::max(fft_plan->matmul_addrs.internal_workspace_size,
+ in_trans_1st_workspace_size);
+ // 2nd transpose: 2 * batch * L * 2^m --> 2 * batch * 2^m * L
+ const int in_trans_2nd_dim_num = 3;
+ int in_trans_2nd_input_dims[in_trans_2nd_dim_num] = {COMPLEX * batch, L,
+ m};
+ int in_trans_2nd_permute[in_trans_2nd_dim_num] = {0, 2, 1};
+ size_t in_trans_2nd_workspace_size = 0;
+ status = fftGetTransposeWorkspaceSize(
+ handle, in_trans_2nd_workspace_size, in_trans_2nd_dim_num,
+ in_trans_2nd_input_dims, in_trans_2nd_permute, in_r_dtype, api);
+ fft_plan->matmul_addrs.internal_workspace_size =
+ std::max(fft_plan->matmul_addrs.internal_workspace_size,
+ in_trans_2nd_workspace_size);
+
+ // input quantize param and workspace
+ if (fftIsIntDtype(in_e_dtype)) {
+ fft_plan->workspace_size += sizeof(int32_t) + sizeof(float);
+ size_t input_quant_workspace_size = 0;
+ fftGetQuantizeParamWorkspaceSize(handle, input_quant_workspace_size,
+ COMPLEX * padded_input_num, in_r_dtype,
+ in_e_dtype, api);
+ fft_plan->matmul_addrs.internal_workspace_size =
+ std::max(fft_plan->matmul_addrs.internal_workspace_size,
+ input_quant_workspace_size);
+ }
+
+ // matmul output
+ const int matmul_times =
+ 4; // real mul real, real mul imag, imag mul real, imag mul imag
+ int per_matmul_output_num = batch * n;
+ size_t per_matmul_output_size = in_r_dtype_size * per_matmul_output_num;
+ fft_plan->workspace_size += matmul_times * per_matmul_output_size;
+ // matmul workspace
+ size_t matmul_workspace_size = 0;
+ status = fftGetQuantizeMatMulWorkspaceSize(
+ handle, matmul_workspace_size, batch * m, L, L, false, true,
+ in_e_dtype, in_e_dtype, in_r_dtype, api);
+ fft_plan->matmul_addrs.internal_workspace_size =
+ std::max(fft_plan->matmul_addrs.internal_workspace_size,
+ matmul_workspace_size);
+ // optensor workspace
+ size_t optensor_workspace_size = 0;
+ status =
+ fftGetOptensorWorkspaceSize(handle, optensor_workspace_size,
+ per_matmul_output_num, in_r_dtype, api);
+ fft_plan->matmul_addrs.internal_workspace_size =
+ std::max(fft_plan->matmul_addrs.internal_workspace_size,
+ optensor_workspace_size);
+
+ // output merge workspace
+ size_t merge_workspace_size =
+ COMPLEX * in_r_dtype_size * per_matmul_output_num;
+ fft_plan->matmul_addrs.internal_workspace_size = std::max(
+ fft_plan->matmul_addrs.internal_workspace_size, merge_workspace_size);
+
+ // output contiguous
+ size_t output_size =
+ mluOpDataTypeBytes(fft_plan->output_dtype) * fft_plan->onum;
+ fft_plan->workspace_size +=
+ fft_plan->is_output_contiguous ? 0 : output_size;
+
+ // internal_workspace
+ fft_plan->workspace_size +=
+ fft_plan->matmul_addrs.internal_workspace_size;
+ VLOG(5) << "internal workspace size: "
+ << fft_plan->matmul_addrs.internal_workspace_size;
+ VLOG(5) << "total workspace size: " << fft_plan->workspace_size;
+ }; break;
+ default: {
+ status = MLUOP_STATUS_NOT_SUPPORTED;
+ return status;
+ }
+ }
+ return status;
+}
+
+static void configureFFT1dMatmulReserveAddrs(mluOpHandle_t handle,
+ mluOpFFTPlan_t fft_plan) {
+ size_t dft_mat_size = 0;
+ // forward real, forward imag, backward real, backward imag
+ const int dft_mat_times = DIRECTION * COMPLEX;
+ mluOpDataType_t in_c_dtype = fft_plan->input_dtype;
+ mluOpDataType_t in_r_dtype = (in_c_dtype == MLUOP_DTYPE_COMPLEX_HALF)
+ ? MLUOP_DTYPE_HALF
+ : MLUOP_DTYPE_FLOAT;
+ size_t in_r_dtype_size = mluOpDataTypeBytes(in_r_dtype);
+ int n = fft_plan->n[0];
+
+ switch (fft_plan->fft_strategy) {
+ case CNFFT_FUNC_MATMUL: {
+ // Matmul Matrix : 2 * 2 * [n, n] (forward and backward)
+ size_t per_dft_mat_size = n * n * in_r_dtype_size;
+ dft_mat_size = dft_mat_times * per_dft_mat_size;
+ fft_plan->matmul_addrs.dft_matrix_addr = fft_plan->reservespace_addr;
+ fft_plan->matmul_addrs.dft_re_matrix_addr =
+ fft_plan->matmul_addrs.dft_matrix_addr;
+ fft_plan->matmul_addrs.dft_im_matrix_addr =
+ (uint8_t *)fft_plan->matmul_addrs.dft_matrix_addr + per_dft_mat_size;
+ fft_plan->matmul_addrs.ifft_dft_matrix_addr =
+ (uint8_t *)fft_plan->matmul_addrs.dft_im_matrix_addr +
+ per_dft_mat_size;
+ fft_plan->matmul_addrs.ifft_dft_re_matrix_addr =
+ fft_plan->matmul_addrs.ifft_dft_matrix_addr;
+ fft_plan->matmul_addrs.ifft_dft_im_matrix_addr =
+ (int8_t *)fft_plan->matmul_addrs.ifft_dft_matrix_addr +
+ per_dft_mat_size;
+ }; break;
+ case CNFFT_FUNC_COOLEY_TUKEY:
+ case CNFFT_FUNC_STOCKHAM: {
+ // Matmul Matrix : 2 * 2 * [L, L] (forward and backward)
+ int L = fft_plan->L;
+ size_t per_dft_mat_size = L * L * in_r_dtype_size;
+ dft_mat_size = dft_mat_times * per_dft_mat_size;
+ fft_plan->matmul_addrs.dft_matrix_addr = fft_plan->reservespace_addr;
+ fft_plan->matmul_addrs.dft_re_matrix_addr =
+ fft_plan->matmul_addrs.dft_matrix_addr;
+ fft_plan->matmul_addrs.dft_im_matrix_addr =
+ (uint8_t *)fft_plan->matmul_addrs.dft_matrix_addr + per_dft_mat_size;
+ fft_plan->matmul_addrs.ifft_dft_matrix_addr =
+ (uint8_t *)fft_plan->matmul_addrs.dft_im_matrix_addr +
+ per_dft_mat_size;
+ fft_plan->matmul_addrs.ifft_dft_re_matrix_addr =
+ fft_plan->matmul_addrs.ifft_dft_matrix_addr;
+ fft_plan->matmul_addrs.ifft_dft_im_matrix_addr =
+ (uint8_t *)fft_plan->matmul_addrs.ifft_dft_matrix_addr +
+ per_dft_mat_size;
+ }; break;
+ default: {
+ break;
+ }
+ }
+ if (fftIsIntDtype(fft_plan->execution_dtype)) {
+ fft_plan->matmul_addrs.dft_pos_addr =
+ (uint8_t *)fft_plan->reservespace_addr + dft_mat_size;
+ fft_plan->matmul_addrs.dft_scale_addr =
+ (uint8_t *)fft_plan->matmul_addrs.dft_pos_addr + sizeof(int32_t);
+ fft_plan->matmul_addrs.dft_quantize_workspace_addr =
+ (uint8_t *)fft_plan->matmul_addrs.dft_scale_addr + sizeof(float);
+ fft_plan->matmul_addrs.dft_quantize_workspace_size =
+ fft_plan->reservespace_size - dft_mat_size - sizeof(int32_t) -
+ sizeof(float);
+ } else {
+ fft_plan->matmul_addrs.dft_pos_addr = nullptr;
+ fft_plan->matmul_addrs.dft_scale_addr = nullptr;
+ fft_plan->matmul_addrs.dft_quantize_workspace_addr = nullptr;
+ fft_plan->matmul_addrs.dft_quantize_workspace_size = 0;
+ }
+}
+
+mluOpStatus_t setFFT1dReserveArea(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan,
+ const std::string api) {
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+ configureFFT1dMatmulReserveAddrs(handle, fft_plan);
+
+ mluOpDataType_t in_c_dtype = fft_plan->input_dtype;
+ mluOpDataType_t in_r_dtype = (in_c_dtype == MLUOP_DTYPE_COMPLEX_HALF)
+ ? MLUOP_DTYPE_HALF
+ : MLUOP_DTYPE_FLOAT;
+ mluOpDataType_t in_e_dtype = fft_plan->execution_dtype;
+ const int n = fft_plan->n[0];
+ const int dft_mat_times =
+ DIRECTION *
+ COMPLEX; // forward real, forward imag, backward real, backward imag
+
+ const unsigned int cluster_number =
+ mluop::runtime::getClusterLimitCapability(handle);
+ const unsigned int core_dim = handle->core_num_per_cluster;
+ cnrtDim3_t k_dim = {core_dim, cluster_number, 1};
+ cnrtFunctionType_t k_type = CNRT_FUNC_TYPE_BLOCK;
+
+ switch (fft_plan->fft_strategy) {
+ case CNFFT_FUNC_MATMUL: {
+ // Matmul Matrix : 2 * 2 * [n, n] (forward and backward)
+ int dft_mat_num = dft_mat_times * n * n;
+ kernelC2CFFTDFTMatrix(k_dim, k_type, handle->queue, fft_plan, in_r_dtype,
+ n);
+ status = fftQuantizePositionScale(
+ handle, dft_mat_num, in_r_dtype, in_e_dtype,
+ fft_plan->matmul_addrs.dft_matrix_addr,
+ fft_plan->matmul_addrs.dft_pos_addr,
+ fft_plan->matmul_addrs.dft_scale_addr,
+ fft_plan->matmul_addrs.dft_quantize_workspace_addr,
+ fft_plan->matmul_addrs.dft_quantize_workspace_size, api);
+ INTERNAL_CHECK("[mluOpSetFFTReserveArea]",
+ status == MLUOP_STATUS_SUCCESS);
+ }; break;
+ case CNFFT_FUNC_COOLEY_TUKEY:
+ case CNFFT_FUNC_STOCKHAM: {
+ // Matmul Matrix : 2 * 2 * [L, L] (forward and backward)
+ int L = fft_plan->L;
+ int dft_mat_num = dft_mat_times * L * L;
+ kernelC2CFFTDFTMatrix(k_dim, k_type, handle->queue, fft_plan, in_r_dtype,
+ L);
+ status = fftQuantizePositionScale(
+ handle, dft_mat_num, in_r_dtype, in_e_dtype,
+ fft_plan->matmul_addrs.dft_matrix_addr,
+ fft_plan->matmul_addrs.dft_pos_addr,
+ fft_plan->matmul_addrs.dft_scale_addr,
+ fft_plan->matmul_addrs.dft_quantize_workspace_addr,
+ fft_plan->matmul_addrs.dft_quantize_workspace_size, api);
+ INTERNAL_CHECK("[mluOpSetFFTReserveArea]",
+ status == MLUOP_STATUS_SUCCESS);
+ }; break;
+ default: {
+ status = MLUOP_STATUS_NOT_SUPPORTED;
+ }
+ }
+ return status;
+}
+
+static void configureFFT1dMatmulWorkspaceAddrs(mluOpHandle_t handle,
+ mluOpFFTPlan_t fft_plan,
+ void *input, void *workspace,
+ void *output) {
+ VLOG(5) << "Into configure FFT1d Matmul Workspace Addrs";
+ size_t workspace_cur_offset = 0;
+ size_t workspace_cur_offset_to_end = 0;
+ size_t workspace_total_size = fft_plan->workspace_size;
+ void *workspace_end = (uint8_t *)workspace + workspace_total_size;
+
+ mluOpDataType_t in_c_dtype = fft_plan->input_dtype;
+ mluOpDataType_t in_r_dtype = (in_c_dtype == MLUOP_DTYPE_COMPLEX_HALF)
+ ? MLUOP_DTYPE_HALF
+ : MLUOP_DTYPE_FLOAT;
+ mluOpDataType_t in_e_dtype = fft_plan->execution_dtype;
+ size_t in_c_dtype_size = mluOpDataTypeBytes(in_c_dtype);
+ size_t in_r_dtype_size = mluOpDataTypeBytes(in_r_dtype);
+ int batch = fft_plan->batch;
+ int n = fft_plan->n[0];
+
+ // input contiguous
+ size_t input_size = in_c_dtype_size * fft_plan->inum;
+ if (!fft_plan->is_input_contiguous) {
+ fft_plan->matmul_addrs.input_contiguous_addr =
+ (uint8_t *)workspace + workspace_cur_offset;
+ workspace_cur_offset += input_size;
+ } else {
+ fft_plan->matmul_addrs.input_contiguous_addr = input;
+ }
+
+ // input pad
+ bool need_pad = (fft_plan->inembed[0] != n);
+ int padded_input_num = batch * n;
+ size_t padded_input_size = in_c_dtype_size * padded_input_num;
+ if (need_pad) {
+ fft_plan->matmul_addrs.input_pad_addr =
+ (uint8_t *)workspace + workspace_cur_offset;
+ workspace_cur_offset += padded_input_size;
+ } else {
+ fft_plan->matmul_addrs.input_pad_addr =
+ fft_plan->matmul_addrs.input_contiguous_addr;
+ }
+
+ if (fft_plan->fft_strategy == CNFFT_FUNC_MATMUL) {
+ // input trans: batch * n * 2 --> 2 * batch * n
+ size_t transed_input_size = padded_input_size;
+ fft_plan->matmul_addrs.input_re_addr =
+ (uint8_t *)workspace + workspace_cur_offset;
+ fft_plan->matmul_addrs.input_im_addr = (uint8_t *)workspace +
+ workspace_cur_offset +
+ transed_input_size / COMPLEX;
+ workspace_cur_offset += transed_input_size;
+ } else if (fft_plan->fft_strategy == CNFFT_FUNC_COOLEY_TUKEY) {
+ // input 1st trans: batch * n * 2 --> 2 * batch * n
+ size_t per_transed_input_size = padded_input_size;
+ fft_plan->matmul_addrs.input_transed_addr =
+ (uint8_t *)workspace + workspace_cur_offset;
+ workspace_cur_offset += per_transed_input_size;
+ // input 2nd trans: 2 * batch * L * 2^m --> 2 * batch * 2^m * L
+ fft_plan->matmul_addrs.input_re_addr =
+ (uint8_t *)workspace + workspace_cur_offset;
+ fft_plan->matmul_addrs.input_im_addr = (uint8_t *)workspace +
+ workspace_cur_offset +
+ per_transed_input_size / COMPLEX;
+ workspace_cur_offset += per_transed_input_size;
+ }
+
+ // input quantize
+ if (fftIsIntDtype(in_e_dtype)) {
+ fft_plan->matmul_addrs.input_pos_addr =
+ (uint8_t *)workspace + workspace_cur_offset;
+ workspace_cur_offset += sizeof(int32_t);
+ fft_plan->matmul_addrs.input_scale_addr =
+ (uint8_t *)workspace + workspace_cur_offset;
+ workspace_cur_offset += sizeof(float);
+ } else {
+ fft_plan->matmul_addrs.input_pos_addr = nullptr;
+ fft_plan->matmul_addrs.input_scale_addr = nullptr;
+ }
+
+ // internal workspace
+ workspace_cur_offset_to_end += fft_plan->matmul_addrs.internal_workspace_size;
+ fft_plan->matmul_addrs.internal_workspace_addr =
+ (uint8_t *)workspace_end - workspace_cur_offset_to_end;
+
+ // output contiguous
+ size_t output_size =
+ mluOpDataTypeBytes(fft_plan->output_dtype) * fft_plan->onum;
+ if (!fft_plan->is_output_contiguous) {
+ workspace_cur_offset_to_end += output_size;
+ fft_plan->matmul_addrs.output_contiguous_addr =
+ (uint8_t *)workspace_end - workspace_cur_offset_to_end;
+ } else {
+ fft_plan->matmul_addrs.output_contiguous_addr = output;
+ }
+
+ // matmul output
+ int per_matmul_output_num = batch * n;
+ size_t per_matmul_output_size = in_r_dtype_size * per_matmul_output_num;
+ if (fft_plan->fft_strategy == CNFFT_FUNC_MATMUL) {
+ workspace_cur_offset_to_end += per_matmul_output_size;
+ fft_plan->matmul_addrs.matmul_im_mul_im_addr =
+ (uint8_t *)workspace_end - workspace_cur_offset_to_end;
+ workspace_cur_offset_to_end += per_matmul_output_size;
+ fft_plan->matmul_addrs.matmul_im_mul_re_addr =
+ (uint8_t *)workspace_end - workspace_cur_offset_to_end;
+ workspace_cur_offset_to_end += per_matmul_output_size;
+ fft_plan->matmul_addrs.matmul_re_mul_im_addr =
+ (uint8_t *)workspace_end - workspace_cur_offset_to_end;
+ workspace_cur_offset_to_end += per_matmul_output_size;
+ fft_plan->matmul_addrs.matmul_re_mul_re_addr =
+ (uint8_t *)workspace_end - workspace_cur_offset_to_end;
+ } else if (fft_plan->fft_strategy == CNFFT_FUNC_COOLEY_TUKEY ||
+ fft_plan->fft_strategy == CNFFT_FUNC_STOCKHAM) {
+ workspace_cur_offset_to_end += per_matmul_output_size;
+ fft_plan->matmul_addrs.matmul_im_mul_im_addr =
+ (uint8_t *)workspace_end - workspace_cur_offset_to_end;
+ workspace_cur_offset_to_end += per_matmul_output_size;
+ fft_plan->matmul_addrs.matmul_im_mul_re_addr =
+ (uint8_t *)workspace_end - workspace_cur_offset_to_end;
+ workspace_cur_offset_to_end += per_matmul_output_size;
+ fft_plan->matmul_addrs.matmul_re_mul_im_addr =
+ (uint8_t *)workspace_end - workspace_cur_offset_to_end;
+ workspace_cur_offset_to_end += per_matmul_output_size;
+ fft_plan->matmul_addrs.matmul_re_mul_re_addr =
+ (uint8_t *)workspace_end - workspace_cur_offset_to_end;
+ }
+}
+
+// input : in input
+// output : in input_contiguous_addr
+static mluOpStatus_t makeFFT1dContiguousInput(mluOpHandle_t handle,
+ mluOpFFTPlan_t fft_plan,
+ const void *input) {
+ std::string api = "[mluOpExecFFT]";
+ VLOG(5) << "into makeFFT1dContiguousInput";
+ auto status = MLUOP_STATUS_SUCCESS;
+ if (!fft_plan->is_input_contiguous) {
+ VLOG(5) << "launch mluOpContiguous for fft1d input";
+ mluOpTensorDescriptor_t input_desc;
+ status = mluOpCreateTensorDescriptor(&input_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ const int in_dim_num = 2;
+ int64_t dims[in_dim_num] = {fft_plan->batch, fft_plan->inembed[0]};
+ int64_t strides[in_dim_num] = {fft_plan->idist, fft_plan->istride};
+ status = mluOpSetTensorDescriptorEx_v2(input_desc, MLUOP_LAYOUT_ARRAY,
+ fft_plan->input_dtype, in_dim_num,
+ dims, strides);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ status = mluOpContiguous(handle, input_desc, input,
+ fft_plan->matmul_addrs.input_contiguous_addr);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ status = mluOpDestroyTensorDescriptor(input_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ }
+ return status;
+}
+
+// input : in input_contiguous_addr
+// output : in input_pad_addr
+static mluOpStatus_t padFFT1dContiguousInput(mluOpHandle_t handle,
+ mluOpFFTPlan_t fft_plan) {
+ std::string api = "[mluOpExecFFT]";
+ VLOG(5) << "into padFFT1dContiguousInput";
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+
+ mluOpDataType_t in_c_dtype = fft_plan->input_dtype;
+ mluOpDataType_t in_r_dtype = (in_c_dtype == MLUOP_DTYPE_COMPLEX_HALF)
+ ? MLUOP_DTYPE_HALF
+ : MLUOP_DTYPE_FLOAT;
+ int batch = fft_plan->batch;
+ int n = fft_plan->n[0];
+ bool need_pad = (fft_plan->inembed[0] != n);
+ if (need_pad) {
+ VLOG(5) << "launch cnnlOpPad for input pad";
+ mluOpTensorDescriptor_t input_desc, padded_input_desc;
+ status = mluOpCreateTensorDescriptor(&input_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ status = mluOpCreateTensorDescriptor(&padded_input_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ const int in_dim_num = 2;
+ int64_t dims[in_dim_num] = {batch, fft_plan->inembed[0] * COMPLEX};
+ status = mluOpSetTensorDescriptor_v2(input_desc, MLUOP_LAYOUT_ARRAY,
+ in_r_dtype, in_dim_num, dims);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ int64_t padded_dims[in_dim_num] = {batch, n * COMPLEX};
+ status = mluOpSetTensorDescriptor_v2(padded_input_desc, MLUOP_LAYOUT_ARRAY,
+ in_r_dtype, in_dim_num, padded_dims);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ const int pad_dim_num = 4;
+ int paddings[pad_dim_num] = {0, 0, 0, (n - fft_plan->inembed[0]) * COMPLEX};
+ uint64_t padding_value = 0x00000000;
+
+ DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle,
+ cnnl_handle); // convert to cnnl_handle
+
+ // convert to cnnl_tensor_descriptor
+ DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(input_desc, cnnl_input_desc);
+ DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(padded_input_desc,
+ cnnl_padded_input_desc);
+ CALL_CNNL(cnnlPad(cnnl_handle, cnnl_input_desc,
+ fft_plan->matmul_addrs.input_contiguous_addr, paddings,
+ &padding_value, cnnl_padded_input_desc,
+ fft_plan->matmul_addrs.input_pad_addr));
+
+ // destroy cnnl descriptor
+ VLOG(5) << "c2cfft cnnlOpPad end";
+ DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_input_desc);
+ DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_padded_input_desc);
+
+ DESTROY_CNNL_HANDLE(cnnl_handle);
+ }
+ return status;
+}
+
+/* CNFFT_FUNC_MATMUL:
+ -------------------------
+ | input_pad |
+ -------------------------
+ |
+ | input trans: batch * n * 2 --> 2 * batch * n
+ \|/
+ -------------------------
+ | input_re |
+ | input_im |
+ -------------------------
+
+ CNFFT_FUNC_COOLEY_TUKEY:
+ -------------------------
+ | input_pad |
+ -------------------------
+ |
+ | input trans: batch * n * 2 --> 2 * batch * n
+ \|/
+ -------------------------
+ | input_transed |
+ -------------------------
+ |
+ | input trans: 2 * batch * L * 2^m --> 2 * batch * 2^m * L
+ \|/
+ -------------------------
+ | input_re |
+ | input_im |
+ -------------------------
+*/
+static mluOpStatus_t transposeFFT1dPaddedInput(mluOpHandle_t handle,
+ mluOpFFTPlan_t fft_plan) {
+ std::string api = "[mluOpExecFFT]";
+ VLOG(5) << "into transposeFFT1dPaddedInput";
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+
+ mluOpDataType_t in_c_dtype = fft_plan->input_dtype;
+ mluOpDataType_t in_r_dtype = (in_c_dtype == MLUOP_DTYPE_COMPLEX_HALF)
+ ? MLUOP_DTYPE_HALF
+ : MLUOP_DTYPE_FLOAT;
+ int batch = fft_plan->batch;
+ int n = fft_plan->n[0];
+
+ if (fft_plan->fft_strategy == CNFFT_FUNC_MATMUL) {
+ // transpose: batch * n * 2 --> 2 * batch * n
+ VLOG(5) << "launch mluOpTranspose for input CNFFT_FUNC_MATMUL";
+ int padded_input_num = batch * n;
+ const int trans_dim_num = 2;
+ int trans_input_dims[trans_dim_num] = {padded_input_num, COMPLEX};
+ int trans_output_dims[trans_dim_num] = {COMPLEX, padded_input_num};
+ int trans_permute[trans_dim_num] = {1, 0};
+
+ status =
+ fftTranspose(handle, trans_dim_num, trans_input_dims, trans_output_dims,
+ trans_permute, fft_plan->matmul_addrs.input_pad_addr,
+ fft_plan->matmul_addrs.input_re_addr, in_r_dtype,
+ fft_plan->matmul_addrs.internal_workspace_addr,
+ fft_plan->matmul_addrs.internal_workspace_size, api);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ } else if (fft_plan->fft_strategy == CNFFT_FUNC_COOLEY_TUKEY) {
+ VLOG(5) << "launch mluOpTranspose for input CNFFT_FUNC_COOLEY_TUKEY";
+ int L = fft_plan->L;
+ int m = (1 << fft_plan->m);
+
+ // 1st transpose: batch * n * 2 --> 2 * batch * n
+ int padded_input_num = batch * n;
+ const int trans_dim_num = 2;
+ int trans_input_dims[trans_dim_num] = {padded_input_num, COMPLEX};
+ int trans_output_dims[trans_dim_num] = {COMPLEX, padded_input_num};
+ int trans_permute[trans_dim_num] = {1, 0};
+
+ status =
+ fftTranspose(handle, trans_dim_num, trans_input_dims, trans_output_dims,
+ trans_permute, fft_plan->matmul_addrs.input_pad_addr,
+ fft_plan->matmul_addrs.input_transed_addr, in_r_dtype,
+ fft_plan->matmul_addrs.internal_workspace_addr,
+ fft_plan->matmul_addrs.internal_workspace_size, api);
+
+ // 2nd transpose: 2 * batch * L * 2^m --> 2 * batch * 2^m * L
+ const int trans_2nd_dim_num = 3;
+ int trans_2nd_input_dims[trans_2nd_dim_num] = {COMPLEX * batch, L, m};
+ int trans_2nd_output_dims[trans_2nd_dim_num] = {COMPLEX * batch, m, L};
+ int trans_2nd_permute[trans_2nd_dim_num] = {0, 2, 1};
+
+ status = fftTranspose(handle, trans_2nd_dim_num, trans_2nd_input_dims,
+ trans_2nd_output_dims, trans_2nd_permute,
+ fft_plan->matmul_addrs.input_transed_addr,
+ fft_plan->matmul_addrs.input_re_addr, in_r_dtype,
+ fft_plan->matmul_addrs.internal_workspace_addr,
+ fft_plan->matmul_addrs.internal_workspace_size, api);
+ }
+ return status;
+}
+
+// input : in input_pad_addr
+// output : in input_pos_addr and input_scale_addr
+static mluOpStatus_t quantizeFFT1dPaddedInput(mluOpHandle_t handle,
+ mluOpFFTPlan_t fft_plan) {
+ std::string api = "[mluOpExecFFT]";
+ VLOG(5) << "into quantizeFFT1dPaddedInput";
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+
+ mluOpDataType_t in_c_dtype = fft_plan->input_dtype;
+ mluOpDataType_t in_r_dtype = (in_c_dtype == MLUOP_DTYPE_COMPLEX_HALF)
+ ? MLUOP_DTYPE_HALF
+ : MLUOP_DTYPE_FLOAT;
+ mluOpDataType_t in_e_dtype = fft_plan->execution_dtype;
+ int padded_input_num = fft_plan->batch * fft_plan->n[0];
+
+ status = fftQuantizePositionScale(
+ handle, COMPLEX * padded_input_num, in_r_dtype, in_e_dtype,
+ fft_plan->matmul_addrs.input_re_addr,
+ fft_plan->matmul_addrs.input_pos_addr,
+ fft_plan->matmul_addrs.input_scale_addr,
+ fft_plan->matmul_addrs.internal_workspace_addr,
+ fft_plan->matmul_addrs.internal_workspace_size, api);
+
+ return status;
+}
+
+/* CNFFT_FUNC_MATMUL and CNFFT_FUNC_COOLEY_TUKEY:
+ -------------------------
+ | input_re |
+ | input_im |
+ -------------------------
+ |
+ | matmul
+ | optensor(re_mul_re - im_mul_im, re_mul_im + im_mul_re)
+ \|/
+ -------------------------
+ | matmul_re_mul_re | (matmul_re)
+ | matmul_re_mul_im | (matmul_im)
+ | matmul_im_mul_re |
+ | matmul_im_mul_im |
+ -------------------------
+*/
+static mluOpStatus_t computeFFT1dMatmulResult(mluOpHandle_t handle,
+ mluOpFFTPlan_t fft_plan,
+ const float scale_factor,
+ int direction) {
+ std::string api = "[mluOpExecFFT]";
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+
+ mluOpDataType_t in_c_dtype = fft_plan->input_dtype;
+ mluOpDataType_t in_r_dtype = (in_c_dtype == MLUOP_DTYPE_COMPLEX_HALF)
+ ? MLUOP_DTYPE_HALF
+ : MLUOP_DTYPE_FLOAT;
+ mluOpDataType_t in_e_dtype = fft_plan->execution_dtype;
+ int batch = fft_plan->batch;
+ int n = fft_plan->n[0];
+
+ void *dft_re_matrix_addr =
+ (direction == 0) ? fft_plan->matmul_addrs.dft_re_matrix_addr
+ : fft_plan->matmul_addrs.ifft_dft_re_matrix_addr;
+ void *dft_im_matrix_addr =
+ (direction == 0) ? fft_plan->matmul_addrs.dft_im_matrix_addr
+ : fft_plan->matmul_addrs.ifft_dft_im_matrix_addr;
+
+ if (fft_plan->fft_strategy == CNFFT_FUNC_MATMUL) {
+ VLOG(5) << "into computeFFT1dMatmulResult CNFFT_FUNC_MATMUL";
+ // input real matmul dft real
+ status = fftQuantMatMul(
+ handle, batch, n, n, fft_plan->matmul_addrs.input_re_addr,
+ fft_plan->matmul_addrs.input_pos_addr,
+ fft_plan->matmul_addrs.input_scale_addr, dft_re_matrix_addr,
+ fft_plan->matmul_addrs.dft_pos_addr,
+ fft_plan->matmul_addrs.dft_scale_addr,
+ fft_plan->matmul_addrs.matmul_re_mul_re_addr, false, true, scale_factor,
+ 0.0, in_e_dtype, in_e_dtype, in_r_dtype,
+ fft_plan->matmul_addrs.internal_workspace_addr,
+ fft_plan->matmul_addrs.internal_workspace_size, api);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ // input imag matmul dft imag
+ status = fftQuantMatMul(
+ handle, batch, n, n, fft_plan->matmul_addrs.input_im_addr,
+ fft_plan->matmul_addrs.input_pos_addr,
+ fft_plan->matmul_addrs.input_scale_addr, dft_im_matrix_addr,
+ fft_plan->matmul_addrs.dft_pos_addr,
+ fft_plan->matmul_addrs.dft_scale_addr,
+ fft_plan->matmul_addrs.matmul_im_mul_im_addr, false, true, scale_factor,
+ 0.0, in_e_dtype, in_e_dtype, in_r_dtype,
+ fft_plan->matmul_addrs.internal_workspace_addr,
+ fft_plan->matmul_addrs.internal_workspace_size, api);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ // input real matmul dft imag
+ status = fftQuantMatMul(
+ handle, batch, n, n, fft_plan->matmul_addrs.input_re_addr,
+ fft_plan->matmul_addrs.input_pos_addr,
+ fft_plan->matmul_addrs.input_scale_addr, dft_im_matrix_addr,
+ fft_plan->matmul_addrs.dft_pos_addr,
+ fft_plan->matmul_addrs.dft_scale_addr,
+ fft_plan->matmul_addrs.matmul_re_mul_im_addr, false, true, scale_factor,
+ 0.0, in_e_dtype, in_e_dtype, in_r_dtype,
+ fft_plan->matmul_addrs.internal_workspace_addr,
+ fft_plan->matmul_addrs.internal_workspace_size, api);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ // input imag matmul dft real
+ status = fftQuantMatMul(
+ handle, batch, n, n, fft_plan->matmul_addrs.input_im_addr,
+ fft_plan->matmul_addrs.input_pos_addr,
+ fft_plan->matmul_addrs.input_scale_addr, dft_re_matrix_addr,
+ fft_plan->matmul_addrs.dft_pos_addr,
+ fft_plan->matmul_addrs.dft_scale_addr,
+ fft_plan->matmul_addrs.matmul_im_mul_re_addr, false, true, scale_factor,
+ 0.0, in_e_dtype, in_e_dtype, in_r_dtype,
+ fft_plan->matmul_addrs.internal_workspace_addr,
+ fft_plan->matmul_addrs.internal_workspace_size, api);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ // real mul real sub imag mul imag
+ int per_matmul_output_num = batch * n;
+ status = fftOptensor(handle, per_matmul_output_num,
+ fft_plan->matmul_addrs.matmul_re_mul_re_addr,
+ fft_plan->matmul_addrs.matmul_im_mul_im_addr,
+ fft_plan->matmul_addrs.matmul_re_mul_re_addr, 1.0, 1.0,
+ 0.0, in_r_dtype, CNNL_OP_TENSOR_SUB,
+ fft_plan->matmul_addrs.internal_workspace_addr,
+ fft_plan->matmul_addrs.internal_workspace_size, api);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ // real mul imag add imag mul real
+ status = fftOptensor(handle, per_matmul_output_num,
+ fft_plan->matmul_addrs.matmul_re_mul_im_addr,
+ fft_plan->matmul_addrs.matmul_im_mul_re_addr,
+ fft_plan->matmul_addrs.matmul_re_mul_im_addr, 1.0, 1.0,
+ 0.0, in_r_dtype, CNNL_OP_TENSOR_ADD,
+ fft_plan->matmul_addrs.internal_workspace_addr,
+ fft_plan->matmul_addrs.internal_workspace_size, api);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ } else if (fft_plan->fft_strategy == CNFFT_FUNC_COOLEY_TUKEY) {
+ int L = fft_plan->L;
+ int m = (1 << fft_plan->m);
+
+ // input real matmul dft real
+ status = fftQuantMatMul(
+ handle, batch * m, L, L, fft_plan->matmul_addrs.input_re_addr,
+ fft_plan->matmul_addrs.input_pos_addr,
+ fft_plan->matmul_addrs.input_scale_addr, dft_re_matrix_addr,
+ fft_plan->matmul_addrs.dft_pos_addr,
+ fft_plan->matmul_addrs.dft_scale_addr,
+ fft_plan->matmul_addrs.matmul_re_mul_re_addr, false, true, scale_factor,
+ 0.0, in_e_dtype, in_e_dtype, in_r_dtype,
+ fft_plan->matmul_addrs.internal_workspace_addr,
+ fft_plan->matmul_addrs.internal_workspace_size, api);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ // input imag matmul dft imag
+ status = fftQuantMatMul(
+ handle, batch * m, L, L, fft_plan->matmul_addrs.input_im_addr,
+ fft_plan->matmul_addrs.input_pos_addr,
+ fft_plan->matmul_addrs.input_scale_addr, dft_im_matrix_addr,
+ fft_plan->matmul_addrs.dft_pos_addr,
+ fft_plan->matmul_addrs.dft_scale_addr,
+ fft_plan->matmul_addrs.matmul_im_mul_im_addr, false, true, scale_factor,
+ 0.0, in_e_dtype, in_e_dtype, in_r_dtype,
+ fft_plan->matmul_addrs.internal_workspace_addr,
+ fft_plan->matmul_addrs.internal_workspace_size, api);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ // input real matmul dft imag
+ status = fftQuantMatMul(
+ handle, batch * m, L, L, fft_plan->matmul_addrs.input_re_addr,
+ fft_plan->matmul_addrs.input_pos_addr,
+ fft_plan->matmul_addrs.input_scale_addr, dft_im_matrix_addr,
+ fft_plan->matmul_addrs.dft_pos_addr,
+ fft_plan->matmul_addrs.dft_scale_addr,
+ fft_plan->matmul_addrs.matmul_re_mul_im_addr, false, true, scale_factor,
+ 0.0, in_e_dtype, in_e_dtype, in_r_dtype,
+ fft_plan->matmul_addrs.internal_workspace_addr,
+ fft_plan->matmul_addrs.internal_workspace_size, api);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ // input imag matmul dft real
+ status = fftQuantMatMul(
+ handle, batch * m, L, L, fft_plan->matmul_addrs.input_im_addr,
+ fft_plan->matmul_addrs.input_pos_addr,
+ fft_plan->matmul_addrs.input_scale_addr, dft_re_matrix_addr,
+ fft_plan->matmul_addrs.dft_pos_addr,
+ fft_plan->matmul_addrs.dft_scale_addr,
+ fft_plan->matmul_addrs.matmul_im_mul_re_addr, false, true, scale_factor,
+ 0.0, in_e_dtype, in_e_dtype, in_r_dtype,
+ fft_plan->matmul_addrs.internal_workspace_addr,
+ fft_plan->matmul_addrs.internal_workspace_size, api);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ } else if (fft_plan->fft_strategy == CNFFT_FUNC_STOCKHAM) {
+ int L = fft_plan->L;
+ int m = (1 << fft_plan->m);
+
+ // origin: in_trans[batch, 2^m, L] * W_real[L, L] -> IN_real[batch, 2^m, L]
+ // in_trans[batch, 2^m, L] * W_imag[L, L] -> IN_imag[batch, 2^m, L]
+ // update: W[c*L, L] * in[batch, L, 2^m] -> out[batch, c*L, 2^m]
+ status = fftBatchMatMulBcast(
+ handle, 2 * L, L, m * 2, batch, dft_re_matrix_addr,
+ fft_plan->matmul_addrs.dft_pos_addr,
+ fft_plan->matmul_addrs.dft_scale_addr,
+ fft_plan->matmul_addrs.input_pad_addr,
+ fft_plan->matmul_addrs.input_pos_addr,
+ fft_plan->matmul_addrs.input_scale_addr,
+ fft_plan->matmul_addrs.matmul_re_mul_re_addr, false, false,
+ scale_factor, 0.0, in_e_dtype, in_e_dtype, in_r_dtype,
+ fft_plan->matmul_addrs.internal_workspace_addr,
+ fft_plan->matmul_addrs.internal_workspace_size, api);
+ }
+
+ return status;
+}
+
+static mluOpStatus_t policyFunc(mluOpHandle_t handle, cnrtDim3_t *k_dim,
+ cnrtFunctionType_t *k_type) {
+ *k_type = CNRT_FUNC_TYPE_UNION1;
+ k_dim->x = handle->core_num_per_cluster;
+ k_dim->y = mluop::runtime::getClusterLimitCapability(handle);
+ k_dim->z = 1;
+ return MLUOP_STATUS_SUCCESS;
+}
+
+// only for CNFFT_FUNC_COOLEY_TUKEY and CNFFT_FUNC_STOCKHAM
+// input : matmul real result in matmul_re_mul_re_addr
+// matmul imag result in matmul_re_mul_im_addr
+// workspace: internal_workspace_addr
+// output : output real result in output_contiguous_addr
+mluOpStatus_t mergeFFT1dOutput(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan,
+ const float scale_factor, int direction) {
+ std::string api = "[mluOpExecFFT]";
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+ if (fft_plan->fft_strategy == CNFFT_FUNC_COOLEY_TUKEY) {
+ VLOG(5) << "launch merge fft1d output";
+ // TODO(niyuming) luanch merge kernel
+ int core_num = handle->core_num_per_cluster;
+ cnrtFunctionType_t k_type = CNRT_FUNC_TYPE_UNION1;
+ int task_type = mluop::runtime::getJobLimitCapability(handle);
+ int task_num = 1;
+
+ switch (task_type) {
+ default:
+ task_num = core_num;
+ break;
+ case (int)CNRT_FUNC_TYPE_UNION2:
+ task_num = core_num * 2;
+ break;
+ case (int)CNRT_FUNC_TYPE_UNION4:
+ task_num = core_num * 4;
+ break;
+ case (int)CNRT_FUNC_TYPE_UNION8:
+ task_num = core_num * 8;
+ break;
+ case (int)CNRT_FUNC_TYPE_UNION16:
+ task_num = core_num * 16;
+ break;
+ }
+
+ unsigned int dimx = task_num;
+ cnrtDim3_t k_dim = {dimx, 1, 1};
+ k_type = (cnrtFunctionType_t)dimx;
+ kernelFFTCooleyTukey(k_dim, k_type, handle->queue, fft_plan, direction,
+ FFT_IFFT);
+ } else if (fft_plan->fft_strategy == CNFFT_FUNC_STOCKHAM) {
+ VLOG(5) << "launch mrege rfft1d output";
+ cnrtDim3_t k_dim;
+ cnrtFunctionType_t k_type;
+ policyFunc(handle, &k_dim, &k_type);
+ kernelFFTStockham(k_dim, k_type, handle->queue, fft_plan, direction,
+ scale_factor, FFT_IFFT);
+ }
+ return status;
+}
+
+// only for CNFFT_FUNC_MATMUL
+// input : matmul real result in matmul_re_mul_re_addr
+// matmul imag result in matmul_re_mul_im_addr
+// output : output complex result in output_contiguous_addr
+static mluOpStatus_t transposeFFT1dOutput(mluOpHandle_t handle,
+ mluOpFFTPlan_t fft_plan) {
+ std::string api = "[mluOpExecFFT]";
+ VLOG(5) << "into transposeFFT1dOutput";
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+
+ if (fft_plan->fft_strategy == CNFFT_FUNC_MATMUL) {
+ VLOG(5) << "launch mluOpTranspose";
+ mluOpDataType_t in_c_dtype = fft_plan->input_dtype;
+ mluOpDataType_t in_r_dtype = (in_c_dtype == MLUOP_DTYPE_COMPLEX_HALF)
+ ? MLUOP_DTYPE_HALF
+ : MLUOP_DTYPE_FLOAT;
+ int batch = fft_plan->batch;
+ int n = fft_plan->n[0];
+
+ int output_num = batch * n;
+ const int trans_dim_num = 2;
+ int trans_input_dims[trans_dim_num] = {COMPLEX, output_num};
+ int trans_output_dims[trans_dim_num] = {output_num, COMPLEX};
+ int trans_permute[trans_dim_num] = {1, 0};
+
+ status = fftTranspose(
+ handle, trans_dim_num, trans_input_dims, trans_output_dims,
+ trans_permute, fft_plan->matmul_addrs.matmul_re_mul_re_addr,
+ fft_plan->matmul_addrs.output_contiguous_addr, in_r_dtype,
+ fft_plan->matmul_addrs.internal_workspace_addr,
+ fft_plan->matmul_addrs.internal_workspace_size, api);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ }
+ return status;
+}
+
+static mluOpStatus_t makeFFT1dContiguousOutput(mluOpHandle_t handle,
+ mluOpFFTPlan_t fft_plan,
+ void *output) {
+ std::string api = "[mluOpExecFFT]";
+ VLOG(5) << "into makeFFT1dContiguousOutput";
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+ if (!fft_plan->is_output_contiguous) {
+ VLOG(5) << "launch copy with stride";
+ mluOpDataType_t out_c_dtype = fft_plan->output_dtype;
+ // create tensor desc
+ mluOpTensorDescriptor_t copy_src_desc, copy_dst_desc;
+ status = mluOpCreateTensorDescriptor(©_src_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ status = mluOpCreateTensorDescriptor(©_dst_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ // set up tensor desc
+ const int out_dim_num = 2;
+ int64_t dims[out_dim_num] = {fft_plan->batch, fft_plan->onembed[0]};
+ int64_t strides[out_dim_num] = {fft_plan->odist, fft_plan->ostride};
+ status = mluOpSetTensorDescriptor_v2(copy_src_desc, MLUOP_LAYOUT_ARRAY,
+ out_c_dtype, out_dim_num, dims);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ status =
+ mluOpSetTensorDescriptorEx_v2(copy_dst_desc, MLUOP_LAYOUT_ARRAY,
+ out_c_dtype, out_dim_num, dims, strides);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ void *copy_src_addr = fft_plan->matmul_addrs.output_contiguous_addr;
+ DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle,
+ cnnl_handle); // convert to cnnl_handle
+ // convert to cnnl_tensor_descriptor
+ DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(copy_src_desc,
+ cnnl_copy_src_desc);
+ DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(copy_dst_desc,
+ cnnl_copy_dst_desc);
+
+ CALL_CNNL(cnnlCopy(cnnl_handle, cnnl_copy_src_desc, copy_src_addr,
+ cnnl_copy_dst_desc, output));
+
+ DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_copy_src_desc);
+ DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_copy_dst_desc);
+ DESTROY_CNNL_HANDLE(cnnl_handle);
+ }
+ return status;
+}
+
+mluOpStatus_t execFFT1d(mluOpHandle_t handle, const mluOpFFTPlan_t fft_plan,
+ const void *input, const float scale_factor,
+ void *workspace, void *output, int direction) {
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+ std::string api = "[mluOpExecFFT]";
+ configureFFT1dMatmulWorkspaceAddrs(handle, fft_plan, (void *)input, workspace,
+ output);
+
+ status = makeFFT1dContiguousInput(handle, fft_plan, input);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ status = padFFT1dContiguousInput(handle, fft_plan);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ status = transposeFFT1dPaddedInput(handle, fft_plan);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ status = quantizeFFT1dPaddedInput(handle, fft_plan);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ status = computeFFT1dMatmulResult(handle, fft_plan, scale_factor, direction);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ status = mergeFFT1dOutput(handle, fft_plan, scale_factor, direction);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ status = transposeFFT1dOutput(handle, fft_plan);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ status = makeFFT1dContiguousOutput(handle, fft_plan, output);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ return status;
+}
diff --git a/kernels/fft/common/fft_basic_ops.cpp b/kernels/fft/common/fft_basic_ops.cpp
new file mode 100644
index 000000000..3c5ee9116
--- /dev/null
+++ b/kernels/fft/common/fft_basic_ops.cpp
@@ -0,0 +1,727 @@
+/*************************************************************************
+ * Copyright (C) [2024] by Cambricon, Inc.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a
+ * copy of this software and associated documentation files (the
+ * "Software"), to deal in the Software without restriction, including
+ * without limitation the rights to use, copy, modify, merge, publish,
+ * distribute, sublicense, and/or sell copies of the Software, and to
+ * permit persons to whom the Software is furnished to do so, subject to
+ * the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included
+ * in all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+ * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+ * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+ * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
+ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
+ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+ *************************************************************************/
+
+#include "fft_basic_ops.h"
+
+bool fftIsIntDtype(const mluOpDataType_t dtype) {
+ if (dtype == MLUOP_DTYPE_INT8 || dtype == MLUOP_DTYPE_INT16) {
+ return true;
+ } else {
+ return false;
+ }
+}
+
+bool fftIsFloatDtype(const mluOpDataType_t dtype) {
+ if (dtype == MLUOP_DTYPE_HALF || dtype == MLUOP_DTYPE_FLOAT) {
+ return true;
+ } else {
+ return false;
+ }
+}
+
+mluOpStatus_t fftGetQuantizeParamWorkspaceSize(mluOpHandle_t handle,
+ size_t &required_size,
+ int array_length,
+ mluOpDataType_t data_type,
+ mluOpDataType_t compute_type,
+ const std::string api) {
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+ // size_t required_size = 0;
+ if (data_type != compute_type) {
+ // create descriptor
+ mluOpTensorDescriptor_t input_desc;
+ status = mluOpCreateTensorDescriptor(&input_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ // set descriptor
+ int64_t input_dims[1] = {array_length};
+ status = mluOpSetTensorDescriptor_v2(input_desc, MLUOP_LAYOUT_ARRAY,
+ data_type, 1, input_dims);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle,
+ cnnl_handle); // convert to cnnl_handle
+ // convert to cnnl_tensor_descriptor
+ DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(input_desc, cnnl_input_desc);
+
+ // get quantize param workspace
+ CALL_CNNL(cnnlGetQuantizeParamWorkspaceSize(cnnl_handle, cnnl_input_desc,
+ &required_size));
+
+ DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_input_desc);
+ DESTROY_CNNL_HANDLE(cnnl_handle);
+ }
+ return status;
+}
+
+mluOpStatus_t fftQuantizePositionScale(mluOpHandle_t handle, int array_length,
+ mluOpDataType_t data_type,
+ mluOpDataType_t compute_type,
+ const void *input, void *position,
+ void *scale, void *workspace,
+ size_t workspace_size,
+ const std::string api) {
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+ if (data_type != compute_type) {
+ // create descriptor
+ mluOpTensorDescriptor_t quant_desc;
+ status = mluOpCreateTensorDescriptor(&quant_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ // set descriptor
+ int64_t quant_dims[1] = {array_length};
+ status = mluOpSetTensorDescriptor_v2(quant_desc, MLUOP_LAYOUT_ARRAY,
+ data_type, 1, quant_dims);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle,
+ cnnl_handle); // convert to cnnl_handle
+ // convert to cnnl_tensor_descriptor
+ DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(quant_desc, cnnl_quant_desc);
+
+ // get quantize param
+ int bit_width;
+ mluop::castDtypeToBitwidth(compute_type, &bit_width);
+ cnnlQuantizeMode_t mode = CNNL_QUANTIZE_POSITION_SCALE;
+ CALL_CNNL(cnnlQuantizeParam(cnnl_handle, mode, cnnl_quant_desc, input,
+ bit_width, workspace, workspace_size, position,
+ scale, nullptr));
+
+ DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_quant_desc);
+ DESTROY_CNNL_HANDLE(cnnl_handle);
+ }
+ return status;
+}
+
+mluOpStatus_t fftGetQuantizeMatMulWorkspaceSize(
+ mluOpHandle_t handle, size_t &workspace_size, int m, int k, int n,
+ bool is_trans_a, bool is_trans_b, mluOpDataType_t a_compute_type,
+ mluOpDataType_t b_compute_type, mluOpDataType_t data_type,
+ const std::string api) {
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+ int trans_a_int = (int)is_trans_a;
+ int trans_b_int = (int)is_trans_b;
+
+ // create descriptor
+ mluOpTensorDescriptor_t a_desc = nullptr;
+ mluOpTensorDescriptor_t b_desc = nullptr;
+ mluOpTensorDescriptor_t c_desc = nullptr;
+ status = mluOpCreateTensorDescriptor(&a_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ status = mluOpCreateTensorDescriptor(&b_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ status = mluOpCreateTensorDescriptor(&c_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ // set descriptor
+ int64_t a_dims[2];
+ int64_t b_dims[2];
+ int64_t c_dims[2] = {m, n};
+ if (is_trans_a) {
+ a_dims[0] = k;
+ a_dims[1] = m;
+ } else {
+ a_dims[0] = m;
+ a_dims[1] = k;
+ }
+ if (is_trans_b) {
+ b_dims[0] = n;
+ b_dims[1] = k;
+ } else {
+ b_dims[0] = k;
+ b_dims[1] = n;
+ }
+ status = mluOpSetTensorDescriptor_v2(a_desc, MLUOP_LAYOUT_ARRAY, data_type, 2,
+ a_dims);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ status = mluOpSetTensorDescriptorOnchipDataType(a_desc, a_compute_type);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ status = mluOpSetTensorDescriptor_v2(b_desc, MLUOP_LAYOUT_ARRAY, data_type, 2,
+ b_dims);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ status = mluOpSetTensorDescriptorOnchipDataType(b_desc, b_compute_type);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ status = mluOpSetTensorDescriptor_v2(c_desc, MLUOP_LAYOUT_ARRAY, data_type, 2,
+ c_dims);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ if (fftIsIntDtype(a_compute_type) && fftIsIntDtype(b_compute_type) &&
+ c_desc->dtype == MLUOP_DTYPE_HALF) {
+ status = mluOpSetTensorDescriptorOnchipDataType(c_desc, MLUOP_DTYPE_FLOAT);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ } else {
+ status = mluOpSetTensorDescriptorOnchipDataType(c_desc, data_type);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ }
+
+ DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle,
+ cnnl_handle); // convert to cnnl_handle
+
+ // convert to cnnl_tensor_descriptor
+ DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(a_desc, cnnl_a_desc);
+ DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(b_desc, cnnl_b_desc);
+ DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(c_desc, cnnl_c_desc);
+
+ // get matmul workspace
+ if (fftIsIntDtype(a_compute_type) && fftIsIntDtype(b_compute_type)) {
+ cnnlMatMulDescriptor_t matmul_desc;
+ CALL_CNNL(cnnlMatMulDescCreate(&matmul_desc));
+ CALL_CNNL(cnnlSetMatMulDescAttr(matmul_desc, CNNL_MATMUL_DESC_COMPUTE_TYPE,
+ &data_type, sizeof(int32_t)));
+ CALL_CNNL(cnnlSetMatMulDescAttr(matmul_desc, CNNL_MATMUL_DESC_TRANSA,
+ &trans_a_int, sizeof(int32_t)));
+ CALL_CNNL(cnnlSetMatMulDescAttr(matmul_desc, CNNL_MATMUL_DESC_TRANSB,
+ &trans_b_int, sizeof(int32_t)));
+
+ cnnlMatMulAlgo_t matmul_algo;
+ CALL_CNNL(cnnlMatMulAlgoCreate(&matmul_algo));
+ cnnlMatMulPreference_t preference = CNNL_MATMUL_FASTEST;
+ CALL_CNNL(cnnlGetQuantizeMatMulAlgorithm(
+ cnnl_handle, matmul_desc, cnnl_a_desc, cnnl_b_desc, cnnl_c_desc,
+ preference, &matmul_algo));
+
+ CALL_CNNL(cnnlGetQuantizeMatMulWorkspaceSize(
+ cnnl_handle, matmul_desc, cnnl_a_desc, cnnl_b_desc, cnnl_c_desc,
+ matmul_algo, &workspace_size));
+
+ CALL_CNNL(cnnlMatMulDescDestroy(matmul_desc));
+ CALL_CNNL(cnnlMatMulAlgoDestroy(matmul_algo));
+ } else {
+ workspace_size = 0; // mluOpMatmul doesn't need workspace.
+ }
+
+ // destroy cnnl descriptor
+ DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_a_desc);
+ DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_b_desc);
+ DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_c_desc);
+
+ DESTROY_CNNL_HANDLE(cnnl_handle);
+ return status;
+}
+
+mluOpStatus_t fftQuantMatMul(mluOpHandle_t handle, int m, int k, int n,
+ void *a_ptr, void *a_pos, void *a_scale,
+ void *b_ptr, void *b_pos, void *b_scale,
+ void *c_ptr, bool is_trans_a, bool is_trans_b,
+ float alpha, float beta,
+ mluOpDataType_t a_compute_type,
+ mluOpDataType_t b_compute_type,
+ mluOpDataType_t data_type, void *workspace,
+ size_t workspace_size, const std::string api) {
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+ int trans_a_int = (int)is_trans_a;
+ int trans_b_int = (int)is_trans_b;
+
+ // create descriptor
+ mluOpTensorDescriptor_t a_desc = nullptr;
+ mluOpTensorDescriptor_t b_desc = nullptr;
+ mluOpTensorDescriptor_t c_desc = nullptr;
+ status = mluOpCreateTensorDescriptor(&a_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ status = mluOpCreateTensorDescriptor(&b_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ status = mluOpCreateTensorDescriptor(&c_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ // set descriptor
+ int64_t a_dims[2];
+ int64_t b_dims[2];
+ int64_t c_dims[2] = {m, n};
+ if (is_trans_a) {
+ a_dims[0] = k;
+ a_dims[1] = m;
+ } else {
+ a_dims[0] = m;
+ a_dims[1] = k;
+ }
+ if (is_trans_b) {
+ b_dims[0] = n;
+ b_dims[1] = k;
+ } else {
+ b_dims[0] = k;
+ b_dims[1] = n;
+ }
+ status = mluOpSetTensorDescriptor_v2(a_desc, MLUOP_LAYOUT_ARRAY, data_type, 2,
+ a_dims);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ status = mluOpSetTensorDescriptorOnchipDataType(a_desc, a_compute_type);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ status = mluOpSetTensorDescriptor_v2(b_desc, MLUOP_LAYOUT_ARRAY, data_type, 2,
+ b_dims);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ status = mluOpSetTensorDescriptorOnchipDataType(b_desc, b_compute_type);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ status = mluOpSetTensorDescriptor_v2(c_desc, MLUOP_LAYOUT_ARRAY, data_type, 2,
+ c_dims);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ if (fftIsIntDtype(a_compute_type) && fftIsIntDtype(b_compute_type) &&
+ c_desc->dtype == MLUOP_DTYPE_HALF) {
+ status = mluOpSetTensorDescriptorOnchipDataType(c_desc, MLUOP_DTYPE_FLOAT);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ } else {
+ status = mluOpSetTensorDescriptorOnchipDataType(c_desc, data_type);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ }
+
+ DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle,
+ cnnl_handle); // convert to cnnl_handle
+
+ // convert to cnnl_tensor_descriptor
+ DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(a_desc, cnnl_a_desc);
+ DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(b_desc, cnnl_b_desc);
+ DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(c_desc, cnnl_c_desc);
+
+ // compute matmul result
+ if (fftIsIntDtype(a_compute_type) && fftIsIntDtype(b_compute_type)) {
+ cnnlMatMulDescriptor_t matmul_desc;
+ CALL_CNNL(cnnlMatMulDescCreate(&matmul_desc));
+ CALL_CNNL(cnnlSetMatMulDescAttr(matmul_desc, CNNL_MATMUL_DESC_COMPUTE_TYPE,
+ &data_type, sizeof(int32_t)));
+ CALL_CNNL(cnnlSetMatMulDescAttr(matmul_desc, CNNL_MATMUL_DESC_TRANSA,
+ &trans_a_int, sizeof(int32_t)));
+ CALL_CNNL(cnnlSetMatMulDescAttr(matmul_desc, CNNL_MATMUL_DESC_TRANSB,
+ &trans_b_int, sizeof(int32_t)));
+
+ cnnlMatMulAlgo_t matmul_algo;
+ CALL_CNNL(cnnlMatMulAlgoCreate(&matmul_algo));
+ cnnlMatMulPreference_t preference = CNNL_MATMUL_FASTEST;
+ CALL_CNNL(cnnlGetQuantizeMatMulAlgorithm(
+ cnnl_handle, matmul_desc, cnnl_a_desc, cnnl_b_desc, cnnl_c_desc,
+ preference, &matmul_algo));
+
+ const float one = 1.0;
+ const float zero = 0.0;
+ CALL_CNNL(cnnlQuantizeMatMul(
+ cnnl_handle, matmul_desc, &one, cnnl_a_desc, a_ptr, a_pos, a_scale,
+ nullptr, cnnl_b_desc, b_ptr, b_pos, b_scale, nullptr, &zero,
+ cnnl_c_desc, c_ptr, matmul_algo, workspace, workspace_size));
+
+ if ((alpha != 1.0) || (beta != 0.0)) {
+ CALL_CNNL(cnnlTransform_v2(cnnl_handle, CNNL_POINTER_MODE_HOST, &alpha,
+ cnnl_c_desc, c_ptr, &beta, cnnl_c_desc,
+ c_ptr));
+ }
+
+ CALL_CNNL(cnnlMatMulDescDestroy(matmul_desc));
+ CALL_CNNL(cnnlMatMulAlgoDestroy(matmul_algo));
+ } else {
+ c_desc->onchip_dtype = MLUOP_DTYPE_FLOAT;
+ CALL_CNNL(cnnlMatMul(cnnl_handle, is_trans_a, is_trans_b, &alpha,
+ cnnl_a_desc, a_ptr, cnnl_b_desc, b_ptr, &beta,
+ cnnl_c_desc, c_ptr));
+ }
+
+ // destroy cnnl descriptor
+ DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_a_desc);
+ DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_b_desc);
+ DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_c_desc);
+
+ DESTROY_CNNL_HANDLE(cnnl_handle);
+
+ return status;
+}
+
+mluOpStatus_t fftBatchMatMulBcast(
+ mluOpHandle_t handle,
+ int m, // 2 * L = 750
+ int k, // L = 375
+ int n, // 2^m = 128
+ int batch, void *a_ptr, void *a_pos, void *a_scale, void *b_ptr,
+ void *b_pos, void *b_scale, void *c_ptr, bool is_trans_a, bool is_trans_b,
+ float alpha, float beta, mluOpDataType_t a_compute_type,
+ mluOpDataType_t b_compute_type, mluOpDataType_t data_type, void *workspace,
+ size_t workspace_size, const std::string api) {
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+ int trans_a_int = (int)is_trans_a;
+ int trans_b_int = (int)is_trans_b;
+
+ // create descriptor
+ mluOpTensorDescriptor_t a_desc = nullptr;
+ mluOpTensorDescriptor_t b_desc = nullptr;
+ mluOpTensorDescriptor_t c_desc = nullptr;
+ status = mluOpCreateTensorDescriptor(&a_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ status = mluOpCreateTensorDescriptor(&b_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ status = mluOpCreateTensorDescriptor(&c_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ // set descriptor
+ int a_dims[2];
+ int b_dims[3] = {batch, k, n};
+ int c_dims[3] = {batch, m, n};
+ if (is_trans_a) {
+ a_dims[0] = k;
+ a_dims[1] = m;
+ } else {
+ a_dims[0] = m;
+ a_dims[1] = k;
+ }
+ if (is_trans_b) {
+ b_dims[1] = n;
+ b_dims[2] = k;
+ } else {
+ b_dims[1] = k;
+ b_dims[2] = n;
+ }
+ status = mluOpSetTensorDescriptor(a_desc, MLUOP_LAYOUT_ARRAY, data_type, 2,
+ a_dims);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ status = mluOpSetTensorDescriptorOnchipDataType(a_desc, a_compute_type);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ status = mluOpSetTensorDescriptor(b_desc, MLUOP_LAYOUT_ARRAY, data_type, 3,
+ b_dims);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ status = mluOpSetTensorDescriptorOnchipDataType(b_desc, b_compute_type);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ status = mluOpSetTensorDescriptor(c_desc, MLUOP_LAYOUT_ARRAY, data_type, 3,
+ c_dims);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ c_desc->onchip_dtype = MLUOP_DTYPE_FLOAT;
+
+ DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle,
+ cnnl_handle); // convert to cnnl_handle
+
+ // convert to cnnl_tensor_descriptor
+ DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(a_desc, cnnl_a_desc);
+ DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(b_desc, cnnl_b_desc);
+ DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(c_desc, cnnl_c_desc);
+
+ CALL_CNNL(cnnlBatchMatMulBCast(cnnl_handle, is_trans_a, is_trans_b,
+ cnnl_a_desc, a_ptr, cnnl_b_desc, b_ptr, NULL,
+ 0, cnnl_c_desc, c_ptr));
+
+ // destroy descriptor
+ // destroy cnnl descriptor
+ DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_a_desc);
+ DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_b_desc);
+ DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_c_desc);
+
+ DESTROY_CNNL_HANDLE(cnnl_handle);
+
+ return status;
+}
+
+mluOpStatus_t fftGetTransposeWorkspaceSize(mluOpHandle_t handle,
+ size_t &workspace_size, int dim_num,
+ int ori_dims[], int permute[],
+ mluOpDataType_t data_type,
+ const std::string api) {
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+
+ // create descriptor
+ mluOpTensorDescriptor_t input_desc = nullptr;
+ status = mluOpCreateTensorDescriptor(&input_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ cnnlTransposeDescriptor_t trans_desc = nullptr;
+ CALL_CNNL(cnnlCreateTransposeDescriptor(&trans_desc));
+
+ DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle,
+ cnnl_handle); // convert to cnnl_handle
+ // set descriptor
+ status = mluOpSetTensorDescriptor(input_desc, MLUOP_LAYOUT_ARRAY, data_type,
+ dim_num, ori_dims);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(input_desc, cnnl_input_desc);
+
+ CALL_CNNL(cnnlSetTransposeDescriptor(trans_desc, dim_num, permute));
+
+ // get workspace
+ CALL_CNNL(cnnlGetTransposeWorkspaceSize(cnnl_handle, cnnl_input_desc,
+ trans_desc, &workspace_size));
+
+ // destroy descriptor
+ DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_input_desc);
+ CALL_CNNL(cnnlDestroyTransposeDescriptor(trans_desc));
+
+ return status;
+}
+
+mluOpStatus_t fftTranspose(mluOpHandle_t handle, int dim_num, int ori_dims[],
+ int transed_dims[], int permute[], void *ori_ptr,
+ void *transed_ptr, mluOpDataType_t data_type,
+ void *workspace, size_t workspace_size,
+ const std::string api) {
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+
+ // create descriptor
+ mluOpTensorDescriptor_t input_desc = nullptr;
+ mluOpTensorDescriptor_t transed_input_desc = nullptr;
+ status = mluOpCreateTensorDescriptor(&input_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ status = mluOpCreateTensorDescriptor(&transed_input_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ // set descriptor
+ status = mluOpSetTensorDescriptor(input_desc, MLUOP_LAYOUT_ARRAY, data_type,
+ dim_num, ori_dims);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ status = mluOpSetTensorDescriptor(transed_input_desc, MLUOP_LAYOUT_ARRAY,
+ data_type, dim_num, transed_dims);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle,
+ cnnl_handle); // convert to cnnl_handle
+ // convert to cnnl_tensor_descriptor
+ DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(input_desc, cnnl_input_desc);
+ DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(transed_input_desc,
+ cnnl_transed_input_desc);
+
+ // compute transpose
+ cnnlTransposeDescriptor_t trans_desc = nullptr;
+ CALL_CNNL(cnnlCreateTransposeDescriptor(&trans_desc));
+ CALL_CNNL(cnnlSetTransposeDescriptor(trans_desc, dim_num, permute));
+
+ CALL_CNNL(cnnlTranspose_v2(cnnl_handle, trans_desc, cnnl_input_desc, ori_ptr,
+ cnnl_transed_input_desc, transed_ptr, workspace,
+ workspace_size));
+
+ // destroy descriptor
+ DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_input_desc);
+ DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_transed_input_desc);
+ CALL_CNNL(cnnlDestroyTransposeDescriptor(trans_desc));
+
+ DESTROY_CNNL_HANDLE(cnnl_handle);
+
+ return status;
+}
+
+mluOpStatus_t fftGetOptensorWorkspaceSize(mluOpHandle_t handle,
+ size_t &workspace_size, int elem_num,
+ mluOpDataType_t data_type,
+ const std::string api) {
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+
+ // create descriptor
+ mluOpTensorDescriptor_t in1_desc = nullptr;
+ mluOpTensorDescriptor_t in2_desc = nullptr;
+ mluOpTensorDescriptor_t out_desc = nullptr;
+ status = mluOpCreateTensorDescriptor(&in1_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ status = mluOpCreateTensorDescriptor(&in2_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ status = mluOpCreateTensorDescriptor(&out_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ // set descriptor
+ int64_t dims[1] = {elem_num};
+ status = mluOpSetTensorDescriptor_v2(in1_desc, MLUOP_LAYOUT_ARRAY, data_type,
+ 1, dims);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ status = mluOpSetTensorDescriptor_v2(in2_desc, MLUOP_LAYOUT_ARRAY, data_type,
+ 1, dims);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ status = mluOpSetTensorDescriptor_v2(out_desc, MLUOP_LAYOUT_ARRAY, data_type,
+ 1, dims);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle,
+ cnnl_handle); // convert to cnnl_handle
+
+ // convert to cnnl_tensor_descriptor
+ DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(in1_desc, cnnl_in1_desc);
+ DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(in2_desc, cnnl_in2_desc);
+ DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(out_desc, cnnl_out_desc);
+
+ // get workspace
+ CALL_CNNL(cnnlGetOpTensorWorkspaceSize(cnnl_handle, cnnl_in1_desc,
+ cnnl_in2_desc, cnnl_out_desc,
+ &workspace_size));
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ // destroy descriptor
+ DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_in1_desc);
+ DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_in2_desc);
+ DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_out_desc);
+
+ DESTROY_CNNL_HANDLE(cnnl_handle);
+
+ return status;
+}
+
+mluOpStatus_t fftOptensor(mluOpHandle_t handle, int elem_num, void *in1_ptr,
+ void *in2_ptr, void *out_ptr, float alpha1,
+ float alpha2, float beta, mluOpDataType_t data_type,
+ cnnlOpTensorDesc_t op_type, void *workspace,
+ size_t workspace_size, const std::string api) {
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+
+ // create descriptor
+ mluOpTensorDescriptor_t in1_desc = nullptr;
+ mluOpTensorDescriptor_t in2_desc = nullptr;
+ mluOpTensorDescriptor_t out_desc = nullptr;
+ status = mluOpCreateTensorDescriptor(&in1_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ status = mluOpCreateTensorDescriptor(&in2_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ status = mluOpCreateTensorDescriptor(&out_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ // set descriptor
+ int64_t dims[1] = {elem_num};
+ status = mluOpSetTensorDescriptor_v2(in1_desc, MLUOP_LAYOUT_ARRAY, data_type,
+ 1, dims);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ status = mluOpSetTensorDescriptor_v2(in2_desc, MLUOP_LAYOUT_ARRAY, data_type,
+ 1, dims);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ status = mluOpSetTensorDescriptor_v2(out_desc, MLUOP_LAYOUT_ARRAY, data_type,
+ 1, dims);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle,
+ cnnl_handle); // convert to cnnl_handle
+
+ // convert to cnnl_tensor_descriptor
+ DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(in1_desc, cnnl_in1_desc);
+ DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(in2_desc, cnnl_in2_desc);
+ DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(out_desc, cnnl_out_desc);
+ // compute optensor
+ cnnlOpTensorDescriptor_t opTensor_desc = nullptr;
+ CALL_CNNL(cnnlCreateOpTensorDescriptor(&opTensor_desc));
+ CALL_CNNL(cnnlSetOpTensorDescriptor(opTensor_desc, op_type,
+ (cnnlDataType_t)data_type,
+ CNNL_NOT_PROPAGATE_NAN));
+
+ CALL_CNNL(cnnlOpTensor(cnnl_handle, opTensor_desc, &alpha1, cnnl_in1_desc,
+ in1_ptr, &alpha2, cnnl_in2_desc, in2_ptr, workspace,
+ workspace_size, &beta, cnnl_out_desc, out_ptr));
+
+ // destroy descriptor
+ CALL_CNNL(cnnlDestroyOpTensorDescriptor(opTensor_desc));
+
+ // destroy cnnl descriptor
+ DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_in1_desc);
+ DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_in2_desc);
+ DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_out_desc);
+
+ DESTROY_CNNL_HANDLE(cnnl_handle);
+
+ return status;
+}
+
+//
+static void initBasicParam(const int n, int &L, int &m) {
+ // split n into 2^m * L
+ m = 0;
+ L = n;
+ while (1) {
+ int rem = L % 2;
+ if (rem != 0) {
+ break;
+ }
+ m++;
+ L = L / 2;
+ }
+
+ // when L is smaller than 64, encrease L to ensure IO efficiency.
+ while (L < 64 && m > 1) {
+ L = L * 2;
+ m--;
+ }
+}
+
+static bool findStockham(mluOpHandle_t handle, int &L, int &m, int &L_sub,
+ bool &find_stockham) {
+ if (find_stockham) {
+ int NFU_ALIGN_NUM = NFU_ALIGN_SIZE / sizeof(float);
+ int max_nram_size = handle->nram_size + REM_FOR_STACK - 32 * 1024;
+ L_sub = PAD_UP(L, NFU_ALIGN_NUM);
+ int L_tmp = L;
+ int m_tmp = m;
+ // one calculation requires 14 copies of space as follows:
+ // input(4): y_in_r, z_in_r, y_in_i, z_in_i,
+ // output(4): x_out1_r, x_out2_r, x_out1_i, x_out2_i,
+ // w matrix(6): w_r, w_i, wz_rr, wz_ri, wz_ir, wz_ii
+ // 2 represents ping_pong for pipeline.
+ // 1 represents a public space stores the incremental sequence shared by
+ // ping_pong.
+ int cal_unit_tmp = 14 * 2 + 1;
+ size_t cal_unit_once_tmp =
+ L_sub * pow(2, m_tmp - 1) * cal_unit_tmp * sizeof(float);
+ while (cal_unit_once_tmp > max_nram_size) {
+ if (L_sub >= NFU_ALIGN_NUM * 2) {
+ L_sub -= NFU_ALIGN_NUM;
+ } else if (m_tmp > 1) {
+ L_tmp = L_tmp * 2;
+ m_tmp--;
+ } else {
+ break;
+ }
+ cal_unit_once_tmp =
+ L_sub * pow(2, m_tmp - 1) * cal_unit_tmp * sizeof(float);
+ }
+ if (cal_unit_once_tmp < max_nram_size && L_tmp <= 4096) {
+ L = L_tmp;
+ m = m_tmp;
+ L_sub = std::min(L, PAD_UP(L_sub, NFU_ALIGN_NUM));
+ VLOG(5) << "m: " << m << ", L: " << L << ", L_sub: " << L_sub;
+ return true;
+ }
+ }
+ return false;
+}
+
+static bool findCooleyTukey(mluOpHandle_t handle, int &L, int &m, int &s) {
+ int cal_unit = 14;
+ int cal_unit_once =
+ PAD_UP(L, NFU_ALIGN_SIZE / sizeof(float)) * cal_unit * sizeof(float);
+
+ // calculate s
+ s = m;
+ if (cal_unit_once <= handle->nram_size) {
+ // split m
+ for (int i = m; i >= 0; i--) {
+ size_t space_use = pow(2, i) * cal_unit_once;
+ if (space_use < handle->nram_size) {
+ s = i;
+ break;
+ }
+ }
+ } else {
+ m = 0;
+ return -1;
+ }
+ if (s == m) {
+ s--;
+ }
+
+ VLOG(5) << "m: " << m << ", L: " << L << ", s: " << s;
+
+ return true;
+}
+
+// Find the most suitable parameters for Cooley-Tukey or Stockham algorithm.
+int findFFTOptLimit(mluOpHandle_t handle, const int n, int &m, int &L, int &s,
+ int &L_sub, bool &find_stockham) {
+ initBasicParam(n, L, m);
+
+ int flag;
+ flag = findStockham(handle, L, m, L_sub, find_stockham);
+ if (flag) {
+ return 0;
+ }
+
+ flag = findCooleyTukey(handle, L, m, s);
+ return flag;
+}
diff --git a/kernels/fft/common/fft_basic_ops.h b/kernels/fft/common/fft_basic_ops.h
new file mode 100644
index 000000000..8d28d179b
--- /dev/null
+++ b/kernels/fft/common/fft_basic_ops.h
@@ -0,0 +1,103 @@
+/*************************************************************************
+ * Copyright (C) [2024] by Cambricon, Inc.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a
+ * copy of this software and associated documentation files (the
+ * "Software"), to deal in the Software without restriction, including
+ * without limitation the rights to use, copy, modify, merge, publish,
+ * distribute, sublicense, and/or sell copies of the Software, and to
+ * permit persons to whom the Software is furnished to do so, subject to
+ * the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included
+ * in all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+ * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+ * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+ * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
+ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
+ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+ *************************************************************************/
+
+#ifndef KERNELS_FFT_COMMON_FFT_BASIC_OPS_H_
+#define KERNELS_FFT_COMMON_FFT_BASIC_OPS_H_
+
+#include
+#include
+#include "core/tensor.h"
+#include "core/context.h"
+#include "core/tool.h"
+#include "kernels/kernel.h"
+#include "kernels/utils/cnnl_helper.h"
+#include "mlu_op.h"
+
+bool fftIsIntDtype(const mluOpDataType_t dtype);
+
+bool fftIsFloatDtype(const mluOpDataType_t dtype);
+
+mluOpStatus_t fftGetQuantizeParamWorkspaceSize(mluOpHandle_t handle,
+ size_t &required_size,
+ int array_length,
+ mluOpDataType_t data_type,
+ mluOpDataType_t compute_type,
+ const std::string api);
+
+mluOpStatus_t fftQuantizePositionScale(
+ mluOpHandle_t handle, int array_length, mluOpDataType_t data_type,
+ mluOpDataType_t compute_type, const void *input, void *position,
+ void *scale, void *workspace, size_t workspace_size, const std::string api);
+
+mluOpStatus_t fftGetQuantizeMatMulWorkspaceSize(
+ mluOpHandle_t handle, size_t &workspace_size, int m, int k, int n,
+ bool is_trans_a, bool is_trans_b, mluOpDataType_t a_compute_type,
+ mluOpDataType_t b_compute_type, mluOpDataType_t data_type,
+ const std::string api);
+
+mluOpStatus_t fftQuantMatMul(mluOpHandle_t handle, int m, int k, int n,
+ void *a_ptr, void *a_pos, void *a_scale,
+ void *b_ptr, void *b_pos, void *b_scale,
+ void *c_ptr, bool is_trans_a, bool is_trans_b,
+ float alpha, float beta,
+ mluOpDataType_t a_compute_type,
+ mluOpDataType_t b_compute_type,
+ mluOpDataType_t data_type, void *workspace,
+ size_t workspace_size, const std::string api);
+
+mluOpStatus_t fftBatchMatMulBcast(mluOpHandle_t handle, int m, int k, int n,
+ int batch, void *a_ptr, void *a_pos,
+ void *a_scale, void *b_ptr, void *b_pos,
+ void *b_scale, void *c_ptr, bool is_trans_a,
+ bool is_trans_b, float alpha, float beta,
+ mluOpDataType_t a_compute_type,
+ mluOpDataType_t b_compute_type,
+ mluOpDataType_t data_type, void *workspace,
+ size_t workspace_size, const std::string api);
+
+mluOpStatus_t fftGetTransposeWorkspaceSize(mluOpHandle_t handle,
+ size_t &workspace_size, int dim_num,
+ int ori_dims[], int permute[],
+ mluOpDataType_t data_type,
+ const std::string api);
+
+mluOpStatus_t fftTranspose(mluOpHandle_t handle, int dim_num, int ori_dims[],
+ int transed_dims[], int permute[], void *ori_ptr,
+ void *transed_ptr, mluOpDataType_t data_type,
+ void *workspace, size_t workspace_size,
+ const std::string api);
+
+mluOpStatus_t fftGetOptensorWorkspaceSize(mluOpHandle_t handle,
+ size_t &workspace_size, int elem_num,
+ mluOpDataType_t data_type,
+ const std::string api);
+
+mluOpStatus_t fftOptensor(mluOpHandle_t handle, int elem_num, void *in1_ptr,
+ void *in2_ptr, void *out_ptr, float alpha1,
+ float alpha2, float beta, mluOpDataType_t data_type,
+ cnnlOpTensorDesc_t op_type, void *workspace,
+ size_t workspace_size, const std::string api);
+
+int findFFTOptLimit(mluOpHandle_t handle, const int n, int &m, int &L, int &s,
+ int &L_sub, bool &find_stockham);
+#endif // KERNELS_FFT_COMMON_FFT_BASIC_OPS_H_
diff --git a/kernels/fft/common/fft_common_kernels.h b/kernels/fft/common/fft_common_kernels.h
new file mode 100644
index 000000000..16e5d2b3b
--- /dev/null
+++ b/kernels/fft/common/fft_common_kernels.h
@@ -0,0 +1,44 @@
+/*************************************************************************
+ * Copyright (C) [2024] by Cambricon, Inc.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a
+ * copy of this software and associated documentation files (the
+ * "Software"), to deal in the Software without restriction, including
+ * without limitation the rights to use, copy, modify, merge, publish,
+ * distribute, sublicense, and/or sell copies of the Software, and to
+ * permit persons to whom the Software is furnished to do so, subject to
+ * the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included
+ * in all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+ * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+ * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+ * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
+ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
+ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+ *************************************************************************/
+#ifndef KERNELS_FFT_COMMON_FFT_COMMON_KERNELS_H_
+#define KERNELS_FFT_COMMON_FFT_COMMON_KERNELS_H_
+
+#include "kernels/fft/fft.h"
+
+mluOpStatus_t MLUOP_WIN_API kernelGenerateRFFTHalfDFTMatrix(
+ cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
+ mluOpFFTPlan_t fft_plan, mluOpDataType_t in_r_dtype, int n);
+
+mluOpStatus_t MLUOP_WIN_API kernelGenerateRFFTFullDFTMatrix(
+ cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
+ mluOpFFTPlan_t fft_plan, mluOpDataType_t in_r_dtype, int row, int n);
+
+mluOpStatus_t MLUOP_WIN_API kernelGenerateIRFFTHalfDFTMatrix(
+ cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
+ mluOpFFTPlan_t fft_plan, mluOpDataType_t in_r_dtype, int n);
+
+mluOpStatus_t MLUOP_WIN_API kernelGenerateIRFFTFullDFTMatrix(
+ cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
+ mluOpFFTPlan_t fft_plan, mluOpDataType_t in_r_dtype, int n);
+
+#endif // KERNELS_FFT_COMMON_FFT_COMMON_KERNELS_H_
diff --git a/kernels/fft/common/fft_common_kernels.mlu b/kernels/fft/common/fft_common_kernels.mlu
new file mode 100644
index 000000000..164ea745c
--- /dev/null
+++ b/kernels/fft/common/fft_common_kernels.mlu
@@ -0,0 +1,637 @@
+/*************************************************************************
+ * Copyright (C) [2024] by Cambricon, Inc.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a
+ * copy of this software and associated documentation files (the
+ * "Software"), to deal in the Software without restriction, including
+ * without limitation the rights to use, copy, modify, merge, publish,
+ * distribute, sublicense, and/or sell copies of the Software, and to
+ * permit persons to whom the Software is furnished to do so, subject to
+ * the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included
+ * in all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+ * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+ * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+ * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
+ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
+ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+ *************************************************************************/
+
+#include
+#include "kernels/debug.h"
+#include "kernels/kernel.h"
+#include "kernels/utils/common.h"
+#include "kernels/fft/fft.h"
+#include "kernels/fft/common/fft_common_kernels.h"
+
+#define PAD_SIZE 64
+
+__nram__ char nram_buffer[MAX_NRAM_SIZE];
+
+/*
+ convert function
+ */
+__mlu_func__ void convert(float *dst, float *src, int length) {
+ if (src == dst) {
+ return;
+ } else {
+ __memcpy(dst, src, length * sizeof(float), NRAM2NRAM, 0, 0, 0);
+ }
+}
+
+__mlu_func__ void convert(half *dst, float *src, int length) {
+ __mluop_float2half(dst, src, length);
+}
+
+/*
+ mod function: input % n
+ [input] src_addr: input data
+ [input] temp_addr: temp space to store middle max data
+ [input] n: input mod n
+ [input] len: data number
+ [output] src_addr: output data
+ */
+template
+__mlu_func__ void __mluop_mod(T *src_addr, T *temp_addr, T n, int len) {
+ T array_max;
+ __bang_argmax(temp_addr, src_addr, len);
+ array_max = temp_addr[0];
+ while (array_max >= n) {
+ __bang_ge_scalar(temp_addr, src_addr, n, len);
+ __bang_mul_scalar(temp_addr, temp_addr, n, len);
+ __bang_sub(src_addr, src_addr, temp_addr, len);
+ __bang_argmax(temp_addr, src_addr, len);
+ array_max = temp_addr[0];
+ }
+}
+
+/*
+ generate sin and cos vector function:
+ [input] src_addr: input data
+ [input] deal_size: input data number(don't need align)
+ [output] sin_addr: input data sin result
+ [output] cos_addr: input data cos result
+ */
+__mlu_func__ void genSinCosVec(float *src_addr, float *sin_addr,
+ float *cos_addr, int deal_size) {
+#if __BANG_ARCH__ >= 372
+ __bang_sin(sin_addr, src_addr, deal_size);
+ __bang_cos(cos_addr, src_addr, deal_size);
+#else
+ for (int i = 0; i < deal_size; i++) {
+ sin_addr[i] = sinf(src_addr[i]);
+ cos_addr[i] = cosf(src_addr[i]);
+ }
+#endif
+}
+
+/*
+ generate select offset vector function:
+ bang_arch >= 372, for gather inst offset, should mul sizeof(float)
+ bang_arch < 372, for scalar select offset
+ [input] offset_addr: offset data in float32
+ [input] deal_size: offset data number(don't need align)
+ [output] offset_int_addr: offset data in int32
+ */
+__mlu_func__ void genSelectOffsetVec(float *offset_addr,
+ int32_t *offset_int_addr, int deal_size) {
+ for (int i = 0; i < deal_size; i++) {
+ offset_int_addr[i] = (int)(offset_addr[i]);
+ }
+}
+
+/*
+ select data function:
+ bang_arch >= 372, use gather inst, offset should mul sizeof(float)
+ bang_arch < 372, use for and scalar select inst
+ [input] src_addr: input data to be selected
+ [input] offset_int_addr: offset data to select data in int32
+ [input] deal_size: offset data number(don't need align)
+ [output] dst_addr: selected data
+ */
+__mlu_func__ void selectVec(float *src_addr, int32_t *offset_int_addr,
+ float *dst_addr, int deal_size) {
+ for (auto i = 0; i < deal_size; i++) {
+ dst_addr[i] = src_addr[offset_int_addr[i]];
+ }
+}
+
+/*
+ generate rfft DFT matrix function: rfft result is FFT_HALF(n)
+ [input] n: fft length
+ [output] output: generated rfft matrix data
+ Matrix size: [FFT_HALF(n) * 2 * n]
+ Data:
+ forward: -2.0 * M_PI / n
+ cos00 cos01 ... cos0(n-1)
+ sin00 sin01 ... sin0(n-1)
+ cos10 cos11 ... cos1(n-1)
+ sin10 sin11 ... sin1(n-1)
+ ... ... ... ...
+ cos(FFT_HALF(n)-1)0 cos(FFT_HALF(n)-1)1 ... cos(FFT_HALF(n)-1)(n-1)
+ sin(FFT_HALF(n)-1)0 sin(FFT_HALF(n)-1)1 ... sin(FFT_HALF(n)-1)(n-1)
+ */
+template
+__mlu_func__ void generateRFFTHalfDFTMatrixImpl(int n, void *output) {
+ int deal_size = std::min(MAX_NRAM_SIZE >> 5, n);
+ deal_size = PAD_UP(deal_size, PAD_SIZE);
+ const int row = FFT_HALF(n);
+ const int col = n;
+ int pad_col = PAD_UP(col, PAD_SIZE);
+
+ float *inc_addr = (float *)nram_buffer;
+ float *cos_addr = inc_addr + deal_size;
+ float *sin_addr = cos_addr + deal_size;
+ float *offset_addr = sin_addr + deal_size;
+ int32_t *offset_int_addr = (int32_t *)offset_addr;
+ float *temp_addr = offset_addr + deal_size;
+ float *row_addr = temp_addr;
+
+ // generate 0 to n indices
+ __mluop_get_indices(inc_addr, (float)0.0, deal_size);
+
+ // generate sin and cos vectors
+ const float scale = -2.0 * M_PI / n;
+ __memcpy(offset_addr, inc_addr, deal_size * sizeof(float), NRAM2NRAM);
+ __bang_mul_scalar(offset_addr, offset_addr, scale, deal_size);
+
+ genSinCosVec(offset_addr, sin_addr, cos_addr, deal_size);
+
+ for (auto row_i = taskId; row_i < row; row_i += taskDim) {
+ // generate offsets
+ __memcpy(offset_addr, inc_addr, deal_size * sizeof(float), NRAM2NRAM);
+ __bang_mul_scalar(offset_addr, offset_addr, (float)row_i, deal_size);
+ __mluop_mod(offset_addr, temp_addr, (float)n, deal_size);
+
+ genSelectOffsetVec(offset_addr, offset_int_addr, pad_col);
+
+ // select cos result
+ selectVec(cos_addr, offset_int_addr, row_addr, col);
+ convert((DT *)row_addr, row_addr, deal_size);
+
+ // save cos result
+ DT *dst_addr = (DT *)output + 2 * row_i * col;
+ __memcpy(dst_addr, row_addr, col * sizeof(DT), NRAM2GDRAM);
+
+ // select sin result
+ selectVec(sin_addr, offset_int_addr, row_addr, col);
+ convert((DT *)row_addr, row_addr, deal_size);
+
+ // save sin result
+ dst_addr = (DT *)output + (2 * row_i + 1) * col;
+ __memcpy(dst_addr, row_addr, col * sizeof(DT), NRAM2GDRAM);
+ }
+}
+
+__mlu_global__ void generateRFFTHalfDFTMatrix(mluOpDataType_t data_type, int n,
+ void *output) {
+ switch (data_type) {
+ case (MLUOP_DTYPE_HALF): {
+ generateRFFTHalfDFTMatrixImpl(n, output);
+ break;
+ };
+ case (MLUOP_DTYPE_FLOAT): {
+ generateRFFTHalfDFTMatrixImpl(n, output);
+ break;
+ };
+ default: {
+ MLULOG("Not Implemented.");
+ }
+ }
+}
+
+/*
+ generate rfft DFT matrix function: rfft result is n
+ [input] n: fft length
+ [output] output: generated rfft matrix data
+ Matrix size: [2 * n * n]
+ Data:
+ forward: -2.0 * M_PI / n
+ cos00 cos01 ... cos0(n-1)
+ cos10 cos11 ... cos1(n-1)
+ ... ... ... ...
+ cos(n-1)0 cos(n-1)1 ... cos(n-1)(n-1)
+ sin00 sin01 ... sin0(n-1)
+ sin10 sin11 ... sin1(n-1)
+ ... ... ... ...
+ sin(n-1)0 sin(n-1)1 ... sin(n-1)(n-1)
+ */
+template
+__mlu_func__ void generateRFFTFullDFTMatrixImpl(int row, int n, void *output) {
+ int deal_size = std::min(MAX_NRAM_SIZE >> 5, n);
+ deal_size = PAD_UP(deal_size, PAD_SIZE);
+ const int col = n;
+ int pad_col = PAD_UP(col, PAD_SIZE);
+
+ float *inc_addr = (float *)nram_buffer;
+ float *cos_addr = inc_addr + deal_size;
+ float *sin_addr = cos_addr + deal_size;
+ float *offset_addr = sin_addr + deal_size;
+ int32_t *offset_int_addr = (int32_t *)offset_addr;
+ float *temp_addr = offset_addr + deal_size;
+ float *row_addr = temp_addr;
+
+ // generate 0 to n indices
+ __mluop_get_indices(inc_addr, (float)0.0, deal_size);
+
+ // generate sin and cos vectors
+ const float scale = -2.0 * M_PI / n;
+ __memcpy(offset_addr, inc_addr, deal_size * sizeof(float), NRAM2NRAM);
+ __bang_mul_scalar(offset_addr, offset_addr, scale, deal_size);
+
+ genSinCosVec(offset_addr, sin_addr, cos_addr, deal_size);
+
+ for (auto row_i = taskId; row_i < row; row_i += taskDim) {
+ // generate offsets
+ __memcpy(offset_addr, inc_addr, deal_size * sizeof(float), NRAM2NRAM);
+ __bang_mul_scalar(offset_addr, offset_addr, (float)row_i, deal_size);
+ __mluop_mod(offset_addr, temp_addr, (float)n, deal_size);
+
+ genSelectOffsetVec(offset_addr, offset_int_addr, pad_col);
+
+ // select cos result
+ selectVec(cos_addr, offset_int_addr, row_addr, col);
+ convert((DT *)row_addr, row_addr, deal_size);
+
+ // save cos result
+ DT *dst_addr = (DT *)output + row_i * col;
+ __memcpy(dst_addr, row_addr, col * sizeof(DT), NRAM2GDRAM);
+
+ // select sin result
+ selectVec(sin_addr, offset_int_addr, row_addr, col);
+ convert((DT *)row_addr, row_addr, deal_size);
+
+ // save sin result
+ dst_addr = (DT *)output + (row_i + row) * col;
+ __memcpy(dst_addr, row_addr, col * sizeof(DT), NRAM2GDRAM);
+ }
+}
+
+__mlu_global__ void generateRFFTFullDFTMatrix(mluOpDataType_t data_type,
+ int row, int n, void *output) {
+ switch (data_type) {
+ case (MLUOP_DTYPE_HALF): {
+ generateRFFTFullDFTMatrixImpl(row, n, output);
+ break;
+ };
+ case (MLUOP_DTYPE_FLOAT): {
+ generateRFFTFullDFTMatrixImpl(row, n, output);
+ break;
+ };
+ default: {
+ MLULOG("Not Implemented.");
+ }
+ }
+}
+
+/*
+ generate irfft DFT matrix function: irfft input is FFT_HALF(n)
+ [input] n: fft length
+ [output] output: generated irfft matrix data
+ Matrix size: [2 * n * FFT_HALF(n)]
+ Data:
+ backward: 2.0 * M_PI / n
+ cos_coeff: [ 1, 2, 2, ..., 2, 1]
+ sin_coeff: [-1, -2, -2, ..., -2, -1]
+ cos00 cos01 ... cos0(FFT_HALF(n)-1)
+ cos10 cos11 ... cos1(FFT_HALF(n)-1)
+ ... ... ... ...
+ cos(n-1)0 cos(n-1)1 ... cos(n-1)(FFT_HALF(n)-1)
+ sin00 sin01 ... sin0(FFT_HALF(n)-1)
+ sin10 sin11 ... sin1(FFT_HALF(n)-1)
+ ... ... ... ...
+ sin(n-1)0 sin(n-1)1 ... sin(n-1)(FFT_HALF(n)-1)
+ */
+template
+__mlu_func__ void generateIRFFTHalfDFTMatrixImpl(int n, void *output) {
+ int deal_size = std::min(MAX_NRAM_SIZE >> 5, n);
+ deal_size = PAD_UP(deal_size, PAD_SIZE);
+ const int row = n;
+ const int col = FFT_HALF(n);
+ int pad_col = PAD_UP(col, PAD_SIZE);
+
+ float *inc_addr = (float *)nram_buffer;
+ float *cos_addr = inc_addr + deal_size;
+ float *sin_addr = cos_addr + deal_size;
+ float *cos_coeff_addr = sin_addr + deal_size;
+ float *sin_coeff_addr = cos_coeff_addr + deal_size;
+ float *offset_addr = sin_coeff_addr + deal_size;
+ int32_t *offset_int_addr = (int32_t *)offset_addr;
+ float *temp_addr = offset_addr + deal_size;
+ float *row_addr = temp_addr;
+
+ // generate 0 to n indices
+ __mluop_get_indices(inc_addr, (float)0.0, deal_size);
+
+ // generate sin and cos coefficient vectors
+ __bang_write_value((float *)cos_coeff_addr, deal_size, (float)2.0);
+ __bang_write_value((float *)sin_coeff_addr, deal_size, (float)-2.0);
+ cos_coeff_addr[0] = 1.0;
+ sin_coeff_addr[0] = -1.0;
+ cos_coeff_addr[(n + 1) / 2] = 1.0;
+ sin_coeff_addr[(n + 1) / 2] = -1.0;
+
+ // generate sin and cos vectors
+ const float scale = 2.0 * M_PI / n;
+ __memcpy(offset_addr, inc_addr, deal_size * sizeof(float), NRAM2NRAM);
+ __bang_mul_scalar(offset_addr, offset_addr, scale, deal_size);
+
+ genSinCosVec(offset_addr, sin_addr, cos_addr, deal_size);
+
+ for (auto row_i = taskId; row_i < row; row_i += taskDim) {
+ // generate offsets
+ __memcpy(offset_addr, inc_addr, pad_col * sizeof(float), NRAM2NRAM);
+ __bang_mul_scalar(offset_addr, offset_addr, (float)row_i, pad_col);
+ __mluop_mod(offset_addr, temp_addr, (float)n, pad_col);
+
+ genSelectOffsetVec(offset_addr, offset_int_addr, pad_col);
+
+ // select cos result
+ selectVec(cos_addr, offset_int_addr, row_addr, col);
+ __bang_mul(row_addr, row_addr, cos_coeff_addr, pad_col);
+ convert((DT *)row_addr, row_addr, pad_col);
+
+ // save cos result
+ DT *dst_addr = (DT *)output + row_i * col;
+ __memcpy(dst_addr, row_addr, col * sizeof(DT), NRAM2GDRAM);
+
+ // select sin result
+ selectVec(sin_addr, offset_int_addr, row_addr, col);
+ __bang_mul(row_addr, row_addr, sin_coeff_addr, pad_col);
+ convert((DT *)row_addr, row_addr, pad_col);
+
+ // save sin result
+ dst_addr = (DT *)output + (row_i + row) * col;
+ __memcpy(dst_addr, row_addr, col * sizeof(DT), NRAM2GDRAM);
+ }
+}
+
+__mlu_global__ void generateIRFFTHalfDFTMatrix(mluOpDataType_t data_type, int n,
+ void *output) {
+ switch (data_type) {
+ case (MLUOP_DTYPE_HALF): {
+ generateIRFFTHalfDFTMatrixImpl(n, output);
+ break;
+ };
+ case (MLUOP_DTYPE_FLOAT): {
+ generateIRFFTHalfDFTMatrixImpl(n, output);
+ break;
+ };
+ default: {
+ MLULOG("Not Implemented.");
+ }
+ }
+}
+
+/*
+ generate irfft DFT matrix function: irfft input is padded to n
+ [input] n: fft length
+ [output] output: generated irfft matrix data
+ Matrix size: [2 * n * n]
+ Data:
+ backward: 2.0 * M_PI / n
+ cos00 cos01 ... cos0(n-1)
+ cos10 cos11 ... cos1(n-1)
+ ... ... ... ...
+ cos(n-1)0 cos(n-1)1 ... cos(n-1)(n-1)
+ sin00 sin01 ... sin0(n-1)
+ sin10 sin11 ... sin1(n-1)
+ ... ... ... ...
+ sin(n-1)0 sin(n-1)1 ... sin(n-1)(n-1)
+ */
+template
+__mlu_func__ void generateIRFFTFullDFTMatrixImpl(int n, void *output) {
+ int deal_size = std::min(MAX_NRAM_SIZE >> 5, n);
+ deal_size = PAD_UP(deal_size, PAD_SIZE);
+ const int row = n;
+ const int col = n;
+ int pad_col = PAD_UP(col, PAD_SIZE);
+
+ float *inc_addr = (float *)nram_buffer;
+ float *cos_addr = inc_addr + deal_size;
+ float *sin_addr = cos_addr + deal_size;
+ float *offset_addr = sin_addr + deal_size;
+ int32_t *offset_int_addr = (int32_t *)offset_addr;
+ float *temp_addr = offset_addr + deal_size;
+ float *row_addr = temp_addr;
+
+ // generate 0 to n indices
+ __mluop_get_indices(inc_addr, (float)0.0, deal_size);
+
+ // generate sin and cos vectors
+ const float scale = 2.0 * M_PI / n;
+ __memcpy(offset_addr, inc_addr, deal_size * sizeof(float), NRAM2NRAM);
+ __bang_mul_scalar(offset_addr, offset_addr, scale, deal_size);
+
+ genSinCosVec(offset_addr, sin_addr, cos_addr, deal_size);
+
+ for (auto row_i = taskId; row_i < row; row_i += taskDim) {
+ // generate offsets
+ __memcpy(offset_addr, inc_addr, pad_col * sizeof(float), NRAM2NRAM);
+ __bang_mul_scalar(offset_addr, offset_addr, (float)row_i, pad_col);
+ __mluop_mod(offset_addr, temp_addr, (float)n, pad_col);
+
+ genSelectOffsetVec(offset_addr, offset_int_addr, pad_col);
+
+ // select cos result
+ selectVec(cos_addr, offset_int_addr, row_addr, col);
+ convert((DT *)row_addr, row_addr, pad_col);
+
+ // save cos result
+ DT *dst_addr = (DT *)output + row_i * col;
+ __memcpy(dst_addr, row_addr, col * sizeof(DT), NRAM2GDRAM);
+
+ // select sin result
+ selectVec(sin_addr, offset_int_addr, row_addr, col);
+ convert((DT *)row_addr, row_addr, pad_col);
+
+ // save cos result
+ dst_addr = (DT *)output + (row_i + row) * col;
+ __memcpy(dst_addr, row_addr, col * sizeof(DT), NRAM2GDRAM);
+ }
+}
+
+__mlu_global__ void generateIRFFTFullDFTMatrix(mluOpDataType_t data_type, int n,
+ void *output) {
+ switch (data_type) {
+ case (MLUOP_DTYPE_HALF): {
+ generateIRFFTFullDFTMatrixImpl(n, output);
+ break;
+ };
+ case (MLUOP_DTYPE_FLOAT): {
+ generateIRFFTFullDFTMatrixImpl(n, output);
+ break;
+ };
+ default: {
+ MLULOG("Not Implemented.");
+ }
+ }
+}
+
+/*
+ generate c2c fft DFT forward and backward matrix function:
+ [input] n: fft length
+ [output] output: generated irfft matrix data
+ Matrix size: [2 * 2 * n * n]
+ Data:
+ forward: -2.0 * M_PI / n
+ cos00 cos01 ... cos0(n-1)
+ cos10 cos11 ... cos1(n-1)
+ ... ... ... ...
+ cos(n-1)0 cos(n-1)1 ... cos(n-1)(n-1)
+ sin00 sin01 ... sin0(n-1)
+ sin10 sin11 ... sin1(n-1)
+ ... ... ... ...
+ sin(n-1)0 sin(n-1)1 ... sin(n-1)(n-1)
+ backward: 2.0 * M_PI / n
+ cos00 cos01 ... cos0(n-1)
+ cos10 cos11 ... cos1(n-1)
+ ... ... ... ...
+ cos(n-1)0 cos(n-1)1 ... cos(n-1)(n-1)
+ sin00 sin01 ... sin0(n-1)
+ sin10 sin11 ... sin1(n-1)
+ ... ... ... ...
+ sin(n-1)0 sin(n-1)1 ... sin(n-1)(n-1)
+ */
+template
+__mlu_func__ void generateC2CFFTDFTMatrixImpl(int n, void *output) {
+ int deal_size = std::min(MAX_NRAM_SIZE >> 5, n);
+ deal_size = PAD_UP(deal_size, PAD_SIZE);
+ const int row = n;
+ const int col = n;
+ int pad_col = PAD_UP(col, PAD_SIZE);
+
+ float *inc_addr = (float *)nram_buffer;
+ float *forward_cos_addr = inc_addr + deal_size;
+ float *forward_sin_addr = forward_cos_addr + deal_size;
+ float *backward_cos_addr = forward_sin_addr + deal_size;
+ float *backward_sin_addr = backward_cos_addr + deal_size;
+ float *offset_addr = backward_sin_addr + deal_size;
+ int32_t *offset_int_addr = (int32_t *)offset_addr;
+ float *temp_addr = offset_addr + deal_size;
+ float *row_addr = temp_addr;
+
+ // generate 0 to n indices
+ __mluop_get_indices(inc_addr, (float)0.0, deal_size);
+
+ // generate sin and cos vectors
+ const float forward_scale = -2.0 * M_PI / n;
+ __memcpy(offset_addr, inc_addr, deal_size * sizeof(float), NRAM2NRAM);
+ __bang_mul_scalar(offset_addr, offset_addr, forward_scale, deal_size);
+
+ genSinCosVec(offset_addr, forward_sin_addr, forward_cos_addr, deal_size);
+
+ const float backward_scale = 2.0 * M_PI / n;
+ __memcpy(offset_addr, inc_addr, deal_size * sizeof(float), NRAM2NRAM);
+ __bang_mul_scalar(offset_addr, offset_addr, backward_scale, deal_size);
+
+ genSinCosVec(offset_addr, backward_sin_addr, backward_cos_addr, deal_size);
+
+ for (auto row_i = taskId; row_i < row; row_i += taskDim) {
+ // generate offsets
+ __memcpy(offset_addr, inc_addr, pad_col * sizeof(float), NRAM2NRAM);
+ __bang_mul_scalar(offset_addr, offset_addr, (float)row_i, pad_col);
+ __mluop_mod(offset_addr, temp_addr, (float)n, pad_col);
+
+ genSelectOffsetVec(offset_addr, offset_int_addr, pad_col);
+
+ // select forward cos result
+ selectVec(forward_cos_addr, offset_int_addr, row_addr, col);
+ convert((DT *)row_addr, row_addr, pad_col);
+
+ // save forward cos result
+ DT *dst_addr = (DT *)output + row_i * col;
+ __memcpy(dst_addr, row_addr, col * sizeof(DT), NRAM2GDRAM);
+
+ // select forward sin result
+ selectVec(forward_sin_addr, offset_int_addr, row_addr, col);
+ convert((DT *)row_addr, row_addr, pad_col);
+
+ // save forward sin result
+ dst_addr = (DT *)output + (row_i + row) * col;
+ __memcpy(dst_addr, row_addr, col * sizeof(DT), NRAM2GDRAM);
+
+ // select backward cos result
+ selectVec(backward_cos_addr, offset_int_addr, row_addr, col);
+ convert((DT *)row_addr, row_addr, pad_col);
+
+ // save backward cos result
+ dst_addr = (DT *)output + (row_i + 2 * row) * col;
+ __memcpy(dst_addr, row_addr, col * sizeof(DT), NRAM2GDRAM);
+
+ // select backward sin result
+ selectVec(backward_sin_addr, offset_int_addr, row_addr, col);
+ convert((DT *)row_addr, row_addr, pad_col);
+
+ // save backward sin result
+ dst_addr = (DT *)output + (row_i + 3 * row) * col;
+ __memcpy(dst_addr, row_addr, col * sizeof(DT), NRAM2GDRAM);
+ }
+}
+
+__mlu_global__ void generateC2CFFTDFTMatrix(mluOpDataType_t data_type, int n,
+ void *output) {
+ switch (data_type) {
+ case (MLUOP_DTYPE_HALF): {
+ generateC2CFFTDFTMatrixImpl(n, output);
+ break;
+ };
+ case (MLUOP_DTYPE_FLOAT): {
+ generateC2CFFTDFTMatrixImpl(n, output);
+ break;
+ };
+ default: {
+ MLULOG("Not Implemented.");
+ }
+ }
+}
+
+mluOpStatus_t MLUOP_WIN_API kernelC2CFFTDFTMatrix(
+ cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
+ mluOpFFTPlan_t fft_plan, mluOpDataType_t in_r_dtype, int n) {
+ VLOG(5) << "Launch Kernel generateC2CFFTDFTMatrix<>>";
+ KERNEL_CHECK((generateC2CFFTDFTMatrix<<>>(
+ in_r_dtype, n, fft_plan->matmul_addrs.dft_matrix_addr)));
+ return MLUOP_STATUS_SUCCESS;
+}
+
+mluOpStatus_t MLUOP_WIN_API kernelGenerateRFFTHalfDFTMatrix(
+ cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
+ mluOpFFTPlan_t fft_plan, mluOpDataType_t in_r_dtype, int n) {
+ VLOG(5) << "Launch Kernel generateRFFTHalfDFTMatrix";
+ KERNEL_CHECK((generateRFFTHalfDFTMatrix<<>>(
+ in_r_dtype, n, fft_plan->matmul_addrs.dft_matrix_addr)));
+ return MLUOP_STATUS_SUCCESS;
+}
+
+mluOpStatus_t MLUOP_WIN_API kernelGenerateRFFTFullDFTMatrix(
+ cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
+ mluOpFFTPlan_t fft_plan, mluOpDataType_t in_r_dtype, int row, int n) {
+ VLOG(5) << "Launch Kernel generateRFFTFullDFTMatrix";
+ KERNEL_CHECK((generateRFFTFullDFTMatrix<<>>(
+ in_r_dtype, row, n, fft_plan->matmul_addrs.dft_matrix_addr)));
+ return MLUOP_STATUS_SUCCESS;
+}
+
+mluOpStatus_t MLUOP_WIN_API kernelGenerateIRFFTHalfDFTMatrix(
+ cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
+ mluOpFFTPlan_t fft_plan, mluOpDataType_t in_r_dtype, int n) {
+ VLOG(5) << "Launch Kernel generateIRFFTHalfDFTMatrix";
+ KERNEL_CHECK((generateIRFFTHalfDFTMatrix<<>>(
+ in_r_dtype, n, fft_plan->matmul_addrs.dft_matrix_addr)));
+ return MLUOP_STATUS_SUCCESS;
+}
+
+mluOpStatus_t MLUOP_WIN_API kernelGenerateIRFFTFullDFTMatrix(
+ cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
+ mluOpFFTPlan_t fft_plan, mluOpDataType_t in_r_dtype, int n) {
+ VLOG(5) << "Launch Kernel kernelGenerateIRFFTFullDFTMatrix";
+ KERNEL_CHECK((generateIRFFTFullDFTMatrix<<>>(
+ in_r_dtype, n, fft_plan->matmul_addrs.dft_matrix_addr)));
+ return MLUOP_STATUS_SUCCESS;
+}
diff --git a/kernels/fft/fft.cpp b/kernels/fft/fft.cpp
new file mode 100644
index 000000000..9133fefdc
--- /dev/null
+++ b/kernels/fft/fft.cpp
@@ -0,0 +1,573 @@
+/*************************************************************************
+ * Copyright (C) [2024] by Cambricon, Inc.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a
+ * copy of this software and associated documentation files (the
+ * "Software"), to deal in the Software without restriction, including
+ * without limitation the rights to use, copy, modify, merge, publish,
+ * distribute, sublicense, and/or sell copies of the Software, and to
+ * permit persons to whom the Software is furnished to do so, subject to
+ * the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included
+ * in all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+ * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+ * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+ * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
+ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
+ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+ *************************************************************************/
+#include
+#include "kernels/fft/fft.h"
+#include "kernels/fft/rfft/rfft.h"
+#include "kernels/fft/irfft/irfft.h"
+#include "kernels/fft/c2c_fft/c2c_fft.h"
+
+// May be use a common function is a better choice?
+static inline bool supportFloatConv(mluOpHandle_t handle) {
+ switch (handle->arch) {
+ case MLUOP_MLU370:
+ return true;
+ default:
+ return true;
+ }
+}
+
+// Calculate whether the optimization strategy can be
+// entered(CNFFT_FUNC_STOCKHAM and CNFFT_FUNC_COOLEY_TUKEY). If it can enter,
+// select the optimal strategy and calculate corresponding parameters.
+mluOpStatus_t selectFFTStrategy(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan,
+ const std::string make_plan_api) {
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+ fft_plan->fft_strategy = CNFFT_FUNC_MATMUL;
+ // The basic conditions for entering the optimization.
+ if (fft_plan->n[0] > 4096) {
+ bool find_stockham = 0;
+ // CNFFT_FUNC_STOCKHAM optimizaion currently has more retrictions as
+ // follows:
+ if (handle->arch >= 300 &&
+ (fft_plan->execution_dtype == MLUOP_DTYPE_HALF ||
+ fft_plan->execution_dtype == MLUOP_DTYPE_FLOAT)) {
+ find_stockham = true;
+ }
+ // strategy_status: 0 means select MLUOP_FUNC_STOCKHAM, 1 means selelct
+ // COOLEY_TUKEY,
+ // -1 means still select CNFFT_FUNC_MATMUL.
+ int strategy_status =
+ findFFTOptLimit(handle, fft_plan->n[0], fft_plan->m, fft_plan->L,
+ fft_plan->s, fft_plan->L_sub, find_stockham);
+ if (strategy_status == 1) {
+ fft_plan->fft_strategy = CNFFT_FUNC_COOLEY_TUKEY;
+ } else if (strategy_status == 0) {
+ fft_plan->fft_strategy = CNFFT_FUNC_STOCKHAM;
+ }
+ }
+ return status;
+}
+
+mluOpStatus_t MLUOP_WIN_API mluOpCreateFFTPlan(mluOpFFTPlan_t *fft_plan) {
+ mluOpFFTStruct *ts = new (std::nothrow) mluOpFFTStruct();
+ if (ts == nullptr) {
+ LOG(ERROR) << "[mluOpCreateFFTPlan]: alloc failed";
+ return MLUOP_STATUS_ALLOC_FAILED;
+ }
+ *fft_plan = ts;
+ return MLUOP_STATUS_SUCCESS;
+}
+
+/**
+ * This function
+ * 1. receives parameters from user;
+ * 2. checks the validity of the parameters;
+ * 3. picks up the supported circumstances;
+ * 4. make decisions which strategy should use.
+ */
+mluOpStatus_t MLUOP_WIN_API mluOpMakeFFTPlanMany(
+ mluOpHandle_t handle, mluOpFFTPlan_t fft_plan,
+ mluOpTensorDescriptor_t input_desc, mluOpTensorDescriptor_t output_desc,
+ const int rank, const int *n, size_t *reservespace_size,
+ size_t *workspace_size) {
+ // bad param check
+ const std::string make_plan_api = "[mluOpMakeFFTPlanMany]";
+ // plan NULL check
+ PARAM_CHECK_NE(make_plan_api, handle, NULL);
+ if (fft_plan == NULL) {
+ LOG(ERROR) << make_plan_api + ": plan is not allocated.";
+ return MLUOP_STATUS_NOT_INITIALIZED;
+ }
+ PARAM_CHECK_NE(make_plan_api, input_desc, NULL);
+ PARAM_CHECK_NE(make_plan_api, output_desc, NULL);
+ PARAM_CHECK_NE(make_plan_api, n, NULL);
+ PARAM_CHECK_NE(make_plan_api, reservespace_size, NULL);
+ PARAM_CHECK_NE(make_plan_api, workspace_size, NULL);
+
+ // plan rank can only be 1, 2, 3
+ if (rank < 1 || rank > FFT_DIM_MAX) {
+ LOG(ERROR) << make_plan_api + ": invalid rank, should be 1, 2 or 3. Now is "
+ << rank << ".";
+ return MLUOP_STATUS_BAD_PARAM;
+ }
+ for (auto i = 0; i < rank; i++) {
+ if (n[i] <= 0) {
+ LOG(ERROR)
+ << make_plan_api +
+ ": fourier transform length should be greater than 0. Now n["
+ << i << "] is " << n[i] << ".";
+ return MLUOP_STATUS_BAD_PARAM;
+ }
+ }
+ fft_plan->rank = rank;
+ for (auto i = 0; i < rank; i++) {
+ fft_plan->n[i] = n[i];
+ }
+
+ // dimension check
+ fft_plan->idim = input_desc->dim;
+ fft_plan->odim = output_desc->dim;
+ fft_plan->inum = mluOpGetTensorElementNum(input_desc);
+ fft_plan->onum = mluOpGetTensorElementNum(output_desc);
+ PARAM_CHECK_GT(make_plan_api, input_desc->dim, 0);
+ PARAM_CHECK_EQ(make_plan_api, fft_plan->idim, fft_plan->odim,
+ ": input and output dimension mismatch.");
+
+ if (!(fft_plan->idim == rank || fft_plan->idim == rank + 1)) {
+ LOG(ERROR) << make_plan_api << ": invalid input dimension, should be "
+ << rank << " or " << rank + 1 << ". Now is " << fft_plan->idim
+ << ".";
+ return MLUOP_STATUS_BAD_PARAM;
+ }
+
+ // batch check
+ if (fft_plan->idim == rank) {
+ fft_plan->batch = 1;
+ } else { // idim == rank + 1
+ fft_plan->batch = input_desc->dims[0];
+ PARAM_CHECK_EQ(make_plan_api, fft_plan->batch, output_desc->dims[0],
+ ": batch size mismatch.");
+ }
+
+ // The FFT Struct is designed after cufftXtMakePlanMany.
+ // An element of coordinates [z, y, x] in signal number b in the batch will
+ // be associated with the following addresses in the memory
+ // 1-D:
+ // input[b * idist + x * istride]
+ // output[b * odist + x * ostride]
+ // 2-D:
+ // input[b * idist + (x * inembed[1] + y) * istride]
+ // output[b * odist + (x * onembed[1] + y) * istride]
+ // 3-D:
+ // input[b * idist + ((x * inembed[1] + y) * inembed[2] + z) * istride]
+ // output[b * odist + ((x * onembed[1] + y) * onembed[2] + z) * ostride]
+ // Thus, cufft and fftw advanced data layout is a subset of mluOp advanced
+ // data layout with tensor dim strides. 2-D and 3-D should pay attention.
+ // stride check, if an in-place fft is adopted check, `istride` should be
+ // equal to `ostride`.
+ fft_plan->istride = input_desc->strides[fft_plan->idim - 1];
+ fft_plan->ostride = output_desc->strides[fft_plan->odim - 1];
+
+ PARAM_CHECK_GE(make_plan_api, fft_plan->istride, 0,
+ ": input stride should be greater than or equal to 0.");
+ PARAM_CHECK_GE(make_plan_api, fft_plan->ostride, 0,
+ ": output stride should be greater than or equal to 0.");
+
+ for (auto i = 0; i < fft_plan->rank; i++) {
+ fft_plan->inembed[i] = input_desc->dims[fft_plan->idim - rank + i];
+ fft_plan->onembed[i] = output_desc->dims[fft_plan->odim - rank + i];
+ }
+ if (fft_plan->idim == rank + 1) {
+ fft_plan->idist = input_desc->strides[0];
+ fft_plan->odist = output_desc->strides[0];
+ } else { // batch == 1
+ fft_plan->idist = mluOpGetTensorElementNum(input_desc) / fft_plan->batch;
+ fft_plan->odist = mluOpGetTensorElementNum(output_desc) / fft_plan->batch;
+ }
+ fft_plan->is_input_contiguous = !mluop::ifNeedTensorStrideProcess(input_desc);
+ fft_plan->is_output_contiguous =
+ !mluop::ifNeedTensorStrideProcess(output_desc);
+
+ // dtype check
+ mluOpDataType_t input_dtype = input_desc->dtype;
+ mluOpDataType_t output_dtype = output_desc->dtype;
+ const mluOpDataType_t f_c_dtype = MLUOP_DTYPE_COMPLEX_FLOAT;
+ const mluOpDataType_t f_r_dtype = MLUOP_DTYPE_FLOAT;
+ const mluOpDataType_t hf_c_dtype = MLUOP_DTYPE_COMPLEX_HALF;
+ const mluOpDataType_t hf_r_dtype = MLUOP_DTYPE_HALF;
+ if (input_dtype == hf_r_dtype && output_dtype == hf_c_dtype) {
+ fft_plan->fft_type = CNFFT_HALF2COMPLEX_HALF;
+ } else if (input_dtype == hf_c_dtype && output_dtype == hf_c_dtype) {
+ fft_plan->fft_type = CNFFT_COMPLEX_HALF2COMPLEX_HALF;
+ } else if (input_dtype == hf_c_dtype && output_dtype == hf_r_dtype) {
+ fft_plan->fft_type = CNFFT_COMPLEX_HALF2HALF;
+ } else if (input_dtype == f_r_dtype && output_dtype == f_c_dtype) {
+ fft_plan->fft_type = CNFFT_FLOAT2COMPLEX_FLOAT;
+ } else if (input_dtype == f_c_dtype && output_dtype == f_c_dtype) {
+ fft_plan->fft_type = CNFFT_COMPLEX_FLOAT2COMPLEX_FLOAT;
+ } else if (input_dtype == f_c_dtype && output_dtype == f_r_dtype) {
+ fft_plan->fft_type = CNFFT_COMPLEX_FLOAT2FLOAT;
+ } else {
+ LOG(ERROR) << make_plan_api
+ << ": invalid data type combination. Now input data type is "
+ << mluOpGetNameOfDataType(input_dtype)
+ << ", and output data type is "
+ << mluOpGetNameOfDataType(output_dtype) << ".";
+ return MLUOP_STATUS_BAD_PARAM;
+ }
+
+ fft_plan->input_dtype = input_desc->dtype;
+ fft_plan->output_dtype = output_desc->dtype;
+ fft_plan->execution_dtype = input_desc->onchip_dtype;
+
+ VLOG(5) << "input data type: "
+ << mluOpGetNameOfDataType(fft_plan->input_dtype);
+ VLOG(5) << "output data type: "
+ << mluOpGetNameOfDataType(fft_plan->output_dtype);
+ VLOG(5) << "execution data type: "
+ << mluOpGetNameOfDataType(fft_plan->execution_dtype);
+
+ // fft length check
+ for (auto i = 0; i < fft_plan->rank - 1; i++) { // except the last dim
+ PARAM_CHECK_EQ(
+ make_plan_api, n[i], fft_plan->inembed[i],
+ ": the signal lengths of fft and input mismatch in dimention ", i, ".");
+ PARAM_CHECK_EQ(
+ make_plan_api, n[i], fft_plan->onembed[i],
+ ": the signal lengths of fft and output mismatch in dimension ", i,
+ ".");
+ }
+ switch (fft_plan->fft_type) {
+ // r2c
+ case CNFFT_HALF2COMPLEX_HALF:
+ case CNFFT_FLOAT2COMPLEX_FLOAT: {
+ PARAM_CHECK_EQ(make_plan_api, fft_plan->n[rank - 1] / 2 + 1,
+ fft_plan->onembed[rank - 1],
+ ": the signal lengths of fft and output last dimention "
+ "mismatch in R2C.");
+ }; break;
+ // c2c
+ case CNFFT_COMPLEX_HALF2COMPLEX_HALF:
+ case CNFFT_COMPLEX_FLOAT2COMPLEX_FLOAT: {
+ PARAM_CHECK_EQ(make_plan_api, fft_plan->n[rank - 1],
+ fft_plan->onembed[rank - 1],
+ ": the signal lengths of fft and output last dimention "
+ "mismatch in C2C.");
+ }; break;
+ // c2r
+ case CNFFT_COMPLEX_HALF2HALF:
+ case CNFFT_COMPLEX_FLOAT2FLOAT: {
+ PARAM_CHECK_EQ(make_plan_api, fft_plan->n[rank - 1],
+ fft_plan->onembed[rank - 1],
+ ": the signal lengths of fft and output last dimention "
+ "mismatch in C2R.");
+ }; break;
+ default: {
+ LOG(ERROR) << make_plan_api << ": invalid fft type.";
+ return MLUOP_STATUS_BAD_PARAM;
+ }
+ }
+
+ mluOpDataType_t execution_dtype = fft_plan->execution_dtype;
+ switch (fft_plan->fft_type) {
+ // half
+ case CNFFT_HALF2COMPLEX_HALF:
+ case CNFFT_COMPLEX_HALF2COMPLEX_HALF:
+ case CNFFT_COMPLEX_HALF2HALF: {
+ if (supportFloatConv(handle)) {
+ if (!(execution_dtype == hf_r_dtype ||
+ execution_dtype == MLUOP_DTYPE_INT16)) {
+ LOG(ERROR) << make_plan_api << ": invalid execution dtype "
+ << mluOpGetNameOfDataType(fft_plan->execution_dtype)
+ << ".";
+ return MLUOP_STATUS_BAD_PARAM;
+ }
+ } else {
+ if (!(execution_dtype == MLUOP_DTYPE_INT16)) {
+ LOG(ERROR) << make_plan_api << ": invalid execution dtype "
+ << mluOpGetNameOfDataType(fft_plan->execution_dtype)
+ << ".";
+ return MLUOP_STATUS_BAD_PARAM;
+ }
+ }
+ }; break;
+ // float
+ case CNFFT_FLOAT2COMPLEX_FLOAT:
+ case CNFFT_COMPLEX_FLOAT2COMPLEX_FLOAT:
+ case CNFFT_COMPLEX_FLOAT2FLOAT: {
+ if (supportFloatConv(handle)) {
+ if (execution_dtype != f_r_dtype) {
+ LOG(ERROR) << make_plan_api << ": invalid execution dtype "
+ << mluOpGetNameOfDataType(fft_plan->execution_dtype)
+ << ".";
+ return MLUOP_STATUS_BAD_PARAM;
+ }
+ }
+ }; break;
+ default: {
+ LOG(ERROR) << make_plan_api << ": invalid execution dtype.";
+ return MLUOP_STATUS_BAD_PARAM;
+ }
+ }
+
+ // unsupported param
+ if (fft_plan->rank != 1) {
+ LOG(ERROR)
+ << make_plan_api
+ << ": 2-dimensional and 3-dimensional FFT are not supported currently.";
+ return MLUOP_STATUS_NOT_SUPPORTED;
+ }
+
+ if (fft_plan->fft_type == CNFFT_HALF2COMPLEX_HALF ||
+ fft_plan->fft_type == CNFFT_COMPLEX_HALF2HALF ||
+ fft_plan->fft_type == CNFFT_COMPLEX_HALF2COMPLEX_HALF) {
+ if ((n[0] & (n[0] - 1)) != 0) {
+ LOG(ERROR) << make_plan_api
+ << ": the signal lengths of half-precision FFT are"
+ << " restriced to power of two only, but now is " << n[0]
+ << ".";
+ return MLUOP_STATUS_NOT_SUPPORTED;
+ }
+ }
+
+ // create input and output descriptor for gen_case
+ // because mluOpExecFFT don't have input and output descriptor
+ mluOpTensorDescriptor_t fft_input_desc, fft_output_desc;
+ INTERNAL_CHECK(make_plan_api, mluOpCreateTensorDescriptor(&fft_input_desc) ==
+ MLUOP_STATUS_SUCCESS);
+ INTERNAL_CHECK(make_plan_api, mluOpCreateTensorDescriptor(&fft_output_desc) ==
+ MLUOP_STATUS_SUCCESS);
+ INTERNAL_CHECK(make_plan_api,
+ mluOpSetTensorDescriptorEx_v2(
+ fft_input_desc, input_desc->layout, input_desc->dtype,
+ input_desc->dim, input_desc->dims,
+ input_desc->strides) == MLUOP_STATUS_SUCCESS);
+ INTERNAL_CHECK(make_plan_api, mluOpSetTensorDescriptorOnchipDataType(
+ fft_input_desc, input_desc->onchip_dtype) ==
+ MLUOP_STATUS_SUCCESS);
+ INTERNAL_CHECK(make_plan_api,
+ mluOpSetTensorDescriptorEx_v2(
+ fft_output_desc, output_desc->layout, output_desc->dtype,
+ output_desc->dim, output_desc->dims,
+ output_desc->strides) == MLUOP_STATUS_SUCCESS);
+ fft_plan->input_desc = fft_input_desc;
+ fft_plan->output_desc = fft_output_desc;
+
+ /*
+ * decision part
+ */
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+ switch (fft_plan->fft_type) {
+ // r2c
+ case CNFFT_HALF2COMPLEX_HALF:
+ case CNFFT_FLOAT2COMPLEX_FLOAT: {
+ if (rank == 1) {
+ VLOG(5) << "into make RFFT1d Policy";
+ status = makeRFFT1dPolicy(handle, fft_plan);
+ }
+ }; break;
+ case CNFFT_COMPLEX_HALF2HALF:
+ case CNFFT_COMPLEX_FLOAT2FLOAT: {
+ if (rank == 1) {
+ VLOG(5) << "into make IRFFT1d Policy";
+ status = makeIRFFT1dPolicy(handle, fft_plan);
+ }
+ }; break;
+ case CNFFT_COMPLEX_HALF2COMPLEX_HALF:
+ case CNFFT_COMPLEX_FLOAT2COMPLEX_FLOAT: {
+ if (rank == 1) {
+ VLOG(5) << "into make FFT1d Policy";
+ status = makeFFT1dPolicy(handle, fft_plan);
+ }
+ }; break;
+ }
+ if (status != MLUOP_STATUS_SUCCESS) {
+ return status;
+ }
+
+ *reservespace_size = fft_plan->reservespace_size;
+ *workspace_size = fft_plan->workspace_size;
+
+ return MLUOP_STATUS_SUCCESS;
+}
+
+mluOpStatus_t MLUOP_WIN_API mluOpDestroyFFTPlan(mluOpFFTPlan_t fft_plan) {
+ const std::string destroy_api = "[mluOpDestroyFFTPlan]";
+ PARAM_CHECK_NE("[mluOpDestroyFFTPlan]", fft_plan, NULL);
+ if (fft_plan->input_desc != NULL) {
+ INTERNAL_CHECK(destroy_api,
+ mluOpDestroyTensorDescriptor(fft_plan->input_desc) ==
+ MLUOP_STATUS_SUCCESS);
+ }
+ if (fft_plan->output_desc != NULL) {
+ INTERNAL_CHECK(destroy_api,
+ mluOpDestroyTensorDescriptor(fft_plan->output_desc) ==
+ MLUOP_STATUS_SUCCESS);
+ }
+ delete fft_plan;
+ return MLUOP_STATUS_SUCCESS;
+}
+
+mluOpStatus_t MLUOP_WIN_API mluOpSetFFTReserveArea(mluOpHandle_t handle,
+ mluOpFFTPlan_t fft_plan,
+ void *reservespace) {
+ const std::string api = "[mluOpSetReserveArea]";
+ PARAM_CHECK_NE(api, handle, NULL);
+ PARAM_CHECK_NE(api, fft_plan, NULL);
+ PARAM_CHECK_NE(api, reservespace, NULL);
+ fft_plan->reservespace_addr = reservespace;
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+
+ switch (fft_plan->fft_type) {
+ // r2c
+ case CNFFT_HALF2COMPLEX_HALF:
+ case CNFFT_FLOAT2COMPLEX_FLOAT: {
+ if (fft_plan->rank == 1) {
+ status = setRFFT1dReserveArea(handle, fft_plan, api);
+ } else {
+ status = MLUOP_STATUS_NOT_SUPPORTED;
+ }
+ }; break;
+ // c2c
+ case CNFFT_COMPLEX_HALF2COMPLEX_HALF:
+ case CNFFT_COMPLEX_FLOAT2COMPLEX_FLOAT: {
+ if (fft_plan->rank == 1) {
+ status = setFFT1dReserveArea(handle, fft_plan, api);
+ } else {
+ status = MLUOP_STATUS_NOT_SUPPORTED;
+ }
+ }; break;
+ // c2r
+ case CNFFT_COMPLEX_HALF2HALF:
+ case CNFFT_COMPLEX_FLOAT2FLOAT: {
+ if (fft_plan->rank == 1) {
+ status = setIRFFT1dReserveArea(handle, fft_plan, api);
+ } else {
+ status = MLUOP_STATUS_NOT_SUPPORTED;
+ }
+ }; break;
+ }
+ return status;
+}
+
+mluOpStatus_t MLUOP_WIN_API mluOpExecFFT(
+ mluOpHandle_t handle, const mluOpFFTPlan_t fft_plan, const void *input,
+ const float scale_factor, void *workspace, void *output, int direction) {
+ const std::string exec_api = "[mluOpExecFFT]";
+ PARAM_CHECK_NE(exec_api, handle, NULL);
+ PARAM_CHECK_NE(exec_api, fft_plan, NULL);
+ VLOG(5) << "input contiguous ? " << fft_plan->is_input_contiguous;
+ VLOG(5) << "output contiguous ? " << fft_plan->is_output_contiguous;
+
+ if (fft_plan->batch == 0) {
+ VLOG(5) << "[mluOpExecFFT] Skip zero element tensor";
+ return MLUOP_STATUS_SUCCESS;
+ }
+ // generate mluOpFFTExec prototxt start!
+ {
+ TENSOR_NUM_CHECK("[mluOpFft]",
+ mluOpGetTensorElementNum(fft_plan->input_desc),
+ LARGE_TENSOR_NUM, "");
+ TENSOR_NUM_CHECK("[mluOpFft]",
+ mluOpGetTensorElementNum(fft_plan->output_desc),
+ LARGE_TENSOR_NUM, "");
+ }
+
+ if (MLUOP_GEN_CASE_ON_NEW) {
+ GEN_CASE_START("fft");
+ GEN_CASE_HANDLE(handle);
+ GEN_CASE_DATA(true, "input", input, fft_plan->input_desc, 1, 0);
+ GEN_CASE_DATA(false, "output", output, fft_plan->output_desc, 0, 0);
+ GEN_CASE_OP_PARAM_SINGLE(0, "fft", "rank", fft_plan->rank);
+ GEN_CASE_OP_PARAM_ARRAY(1, "fft", "n", fft_plan->n, fft_plan->rank);
+ GEN_CASE_OP_PARAM_SINGLE(1, "fft", "direction", direction);
+ GEN_CASE_OP_PARAM_SINGLE(2, "fft", "scale_factor", scale_factor);
+ GEN_CASE_TEST_PARAM_NEW(true, true, false, 0.003, 0.003, 0);
+ }
+
+ if (fft_plan->workspace_size > 0) {
+ PARAM_CHECK_NE(exec_api, workspace, NULL);
+ }
+ if (fft_plan->inum > 0) {
+ PARAM_CHECK_NE(exec_api, input, NULL);
+ }
+ PARAM_CHECK_NE(exec_api, output, NULL);
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+
+ bool is_in_place = (input == output);
+ VLOG(5) << exec_api << ": in place ? " << is_in_place;
+ switch (fft_plan->fft_type) {
+ // r2c
+ case CNFFT_HALF2COMPLEX_HALF:
+ case CNFFT_FLOAT2COMPLEX_FLOAT: {
+ if ((fft_plan->idist < (fft_plan->odist * 2)) && is_in_place) {
+ LOG(ERROR)
+ << exec_api
+ << ": output overwritten may occur during an in-place "
+ "real-to-complex fft "
+ "execution, input needs to be slightly padding. Please refer to "
+ "mluOpExecFFT document for detail.";
+ status = MLUOP_STATUS_BAD_PARAM;
+ }
+ if (fft_plan->rank == 1) {
+ status = execRFFT1d(handle, fft_plan, input, scale_factor, workspace,
+ output);
+ } else if (fft_plan->rank == 2) {
+ // TODO(who)
+ status = MLUOP_STATUS_NOT_SUPPORTED;
+ } else if (fft_plan->rank == 3) {
+ // TODO(who)
+ status = MLUOP_STATUS_NOT_SUPPORTED;
+ }
+ }; break;
+ // c2c
+ case CNFFT_COMPLEX_HALF2COMPLEX_HALF:
+ case CNFFT_COMPLEX_FLOAT2COMPLEX_FLOAT: {
+ if ((fft_plan->idist < fft_plan->odist) && is_in_place) {
+ LOG(ERROR)
+ << exec_api
+ << ": output overwritten may occur during an in-place "
+ "complex-to-complex fft "
+ "execution, input needs to be slightly padding. Please refer to "
+ "mluOpExecFFT document for detail.";
+ status = MLUOP_STATUS_BAD_PARAM;
+ }
+ if (fft_plan->rank == 1) {
+ status = execFFT1d(handle, fft_plan, input, scale_factor, workspace,
+ output, direction);
+ } else if (fft_plan->rank == 2) {
+ // TODO(who)
+ status = MLUOP_STATUS_NOT_SUPPORTED;
+ } else if (fft_plan->rank == 3) {
+ // TODO(who)
+ status = MLUOP_STATUS_NOT_SUPPORTED;
+ }
+ }; break;
+ // c2r
+ case CNFFT_COMPLEX_HALF2HALF:
+ case CNFFT_COMPLEX_FLOAT2FLOAT: {
+ if (((fft_plan->idist * 2) < fft_plan->odist) && is_in_place) {
+ LOG(ERROR)
+ << exec_api
+ << ": output overwritten may occur during an in-place "
+ "complex-to-real fft "
+ "execution, input needs to be slightly padding. Please refer to "
+ "mluOpExecFFT document for detail.";
+ status = MLUOP_STATUS_BAD_PARAM;
+ }
+ if (fft_plan->rank == 1) {
+ status = execIRFFT1d(handle, fft_plan, input, scale_factor, workspace,
+ output);
+ } else if (fft_plan->rank == 2) {
+ // TODO(who)
+ status = MLUOP_STATUS_NOT_SUPPORTED;
+ } else if (fft_plan->rank == 3) {
+ // TODO(who)
+ status = MLUOP_STATUS_NOT_SUPPORTED;
+ }
+ }; break;
+ }
+
+ GEN_CASE_END();
+ return status;
+}
diff --git a/kernels/fft/fft.h b/kernels/fft/fft.h
new file mode 100644
index 000000000..92fd652bc
--- /dev/null
+++ b/kernels/fft/fft.h
@@ -0,0 +1,237 @@
+/*************************************************************************
+ * Copyright (C) [2024] by Cambricon, Inc.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a
+ * copy of this software and associated documentation files (the
+ * "Software"), to deal in the Software without restriction, including
+ * without limitation the rights to use, copy, modify, merge, publish,
+ * distribute, sublicense, and/or sell copies of the Software, and to
+ * permit persons to whom the Software is furnished to do so, subject to
+ * the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included
+ * in all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+ * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+ * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+ * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
+ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
+ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+ *************************************************************************/
+#ifndef KERNELS_FFT_FFT_H_
+#define KERNELS_FFT_FFT_H_
+
+#include
+#include "core/context.h"
+#include "core/logging.h"
+#include "core/gen_case.h"
+#include "core/runtime/device.h"
+#include "core/tensor.h"
+#include "core/type.h"
+#include "core/tool.h"
+#include "kernels/tensor_stride_process/tensor_stride_process_host.h"
+#include "kernels/tensor_stride_process/tensor_stride_process_mlu.h"
+#include "kernels/fft/common/fft_basic_ops.h"
+#include "kernels/fft/common/fft_common_kernels.h"
+#include "kernels/debug.h"
+#include "kernels/kernel.h"
+
+#ifndef FFT_DIM_MAX
+#define FFT_DIM_MAX 3
+#endif
+
+#ifndef FFT_L_LIMIT
+#define FFT_L_LIMIT 4096
+#endif
+
+#ifndef COMPLEX
+#define COMPLEX 2
+#endif
+
+#ifndef FFT_HALF
+#define FFT_HALF(x) ((x) / 2 + 1)
+#endif
+
+typedef enum {
+ FFT_IFFT = 0,
+ RFFT = 1,
+ IRFFT = 2,
+} FFTFlag;
+
+typedef enum {
+ CNFFT_FUNC_MATMUL =
+ 0, // directly matmul strategy, specified for multiple batches of
+ // transform, and output size is relatively small. Its structure is
+ // suitable tensor computing-oriented machines.
+ CNFFT_FUNC_STOCKHAM =
+ 1, // an iterative FFT algorithm for n = r^l. It is self-sorting (does
+ // not have a digit reversal permutation). Its structure is suitable
+ // for long vector computing machines.
+ CNFFT_FUNC_FOUR_STEP =
+ 2, // a recursive FFT algorithm for n = km. It is built from two stages
+ // of vector FFTs, the twiddle diagonal and a transposition. Its
+ // structure is suitable for vector computers.
+ CNFFT_FUNC_BLUESTEIN =
+ 3, // a general-purpose algorithm (i.e., n is a prime number).
+
+ CNFFT_FUNC_COOLEY_TUKEY =
+ 4, // a recursive FFT algorithm for n = 2^m * L; It saves the space
+ // occupied by the w matrix. And, compared to DFT, the time
+ // complexity is reduced from o(n^2) to o(n * logn)
+} FFTStrategy;
+
+typedef enum {
+ CNFFT_HALF2COMPLEX_HALF = 0,
+ CNFFT_COMPLEX_HALF2HALF = 1,
+ CNFFT_COMPLEX_HALF2COMPLEX_HALF = 2,
+ CNFFT_FLOAT2COMPLEX_FLOAT = 3,
+ CNFFT_COMPLEX_FLOAT2FLOAT = 4,
+ CNFFT_COMPLEX_FLOAT2COMPLEX_FLOAT = 5,
+} FFTType;
+
+// struct for CNFFT_FUNC_MATMUL strategy.
+struct cnfftMatmulAddrs {
+ /* addrs set in the preprocess-stage */
+ void *dft_matrix_addr;
+ void *dft_re_matrix_addr;
+ void *dft_im_matrix_addr;
+ void *ifft_dft_matrix_addr;
+ void *ifft_dft_re_matrix_addr;
+ void *ifft_dft_im_matrix_addr;
+ void *dft_pos_addr;
+ void *dft_scale_addr;
+ size_t dft_quantize_workspace_size;
+ void *dft_quantize_workspace_addr;
+ /* addrs set in the runtime stage */
+ void *input_contiguous_addr;
+ void *input_pad_addr;
+ void *input_transed_addr;
+ void *input_reversed_addr;
+ void *input_merged_addr;
+ void *input_re_addr;
+ void *input_im_addr;
+ void *input_pos_addr;
+ void *input_scale_addr;
+ void *matmul_re_mul_re_addr;
+ void *matmul_re_mul_im_addr;
+ void *matmul_im_mul_re_addr;
+ void *matmul_im_mul_im_addr;
+ void *output_re_addr;
+ void *output_im_addr;
+ void *output_contiguous_addr;
+ void *internal_workspace_addr;
+ size_t internal_workspace_size;
+};
+
+struct mluOpFFTStruct {
+ int rank; // rank of FFT
+ int n[FFT_DIM_MAX]; // FFT lengths on each dimension
+ mluOpDataType_t input_dtype;
+ mluOpDataType_t output_dtype;
+ mluOpDataType_t execution_dtype;
+ int idim; // the dimension size of input tensor
+ int inembed[FFT_DIM_MAX]; // Pointer of size rank that indicates the storage
+ // dimensions of the input data in memory.
+ int inum; // element num of input tensor
+ int istride; // distance between two successive input elements in the
+ // innermost dimension
+ int idist; // distance between the first element of two consecutive signals
+ // in a batch of the input data
+ int odim; // the dimension size of output tensor
+ int onembed[FFT_DIM_MAX]; // Pointer of size rank that indicates the storage
+ // dimensions of the output data in memory
+ int onum; // element num of output tensor
+ int ostride; // distance between two successive output elements in the
+ // innermost dimension
+ int odist; // distance between the first element of two consecutive signals
+ // in a batch of the output data
+ int batch; // batch size for this transform
+ int L; // n = L * 2^m, L size for this transform
+ int m; // n = L * 2^m, m size for this transform
+ int s; // The size that can be put down on NRAM: L * 2^s, only used by
+ // Cooley-Tukey algorithm
+ int L_sub; // The size that can be put down on NRAM: L_sub * 2^m, only used
+ // by Stockham algorithm
+ bool is_input_contiguous;
+ bool is_output_contiguous;
+ size_t reservespace_size;
+ size_t workspace_size;
+ FFTType fft_type; // types of fft
+ FFTStrategy fft_strategy;
+ mluOpTensorDescriptor_t input_desc;
+ mluOpTensorDescriptor_t output_desc;
+ void *reservespace_addr;
+ cnfftMatmulAddrs matmul_addrs;
+};
+
+struct ParamNode {
+ int subgraph_size;
+ int L_bytes;
+ int L_align;
+ int L_align_bytes;
+ int op_size;
+ int op_size_align;
+ int op_size_align_via_L;
+ int op_size_bytes;
+ int op_size_bytes_align;
+ int op_size_align_via_L_trans;
+ int op_group_num_1_batch;
+ int op_group_num_x_batch;
+ int remain_layer_num;
+};
+
+template
+struct AddrNode {
+ // GDRAM Addr Info:
+ DT *wspace_r;
+ DT *wspace_i;
+
+ // NRAM Addr Info:
+ // input addr:
+ DT *y_in_r;
+ DT *z_in_r;
+ DT *y_in_i;
+ DT *z_in_i;
+ // output addr:
+ DT *x_out1_r;
+ DT *x_out2_r;
+ DT *x_out1_i;
+ DT *x_out2_i;
+ // w_matrix addr:
+ DT *w_r;
+ DT *w_i;
+ // temp addr reserved for vector generation w_matrix.
+ DT *w_tmp1;
+ DT *w_tmp2;
+ DT *w_tmp3;
+ // temp addr reserved for subgraph internal merge calculation, using the same
+ // addr with w_tmp*.
+ DT *wz_rr;
+ DT *wz_ri;
+ DT *wz_ir;
+ DT *wz_ii;
+ DT *wz_r;
+ DT *wz_i;
+};
+
+mluOpStatus_t selectFFTStrategy(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan,
+ const std::string make_plan_api);
+
+mluOpStatus_t MLUOP_WIN_API kernelFFTCooleyTukey(cnrtDim3_t k_dim,
+ cnrtFunctionType_t k_type,
+ cnrtQueue_t queue,
+ mluOpFFTPlan_t fft_plan,
+ int direction, FFTFlag flag);
+
+mluOpStatus_t MLUOP_WIN_API
+kernelFFTStockham(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
+ cnrtQueue_t queue, mluOpFFTPlan_t fft_plan, int direction,
+ const float scale_factor, FFTFlag flag);
+
+mluOpStatus_t MLUOP_WIN_API kernelC2CFFTDFTMatrix(
+ cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
+ mluOpFFTPlan_t fft_plan, mluOpDataType_t in_r_dtype, int n);
+
+#endif // KERNELS_FFT_FFT_H_
diff --git a/kernels/fft/fft_optm_device/fft_cooley-tukey_ux_device.mlu b/kernels/fft/fft_optm_device/fft_cooley-tukey_ux_device.mlu
new file mode 100644
index 000000000..6643c9deb
--- /dev/null
+++ b/kernels/fft/fft_optm_device/fft_cooley-tukey_ux_device.mlu
@@ -0,0 +1,775 @@
+/*************************************************************************
+ * Copyright (C) [2024] by Cambricon, Inc.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a
+ * copy of this software and associated documentation files (the
+ * "Software"), to deal in the Software without restriction, including
+ * without limitation the rights to use, copy, modify, merge, publish,
+ * distribute, sublicense, and/or sell copies of the Software, and to
+ * permit persons to whom the Software is furnished to do so, subject to
+ * the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included
+ * in all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+ * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+ * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+ * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
+ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
+ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+ *************************************************************************/
+#include "mlu.h"
+#include "kernels/debug.h"
+#include "kernels/kernel.h"
+#include "kernels/utils/common.h"
+#include "kernels/fft/fft.h"
+
+#define TRANS_ALIGN_SIZE 64
+
+#define ITER_ONCHIP 0
+#define ITER_OFFCHIP 1
+
+// direction
+#define FFT_INVERSE 1
+
+#define COMPLEX_FACTOR 2
+#define YZ_FACTOR 2
+
+__nram__ char nram_buffer[MAX_NRAM_SIZE];
+
+// Generate W matrix with alignment as interval.
+// For example: when W is [L_align * s], the valid range of values is
+// [0, L - 1], [L_align, L_align + L - 1], ..., [L_align * (s - 1), L_align
+// * (s - 1) + L - 1].
+template
+__mlu_func__ void genWVec1(float *w_r, float *w_i, float *w_tmp1, float *w_tmp2,
+ float *w_tmp3, int L, int L_align, int L_align_bytes,
+ int n, int fft_flag, int direction) {
+ float *cos_addr = w_r;
+ float *sin_addr = w_i;
+ float *offset_addr = w_tmp1;
+ float *inc_addr = w_tmp2;
+
+ // deal with each L segment.
+ for (int i = 0; i < n / L_align; i++) {
+ float *tmp_cos_addr = cos_addr + L_align * i;
+ float *tmp_sin_addr = sin_addr + L_align * i;
+ float *tmp_offset_addr = offset_addr + L_align * i;
+ float *tmp_inc_addr = inc_addr + L_align * i;
+ float start_value = L * i;
+ __mluop_get_indices(tmp_inc_addr, start_value, L_align);
+ float scale =
+ 2.0 * M_PI / (n / L_align * L * 2.0) * (fft_flag != IRFFT ? -1 : 1);
+ scale *= ((fft_flag == FFT_IFFT && direction == FFT_INVERSE) ? -1 : 1);
+ __bang_mul_scalar(tmp_offset_addr, tmp_inc_addr, scale, L_align);
+#if __BANG_ARCH__ >= 372
+ __bang_cos(tmp_cos_addr, tmp_offset_addr, L_align);
+ __bang_sin(tmp_sin_addr, tmp_offset_addr, L_align);
+#else
+ for (int i = 0; i < L_align; i++) {
+ tmp_cos_addr[i] = cosf(tmp_offset_addr[i]);
+ tmp_sin_addr[i] = sinf(tmp_offset_addr[i]);
+ }
+#endif
+ }
+}
+
+// Generate W matrix contimuously with different start value.
+template
+__mlu_func__ void genWVec2(float *w_r, float *w_i, float *w_tmp1, float *w_tmp2,
+ float *w_tmp3, float n_tmp, int n, int L,
+ int L_align, int ri, int op_size, int op_size_align,
+ int op_size_bytes_align, int fft_flag,
+ int direction) {
+ float *cos_addr = w_r;
+ float *sin_addr = w_i;
+ float *offset_addr = w_tmp1;
+ float *inc_addr = w_tmp2;
+ float scale = 2.0 * M_PI / (n_tmp * 2.0) * (fft_flag != IRFFT ? -1 : 1);
+ scale *= ((fft_flag == FFT_IFFT && direction == FFT_INVERSE) ? -1 : 1);
+ float start_value = ri * op_size;
+ __mluop_get_indices(inc_addr, start_value, op_size_align);
+ __bang_mul_scalar(offset_addr, inc_addr, scale, op_size_align);
+ __bang_cos(cos_addr, offset_addr, op_size);
+ if (n <= 48000) {
+ __bang_sin(sin_addr, offset_addr, op_size);
+ } else {
+ __cn_vector_sin_f32(op_size, sin_addr, offset_addr);
+ }
+ for (int i = 0; i < op_size; i++) {
+ cos_addr[i] = cosf(offset_addr[i]);
+ sin_addr[i] = sinf(offset_addr[i]);
+ }
+}
+
+template
+__mlu_func__ void genWSc1(float *w_r, float *w_i, int n, int fft_flag,
+ int direction, int L, int L_align) {
+ float scale =
+ 2.0 * M_PI / (n / L_align * L * 2.0) * (fft_flag != IRFFT ? -1 : 1);
+ scale *= ((fft_flag == FFT_IFFT && direction == FFT_INVERSE) ? -1 : 1);
+ for (int i = 0; i < n / L_align; i++) {
+ for (int j = 0; j < L; j++) {
+ w_r[i * L_align + j] = std::cos((i * L + j) * scale);
+ w_i[i * L_align + j] = std::sin((i * L + j) * scale);
+ }
+ }
+}
+
+template
+__mlu_func__ void genWSc2(float *w_r, float *w_i, float n_tmp, int ri,
+ int op_size, int op_size_align, int fft_flag,
+ int direction, int L, int L_align) {
+ float scale = 2.0 * M_PI / (n_tmp * 2.0) * (fft_flag != IRFFT ? -1 : 1);
+ scale *= ((fft_flag == FFT_IFFT && direction == FFT_INVERSE) ? -1 : 1);
+ for (int i = 0; i < op_size; i++) {
+ w_r[i] = std::cos((ri * op_size + i) * scale);
+ w_i[i] = std::sin((ri * op_size + i) * scale);
+ }
+}
+
+// pick the correct input index.
+__mlu_func__ void permute(int &ind_inner_op, int &ind_outer_op, int M) {
+ for (int i = 0; i < M; i++) {
+ ind_inner_op = 2 * ind_inner_op + ind_outer_op % 2;
+ ind_outer_op = ind_outer_op / 2;
+ }
+}
+
+// Subgraph internal merge calculation as follows:
+// x_out1 = y + wz = (y_real + i * y_imag) + (w_real + i * w_imag) * (z_real
+// + i * z_imag) x_out2 = y - wz = (y_real + i * y_imag) - (w_real + i *
+// w_imag) * (z_real + i * z_imag)
+// Note: the output of each iteration is the input of next layer. When not
+// iterating to the last
+// layer, that is ,stage == ITER_ONCHIP, the output result needs to be
+// moved back to the input.
+template
+__mlu_func__ void computeOneStep(DT *wz_rr, DT *wz_ri, DT *wz_ir, DT *wz_ii,
+ DT *w_r, DT *w_i, DT *wz_r, DT *wz_i,
+ DT *x_out1_r, DT *x_out2_r, DT *x_out1_i,
+ DT *x_out2_i, DT *y_in_r, DT *z_in_r,
+ DT *y_in_i, DT *z_in_i, int op_iter_size,
+ int stage) {
+ uint32_t op_iter_size_bytes = op_iter_size * sizeof(DT);
+ if (std::is_same::value) {
+ __memcpy((half *)x_out1_r, (half *)z_in_r, op_iter_size_bytes, NRAM2NRAM);
+ __memcpy((half *)x_out1_i, (half *)z_in_i, op_iter_size_bytes, NRAM2NRAM);
+ __bang_half2float((float *)x_out2_r, (half *)x_out1_r, op_iter_size);
+ __bang_half2float((float *)x_out2_i, (half *)x_out1_i, op_iter_size);
+
+ // (w_real + i * w_imag) * (z_real + i * z_imag)
+ __bang_mul((float *)wz_rr, (float *)w_r, (float *)x_out2_r, op_iter_size);
+ __bang_mul((float *)wz_ri, (float *)w_r, (float *)x_out2_i, op_iter_size);
+ __bang_mul((float *)wz_ir, (float *)w_i, (float *)x_out2_r, op_iter_size);
+ __bang_mul((float *)wz_ii, (float *)w_i, (float *)x_out2_i, op_iter_size);
+
+ // wz_real = w_real * z_real - w_imag * z_imag
+ __bang_sub((float *)wz_r, (float *)wz_rr, (float *)wz_ii, op_iter_size);
+
+ // wz_imag = w_real * z_imag + w_imag * z_real
+ __bang_add((float *)wz_i, (float *)wz_ri, (float *)wz_ir, op_iter_size);
+
+ __memcpy((half *)x_out1_r, (half *)y_in_r, op_iter_size_bytes, NRAM2NRAM);
+ __memcpy((half *)x_out1_i, (half *)y_in_i, op_iter_size_bytes, NRAM2NRAM);
+ __bang_half2float((float *)x_out2_r, (half *)x_out1_r, op_iter_size);
+ __bang_half2float((float *)x_out2_i, (half *)x_out1_i, op_iter_size);
+
+ // y + wz
+ __bang_add((float *)x_out1_r, (float *)x_out2_r, (float *)wz_r,
+ op_iter_size);
+ __bang_add((float *)x_out1_i, (float *)x_out2_i, (float *)wz_i,
+ op_iter_size);
+
+ // y - wz
+ __bang_sub((float *)x_out2_r, (float *)x_out2_r, (float *)wz_r,
+ op_iter_size);
+ __bang_sub((float *)x_out2_i, (float *)x_out2_i, (float *)wz_i,
+ op_iter_size);
+
+ __mluop_float2half((half *)x_out1_r, (float *)x_out1_r, op_iter_size);
+ __mluop_float2half((half *)x_out1_i, (float *)x_out1_i, op_iter_size);
+ __mluop_float2half((half *)x_out2_r, (float *)x_out2_r, op_iter_size);
+ __mluop_float2half((half *)x_out2_i, (float *)x_out2_i, op_iter_size);
+ } else {
+ __bang_mul(wz_rr, w_r, z_in_r, op_iter_size);
+ __bang_mul(wz_ri, w_r, z_in_i, op_iter_size);
+ __bang_mul(wz_ir, w_i, z_in_r, op_iter_size);
+ __bang_mul(wz_ii, w_i, z_in_i, op_iter_size);
+
+ // wz_real = w_real * z_real - w_imag * z_imag
+ __bang_sub(wz_r, wz_rr, wz_ii, op_iter_size);
+
+ // wz_imag = w_real * z_imag + w_imag * z_real
+ __bang_add(wz_i, wz_ri, wz_ir, op_iter_size);
+
+ // y + wz
+ __bang_add(x_out1_r, y_in_r, wz_r, op_iter_size);
+ __bang_add(x_out1_i, y_in_i, wz_i, op_iter_size);
+
+ // y - wz
+ __bang_sub(x_out2_r, y_in_r, wz_r, op_iter_size);
+ __bang_sub(x_out2_i, y_in_i, wz_i, op_iter_size);
+ }
+
+ // move the output result back to the input.
+ if (stage == ITER_ONCHIP) { // iterate on chip
+ __memcpy(y_in_r, x_out1_r, op_iter_size_bytes, NRAM2NRAM);
+ __memcpy(z_in_r, x_out2_r, op_iter_size_bytes, NRAM2NRAM);
+ __memcpy(y_in_i, x_out1_i, op_iter_size_bytes, NRAM2NRAM);
+ __memcpy(z_in_i, x_out2_i, op_iter_size_bytes, NRAM2NRAM);
+ }
+}
+
+template
+__mlu_func__ void computeOnchip(DT *y_in_r, DT *y_in_i, DT *x_out1_r,
+ DT *x_out1_i, DT *x_out2_r, DT *x_out2_i,
+ DT *w_tmp1, DT *w_tmp2, DT *w_tmp3, DT *w_r,
+ DT *w_i, DT *wz_rr, DT *wz_ri, DT *wz_ir,
+ DT *wz_ii, DT *wz_r, DT *wz_i, int L, int s,
+ int subgraph_size, int L_align,
+ int L_align_bytes, int fft_flag,
+ int direction) {
+ int op_iter_size = L_align;
+ for (int sub = 0; sub < subgraph_size; sub++) {
+ int unit_num_each_layer = powf(2, s - sub);
+#if 1 // generate w1 using vector operators
+ genWVec1((float *)w_r, (float *)w_i, (float *)w_tmp1, (float *)w_tmp2,
+ (float *)w_tmp3, L, L_align, L_align_bytes, op_iter_size,
+ fft_flag, direction);
+#else
+ genWSc1((float *)w_r, (float *)w_i, op_iter_size, fft_flag, direction,
+ l, L_align);
+#endif
+ for (int cnt = 0; cnt < unit_num_each_layer; cnt++) {
+ int offset = op_iter_size * YZ_FACTOR * cnt;
+ DT *y_in_r_local = y_in_r + offset;
+ DT *z_in_r_local = y_in_r_local + op_iter_size;
+ DT *y_in_i_local = y_in_i + offset;
+ DT *z_in_i_local = y_in_i_local + op_iter_size;
+ computeOneStep(wz_rr, wz_ri, wz_ir, wz_ii, w_r, w_i, wz_r, wz_i, x_out1_r,
+ x_out2_r, x_out1_i, x_out2_i, y_in_r_local, z_in_r_local,
+ y_in_i_local, z_in_i_local, op_iter_size, ITER_ONCHIP);
+ }
+ op_iter_size *= 2;
+ }
+}
+
+// Transpose to the required layout before output: [C, N] -> [N, C], C == 2 when
+// output is complex.
+template
+__mlu_func__ void transAndStore(DT *x_out_trans, DT *y_in_r, DT *y_in_i,
+ DT *z_in_i, void *output,
+ int basic_size_align_via_L,
+ int basic_size_bytes, int bc, int n, int ro,
+ int repeat_inner_basic_group, int ri,
+ int basic_size, // represent for y or y + z
+ int fft_flag, int stage) {
+ if (fft_flag == FFT_IFFT) {
+ int bc_offset = bc * n * COMPLEX_FACTOR;
+ int dst_offset =
+ (ro * repeat_inner_basic_group + ri) * basic_size * 2 + bc_offset;
+#if __BANG_ARCH__ >= 372
+ __bang_transpose(x_out_trans, y_in_r, COMPLEX_FACTOR,
+ basic_size_align_via_L);
+ __memcpy((DT *)output + dst_offset, x_out_trans,
+ basic_size_bytes * COMPLEX_FACTOR, NRAM2GDRAM);
+ if (stage == ITER_ONCHIP) {
+ __bang_transpose(x_out_trans, y_in_i, COMPLEX_FACTOR,
+ basic_size_align_via_L);
+ __memcpy((DT *)output + dst_offset + n, x_out_trans,
+ basic_size_bytes * COMPLEX_FACTOR, NRAM2GDRAM);
+ }
+#else
+ // For efficiency and space, we split one transpose [2, N] --> [2, x, N/x],
+ // into two as follows:
+ // trans1: [(2, x), N/x] -> [N/x, (2, x)]
+ // trans2: [(N/x, 2), x)] -> [x, (N/x, 2)]
+ // Note: both dim2 and dim3 need to meet alignment requirement:
+ // TRANS_ALIGN_SIZE / (int)sizeof(DT)
+ int dim1 = COMPLEX_FACTOR;
+ int dim2 = TRANS_ALIGN_SIZE / (int)sizeof(DT);
+ int dim3 = basic_size_align_via_L / dim2;
+ DT *x_out_trans_tmp = x_out_trans + dim1 * dim2 * dim3;
+ __bang_transpose(x_out_trans_tmp, y_in_r, dim1 * dim2, dim3);
+ __bang_transpose(x_out_trans, x_out_trans_tmp, dim3 * dim1, dim2);
+ __memcpy((DT *)output + dst_offset, x_out_trans,
+ basic_size_bytes * COMPLEX_FACTOR, NRAM2GDRAM);
+ if (stage == ITER_ONCHIP) {
+ __bang_transpose(x_out_trans_tmp, y_in_i, dim1 * dim2, dim3);
+ __bang_transpose(x_out_trans, x_out_trans_tmp, dim3 * dim1, dim2);
+ __memcpy((DT *)output + dst_offset + n, x_out_trans,
+ basic_size_bytes * COMPLEX_FACTOR, NRAM2GDRAM);
+ }
+#endif
+ } else if (fft_flag == RFFT) {
+ int bc_offset = bc * (n / 2 + 1) * COMPLEX_FACTOR;
+ int dst_offset = (ro * repeat_inner_basic_group + ri) * basic_size;
+ if (stage == ITER_OFFCHIP) {
+ *((DT *)output + bc_offset + n) =
+ *((DT *)y_in_r + basic_size / YZ_FACTOR);
+ *((DT *)output + bc_offset + (n + 1)) =
+ *((DT *)y_in_i + basic_size / YZ_FACTOR);
+ }
+#if __BANG_ARCH__ >= 372
+ __bang_transpose(x_out_trans, y_in_r, COMPLEX_FACTOR,
+ basic_size_align_via_L);
+#else
+ // the principle of transpose is the same as FFT_IFFT
+ int dim1 = COMPLEX_FACTOR;
+ int dim2 = TRANS_ALIGN_SIZE / (int)sizeof(DT);
+ int dim3 = basic_size_align_via_L / dim2;
+ DT *x_out_trans_tmp = x_out_trans + dim1 * dim2 * dim3;
+ __bang_transpose(x_out_trans_tmp, y_in_r, dim1 * dim2, dim3);
+ __bang_transpose(x_out_trans, x_out_trans_tmp, dim3 * dim1, dim2);
+#endif
+ if (stage == ITER_OFFCHIP) {
+ __memcpy((DT *)output + dst_offset + bc_offset, x_out_trans,
+ basic_size_bytes, NRAM2GDRAM);
+ } else {
+ __memcpy((DT *)output + dst_offset * YZ_FACTOR + bc_offset, x_out_trans,
+ basic_size_bytes * COMPLEX_FACTOR, NRAM2GDRAM);
+ }
+ if (ro == 0 && ri == 0 && stage == ITER_ONCHIP) {
+ *((DT *)output + n + bc_offset) = *(DT *)y_in_i;
+ *((DT *)output + (n + 1) + bc_offset) = *(DT *)z_in_i;
+ }
+ } else if (fft_flag == IRFFT) {
+ int bc_offset = bc * n;
+ int dst_offset = (ro * repeat_inner_basic_group + ri) * basic_size;
+ __memcpy((DT *)output + dst_offset + bc_offset, y_in_r, basic_size_bytes,
+ NRAM2GDRAM);
+ if (stage == ITER_ONCHIP) {
+ __memcpy((DT *)output + dst_offset + bc_offset + n / 2, y_in_i,
+ basic_size_bytes, NRAM2GDRAM);
+ }
+ }
+}
+
+template
+__mlu_func__ void loadMultiLayer(DT *y_in_r, DT *y_in_i, DT *x_out1_r,
+ DT *x_out2_r, DT *matmul_re_mul_re_addr,
+ DT *matmul_re_mul_im_addr,
+ DT *matmul_im_mul_re_addr,
+ DT *matmul_im_mul_im_addr, int L, int m,
+ int sub, int L_num_in_op_group, int L_bytes,
+ int L_align, int bc_offset, int fft_flag) {
+ for (int ln = 0; ln < L_num_in_op_group; ln++) {
+ int ind_outer_op = sub * L_num_in_op_group + ln; // set index for each op.
+ int ind_inner_op = 0;
+ permute(ind_inner_op, ind_outer_op, m);
+ int dst_offset = ln * L_align;
+ int src_offset = L * ind_inner_op + bc_offset;
+ // y and z: x_real*w_real
+ __memcpy(y_in_r + dst_offset, matmul_re_mul_re_addr + src_offset, L_bytes,
+ GDRAM2NRAM);
+ // y and z: x_real*w_image
+ __memcpy(y_in_i + dst_offset, matmul_re_mul_im_addr + src_offset, L_bytes,
+ GDRAM2NRAM);
+ // combine when input is: rr, ri, ir, ii
+ if (fft_flag == FFT_IFFT || fft_flag == IRFFT) {
+ // y and z: x_real*w_real
+ __memcpy(x_out1_r + dst_offset, matmul_im_mul_re_addr + src_offset,
+ L_bytes, GDRAM2NRAM);
+ // y and z: x_real*w_image
+ __memcpy(x_out2_r + dst_offset, matmul_im_mul_im_addr + src_offset,
+ L_bytes, GDRAM2NRAM);
+ __bang_sub(y_in_r + dst_offset, y_in_r + dst_offset,
+ x_out2_r + dst_offset, L_align);
+ __bang_add(y_in_i + dst_offset, y_in_i + dst_offset,
+ x_out1_r + dst_offset, L_align);
+ }
+ }
+}
+
+template
+__mlu_func__ void storeMultiLayer(DT *y_in_r, DT *y_in_i, DT *z_in_r,
+ DT *z_in_i, DT *wspace_r, DT *wspace_i,
+ DT *w_tmp1, DT *output, int n, int L, int bc,
+ int sub, int L_num_in_op_group, int L_bytes,
+ int L_align_bytes, int op_size,
+ int op_size_bytes,
+ int op_size_align_via_L_trans, int bc_offset,
+ int remain_layer_num, int fft_flag) {
+ // output -> workspace
+ int dst_offset = sub * L_num_in_op_group * L + bc_offset;
+
+ __memcpy(wspace_r + dst_offset, y_in_r, L_bytes, NRAM2GDRAM, L_bytes,
+ L_align_bytes, L_num_in_op_group - 1);
+ __memcpy(wspace_i + dst_offset, y_in_i, L_bytes, NRAM2GDRAM, L_bytes,
+ L_align_bytes, L_num_in_op_group - 1);
+ if (remain_layer_num == 0) {
+ // reorganize NRAM align size for transpose
+ DT *z_in_r = y_in_r + op_size_align_via_L_trans;
+ DT *y_in_i = z_in_r + op_size_align_via_L_trans;
+ DT *z_in_i = y_in_i + op_size_align_via_L_trans;
+ // y_real and z_real
+ __memcpy(y_in_r, wspace_r + bc_offset, op_size_bytes * YZ_FACTOR,
+ GDRAM2NRAM);
+ // y_imag and z_imag
+ __memcpy(y_in_i, wspace_i + bc_offset, op_size_bytes * YZ_FACTOR,
+ GDRAM2NRAM);
+ transAndStore(w_tmp1, y_in_r, y_in_i, z_in_i, output,
+ op_size_align_via_L_trans * YZ_FACTOR,
+ op_size_bytes * YZ_FACTOR, bc, n, sub, 1, 0,
+ op_size * YZ_FACTOR, fft_flag, ITER_OFFCHIP);
+ }
+}
+
+template
+__mlu_func__ void loadLayerByLayer(
+ DT *y_in_r, DT *y_in_i, DT *z_in_r, DT *z_in_i, DT *w_r, DT *w_i,
+ DT *w_tmp1, DT *w_tmp2, DT *w_tmp3, DT *wspace_r, DT *wspace_i,
+ int y_local_offset, int z_local_offset, int n, int L, int L_align, int ri,
+ int op_size, int op_size_bytes, int op_size_align, int op_size_bytes_align,
+ int op_group_distance, int bc_offset, int fft_flag, int direction) {
+#if 1 // generate w2 using vector operators
+ genWVec2((float *)w_r, (float *)w_i, (float *)w_tmp1, (float *)w_tmp2,
+ (float *)w_tmp3, op_size * op_group_distance / 2, n, L, L_align,
+ ri, op_size, op_size_align, op_size_bytes_align, fft_flag,
+ direction);
+#else
+ genWSc2((float *)w_r, (float *)w_i, op_size * op_group_distance / 2, ri,
+ op_size, op_size_align, fft_flag, direction, L, L_align);
+#endif
+
+ // load input data
+ __memcpy(y_in_r, wspace_r + y_local_offset + bc_offset, op_size_bytes,
+ GDRAM2NRAM);
+ __memcpy(z_in_r, wspace_r + z_local_offset + bc_offset, op_size_bytes,
+ GDRAM2NRAM);
+ __memcpy(y_in_i, wspace_i + y_local_offset + bc_offset, op_size_bytes,
+ GDRAM2NRAM);
+ __memcpy(z_in_i, wspace_i + z_local_offset + bc_offset, op_size_bytes,
+ GDRAM2NRAM);
+}
+
+template
+__mlu_func__ void storeLayerByLayer(
+ DT *y_in_r, DT *y_in_i, DT *z_in_r, DT *z_in_i, DT *x_out1_r, DT *x_out1_i,
+ DT *x_out2_r, DT *x_out2_i, DT *w_r, DT *w_i, DT *w_tmp1, DT *wspace_r,
+ DT *wspace_i, DT *output, int y_local_offset, int z_local_offset,
+ int repeat_id, int repeat_outer_op_group, int repeat_inner_op_group, int n,
+ int bc, int L, int L_align, int ri, int ro, int op_size, int op_size_bytes,
+ int op_size_align, int op_size_bytes_align, int op_size_align_via_L_trans,
+ int bc_offset, int remain_layer_num, int layer, int fft_flag,
+ int direction) {
+ if (layer < remain_layer_num - 1) {
+ __memcpy(wspace_r + y_local_offset + bc_offset, x_out1_r, op_size_bytes,
+ NRAM2GDRAM);
+ __memcpy(wspace_r + z_local_offset + bc_offset, x_out2_r, op_size_bytes,
+ NRAM2GDRAM);
+ __memcpy(wspace_i + y_local_offset + bc_offset, x_out1_i, op_size_bytes,
+ NRAM2GDRAM);
+ __memcpy(wspace_i + z_local_offset + bc_offset, x_out2_i, op_size_bytes,
+ NRAM2GDRAM);
+ } else {
+ DT *z_in_r = y_in_r + op_size_align_via_L_trans;
+ DT *y_in_i = z_in_r + op_size_align_via_L_trans;
+ DT *z_in_i = y_in_i + op_size_align_via_L_trans;
+ __memcpy(y_in_r, x_out1_r, op_size_bytes_align, NRAM2NRAM);
+ __memcpy(z_in_r, x_out1_i, op_size_bytes_align, NRAM2NRAM);
+ __memcpy(y_in_i, x_out2_r, op_size_bytes_align, NRAM2NRAM);
+ __memcpy(z_in_i, x_out2_i, op_size_bytes_align, NRAM2NRAM);
+ transAndStore(w_tmp1, y_in_r, y_in_i, z_in_i, output,
+ op_size_align_via_L_trans, op_size_bytes, bc, n, ro,
+ repeat_inner_op_group, ri, op_size, fft_flag,
+ ITER_ONCHIP);
+ }
+}
+
+template
+__mlu_func__ void computeMutiLayerOnchip(
+ const AddrNode &addr, const ParamNode ¶m, DT *matmul_re_mul_re_addr,
+ DT *matmul_re_mul_im_addr, DT *matmul_im_mul_re_addr,
+ DT *matmul_im_mul_im_addr, DT *output, int batch, int n, int m, int l,
+ int s, int fft_flag, int direction) {
+ // load subgraph from workspace data: X[C, batch_id, 2^s * core_offset, L]
+ // ->(C x 1 x 2^s x L) each ipu core deals with 2 sub graph
+ int repeat_remain_flag = (param.op_group_num_x_batch % taskDimX);
+ int repeat_plus_one = repeat_remain_flag > 0 ? 1 : 0;
+ int repeat_for_each_core =
+ (param.op_group_num_x_batch / taskDimX + repeat_plus_one);
+ MLULOG("[computeMutiLayerOnchip]: repeat_for_each_core: %ld\n",
+ repeat_for_each_core);
+ for (int id = 0; id < repeat_for_each_core; id++) {
+ int continue_flag_for_each_core =
+ repeat_remain_flag == 0 || (id != repeat_for_each_core - 1) ||
+ (id == repeat_for_each_core - 1 && taskId < repeat_remain_flag);
+ MLULOG(
+ "[computeMutiLayerOnchip]: taskIdX: %d, id: %ld, "
+ "continue_flag_for_each_core: %ld, ",
+ taskIdX, id, continue_flag_for_each_core);
+ MLULOG(
+ "repeat_remain_flag: %ld, repeat_plus_one: %ld, repeat_for_each_core: "
+ "%ld, ",
+ repeat_remain_flag, repeat_plus_one, repeat_for_each_core);
+ if (continue_flag_for_each_core) {
+ int id_global = id * taskDimX + taskId;
+ int bc = id_global / param.op_group_num_1_batch;
+ int sub = id_global % param.op_group_num_1_batch;
+ int bc_offset = bc * n;
+ int L_num_in_op_group =
+ param.op_size_align_via_L / param.L_align * YZ_FACTOR;
+ MLULOG(
+ "id_global: %ld, bc: %ld, sub: %ld, bc_offset: %ld, "
+ "L_num_in_op_group: %ld\n",
+ id_global, bc, sub, bc_offset, L_num_in_op_group);
+ loadMultiLayer(addr.y_in_r, addr.y_in_i, addr.x_out1_r, addr.x_out2_r,
+ matmul_re_mul_re_addr, matmul_re_mul_im_addr,
+ matmul_im_mul_re_addr, matmul_im_mul_im_addr, l, m, sub,
+ L_num_in_op_group, param.L_bytes, param.L_align, bc_offset,
+ fft_flag);
+ computeOnchip(addr.y_in_r, addr.y_in_i, addr.x_out1_r, addr.x_out1_i,
+ addr.x_out2_r, addr.x_out2_i, addr.w_tmp1, addr.w_tmp2,
+ addr.w_tmp3, addr.w_r, addr.w_i, addr.wz_rr, addr.wz_ri,
+ addr.wz_ir, addr.wz_ii, addr.wz_r, addr.wz_i, l, s,
+ param.subgraph_size, param.L_align, param.L_align_bytes,
+ fft_flag, direction);
+ storeMultiLayer(addr.y_in_r, addr.y_in_i, addr.z_in_r, addr.z_in_i,
+ addr.wspace_r, addr.wspace_i, addr.w_tmp1, output, n, l,
+ bc, sub, L_num_in_op_group, param.L_bytes,
+ param.L_align_bytes, param.op_size, param.op_size_bytes,
+ param.op_size_align_via_L_trans, bc_offset,
+ param.remain_layer_num, fft_flag);
+ }
+ }
+}
+
+template
+__mlu_func__ void computeLayerByLayer(const AddrNode &addr,
+ const ParamNode ¶m, DT *output,
+ int batch, int n, int m, int l, int s,
+ int fft_flag, int direction) {
+ for (int layer = 0; layer < param.remain_layer_num; layer++) {
+ int op_cnt_each_layer = powf(2, m - s - 1 - layer);
+ int repeat_outer_op_group = op_cnt_each_layer / 2;
+ int repeat_inner_op_group = powf(2, layer + 1);
+ int repeat_total_with_batch =
+ batch * repeat_outer_op_group * repeat_inner_op_group;
+ int repeat_remain_flag = (repeat_total_with_batch % taskDimX);
+ int repeat_plus_one = repeat_remain_flag > 0 ? 1 : 0;
+ int repeat_for_each_core =
+ (repeat_total_with_batch / taskDimX + repeat_plus_one);
+ int op_group_distance = powf(2, layer + YZ_FACTOR);
+ MLULOG("[computeLayerByLayer]: repeat_for_each_core: %ld\n",
+ repeat_for_each_core);
+ for (int repeat_id = 0; repeat_id < repeat_for_each_core; repeat_id++) {
+ int continue_flag_for_each_core =
+ // all ipu cores will be used the same times
+ repeat_remain_flag == 0
+ // assume that all ipu cores just less one than others at most
+ || (repeat_id != repeat_for_each_core - 1) ||
+ (repeat_id == repeat_for_each_core - 1 &&
+ taskId < repeat_remain_flag);
+ MLULOG(
+ "[computeLayerByLayer ]: taskIdX: %ld, id: %ld, "
+ "continue_flag_for_each_core: %ld, ",
+ taskIdX, repeat_id, continue_flag_for_each_core);
+ MLULOG(
+ "repeat_remain_flag: %ld, repeat_plus_one: %ld, "
+ "repeat_for_each_core: %ld, ",
+ repeat_remain_flag, repeat_plus_one, repeat_for_each_core, layer);
+ MLULOG("layer: %ld\n", layer);
+ int id_global = repeat_id * taskDimX + taskId;
+ int bc = id_global / (repeat_outer_op_group * repeat_inner_op_group);
+ int ro = id_global % (repeat_outer_op_group * repeat_inner_op_group) /
+ repeat_inner_op_group;
+ int ri = id_global % repeat_inner_op_group;
+ int bc_offset = bc * n;
+ int y_local_offset = (ro * op_group_distance + ri) * param.op_size;
+ int z_local_offset =
+ y_local_offset + op_group_distance / 2 * param.op_size;
+ MLULOG("id_global: %ld, bc: %ld, ro: %ld, ri: %ld\n", id_global, bc, ro,
+ ri);
+ MLULOG("y_local_offset: %ld, z_local_offset: %ld, bc_offset: %ld\n",
+ y_local_offset, z_local_offset, bc_offset);
+ if (continue_flag_for_each_core) {
+ loadLayerByLayer(addr.y_in_r, addr.y_in_i, addr.z_in_r, addr.z_in_i,
+ addr.w_r, addr.w_i, addr.w_tmp1, addr.w_tmp2,
+ addr.w_tmp3, addr.wspace_r, addr.wspace_i,
+ y_local_offset, z_local_offset, n, l, param.L_align,
+ ri, param.op_size, param.op_size_bytes,
+ param.op_size_align, param.op_size_bytes_align,
+ op_group_distance, bc_offset, fft_flag, direction);
+ computeOneStep(addr.wz_rr, addr.wz_ri, addr.wz_ir, addr.wz_ii, addr.w_r,
+ addr.w_i, addr.wz_r, addr.wz_i, addr.x_out1_r,
+ addr.x_out2_r, addr.x_out1_i, addr.x_out2_i, addr.y_in_r,
+ addr.z_in_r, addr.y_in_i, addr.z_in_i,
+ param.op_size_align, ITER_OFFCHIP);
+ storeLayerByLayer(
+ addr.y_in_r, addr.y_in_i, addr.z_in_r, addr.z_in_i, addr.x_out1_r,
+ addr.x_out1_i, addr.x_out2_r, addr.x_out2_i, addr.w_r, addr.w_i,
+ addr.w_tmp1, addr.wspace_r, addr.wspace_i, output, y_local_offset,
+ z_local_offset, repeat_id, repeat_outer_op_group,
+ repeat_inner_op_group, n, bc, l, param.L_align, ri, ro,
+ param.op_size, param.op_size_bytes, param.op_size_align,
+ param.op_size_bytes_align, param.op_size_align_via_L_trans,
+ bc_offset, param.remain_layer_num, layer, fft_flag, direction);
+ }
+ }
+ __sync_all_ipu();
+ }
+}
+
+// fftCooleyTukey combine subgraphs as follows:
+//
+// layer0: layer1: layer2: ...
+// subgraph0: y_in_0 -------> x_out1--(y_in_0) -------> x_out1--(y_in_0)
+// -------> x_out --> x_trans
+// z_in_0 _| |_> x_out2_| | | | |
+// | | | |
+// subgraph1: y_in_0 -------> x_out1--(z_in_0) _| |_> x_out2_| |
+// z_in_1 _| |_> x_out2_| |
+// |
+// subgraph2: y_in_0 -------> x_out1--(y_in_0) -------> x_out1--(y_in_1) _|
+// z_in_0 _| |_> x_out2_| | | |
+// | | |
+// subgraph3: y_in_0 -------> x_out1--(z_in_0) _| |_> x_out2_|
+// z_in_1 _| |_> x_out2_|
+// ...
+//
+// where: x_out1 = y_in_0 + W * z_in_0, x_out_2 = y_in_0 - W * z_in_1.
+// the size of subgraph increases layer by layer, equals to (y_in_0) + (z_in_0).
+//
+// when subgraph can be placed on chip, call function computeMutiLayerOnchip(),
+// and write the intermediate result back to the workspace, otherwise call
+// function computeLayerByLayer().
+template
+__mlu_func__ void fftCooleyTukey(DT *matmul_re_mul_re_addr,
+ DT *matmul_re_mul_im_addr,
+ DT *matmul_im_mul_re_addr,
+ DT *matmul_im_mul_im_addr,
+ DT *internal_workspace_addr, DT *output,
+ int fft_flag, int direction, int n, int batch,
+ int L, int m, int s) {
+ MLULOG("batch: %d, n: %d, L: %d, m: %d, s: %d, fft_flag: %d, direction: %d\n",
+ batch, n, L, m, s, fft_flag, direction);
+ int align_size = NFU_ALIGN_SIZE / sizeof(DT);
+ ParamNode param;
+ // Data Info:
+ param.subgraph_size =
+ s + 1; // the size of subgraph that can be placed on NRAM
+ param.L_bytes = L * sizeof(DT);
+ param.L_align = PAD_UP(L, align_size);
+ param.L_align_bytes = param.L_align * sizeof(DT);
+ param.op_size = powf(2, s) * L;
+ param.op_size_align = PAD_UP(param.op_size, align_size);
+ param.op_size_align_via_L = powf(2, s) * param.L_align;
+
+ param.op_size_bytes = param.op_size * sizeof(DT);
+ param.op_size_bytes_align = PAD_UP(param.op_size_bytes, NFU_ALIGN_SIZE);
+ param.op_size_align_via_L_trans =
+ PAD_UP(param.op_size_align_via_L,
+ int(powf(TRANS_ALIGN_SIZE / (int)sizeof(DT), 2)));
+ param.op_group_num_1_batch = powf(2, m - (s + 1));
+ param.op_group_num_x_batch = param.op_group_num_1_batch * batch;
+ param.remain_layer_num = m - (s + 1);
+ int half_multiplier = sizeof(DT) == sizeof(half) ? 2 : 1;
+ int op_size_align_via_L_dt = param.op_size_align_via_L * half_multiplier;
+ MLULOG("subgraph_size: %d, L_bytes: %d, L_align: %d, L_align_bytes: %d",
+ param.subgraph_size, param.L_bytes, param.L_align,
+ param.L_align_bytes);
+ MLULOG(
+ "op_size: %d, op_size_align: %d, op_size_align_via_L: %d, op_size_bytes: "
+ "%d",
+ param.op_size, param.op_size_align, param.op_size_align_via_L,
+ param.op_size_bytes);
+ MLULOG(
+ "op_size_bytes_align: %d, op_size_align_via_L_trans: %d, "
+ "op_group_num_1_batch: %d",
+ param.op_size_bytes_align, param.op_size_align_via_L_trans,
+ param.op_group_num_1_batch);
+ MLULOG("op_group_num_x_batch: %d, remain_layer_num: %d\n",
+ param.op_group_num_x_batch, param.remain_layer_num);
+ AddrNode addr;
+ // GDRAM Addr Info:
+ addr.wspace_r = internal_workspace_addr;
+ addr.wspace_i = internal_workspace_addr + n * batch;
+
+ // NRAM Addr Info:
+ // input addr:
+ addr.y_in_r = (DT *)nram_buffer;
+ addr.z_in_r = addr.y_in_r + op_size_align_via_L_dt;
+ addr.y_in_i = addr.z_in_r + op_size_align_via_L_dt;
+ addr.z_in_i = addr.y_in_i + op_size_align_via_L_dt;
+ // output addr:
+ addr.x_out1_r = addr.z_in_i + op_size_align_via_L_dt;
+ addr.x_out2_r = addr.x_out1_r + op_size_align_via_L_dt;
+ addr.x_out1_i = addr.x_out2_r + op_size_align_via_L_dt;
+ addr.x_out2_i = addr.x_out1_i + op_size_align_via_L_dt;
+ // w_matrix addr:
+ addr.w_r = addr.x_out2_i + op_size_align_via_L_dt;
+ addr.w_i = addr.w_r + op_size_align_via_L_dt;
+ // temp addr reserved for vector generation w_matrix.
+ addr.w_tmp1 = addr.w_i + op_size_align_via_L_dt;
+ addr.w_tmp2 = addr.w_tmp1 + op_size_align_via_L_dt;
+ addr.w_tmp3 = addr.w_tmp2 + op_size_align_via_L_dt;
+ // temp addr reserved for subgraph internal merge calculation, using the same
+ // addr with w_tmp*.
+ addr.wz_rr = addr.w_i + op_size_align_via_L_dt;
+ addr.wz_ri = addr.wz_rr + op_size_align_via_L_dt;
+ addr.wz_ir = addr.wz_ri + op_size_align_via_L_dt;
+ addr.wz_ii = addr.wz_ir + op_size_align_via_L_dt;
+ addr.wz_r = addr.wz_rr; // using the same addr with wz_rr
+ addr.wz_i = addr.wz_ri; // using the same addr with wz_ri
+ computeMutiLayerOnchip(addr, param, matmul_re_mul_re_addr,
+ matmul_re_mul_im_addr, matmul_im_mul_re_addr,
+ matmul_im_mul_im_addr, output, batch, n, m, L, s,
+ fft_flag, direction);
+ __sync_all_ipu();
+ computeLayerByLayer(addr, param, output, batch, n, m, L, s, fft_flag,
+ direction);
+}
+
+__mlu_global__ void MLUKernelFFTCooleyTukey(
+ void *matmul_re_mul_re_addr, void *matmul_re_mul_im_addr,
+ void *matmul_im_mul_re_addr, void *matmul_im_mul_im_addr,
+ void *internal_workspace_addr, void *output, int fft_flag, int direction,
+ int n, int batch, int L, int m, int s, int dtype_size) {
+ if (coreId == 0x80) return;
+ switch (dtype_size) {
+ default: {
+ MLULOG("mluOpFFT Not Implemented.");
+ }
+ case (MLUOP_DTYPE_COMPLEX_FLOAT):
+ case (MLUOP_DTYPE_FLOAT): {
+ MLULOG("MLUOP_DTYPE_COMPLEX_FLOAT: MLUOP_DTYPE_FLOAT\n");
+ fftCooleyTukey(
+ (float *)matmul_re_mul_re_addr, (float *)matmul_re_mul_im_addr,
+ (float *)matmul_im_mul_re_addr, (float *)matmul_im_mul_im_addr,
+ (float *)internal_workspace_addr, (float *)output, fft_flag,
+ direction, n, batch, L, m, s);
+ }; break;
+ case (MLUOP_DTYPE_COMPLEX_HALF):
+ case (MLUOP_DTYPE_HALF): {
+ MLULOG("MLUOP_DTYPE_COMPLEX_HALF: MLUOP_DTYPE_HALF\n");
+ fftCooleyTukey(
+ (half *)matmul_re_mul_re_addr, (half *)matmul_re_mul_im_addr,
+ (half *)matmul_im_mul_re_addr, (half *)matmul_im_mul_im_addr,
+ (half *)internal_workspace_addr, (half *)output, fft_flag, direction,
+ n, batch, L, m, s);
+ }; break;
+ }
+}
+
+mluOpStatus_t MLUOP_WIN_API kernelFFTCooleyTukey(cnrtDim3_t k_dim,
+ cnrtFunctionType_t k_type,
+ cnrtQueue_t queue,
+ mluOpFFTPlan_t fft_plan,
+ int direction, FFTFlag flag) {
+ VLOG(5) << "Launch Kernel MLUKernelFFTCooleyTukey<>>";
+ KERNEL_CHECK((MLUKernelFFTCooleyTukey<<>>(
+ fft_plan->matmul_addrs.matmul_re_mul_re_addr,
+ fft_plan->matmul_addrs.matmul_re_mul_im_addr,
+ fft_plan->matmul_addrs.matmul_im_mul_re_addr,
+ fft_plan->matmul_addrs.matmul_im_mul_im_addr,
+ fft_plan->matmul_addrs.internal_workspace_addr,
+ fft_plan->matmul_addrs.output_contiguous_addr, flag, direction,
+ fft_plan->n[0], fft_plan->batch, fft_plan->L, fft_plan->m, fft_plan->s,
+ fft_plan->output_dtype)));
+ return MLUOP_STATUS_SUCCESS;
+}
diff --git a/kernels/fft/fft_optm_device/fft_stockham_u1_device.mlu b/kernels/fft/fft_optm_device/fft_stockham_u1_device.mlu
new file mode 100644
index 000000000..f819b87cc
--- /dev/null
+++ b/kernels/fft/fft_optm_device/fft_stockham_u1_device.mlu
@@ -0,0 +1,736 @@
+/*************************************************************************
+ * Copyright (C) [2024] by Cambricon, Inc.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a
+ * copy of this software and associated documentation files (the
+ * "Software"), to deal in the Software without restriction, including
+ * without limitation the rights to use, copy, modify, merge, publish,
+ * distribute, sublicense, and/or sell copies of the Software, and to
+ * permit persons to whom the Software is furnished to do so, subject to
+ * the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included
+ * in all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+ * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+ * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+ * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
+ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
+ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+ *************************************************************************/
+#include "mlu.h"
+#include "kernels/debug.h"
+#include "kernels/kernel.h"
+#include "kernels/utils/common.h"
+#include "kernels/fft/fft.h"
+
+// direction: 1 means IFFT, used to distinguish FFT and IFFT.
+#define FFT_INVERSE 1
+
+// Two split dimemsion(batch && L) trversal orders can be selected via
+// "L_FIRST". By default "L" is preferred, then "batch".
+#define L_FIRST 1
+
+__nram__ char nram_buffer[MAX_NRAM_SIZE + REM_FOR_STACK - 32 * 1024];
+
+// Generate w vector.
+template
+__mlu_func__ void genWSc1_opt(DT* w_r, DT* w_i, DT* tmp, DT* seq_addr,
+ const int& L, const int& L_sub, const int& part,
+ const int& unit_size, float scale, int n) {
+ float inc_value = part * L_sub;
+ int size_tmp_bytes = L_sub * unit_size;
+ scale = scale / unit_size;
+ __bang_add_scalar(tmp, seq_addr, inc_value, size_tmp_bytes);
+ __bang_mul_scalar(tmp, tmp, scale, size_tmp_bytes);
+
+#if __BANG_ARCH__ >= 372
+ __bang_cos((float*)w_r, (float*)tmp, size_tmp_bytes);
+ if (n <= 48000) {
+ __bang_sin((float*)w_i, (float*)tmp, size_tmp_bytes);
+ } else {
+ // This function has higher precision, and the actual test determined n.
+ __cn_vector_sin_f32(size_tmp_bytes, (float*)w_i, (float*)tmp);
+ }
+#endif
+}
+
+// Load input data from GDRAM to NRAM. The data source(src_in) is the
+// calculation result of "mluOpBatchMatmulBcast". Different "fft_flag" has
+// different layout, as follows:
+// RFFT: src_in[batch, 2, L, powf(2, m)] = w[batch, 2, (L/2 + 1), l] *
+// in_ori[L, powf(2, m)] IRFFT: src_in[2, 2, batch, L, powf(2, m)] = w[2,
+// batch, L, l] * in_ori[2, L, powf(2, m)] FFT_IFFT: src_in[batch, 2, L,
+// powf(2, m), 2] = w[batch, 2, L, l] * in_ori[L, powf(2, m), 2]
+// "2," represents the real and imag part size of "w" or "src_in". According to
+// the multicore splitting method, each "load" takes data block [1, *, L_sub,
+// powf(2, m), *], using "memcpy_async" with stride. When input type is float,
+// the address is used as defined by name: in(y_in_r, z_in_r, ...), ... When
+// input type is half, the address interval is the same as float. The reason is
+// that we perform bit width promotion calculations, considering the accuracy.
+// Some temporarily unused space is used to ensure opertions such as
+// "half2float" in "compute()" function.
+template
+__mlu_func__ void load(DT* y_in_r, DT* y_in_i, DT* z_in_r, DT* z_in_i,
+ DT* x_out1_r, DT* x_out1_i, DT* x_out2_r, DT* x_out2_i,
+ DT* wz_rr, DT* wz_ir, DT* matmul_re_mul_re_addr,
+ const int& n, const int& L, const int& L_sub,
+ const int& part_num, const int& pow_2_m,
+ const int& pow_2_m_half, const int& batch_x_part,
+ const int& fft_flag, const int& batch,
+ const int& op_size_align_via_L_dt,
+ const int& ping_pong) {
+#if L_FIRST
+ int b = batch_x_part / part_num;
+ int part = batch_x_part % part_num;
+#else
+ int part = batch_x_part / batch;
+ int b = batch_x_part % batch;
+#endif
+ int pingpong_offset = batch_x_part % 2 == 0 ? 0 : ping_pong;
+ int L_re = L % L_sub;
+ int L_deal = part < (part_num - 1) ? L_sub : (L_re != 0 ? L_re : L_sub);
+
+ if (sizeof(DT) == sizeof(half)) {
+ if (fft_flag == RFFT) {
+ // Save the real part of RFFT, data shape is [1, 1, L_sub, powf(2, m)].
+ x_out1_r = y_in_i;
+ // Save the imag part of RFFT, data shape is the same as above.
+ x_out1_i = wz_ir;
+ } else if (fft_flag == IRFFT) {
+ // Save the real * real part of IRFFT, data shape is [1, 1, 1, L_sub,
+ // powf(2, m)].
+ x_out1_r = x_out2_r;
+ // Save the real * imag part of IRFFT, data shape is the same as above.
+ x_out1_i = x_out2_i;
+ // Save the imag * real part of IRFFT, data shape is the same as above.
+ y_in_r = z_in_r;
+ // Save the imag * imag part of IRFFT, data shape is the same as above.
+ y_in_i = z_in_i;
+ } else if (fft_flag == FFT_IFFT) {
+ // Save the real * (real + imag) part of FFT_IFFT, data shape is
+ // [1, 1, L_sub, powf(2, m), 2].
+ y_in_r = y_in_i;
+ // Save the imag * (real + imag) part of FFT_IFFT, data shape is the same
+ // as above.
+ wz_rr = wz_ir;
+ }
+ }
+ x_out1_r += pingpong_offset;
+ x_out1_i += pingpong_offset;
+ y_in_r += pingpong_offset;
+ y_in_i += pingpong_offset;
+
+ if (fft_flag == RFFT) {
+ int src_offset = L_sub * pow_2_m * part + b * n * 2;
+ int data_size_bytes = pow_2_m * sizeof(DT);
+ int total_data_size_bytes = L_deal * data_size_bytes;
+ int distance_bytes = int((char*)x_out1_i - (char*)x_out1_r);
+ if (part < part_num / 2 || part_num == 1) {
+ __memcpy_async(x_out1_r, matmul_re_mul_re_addr + src_offset,
+ total_data_size_bytes, GDRAM2NRAM, distance_bytes,
+ n * sizeof(DT), 1);
+ } else {
+ // According to conjugate symmetry, only the first L/2+1 set of data is
+ // calculated by "mluOpBatchMatmulBcast", and the second half of the data
+ // eed to be calculated according to the coordinate mapping. This is the
+ // reson why memcpy's src_str appears negative below.
+ int ind_fwd = part * L_sub + L % 2;
+ int src_offset = b * n * 2 + (L - ind_fwd + L % 2) * pow_2_m;
+ __memcpy_async(x_out1_r, // dst addr
+ matmul_re_mul_re_addr + src_offset, // src addr
+ data_size_bytes, // size
+ GDRAM2NRAM, // direction
+ data_size_bytes, // dst_stride o0
+ L_deal - 1, // dst_segnum o1
+ distance_bytes, // dst_stride o1
+ 1, // dst_segnum o2
+ -data_size_bytes, // src_stride i0
+ L_deal - 1, // src_segnum i1
+ n * sizeof(DT), // src_stride i1
+ 1); // src_segnum i2
+ }
+ } else if (fft_flag == IRFFT) {
+ int total_data_size_bytes = L_deal * pow_2_m * sizeof(DT);
+ DT* x[4] = {x_out1_r, x_out1_i, y_in_r, y_in_i};
+ for (int addr_i = 0; addr_i < 4; addr_i++) {
+ int complex_in = addr_i / 2;
+ int complex_w = addr_i % 2;
+ __memcpy_async(x[addr_i],
+ matmul_re_mul_re_addr +
+ complex_in * batch * 2 * L * pow_2_m +
+ b * 2 * L * pow_2_m + complex_w * L * pow_2_m +
+ part * L_sub * pow_2_m,
+ total_data_size_bytes, GDRAM2NRAM);
+ }
+ } else if (fft_flag == FFT_IFFT) {
+ wz_rr += pingpong_offset;
+ int src_offset = b * 2 * n * 2 + part * L_sub * pow_2_m * 2;
+ int total_data_size_bytes = L_deal * pow_2_m * 2 * sizeof(DT);
+ __memcpy_async(y_in_r, matmul_re_mul_re_addr + src_offset,
+ total_data_size_bytes, GDRAM2NRAM);
+ __memcpy_async(wz_rr, matmul_re_mul_re_addr + src_offset + n * 2,
+ total_data_size_bytes, GDRAM2NRAM);
+ }
+}
+
+template
+__mlu_func__ void preProcessRFFT(YT* y_in_r, YT* y_in_i, YT* x_out1_r,
+ YT* x_out1_i, YT* wz_ir, const int& L_sub,
+ const int& part_num, const int& pow_2_m,
+ const int& part) {
+ if (sizeof(DT) == sizeof(half)) {
+ if (part >= part_num / 2 && part_num > 1) {
+ // According to conjugate symmetry, it need to multiply the second half of
+ // the imag part by -1.
+ __bang_mul_scalar((DT*)wz_ir, (DT*)wz_ir, -1.0, pow_2_m * L_sub);
+ }
+ // Transpose L_sub to the lowest dimension for easy vector operations.
+ __bang_transpose((DT*)x_out1_r, (DT*)y_in_i, L_sub, pow_2_m);
+ __bang_transpose((DT*)x_out1_i, (DT*)wz_ir, L_sub, pow_2_m);
+ // Convert to float, prepare for bitwidth promition calculation.
+ __bang_half2float((float*)y_in_r, (half*)x_out1_r, L_sub * pow_2_m);
+ __bang_half2float((float*)y_in_i, (half*)x_out1_i, L_sub * pow_2_m);
+ } else {
+ if (part >= part_num / 2 && part_num > 1) {
+ // According to conjugate symmetry, it need to multiply the second half of
+ // the imag part by -1.
+ __bang_mul_scalar(x_out1_i, x_out1_i, -1.0, pow_2_m * L_sub);
+ }
+ // Transpose L_sub to the lowest dimension for easy vector operations.
+ __bang_transpose(y_in_r, x_out1_r, L_sub, pow_2_m);
+ __bang_transpose(y_in_i, x_out1_i, L_sub, pow_2_m);
+ }
+}
+
+template
+__mlu_func__ void preProcessFFT_IFFT(YT* y_in_r, YT* y_in_i, YT* z_in_r,
+ YT* x_out1_r, YT* x_out1_i, YT* wz_rr,
+ YT* wz_ri, YT* wz_ir, const int& L_sub,
+ const int& pow_2_m) {
+ if (sizeof(DT) == sizeof(half)) {
+ // Transpose L_sub to the lowest dimension for easy vector operations.
+ __bang_transpose((DT*)y_in_r, (DT*)y_in_i, L_sub * pow_2_m, 2);
+ __bang_transpose((DT*)wz_rr, (DT*)wz_ir, L_sub * pow_2_m, 2);
+ // Compute the real part: src_in(real * real) - src_in(imag * imag).
+ __bang_sub((DT*)y_in_r, (DT*)y_in_r, (DT*)wz_ri, L_sub * pow_2_m);
+ // Compute the imag part: src_in(real * imag) - src_in(imag * real).
+ __bang_add((DT*)wz_rr, (DT*)wz_rr, (DT*)z_in_r, L_sub * pow_2_m);
+ // Transpose L_sub to the lowest dimension for easy vector operations.
+ __bang_transpose((DT*)y_in_i, (DT*)y_in_r, L_sub, pow_2_m);
+ __bang_transpose((DT*)wz_ir, (DT*)wz_rr, L_sub, pow_2_m);
+ // Convert to float, prepare for bitwidth promition calculation.
+ __bang_half2float((float*)y_in_r, (half*)y_in_i, L_sub * pow_2_m);
+ __bang_half2float((float*)y_in_i, (half*)wz_ir, L_sub * pow_2_m);
+ } else {
+ // Transpose the read and imag parts to the highest dimension for easy
+ // vector operations.
+ __bang_transpose(x_out1_r, y_in_r, L_sub * pow_2_m, 2);
+ __bang_transpose(y_in_r, wz_rr, L_sub * pow_2_m, 2);
+ // Compute the real part: src_in(real * real) - src_in(imag * imag).
+ __bang_sub(x_out1_r, x_out1_r, y_in_i, L_sub * pow_2_m);
+ // Compute the imag part: src_in(real * imag) - src_in(imag * real).
+ __bang_add(x_out1_i, x_out1_i, y_in_r, L_sub * pow_2_m);
+ // Transpose L_sub to the lowest dimension for easy vector operations.
+ __bang_transpose(y_in_r, x_out1_r, L_sub, pow_2_m);
+ __bang_transpose(y_in_i, x_out1_i, L_sub, pow_2_m);
+ }
+}
+
+template
+__mlu_func__ void preProcessIRFFT(YT* y_in_r, YT* y_in_i, YT* z_in_r,
+ YT* z_in_i, YT* x_out1_r, YT* x_out1_i,
+ YT* x_out2_r, YT* x_out2_i, YT* wz_ir,
+ const int& L_sub, const int& pow_2_m) {
+ if (sizeof(DT) == sizeof(half)) {
+ // Compute the real part: src_in(real * real) - src_in(imag * imag).
+ __bang_sub((DT*)x_out2_r, (DT*)x_out2_r, (DT*)z_in_i, L_sub * pow_2_m);
+ // Compute the imag part: src_in(real * imag) - src_in(imag * real).
+ __bang_add((DT*)x_out2_i, (DT*)x_out2_i, (DT*)z_in_r, L_sub * pow_2_m);
+ // Transpose L_sub to the lowest dimension for easy vector operations.
+ __bang_transpose((DT*)z_in_r, (DT*)x_out2_r, L_sub, pow_2_m);
+ __bang_transpose((DT*)wz_ir, (DT*)x_out2_i, L_sub, pow_2_m);
+ // Convert to float, prepare for bitwidth promition calculation.
+ __bang_half2float((float*)y_in_r, (half*)z_in_r, L_sub * pow_2_m);
+ __bang_half2float((float*)y_in_i, (half*)wz_ir, L_sub * pow_2_m);
+ } else {
+ // Compute the real part: src_in(real * real) - src_in(imag * imag).
+ __bang_sub(x_out1_r, x_out1_r, y_in_i, L_sub * pow_2_m);
+ // Compute the imag part: src_in(real * imag) - src_in(imag * real).
+ __bang_add(x_out1_i, x_out1_i, y_in_r, L_sub * pow_2_m);
+ // Transpose L_sub to the lowest dimension for easy vector operations.
+ __bang_transpose(y_in_r, x_out1_r, L_sub, pow_2_m);
+ __bang_transpose(y_in_i, x_out1_i, L_sub, pow_2_m);
+ }
+}
+
+// Perform preprocessing for "compute()" function, including merging of real and
+// imag parts, transposition and data types conversion, etc.
+template
+__mlu_func__ void preProcess(YT* y_in_r, YT* y_in_i, YT* z_in_r, YT* z_in_i,
+ YT* x_out1_r, YT* x_out1_i, YT* x_out2_r,
+ YT* x_out2_i, YT* wz_rr, YT* wz_ri, YT* wz_ir,
+ const int& fft_flag, const int& L_sub,
+ const int& part_num, const int& pow_2_m,
+ const int& part) {
+ if (fft_flag == RFFT) {
+ preProcessRFFT((float*)y_in_r, (float*)y_in_i, (float*)x_out1_r,
+ (float*)x_out1_i, (float*)wz_ir, L_sub, part_num,
+ pow_2_m, part);
+ } else if (fft_flag == FFT_IFFT) {
+ preProcessFFT_IFFT((float*)y_in_r, (float*)y_in_i,
+ (float*)z_in_r, (float*)x_out1_r,
+ (float*)x_out1_i, (float*)wz_rr,
+ (float*)wz_ri, (float*)wz_ir, L_sub, pow_2_m);
+ } else if (fft_flag == IRFFT) {
+ preProcessIRFFT((float*)y_in_r, (float*)y_in_i, (float*)z_in_r,
+ (float*)z_in_i, (float*)x_out1_r,
+ (float*)x_out1_i, (float*)x_out2_r,
+ (float*)x_out2_i, (float*)wz_ir, L_sub, pow_2_m);
+ }
+}
+
+template
+__mlu_func__ void computeOneLayer(YT* y_in_r, YT* y_in_i, YT* z_in_r,
+ YT* z_in_i, YT* x_out1_r, YT* x_out1_i,
+ YT* w_r, YT* w_i, YT* wz_rr, YT* wz_ri,
+ YT* wz_ir, YT* wz_ii, const int& fft_flag,
+ const int& L_sub, const int& part,
+ const int& pow_2_m_half, const int& layer_num,
+ int ln, int ln_pow2) {
+ int basic_size = L_sub * ln_pow2;
+ int group_size = basic_size * 2;
+ int basic_group_num = pow_2_m_half / ln_pow2;
+ int long_size_bytes = basic_size * basic_group_num;
+ // Compute w * z_in: real * reaL, real * imag, imag * reaL, imag * imag.
+ __bang_cycle_mul(wz_rr, z_in_r, w_r, long_size_bytes, basic_size);
+ __bang_cycle_mul(wz_ri, z_in_i, w_r, long_size_bytes, basic_size);
+ __bang_cycle_mul(wz_ir, z_in_r, w_i, long_size_bytes, basic_size);
+ __bang_cycle_mul(wz_ii, z_in_i, w_i, long_size_bytes, basic_size);
+ // Combine real and imag parts: real = real * real - imag * imag, imag = real
+ // * imag + imag * real.
+ __bang_sub(wz_rr, wz_rr, wz_ii, long_size_bytes);
+ __bang_add(wz_ri, wz_ri, wz_ir, long_size_bytes);
+
+ for (int bgn = 0; bgn < basic_group_num; bgn++) {
+ int bgn_offset = basic_size * bgn;
+ YT* y_r = y_in_r + bgn_offset;
+ YT* y_i = y_in_i + bgn_offset;
+ YT* x_r = x_out1_r + group_size * bgn;
+ YT* x_i = x_out1_i + group_size * bgn;
+ YT* wz_rr_tmp = wz_rr + bgn_offset;
+ YT* wz_ri_tmp = wz_ri + bgn_offset;
+ // Compute x_out1 = y_in + w * z_in.
+ __bang_add(x_r, y_r, wz_rr_tmp, basic_size);
+ __bang_add(x_i, y_i, wz_ri_tmp, basic_size);
+ if (fft_flag == RFFT) {
+ if (ln != layer_num - 1) {
+ // Compute x_out2 = y_in - w * z_in.
+ __bang_sub(x_r + basic_size, y_r, wz_rr_tmp, basic_size);
+ __bang_sub(x_i + basic_size, y_i, wz_ri_tmp, basic_size);
+ } else if (part == 0) {
+ // According to conjugate symmetrym the last layer does not need to
+ // calculate the second half part, except the point (n/2 + 1).
+ *((YT*)x_r + basic_size) = *((YT*)y_r) - *((YT*)wz_rr_tmp);
+ *((YT*)x_i + basic_size) = *((YT*)y_i) - *((YT*)wz_ri_tmp);
+ }
+ } else {
+ // Compute x_out2 = y_in - w * z_in.
+ __bang_sub(x_r + basic_size, y_r, wz_rr_tmp, basic_size);
+ __bang_sub(x_i + basic_size, y_i, wz_ri_tmp, basic_size);
+ }
+ }
+}
+
+// Accoding to the merging rules of Stockham algorithm, calculate layer by
+// layer. An examples is as follows:
+//
+// layer0 |layer1 |layer2 |layer3
+// ---------|------------|------------------|-------------------------
+// {0} |{0, 4} |{0, 4, 2, 6} |{0, 4, 2, 6, 1, 5, 3, 7}
+// {1} | | |
+// {2} |{1, 5} | |
+// {3} | | |
+// {4} |{2, 6} |{1, 5, 2, 6} |
+// {5} | | |
+// {6} |{3, 7} | |
+// {7} | | |
+//
+// Each {*} represets a sequence of of complex numbers of length l. Each time
+// the first half and the second half are merged, such as {0} and {4}, {0, 4}
+// and {1, 6}. The first half is y_in, the second half is z_in, and the output
+// is x_out*(the first half is x_out1, the second half is x_out2). The
+// calculation formula(Butterfly Transform) is:
+// x_out1 = y_in + w * z_in
+// x_out2 = y_in - w * z_in
+// w is calculted as follows: w_k = exp(-i * k * (2 * pi / N) * flag), k
+// represents the k_th point, i represents real and imag part, N represents the
+// total number of points, flag represents FFT type, 1 for RFFT and FFT, -1 for
+// IRFFT and IFFT.
+template
+__mlu_func__ void compute(YT* y_in_r, YT* y_in_i, YT* z_in_r, YT* z_in_i,
+ YT* x_out1_r, YT* x_out1_i, YT* x_out2_r,
+ YT* x_out2_i, YT* w_r, YT* w_i, YT* wz_rr, YT* wz_ri,
+ YT* wz_ir, YT* wz_ii, YT* seq_addr,
+ const int& fft_flag, const int& direction,
+ const int& n, const int& L, const int& L_sub,
+ const int& part_num, const int& pow_2_m,
+ const int& pow_2_m_half, const int& layer_num,
+ const int& op_size_align_via_L_dt, float scale,
+ const float scale_factor, const int& batch_x_part,
+ const int& batch, int ping_pong) {
+#if L_FIRST
+ int part = batch_x_part % part_num;
+#else
+ int part = batch_x_part / batch;
+#endif
+ if (sizeof(DT) == sizeof(half)) {
+ // Because float type is acually used, the number of points is half of half
+ // type.
+ ping_pong = ping_pong / 2;
+ }
+ int pingpong_offset = batch_x_part % 2 == 0 ? 0 : ping_pong;
+ y_in_r += pingpong_offset;
+ y_in_i += pingpong_offset;
+ z_in_r += pingpong_offset;
+ z_in_i += pingpong_offset;
+ x_out1_r += pingpong_offset;
+ x_out1_i += pingpong_offset;
+ x_out2_r += pingpong_offset;
+ x_out2_i += pingpong_offset;
+ w_r += pingpong_offset;
+ w_i += pingpong_offset;
+ wz_rr += pingpong_offset;
+ wz_ri += pingpong_offset;
+ wz_ir += pingpong_offset;
+ wz_ii += pingpong_offset;
+ preProcess((float*)y_in_r, (float*)y_in_i, (float*)z_in_r,
+ (float*)z_in_i, (float*)x_out1_r, (float*)x_out1_i,
+ (float*)x_out2_r, (float*)x_out2_i, (float*)wz_rr,
+ (float*)wz_ri, (float*)wz_ir, fft_flag, L_sub, part_num,
+ pow_2_m, part);
+
+ // Calculate layer by layer as shown in the example.
+ for (int ln = 0; ln < layer_num; ln++) {
+ int ln_pow2 = powf(2, ln);
+ // Generate w vector.
+ genWSc1_opt(w_r, w_i, wz_ii, seq_addr, L, L_sub, part, ln_pow2, scale,
+ n);
+ computeOneLayer(
+ (float*)y_in_r, (float*)y_in_i, (float*)z_in_r, (float*)z_in_i,
+ (float*)x_out1_r, (float*)x_out1_i, (float*)w_r, (float*)w_i,
+ (float*)wz_rr, (float*)wz_ri, (float*)wz_ir, (float*)wz_ii, fft_flag,
+ L_sub, part, pow_2_m_half, layer_num, ln, ln_pow2);
+
+ // In order to avoid the data movement, the addr of input and output are
+ // exchanged here.
+ YT* tmp_y_r = y_in_r;
+ YT* tmp_y_i = y_in_i;
+ YT* tmp_z_r = z_in_r;
+ YT* tmp_z_i = z_in_i;
+ y_in_r = x_out1_r;
+ y_in_i = x_out1_i;
+ z_in_r = x_out2_r;
+ z_in_i = x_out2_i;
+ x_out1_r = tmp_y_r;
+ x_out1_i = tmp_y_i;
+ x_out2_r = tmp_z_r;
+ x_out2_i = tmp_z_i;
+ }
+
+ if (fft_flag != IRFFT) {
+ // Iranspose to the output save data format: the real and imag parts are at
+ // the lowest dimention: [c, 2^M * L_sub] -> [2^M * L_sub, c]
+ __bang_transpose(x_out1_r, y_in_r, pow_2_m * 2, L_sub);
+ __bang_transpose(y_in_r, x_out1_r, L_sub * 2, pow_2_m);
+ }
+ if (scale_factor != 1.0) {
+ __bang_mul_scalar(y_in_r, y_in_r, scale_factor, L_sub * 2 * pow_2_m);
+ }
+ if (sizeof(DT) == sizeof(half)) {
+ __mluop_float2half((half*)y_in_r, (float*)y_in_r, L_sub * 2 * pow_2_m);
+ }
+}
+
+// Store the calculation result to output. The difference between RFFT, IRFFT
+// and FFT_IFFT can see in the description of "load()" function.
+template
+__mlu_func__ void store(DT* output, DT* y_in_r, DT* x_out1_r,
+ const int& pow_2_m, const int& pow_2_m_half,
+ const int& m, const int& L, const int& L_sub,
+ const int& part_num, const int& n, const int& out_n,
+ const int& batch_x_part, const int& batch,
+ const int& fft_flag, const int& ping_pong) {
+#if L_FIRST
+ int b = batch_x_part / part_num;
+ int part = batch_x_part % part_num;
+#else
+ int part = batch_x_part / batch;
+ int b = batch_x_part % batch;
+#endif
+ int pingpong_offset = batch_x_part % 2 == 0 ? 0 : ping_pong;
+ int L_re = L % L_sub;
+ int L_deal = part < (part_num - 1) ? L_sub : (L_re != 0 ? L_re : L_sub);
+ int dst_offset = part * L_sub * 2;
+ DT* out_nram = m % 2 == 0 ? y_in_r : x_out1_r;
+ out_nram += pingpong_offset;
+ if (fft_flag == RFFT) {
+ int output_block = pow_2_m_half - 1;
+ __memcpy_async(output + dst_offset + b * out_n * 2, out_nram,
+ L_deal * sizeof(DT) * 2, NRAM2GDRAM, L * sizeof(DT) * 2,
+ L_sub * sizeof(DT) * 2, output_block);
+ if (part == 0) {
+ int dst_one_point_offset = b * out_n * 2 + n;
+ int src_one_point_offset = pow_2_m * L_sub;
+ *(output + dst_one_point_offset) = *(out_nram + src_one_point_offset);
+ *(output + dst_one_point_offset + 1) =
+ *(out_nram + src_one_point_offset + 1);
+ }
+ } else if (fft_flag == IRFFT) {
+ int dst_offset = part * L_sub;
+ int output_block = pow_2_m - 1;
+ __memcpy_async(output + dst_offset + b * out_n, out_nram,
+ L_deal * sizeof(DT), NRAM2GDRAM, L * sizeof(DT),
+ L_sub * sizeof(DT), output_block);
+
+ } else if (fft_flag == FFT_IFFT) {
+ int output_block = pow_2_m - 1;
+ __memcpy_async(output + dst_offset + b * out_n * 2, out_nram,
+ L_deal * sizeof(DT) * 2, NRAM2GDRAM, L * sizeof(DT) * 2,
+ L_sub * sizeof(DT) * 2, output_block);
+ }
+}
+
+// Generate an incremental sequence acorrding to the following rules:
+// 1. the sequence length is L_sub*pow_2_m_half, means pow_2_m_half groups,
+// each group has L_sub
+// numbers.
+// 2. the init_value of each group are 0, L, L*2, ..., L*pow_2_m_half.
+// For FFT algorithm, a step called many times is vector operation: W * Z, where
+// compute W requires two steps:
+// 1. generate an incermental sequence.
+// 2. perform sin && cos operation with scale on the incermental sequence.
+// where, the sequence generated by step1 can be reused. Therefore, we save it
+// in seq_addr.
+__mlu_func__ void generateIncSequence(float* seq_addr, float* tmp_addr, int L,
+ int L_sub, int pow_2_m_half) {
+ __mluop_get_indices((float*)seq_addr, (float)0.0,
+ PAD_UP(L_sub, NFU_ALIGN_SIZE));
+ // reduce call times of "__mluop_get_indices", which time is longer, by
+ // using "for loop" and
+ // "__bang_add_scalar".
+ for (size_t i = 1; i < pow_2_m_half; i++) {
+ int offset = i * L_sub;
+ int init_value = i * L;
+ __bang_add_scalar((float*)seq_addr + offset, (float*)seq_addr, init_value,
+ L_sub);
+ }
+}
+
+// Onchip iterative calculation of Stockham algorithm. It is divided into three
+// steps:
+// 1. Load input data. RFFT, IRFFT and FFT_IFFT are processed different
+// because of different data
+// characteristics. See the "load()" function for details.
+// 2. Compute data. Before the calculation, the data is put into a suitable
+// format through
+// "compute stream transpose", and then, the calculation is carried out
+// layer by layer according to the Stockham rules. Finally, through
+// "transpose", the real and imag parts that were calculated separately
+// are mixed. See the "compute()" function for details. (In order to
+// ensure the accuracy, the HALF type is calculated with a bit width
+// increase processing: HALF->FLOAT)
+// 3. Store output data. See the "store()" function for details.
+template
+__mlu_func__ void computeMutiLayerOnchip(
+ const AddrNode& addr, DT* matmul_re_mul_re_addr, DT* output,
+ DT* seq_addr, int batch, int n, int m, int L, int fft_flag, int direction,
+ int op_size_align_via_L_dt, int pow_2_m, int pow_2_m_half, int L_sub,
+ const float scale_factor, int ping_pong) {
+ // Generate an incremental sequence
+ generateIncSequence((float*)seq_addr, (float*)addr.y_in_r, L, L_sub,
+ pow_2_m_half);
+ // Calculate the fixed part of W scale.
+ float scale = M_PI / L;
+ scale *=
+ (fft_flag == RFFT || (fft_flag == FFT_IFFT && direction != FFT_INVERSE))
+ ? -1
+ : 1;
+ // When RFFT, using conjugate symmetry, do "BatchMatmulBcast" only on half of
+ // the data, so the input n also becames half. int in_n = fft_flag ==
+ // RFFT ? int(PAD_UP(L, L_sub)/2 + 1) * pow_2_m : n;
+ int in_n = fft_flag == RFFT ? int(PAD_UP(L / 2, L_sub) + 1) * pow_2_m : n;
+ in_n = L <= L_sub ? n : in_n;
+ // The obtain of out_n is the same as in_n, the difference is that no
+ // alignment is performed.
+ int out_n = fft_flag == RFFT ? n / 2 + 1 : n;
+ // Input_size = batch * L * powf(2, m), NRAM can deal at least one "powf(2,
+ // m)" at a time. Split "batch" and "L" between multi-core. "batch" processes
+ // one at a time. Addording to the limit of NRAM size, "L" can be splitted
+ // into "part_num" parts.
+ int part_num = (L / L_sub) + (L % L_sub > 0 ? 1 : 0);
+ // "total_num" blocks need to be processed.
+ int total_num = part_num * batch;
+ int repeat_num = total_num / taskDim;
+ int remain_num = total_num % taskDim;
+ if (repeat_num > 0 || taskId < remain_num) {
+ // Each core needs to process "t_len" blocks, "remain_num" is evenly
+ // assigned to the previous "remian_num" cores.
+ int t_len = repeat_num + ((remain_num > 0 && taskId < remain_num) ? 1 : 0);
+ // Calculate the offset of the block at each core.
+ int t_start = taskId - remain_num <= 0
+ ? taskId * (repeat_num + 1)
+ : (remain_num * (repeat_num + 1) +
+ (taskId - remain_num) * repeat_num);
+ int t_end = (t_start + t_len);
+ MLULOG("taskId: %d, taskDim: %d\n", taskId, taskDim);
+ MLULOG(
+ "scale: %d, in_n: %d, out_n: %d, part_num: %d, total_num: %d, "
+ "repeat_num: %d, "
+ "remain_num: %d, t_len: %d, t_start: %d, t_end: %d\n",
+ scale, in_n, out_n, part_num, total_num, repeat_num, remain_num, t_len,
+ t_start, t_end);
+
+ // Exectue three-stage pipeline operation(load: GDRAM2NRAM, compute, store:
+ // NRAM2GDRAM) as follows: L1
+ // -----------------sync
+ // C1 L2
+ // -----------------sync
+ // S1 C2 L3
+ // -----------------sync
+ // S2 C3
+ // -----------------sync
+ // S3
+ // ...
+ for (int t = t_start; t < t_end + 2; t++) {
+ // Store output data.
+ if (t >= t_start + 2) {
+ store(output, addr.y_in_r, addr.x_out1_r, pow_2_m, pow_2_m_half, m, L,
+ L_sub, part_num, n, out_n, t - 2, batch, fft_flag, ping_pong);
+ }
+ // Compute data layer by layer according to the Stockham rules.
+ if (t >= t_start + 1 && t < t_end + 1) {
+ compute(
+ (float*)addr.y_in_r, (float*)addr.y_in_i, (float*)addr.z_in_r,
+ (float*)addr.z_in_i, (float*)addr.x_out1_r, (float*)addr.x_out1_i,
+ (float*)addr.x_out2_r, (float*)addr.x_out2_i, (float*)addr.w_r,
+ (float*)addr.w_i, (float*)addr.wz_rr, (float*)addr.wz_ri,
+ (float*)addr.wz_ir, (float*)addr.wz_ii, (float*)seq_addr, fft_flag,
+ direction, n, L, L_sub, part_num, pow_2_m, pow_2_m_half, m,
+ op_size_align_via_L_dt, scale, scale_factor, t - 1, batch,
+ ping_pong);
+ }
+ // Load input data.
+ if (t < t_end) {
+ load(addr.y_in_r, addr.y_in_i, addr.z_in_r, addr.z_in_i, addr.x_out1_r,
+ addr.x_out1_i, addr.x_out2_r, addr.x_out2_i, addr.wz_rr,
+ addr.wz_ir, matmul_re_mul_re_addr, in_n, L, L_sub, part_num,
+ pow_2_m, pow_2_m_half, t, fft_flag, batch, op_size_align_via_L_dt,
+ ping_pong);
+ }
+ __sync();
+ }
+ }
+}
+
+// Divide the space size and call the onchip iterative calculation of Stockham
+// algorithm.
+template
+__mlu_func__ void fftStockham(DT* matmul_re_mul_re_addr, DT* output,
+ int fft_flag, int direction, int n, int batch,
+ int L, int m, int L_sub,
+ const float scale_factor) {
+ MLULOG(
+ "batch: %d, n: %d, l: %d, m: %d, L_sub: %d, fft_flag: %d, direction: "
+ "%d\n",
+ batch, n, L, m, L_sub, fft_flag, direction);
+ int pow_2_m = powf(2, m);
+ // Number of L_sub processed by a src input at a time.
+ int pow_2_m_half = pow_2_m / 2;
+ // Double the number of inverval points in half type, because the bit width
+ // lifing processing is required to ensure the accuracy.
+ int half_multiplier = sizeof(DT) == sizeof(half) ? 2 : 1;
+ // The length of an float input vector, such as "z_in_r" in "w_in_r * z_in_r"
+ // below.
+ int op_size_align_via_L_dt = pow_2_m_half * L_sub * half_multiplier;
+
+ // NRAM Addr Info: "_r" represents the real part of the complex vector, "_i"
+ // represents the imag part of the complex vector. The complex process is as
+ // follows:
+ // x_out1 = y_in + w * z_in
+ // x_out2 = y_in - w * z_in
+ AddrNode addr;
+
+ // Input vector addr.
+ addr.y_in_r = (DT*)nram_buffer;
+ addr.z_in_r = addr.y_in_r + op_size_align_via_L_dt;
+ addr.y_in_i = addr.z_in_r + op_size_align_via_L_dt;
+ addr.z_in_i = addr.y_in_i + op_size_align_via_L_dt;
+
+ // Output vector addr.
+ addr.x_out1_r = addr.z_in_i + op_size_align_via_L_dt;
+ addr.x_out2_r = addr.x_out1_r + op_size_align_via_L_dt;
+ addr.x_out1_i = addr.x_out2_r + op_size_align_via_L_dt;
+ addr.x_out2_i = addr.x_out1_i + op_size_align_via_L_dt;
+
+ // W vector addr.
+ addr.w_r = addr.x_out2_i + op_size_align_via_L_dt;
+ addr.w_i = addr.w_r + op_size_align_via_L_dt;
+ addr.wz_rr = addr.w_i + op_size_align_via_L_dt;
+ addr.wz_ri = addr.wz_rr + op_size_align_via_L_dt;
+ addr.wz_ir = addr.wz_ri + op_size_align_via_L_dt;
+ addr.wz_ii = addr.wz_ir + op_size_align_via_L_dt;
+
+ // From "addr.y_in_r" to "addr.wz_ii", each ping_pong needs 14 spaces for
+ // three-stage pipeline operation.
+ int ping_pong = op_size_align_via_L_dt * 14;
+ // The public space stores the incremental sequence shared by ping_pong.
+ DT* seq_addr = (DT*)nram_buffer + ping_pong * 2;
+
+ computeMutiLayerOnchip(addr, matmul_re_mul_re_addr, output, seq_addr, batch,
+ n, m, L, fft_flag, direction, op_size_align_via_L_dt,
+ pow_2_m, pow_2_m_half, L_sub, scale_factor, ping_pong);
+}
+
+__mlu_global__ void MLUKernelFFTStockham(void* matmul_re_mul_re_addr,
+ void* output, int fft_flag,
+ int direction, int n, int batch, int L,
+ int m, int L_sub, int dtype_size,
+ const float scale_factor) {
+ if (__is_mpu()) return;
+ switch (dtype_size) {
+ default: {
+ MLULOG("mluOpFFT Not Implemented.");
+ }
+ case (MLUOP_DTYPE_COMPLEX_FLOAT):
+ case (MLUOP_DTYPE_FLOAT): {
+ MLULOG("MLUOP_DTYPE_COMPLEX_FLOAT: MLUOP_DTYPE_FLOAT\n");
+ fftStockham((float*)matmul_re_mul_re_addr, (float*)output,
+ fft_flag, direction, n, batch, L, m, L_sub,
+ scale_factor);
+ }; break;
+ case (MLUOP_DTYPE_COMPLEX_HALF):
+ case (MLUOP_DTYPE_HALF): {
+ MLULOG("MLUOP_DTYPE_COMPLEX_HALF: MLUOP_DTYPE_HALF\n");
+ fftStockham((half*)matmul_re_mul_re_addr, (half*)output, fft_flag,
+ direction, n, batch, L, m, L_sub, scale_factor);
+ }; break;
+ }
+}
+
+mluOpStatus_t MLUOP_WIN_API
+kernelFFTStockham(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
+ cnrtQueue_t queue, mluOpFFTPlan_t fft_plan, int direction,
+ const float scale_factor, FFTFlag flag) {
+ VLOG(5) << "Launch Kernel MLUKernelFFTStockham<>>";
+ KERNEL_CHECK((MLUKernelFFTStockham<<>>(
+ fft_plan->matmul_addrs.matmul_re_mul_re_addr,
+ fft_plan->matmul_addrs.output_contiguous_addr, flag,
+ direction, // direction, -1 means invalid(only FFT_IFFT use).
+ fft_plan->n[0], fft_plan->batch, fft_plan->L, fft_plan->m,
+ fft_plan->L_sub, fft_plan->output_dtype, scale_factor)));
+ return MLUOP_STATUS_SUCCESS;
+}
diff --git a/kernels/fft/irfft/irfft.h b/kernels/fft/irfft/irfft.h
new file mode 100644
index 000000000..977cf52d2
--- /dev/null
+++ b/kernels/fft/irfft/irfft.h
@@ -0,0 +1,39 @@
+/*************************************************************************
+ * Copyright (C) [2024] by Cambricon, Inc.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a
+ * copy of this software and associated documentation files (the
+ * "Software"), to deal in the Software without restriction, including
+ * without limitation the rights to use, copy, modify, merge, publish,
+ * distribute, sublicense, and/or sell copies of the Software, and to
+ * permit persons to whom the Software is furnished to do so, subject to
+ * the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included
+ * in all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+ * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+ * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+ * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
+ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
+ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+ *************************************************************************/
+#ifndef KERNELS_FFT_IRFFT_IRFFT_H_
+#define KERNELS_FFT_IRFFT_IRFFT_H_
+
+#include
+#include "kernels/fft/fft.h"
+
+mluOpStatus_t makeIRFFT1dPolicy(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan);
+
+mluOpStatus_t setIRFFT1dReserveArea(mluOpHandle_t handle,
+ mluOpFFTPlan_t fft_plan,
+ const std::string api);
+
+mluOpStatus_t execIRFFT1d(mluOpHandle_t handle, const mluOpFFTPlan_t fft_plan,
+ const void *input, const float scale_factor,
+ void *workspace, void *output);
+
+#endif // KERNELS_FFT_IRFFT_IRFFT_H_
diff --git a/kernels/fft/irfft/irfft_host.cpp b/kernels/fft/irfft/irfft_host.cpp
new file mode 100644
index 000000000..efebd111e
--- /dev/null
+++ b/kernels/fft/irfft/irfft_host.cpp
@@ -0,0 +1,1351 @@
+/*************************************************************************
+ * Copyright (C) [2024] by Cambricon, Inc.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a
+ * copy of this software and associated documentation files (the
+ * "Software"), to deal in the Software without restriction, including
+ * without limitation the rights to use, copy, modify, merge, publish,
+ * distribute, sublicense, and/or sell copies of the Software, and to
+ * permit persons to whom the Software is furnished to do so, subject to
+ * the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included
+ * in all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+ * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+ * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+ * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
+ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
+ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+ *************************************************************************/
+
+#include
+#include
+#include "kernels/fft/irfft/irfft.h"
+#include "kernels/fft/common/fft_common_kernels.h"
+
+static mluOpStatus_t selectIRFFT1dStrategy(mluOpHandle_t handle,
+ mluOpFFTPlan_t fft_plan) {
+ const std::string make_plan_api = "[selectIRFFT1dStrategy]";
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+ /* there are plenty of algorithms for FFT, depending on the fft length.
+ * Iterative FFT:
+ * Stockham FFT, Cooley-Tukey FFT, peaseFFT, Kron-Lambiotte FFT
+ * Recursive FFT:
+ * Recursive Cooley-Tukey FFT, Four-step FFT, Six-step FFT, Multicore FFT,
+ * SIMD short vector FFT. General FFT: chirp-Z Bluestein FFT.
+ */
+ // select Stockham FFT, Cooley-Tukey FFT or MATMUL strategy logic
+ fft_plan->fft_strategy = CNFFT_FUNC_MATMUL;
+ status = selectFFTStrategy(handle, fft_plan, make_plan_api);
+ return status;
+}
+
+/*
+ * Make the policy of IRFFT1d.
+ */
+mluOpStatus_t makeIRFFT1dPolicy(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan) {
+ std::string api = "[mluOpMakeFFTPlanMany]";
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+ INTERNAL_CHECK(
+ api, selectIRFFT1dStrategy(handle, fft_plan) == MLUOP_STATUS_SUCCESS);
+
+ mluOpDataType_t in_c_dtype = fft_plan->input_dtype;
+ mluOpDataType_t in_r_dtype = (in_c_dtype == MLUOP_DTYPE_COMPLEX_HALF)
+ ? MLUOP_DTYPE_HALF
+ : MLUOP_DTYPE_FLOAT;
+ mluOpDataType_t in_e_dtype = fft_plan->execution_dtype;
+ size_t in_c_dtype_size = mluOpDataTypeBytes(in_c_dtype);
+ size_t in_r_dtype_size = mluOpDataTypeBytes(in_r_dtype);
+ int batch = fft_plan->batch;
+ int n = fft_plan->n[0];
+
+ switch (fft_plan->fft_strategy) {
+ case CNFFT_FUNC_MATMUL: {
+ if (n > FFT_L_LIMIT) {
+ LOG(ERROR) << "[mluOpMakeFFTPlanMany]: IRFFT1d CNFFT_FUNC_MATMUL "
+ << "length > 4096 is not supported currently.";
+ return MLUOP_STATUS_NOT_SUPPORTED;
+ }
+
+ // Matmul Input : 2 * [batch, (n / 2 + 1)]
+ // Matmul Matrix : 2 * [n, (n / 2 + 1)]
+ // Matmul Result : 2 * [batch, n]
+ int dft_mat_times = COMPLEX;
+ int dim0 = n;
+ int dim1 = FFT_HALF(n);
+ int dft_mat_num = dft_mat_times * dim0 * dim1;
+
+ // reservespace size allocation
+ fft_plan->reservespace_size = 0;
+ fft_plan->reservespace_size +=
+ dft_mat_num * mluOpDataTypeBytes(in_r_dtype);
+ if (fftIsIntDtype(in_e_dtype)) {
+ fft_plan->reservespace_size += sizeof(int32_t) + sizeof(float);
+ size_t required_size = 0;
+ status = fftGetQuantizeParamWorkspaceSize(
+ handle, required_size, dft_mat_num, in_r_dtype, in_e_dtype, api);
+ fft_plan->reservespace_size += required_size;
+ }
+
+ /* CNFFT_FUNC_MATMUL :
+ -------------------------
+ | input |
+ -------------------------
+ |
+ | input contiguous
+ \|/
+ -------------------------
+ | input_contiguous |
+ -------------------------
+ |
+ | input pad
+ \|/
+ -------------------------
+ | input_pad |
+ -------------------------
+ |
+ | input trans: batch * (n / 2 + 1) * 2 --> 2 * batch * (n /
+ 2 + 1)
+ \|/
+ -------------------------
+ | input_re |
+ | input_im |
+ -------------------------
+ |
+ | matmul
+ \|/
+ -------------------------
+ | matmul_re_mul_re |
+ | matmul_im_mul_im |(reuse output_contiguous)
+ -------------------------
+ |
+ | op_tensor
+ \|/
+ -------------------------
+ | output_contiguous |
+ -------------------------
+ |
+ | output contiguous
+ \|/
+ -------------------------
+ | output |
+ -------------------------
+ */
+ // worksapce size allocation
+ fft_plan->matmul_addrs.internal_workspace_size = 0;
+ fft_plan->workspace_size = 0;
+
+ // input contiguous
+ size_t input_size = in_c_dtype_size * fft_plan->inum;
+ fft_plan->workspace_size +=
+ fft_plan->is_input_contiguous ? 0 : input_size;
+
+ // input pad
+ bool need_pad = (fft_plan->inembed[0] != FFT_HALF(n));
+ int padded_input_num = batch * FFT_HALF(n);
+ size_t padded_input_size = in_c_dtype_size * padded_input_num;
+ fft_plan->workspace_size += need_pad ? padded_input_size : 0;
+
+ // input trans and workspace
+ size_t transed_input_size = padded_input_size;
+ fft_plan->workspace_size += transed_input_size;
+ // input trans workspace: batch * (n / 2 + 1) * 2 --> 2 * batch * (n / 2 +
+ // 1)
+ const int trans_dim_num = 2;
+ int trans_input_dims[trans_dim_num] = {padded_input_num, COMPLEX};
+ int trans_permute[trans_dim_num] = {1, 0};
+ size_t trans_workspace_size = 0;
+ status = fftGetTransposeWorkspaceSize(handle, trans_workspace_size,
+ trans_dim_num, trans_input_dims,
+ trans_permute, in_r_dtype, api);
+ fft_plan->matmul_addrs.internal_workspace_size = std::max(
+ fft_plan->matmul_addrs.internal_workspace_size, trans_workspace_size);
+
+ // input quantize param and workspace
+ if (fftIsIntDtype(in_e_dtype)) {
+ fft_plan->workspace_size += sizeof(int32_t) + sizeof(float);
+ size_t input_quant_workspace_size = 0;
+ status = fftGetQuantizeParamWorkspaceSize(
+ handle, input_quant_workspace_size, COMPLEX * padded_input_num,
+ in_r_dtype, in_e_dtype, api);
+ fft_plan->matmul_addrs.internal_workspace_size =
+ std::max(fft_plan->matmul_addrs.internal_workspace_size,
+ input_quant_workspace_size);
+ }
+
+ // matmul output(reuse output_coniguous)
+ int matmul_times = COMPLEX;
+ int per_matmul_output_num = batch * n;
+ size_t per_matmul_output_size = in_r_dtype_size * per_matmul_output_num;
+ fft_plan->workspace_size += (matmul_times - 1) * per_matmul_output_size;
+ // matmul workspace
+ size_t matmul_workspace_size = 0;
+ status = fftGetQuantizeMatMulWorkspaceSize(
+ handle, matmul_workspace_size, batch, dim1, dim0, false, true,
+ in_e_dtype, in_e_dtype, in_r_dtype, api);
+ fft_plan->matmul_addrs.internal_workspace_size =
+ std::max(fft_plan->matmul_addrs.internal_workspace_size,
+ matmul_workspace_size);
+ // optensor workspace
+ size_t optensor_workspace_size = 0;
+ status =
+ fftGetOptensorWorkspaceSize(handle, optensor_workspace_size,
+ per_matmul_output_num, in_r_dtype, api);
+ fft_plan->matmul_addrs.internal_workspace_size =
+ std::max(fft_plan->matmul_addrs.internal_workspace_size,
+ optensor_workspace_size);
+
+ // output contiguous
+ size_t output_size =
+ mluOpDataTypeBytes(fft_plan->output_dtype) * fft_plan->onum;
+ fft_plan->workspace_size +=
+ fft_plan->is_output_contiguous ? 0 : output_size;
+
+ // internal_workspace
+ fft_plan->workspace_size +=
+ fft_plan->matmul_addrs.internal_workspace_size;
+ VLOG(5) << "internal workspace size: "
+ << fft_plan->matmul_addrs.internal_workspace_size;
+ VLOG(5) << "total workspace size: " << fft_plan->workspace_size;
+ }; break;
+ case CNFFT_FUNC_COOLEY_TUKEY:
+ case CNFFT_FUNC_STOCKHAM: {
+ int L = fft_plan->L;
+ int m = (1 << fft_plan->m);
+ if (L > FFT_L_LIMIT) {
+ LOG(ERROR) << "[mluOpMakeFFTPlanMany]: IRFFT1d CNFFT_FUNC_COOLEY_TUKEY "
+ << "n = L * 2^m and L > 4096 is not supported currently.";
+ return MLUOP_STATUS_NOT_SUPPORTED;
+ }
+
+ // Matmul Input : 2 * [batch, 2^m, L]
+ // Matmul Matrix : 2 * [L, L]
+ // Matmul Result : 4 * [batch, 2^m, L]
+ int dft_mat_times = COMPLEX;
+ int dim0 = L;
+ int dim1 = L;
+ int dft_mat_num = dft_mat_times * dim0 * dim1;
+
+ // reservespace size allocation
+ fft_plan->reservespace_size = 0;
+ fft_plan->reservespace_size += dft_mat_num * in_r_dtype_size;
+ if (fftIsIntDtype(in_e_dtype)) {
+ fft_plan->reservespace_size += sizeof(int32_t) + sizeof(float);
+ size_t required_size = 0;
+ status = fftGetQuantizeParamWorkspaceSize(
+ handle, required_size, dft_mat_num, in_r_dtype, in_e_dtype, api);
+ fft_plan->reservespace_size += required_size;
+ }
+
+ /* CNFFT_FUNC_COOLEY_TUKEY :
+ -------------------------
+ | input |
+ -------------------------
+ |
+ | input contiguous
+ \|/
+ -------------------------
+ | input_contiguous |
+ -------------------------
+ |
+ | input pad
+ \|/
+ -------------------------
+ | input_pad |
+ -------------------------
+ |
+ | input trans: batch * (n / 2 + 1) * 2 --> 2 * batch * (n /
+ 2 + 1)
+ \|/
+ -------------------------
+ | input_transed_re |
+ | input_transed_im |
+ -------------------------
+ |
+ | stridedslice
+ | optensor(im mul -1)
+ \|/
+ -------------------------
+ | input_reversed_re |
+ | input_reversed_im |
+ -------------------------
+ |
+ | concat
+ \|/
+ -------------------------
+ | input_merged_re |
+ | input_merged_im |
+ -------------------------
+ |
+ | input trans: 2 * batch * L * 2^m --> 2 * batch * 2^m * L
+ \|/
+ -------------------------
+ | input_re |
+ | input_im |
+ -------------------------
+ |
+ | matmul
+ | optensor(re_mul_re - im_mul_im, re_mul_im + im_mul_re)
+ \|/
+ -------------------------
+ | matmul_re_mul_re | (matmul_re)
+ | matmul_re_mul_im | (matmul_im)
+ | matmul_im_mul_re |
+ | matmul_im_mul_im |
+ -------------------------
+ |
+ | output merge
+ \|/
+ -------------------------
+ | output_contiguous |
+ -------------------------
+ |
+ | output contiguous
+ \|/
+ -------------------------
+ | output |
+ -------------------------
+ */
+ // worksapce size allocation
+ fft_plan->matmul_addrs.internal_workspace_size = 0;
+ fft_plan->workspace_size = 0;
+
+ // input contiguous
+ size_t input_size = in_c_dtype_size * fft_plan->inum;
+ fft_plan->workspace_size +=
+ fft_plan->is_input_contiguous ? 0 : input_size;
+
+ // input pad
+ bool need_pad = (fft_plan->inembed[0] != FFT_HALF(n));
+ int padded_input_num = batch * FFT_HALF(n);
+ size_t padded_input_size = in_c_dtype_size * padded_input_num;
+ fft_plan->workspace_size += need_pad ? padded_input_size : 0;
+
+ // input merge (transed_input and reversed_input reuse input_re)
+ int merged_input_num = batch * n;
+ size_t merged_input_size = in_c_dtype_size * merged_input_num;
+ fft_plan->workspace_size += merged_input_size;
+ // input merge workspace:
+ // transpose workspace: batch * (n / 2 + 1) * 2 --> 2 * batch * (n / 2 +
+ // 1) concat workspace: concat do not need workspace now
+ const int trans_1st_dim_num = 2;
+ int trans_1st_input_dims[trans_1st_dim_num] = {padded_input_num, COMPLEX};
+ int trans_1st_permute[trans_1st_dim_num] = {1, 0};
+ size_t trans_1st_workspace_size = 0;
+ status = fftGetTransposeWorkspaceSize(
+ handle, trans_1st_workspace_size, trans_1st_dim_num,
+ trans_1st_input_dims, trans_1st_permute, in_r_dtype, api);
+ fft_plan->matmul_addrs.internal_workspace_size =
+ std::max(fft_plan->matmul_addrs.internal_workspace_size,
+ trans_1st_workspace_size);
+
+ // input trans
+ int transed_input_num = batch * n;
+ size_t transed_input_size = in_c_dtype_size * transed_input_num;
+ fft_plan->workspace_size += transed_input_size;
+ // input trans workspace: 2 * batch * L * 2^m --> 2 * batch * 2^m * L
+ const int trans_2nd_dim_num = 3;
+ int trans_2nd_input_dims[trans_2nd_dim_num] = {COMPLEX * batch, L, m};
+ int trans_2nd_permute[trans_2nd_dim_num] = {0, 2, 1};
+ size_t trans_2nd_workspace_size = 0;
+ status = fftGetTransposeWorkspaceSize(
+ handle, trans_2nd_workspace_size, trans_2nd_dim_num,
+ trans_2nd_input_dims, trans_2nd_permute, in_r_dtype, api);
+ fft_plan->matmul_addrs.internal_workspace_size =
+ std::max(fft_plan->matmul_addrs.internal_workspace_size,
+ trans_2nd_workspace_size);
+
+ // input quantize param and workspace
+ if (fftIsIntDtype(in_e_dtype)) {
+ fft_plan->workspace_size += sizeof(int32_t) + sizeof(float);
+ size_t input_quant_workspace_size = 0;
+ status = fftGetQuantizeParamWorkspaceSize(
+ handle, input_quant_workspace_size, COMPLEX * padded_input_num,
+ in_r_dtype, in_e_dtype, api);
+ fft_plan->matmul_addrs.internal_workspace_size =
+ std::max(fft_plan->matmul_addrs.internal_workspace_size,
+ input_quant_workspace_size);
+ }
+
+ // matmul output
+ const int matmul_times =
+ 4; // real mul real, real mul imag, imag mul real, imag mul imag
+ int per_matmul_output_num = batch * n;
+ size_t per_matmul_output_size = in_r_dtype_size * per_matmul_output_num;
+ size_t matmul_output_size = matmul_times * per_matmul_output_size;
+ fft_plan->workspace_size += matmul_output_size;
+ // matmul workspace
+ size_t matmul_workspace_size = 0;
+ status = fftGetQuantizeMatMulWorkspaceSize(
+ handle, matmul_workspace_size, batch * m, L, L, false, true,
+ in_e_dtype, in_e_dtype, in_r_dtype, api);
+ fft_plan->matmul_addrs.internal_workspace_size =
+ std::max(fft_plan->matmul_addrs.internal_workspace_size,
+ matmul_workspace_size);
+ // optensor workspace
+ size_t optensor_workspace_size = 0;
+ status =
+ fftGetOptensorWorkspaceSize(handle, optensor_workspace_size,
+ per_matmul_output_num, in_r_dtype, api);
+ fft_plan->matmul_addrs.internal_workspace_size =
+ std::max(fft_plan->matmul_addrs.internal_workspace_size,
+ optensor_workspace_size);
+
+ // output merge workspace
+ size_t merge_workspace_size =
+ COMPLEX * in_r_dtype_size * per_matmul_output_num;
+ fft_plan->matmul_addrs.internal_workspace_size = std::max(
+ fft_plan->matmul_addrs.internal_workspace_size, merge_workspace_size);
+
+ // output contiguous
+ size_t output_size =
+ mluOpDataTypeBytes(fft_plan->output_dtype) * fft_plan->onum;
+ fft_plan->workspace_size +=
+ fft_plan->is_output_contiguous ? 0 : output_size;
+
+ // internal_workspace
+ fft_plan->workspace_size +=
+ fft_plan->matmul_addrs.internal_workspace_size;
+ VLOG(5) << "internal workspace size: "
+ << fft_plan->matmul_addrs.internal_workspace_size;
+ VLOG(5) << "total workspace size: " << fft_plan->workspace_size;
+ }; break;
+ default: {
+ status = MLUOP_STATUS_NOT_SUPPORTED;
+ return status;
+ }
+ }
+ return status;
+}
+
+static void configureIRFFT1dMatmulReserveAddrs(mluOpHandle_t handle,
+ mluOpFFTPlan_t fft_plan) {
+ size_t dft_mat_size = 0;
+ const int dft_mat_times = COMPLEX;
+
+ mluOpDataType_t in_c_dtype = fft_plan->input_dtype;
+ mluOpDataType_t in_r_dtype = (in_c_dtype == MLUOP_DTYPE_COMPLEX_HALF)
+ ? MLUOP_DTYPE_HALF
+ : MLUOP_DTYPE_FLOAT;
+ size_t in_r_dtype_size = mluOpDataTypeBytes(in_r_dtype);
+ int n = fft_plan->n[0];
+
+ switch (fft_plan->fft_strategy) {
+ case CNFFT_FUNC_MATMUL: {
+ // Matmul Matrix : 2 * [n, (n / 2 + 1)]
+ int dim0 = n;
+ int dim1 = FFT_HALF(n);
+ size_t per_dft_mat_size = dim0 * dim1 * in_r_dtype_size;
+ dft_mat_size = dft_mat_times * per_dft_mat_size;
+ fft_plan->matmul_addrs.dft_matrix_addr = fft_plan->reservespace_addr;
+ fft_plan->matmul_addrs.dft_re_matrix_addr = fft_plan->reservespace_addr;
+ fft_plan->matmul_addrs.dft_im_matrix_addr =
+ (uint8_t *)fft_plan->reservespace_addr + per_dft_mat_size;
+ }; break;
+ case CNFFT_FUNC_COOLEY_TUKEY:
+ case CNFFT_FUNC_STOCKHAM: {
+ // Matmul Matrix : 2 * [L, L]
+ int L = fft_plan->L;
+ int dim0 = L;
+ int dim1 = L;
+ size_t per_dft_mat_size = dim0 * dim1 * in_r_dtype_size;
+ dft_mat_size = dft_mat_times * per_dft_mat_size;
+ fft_plan->matmul_addrs.dft_matrix_addr = fft_plan->reservespace_addr;
+ fft_plan->matmul_addrs.dft_re_matrix_addr = fft_plan->reservespace_addr;
+ fft_plan->matmul_addrs.dft_im_matrix_addr =
+ (uint8_t *)fft_plan->reservespace_addr + per_dft_mat_size;
+ }; break;
+ default: {
+ break;
+ }
+ }
+ if (fftIsIntDtype(fft_plan->execution_dtype)) {
+ fft_plan->matmul_addrs.dft_pos_addr =
+ (uint8_t *)fft_plan->reservespace_addr + dft_mat_size;
+ fft_plan->matmul_addrs.dft_scale_addr =
+ (uint8_t *)fft_plan->matmul_addrs.dft_pos_addr + sizeof(int32_t);
+ fft_plan->matmul_addrs.dft_quantize_workspace_addr =
+ (uint8_t *)fft_plan->matmul_addrs.dft_scale_addr + sizeof(float);
+ fft_plan->matmul_addrs.dft_quantize_workspace_size =
+ fft_plan->reservespace_size - dft_mat_size - sizeof(int32_t) -
+ sizeof(float);
+ } else {
+ fft_plan->matmul_addrs.dft_pos_addr = nullptr;
+ fft_plan->matmul_addrs.dft_scale_addr = nullptr;
+ fft_plan->matmul_addrs.dft_quantize_workspace_addr = nullptr;
+ fft_plan->matmul_addrs.dft_quantize_workspace_size = 0;
+ }
+}
+
+mluOpStatus_t setIRFFT1dReserveArea(mluOpHandle_t handle,
+ mluOpFFTPlan_t fft_plan,
+ const std::string api) {
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+ configureIRFFT1dMatmulReserveAddrs(handle, fft_plan);
+
+ mluOpDataType_t in_c_dtype = fft_plan->input_dtype;
+ mluOpDataType_t in_r_dtype = (in_c_dtype == MLUOP_DTYPE_COMPLEX_HALF)
+ ? MLUOP_DTYPE_HALF
+ : MLUOP_DTYPE_FLOAT;
+ mluOpDataType_t in_e_dtype = fft_plan->execution_dtype;
+ int n = fft_plan->n[0];
+ const int dft_mat_times = COMPLEX;
+
+ const unsigned int cluster_number =
+ mluop::runtime::getClusterLimitCapability(handle);
+ const unsigned int core_dim = handle->core_num_per_cluster;
+ cnrtDim3_t k_dim = {core_dim, cluster_number, 1};
+ cnrtFunctionType_t k_type = CNRT_FUNC_TYPE_BLOCK;
+
+ switch (fft_plan->fft_strategy) {
+ case CNFFT_FUNC_MATMUL: {
+ // Matmul Matrix : 2 * [n, (n / 2 + 1)]
+ int dim0 = n;
+ int dim1 = (n / 2 + 1);
+ int dft_mat_num = dft_mat_times * dim0 * dim1;
+ kernelGenerateIRFFTHalfDFTMatrix(k_dim, k_type, handle->queue, fft_plan,
+ in_r_dtype, n);
+ status = fftQuantizePositionScale(
+ handle, dft_mat_num, in_r_dtype, in_e_dtype,
+ fft_plan->matmul_addrs.dft_matrix_addr,
+ fft_plan->matmul_addrs.dft_pos_addr,
+ fft_plan->matmul_addrs.dft_scale_addr,
+ fft_plan->matmul_addrs.dft_quantize_workspace_addr,
+ fft_plan->matmul_addrs.dft_quantize_workspace_size, api);
+ INTERNAL_CHECK("[mluOpSetFFTReserveArea]",
+ status == MLUOP_STATUS_SUCCESS);
+ }; break;
+ case CNFFT_FUNC_COOLEY_TUKEY:
+ case CNFFT_FUNC_STOCKHAM: {
+ // Matmul Matrix : 2 * [L, L]
+ int L = fft_plan->L;
+ int dim0 = L;
+ int dim1 = L;
+ int dft_mat_num = dft_mat_times * dim0 * dim1;
+ kernelGenerateIRFFTFullDFTMatrix(k_dim, k_type, handle->queue, fft_plan,
+ in_r_dtype, L);
+
+ status = fftQuantizePositionScale(
+ handle, dft_mat_num, in_r_dtype, in_e_dtype,
+ fft_plan->matmul_addrs.dft_matrix_addr,
+ fft_plan->matmul_addrs.dft_pos_addr,
+ fft_plan->matmul_addrs.dft_scale_addr,
+ fft_plan->matmul_addrs.dft_quantize_workspace_addr,
+ fft_plan->matmul_addrs.dft_quantize_workspace_size, api);
+ INTERNAL_CHECK("[mluOpSetFFTReserveArea]",
+ status == MLUOP_STATUS_SUCCESS);
+ }; break;
+ default: {
+ status = MLUOP_STATUS_NOT_SUPPORTED;
+ }
+ }
+ return status;
+}
+
+static void configureIRFFT1dMatmulWorkspaceAddrs(mluOpHandle_t handle,
+ mluOpFFTPlan_t fft_plan,
+ void *input, void *workspace,
+ void *output) {
+ VLOG(5) << "Into configure IRFFT1d Matmul Workspace Addrs";
+ size_t workspace_cur_offset = 0;
+ size_t workspace_cur_offset_to_end = 0;
+ size_t workspace_total_size = fft_plan->workspace_size;
+ void *workspace_end = (uint8_t *)workspace + workspace_total_size;
+
+ mluOpDataType_t in_c_dtype = fft_plan->input_dtype;
+ mluOpDataType_t in_r_dtype = (in_c_dtype == MLUOP_DTYPE_COMPLEX_HALF)
+ ? MLUOP_DTYPE_HALF
+ : MLUOP_DTYPE_FLOAT;
+ mluOpDataType_t in_e_dtype = fft_plan->execution_dtype;
+ size_t in_c_dtype_size = mluOpDataTypeBytes(in_c_dtype);
+ size_t in_r_dtype_size = mluOpDataTypeBytes(in_r_dtype);
+ int batch = fft_plan->batch;
+ int n = fft_plan->n[0];
+
+ // input contiguous
+ size_t input_size = in_c_dtype_size * fft_plan->inum;
+ if (!fft_plan->is_input_contiguous) {
+ fft_plan->matmul_addrs.input_contiguous_addr =
+ (uint8_t *)workspace + workspace_cur_offset;
+ workspace_cur_offset += input_size;
+ } else {
+ fft_plan->matmul_addrs.input_contiguous_addr = input;
+ }
+
+ // input pad
+ bool need_pad = (fft_plan->inembed[0] != FFT_HALF(n));
+ int padded_input_num = batch * FFT_HALF(n);
+ size_t padded_input_size = in_c_dtype_size * padded_input_num;
+ if (need_pad) {
+ fft_plan->matmul_addrs.input_pad_addr =
+ (uint8_t *)workspace + workspace_cur_offset;
+ workspace_cur_offset += padded_input_size;
+ } else {
+ fft_plan->matmul_addrs.input_pad_addr =
+ fft_plan->matmul_addrs.input_contiguous_addr;
+ }
+
+ if (fft_plan->fft_strategy == CNFFT_FUNC_MATMUL) {
+ // input trans: batch * (n / 2 + 1) * 2 --> 2 * batch * (n / 2 + 1)
+ size_t transed_input_size = padded_input_size;
+ fft_plan->matmul_addrs.input_re_addr =
+ (uint8_t *)workspace + workspace_cur_offset;
+ fft_plan->matmul_addrs.input_im_addr =
+ (uint8_t *)fft_plan->matmul_addrs.input_re_addr +
+ transed_input_size / COMPLEX;
+ workspace_cur_offset += transed_input_size;
+ } else if (fft_plan->fft_strategy == CNFFT_FUNC_COOLEY_TUKEY ||
+ fft_plan->fft_strategy == CNFFT_FUNC_STOCKHAM) {
+ // input merge (transed_input and reversed_input reuse input_re)
+ // 1st input trans: batch * (n / 2 + 1) * 2 --> 2 * batch * (n / 2 + 1)
+ size_t transed_1st_input_size = padded_input_size;
+ fft_plan->matmul_addrs.input_transed_addr =
+ (uint8_t *)workspace + workspace_cur_offset;
+ workspace_cur_offset += transed_1st_input_size;
+
+ // input reverse(stridedslice)
+ int reversed_input_num = batch * (n - FFT_HALF(n));
+ size_t reversed_input_size = in_c_dtype_size * reversed_input_num;
+ fft_plan->matmul_addrs.input_reversed_addr =
+ (uint8_t *)workspace + workspace_cur_offset;
+ workspace_cur_offset += reversed_input_size;
+
+ // input merge
+ int merged_input_num = batch * n;
+ size_t merged_input_size = in_c_dtype_size * merged_input_num;
+ fft_plan->matmul_addrs.input_merged_addr =
+ (uint8_t *)workspace + workspace_cur_offset;
+ workspace_cur_offset += merged_input_size;
+
+ // input_re reuse transed_input and reversed_input
+ // 2nd input trans: 2 * batch * L * 2^m --> 2 * batch * 2^m * L
+ size_t transed_2nd_input_size = merged_input_size;
+ fft_plan->matmul_addrs.input_re_addr =
+ (uint8_t *)fft_plan->matmul_addrs.input_transed_addr;
+ fft_plan->matmul_addrs.input_im_addr =
+ (uint8_t *)fft_plan->matmul_addrs.input_re_addr +
+ transed_2nd_input_size / COMPLEX;
+ }
+
+ // input quantize
+ if (fftIsIntDtype(in_e_dtype)) {
+ fft_plan->matmul_addrs.input_pos_addr =
+ (uint8_t *)workspace + workspace_cur_offset;
+ workspace_cur_offset += sizeof(int32_t);
+ fft_plan->matmul_addrs.input_scale_addr =
+ (uint8_t *)workspace + workspace_cur_offset;
+ workspace_cur_offset += sizeof(float);
+ } else {
+ fft_plan->matmul_addrs.input_pos_addr = nullptr;
+ fft_plan->matmul_addrs.input_scale_addr = nullptr;
+ }
+
+ // internal workspace
+ workspace_cur_offset_to_end += fft_plan->matmul_addrs.internal_workspace_size;
+ fft_plan->matmul_addrs.internal_workspace_addr =
+ (uint8_t *)workspace_end - workspace_cur_offset_to_end;
+
+ // output contiguous
+ size_t output_size =
+ mluOpDataTypeBytes(fft_plan->output_dtype) * fft_plan->onum;
+ if (!fft_plan->is_output_contiguous) {
+ workspace_cur_offset_to_end += output_size;
+ fft_plan->matmul_addrs.output_contiguous_addr =
+ (uint8_t *)workspace_end - workspace_cur_offset_to_end;
+ } else {
+ fft_plan->matmul_addrs.output_contiguous_addr = output;
+ }
+
+ // matmul output
+ int per_matmul_output_num = batch * n;
+ size_t per_matmul_output_size = in_r_dtype_size * per_matmul_output_num;
+ if (fft_plan->fft_strategy == CNFFT_FUNC_MATMUL) {
+ // matmut_im_mul_im reuse output_coniguous
+ fft_plan->matmul_addrs.matmul_im_mul_im_addr =
+ fft_plan->matmul_addrs.output_contiguous_addr;
+ workspace_cur_offset_to_end += per_matmul_output_size;
+ fft_plan->matmul_addrs.matmul_re_mul_re_addr =
+ (uint8_t *)workspace_end - workspace_cur_offset_to_end;
+ } else if (fft_plan->fft_strategy == CNFFT_FUNC_COOLEY_TUKEY ||
+ fft_plan->fft_strategy == CNFFT_FUNC_STOCKHAM) {
+ workspace_cur_offset_to_end += per_matmul_output_size;
+ fft_plan->matmul_addrs.matmul_im_mul_im_addr =
+ (uint8_t *)workspace_end - workspace_cur_offset_to_end;
+ workspace_cur_offset_to_end += per_matmul_output_size;
+ fft_plan->matmul_addrs.matmul_im_mul_re_addr =
+ (uint8_t *)workspace_end - workspace_cur_offset_to_end;
+ workspace_cur_offset_to_end += per_matmul_output_size;
+ fft_plan->matmul_addrs.matmul_re_mul_im_addr =
+ (uint8_t *)workspace_end - workspace_cur_offset_to_end;
+ workspace_cur_offset_to_end += per_matmul_output_size;
+ fft_plan->matmul_addrs.matmul_re_mul_re_addr =
+ (uint8_t *)workspace_end - workspace_cur_offset_to_end;
+ }
+}
+
+// input : in input
+// output : in input_contiguous_addr
+static mluOpStatus_t makeIRFFT1dContiguousInput(mluOpHandle_t handle,
+ mluOpFFTPlan_t fft_plan,
+ const void *input) {
+ std::string api = "[mluOpExecFFT]";
+ VLOG(5) << "into makeIRFFT1dContiguousInput";
+ auto status = MLUOP_STATUS_SUCCESS;
+ if (!fft_plan->is_input_contiguous) {
+ VLOG(5) << "launch mluOpContiguous for irfft1d input";
+ mluOpTensorDescriptor_t input_desc;
+ status = mluOpCreateTensorDescriptor(&input_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ const int in_dim_num = 2;
+ int64_t dims[in_dim_num] = {fft_plan->batch, fft_plan->inembed[0]};
+ int64_t strides[in_dim_num] = {fft_plan->idist, fft_plan->istride};
+ status = mluOpSetTensorDescriptorEx_v2(input_desc, MLUOP_LAYOUT_ARRAY,
+ fft_plan->input_dtype, in_dim_num,
+ dims, strides);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ status = mluOpContiguous(handle, input_desc, input,
+ fft_plan->matmul_addrs.input_contiguous_addr);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ status = mluOpDestroyTensorDescriptor(input_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ }
+ return status;
+}
+
+// input : in input_contiguous_addr
+// output : in input_pad_addr
+static mluOpStatus_t padIRFFT1dContiguousInput(mluOpHandle_t handle,
+ mluOpFFTPlan_t fft_plan) {
+ std::string api = "[mluOpExecFFT]";
+ VLOG(5) << "into padIRFFT1dContiguousInput";
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+
+ mluOpDataType_t in_c_dtype = fft_plan->input_dtype;
+ mluOpDataType_t in_r_dtype = (in_c_dtype == MLUOP_DTYPE_COMPLEX_HALF)
+ ? MLUOP_DTYPE_HALF
+ : MLUOP_DTYPE_FLOAT;
+ int batch = fft_plan->batch;
+ int n = fft_plan->n[0];
+ bool need_pad = (fft_plan->inembed[0] != FFT_HALF(n));
+ if (need_pad) {
+ VLOG(5) << "launch cnnlOpPad for input pad";
+ mluOpTensorDescriptor_t input_desc, padded_input_desc;
+ status = mluOpCreateTensorDescriptor(&input_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ status = mluOpCreateTensorDescriptor(&padded_input_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ const int in_dim_num = 2;
+ int64_t dims[in_dim_num] = {batch, fft_plan->inembed[0] * COMPLEX};
+ status = mluOpSetTensorDescriptor_v2(input_desc, MLUOP_LAYOUT_ARRAY,
+ in_r_dtype, in_dim_num, dims);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ int64_t padded_dims[in_dim_num] = {batch, FFT_HALF(n) * COMPLEX};
+ status = mluOpSetTensorDescriptor_v2(padded_input_desc, MLUOP_LAYOUT_ARRAY,
+ in_r_dtype, in_dim_num, padded_dims);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ const int pad_dim_num = 4;
+ int paddings[pad_dim_num] = {
+ 0, 0, 0, (FFT_HALF(n) - fft_plan->inembed[0]) * COMPLEX};
+ uint64_t padding_value = 0x00000000;
+
+ DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle,
+ cnnl_handle); // convert to cnnl_handle
+
+ // convert to cnnl_tensor_descriptor
+ DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(input_desc, cnnl_input_desc);
+ DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(padded_input_desc,
+ cnnl_padded_input_desc);
+ CALL_CNNL(cnnlPad(cnnl_handle, cnnl_input_desc,
+ fft_plan->matmul_addrs.input_contiguous_addr, paddings,
+ &padding_value, cnnl_padded_input_desc,
+ fft_plan->matmul_addrs.input_pad_addr));
+
+ // destroy cnnl descriptor
+ VLOG(5) << "irfft cnnlOpPad end";
+ DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_input_desc);
+ DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_padded_input_desc);
+
+ DESTROY_CNNL_HANDLE(cnnl_handle);
+ }
+ return status;
+}
+
+/* only for CNFFT_FUNC_COOLEY_TUKEY:
+ -------------------------
+ | input_pad |
+ -------------------------
+ |
+ | input trans: batch * (n / 2 + 1) * 2 --> 2 * batch * (n /
+ 2 + 1)
+ \|/
+ -------------------------
+ | input_transed_re |
+ | input_transed_im |
+ -------------------------
+ |
+ | stridedslice
+ | optensor(im mul -1)
+ \|/
+ -------------------------
+ | input_reversed_re |
+ | input_reversed_im |
+ -------------------------
+ |
+ | concat
+ \|/
+ -------------------------
+ | input_merged_re |
+ | input_merged_im |
+ -------------------------
+*/
+static mluOpStatus_t mergeIRFFT1dInput(mluOpHandle_t handle,
+ mluOpFFTPlan_t fft_plan) {
+ std::string api = "[mluOpExecFFT]";
+ VLOG(5) << "into mergeIRFFT1dInput";
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+ if (fft_plan->fft_strategy == CNFFT_FUNC_COOLEY_TUKEY ||
+ fft_plan->fft_strategy == CNFFT_FUNC_STOCKHAM) {
+ mluOpDataType_t in_c_dtype = fft_plan->input_dtype;
+ mluOpDataType_t in_r_dtype = (in_c_dtype == MLUOP_DTYPE_COMPLEX_HALF)
+ ? MLUOP_DTYPE_HALF
+ : MLUOP_DTYPE_FLOAT;
+ size_t in_r_dtype_size = mluOpDataTypeBytes(in_r_dtype);
+ int batch = fft_plan->batch;
+ int n = fft_plan->n[0];
+
+ // 1st transpose: batch * (n / 2 + 1) * 2 --> 2 * batch * (n / 2 + 1)
+ VLOG(5) << "launch mluOpTranspose for input";
+ int padded_input_num = batch * FFT_HALF(n);
+ const int trans_dim_num = 2;
+ int trans_input_dims[trans_dim_num] = {padded_input_num, COMPLEX};
+ int trans_output_dims[trans_dim_num] = {COMPLEX, padded_input_num};
+ int trans_permute[trans_dim_num] = {1, 0};
+
+ status =
+ fftTranspose(handle, trans_dim_num, trans_input_dims, trans_output_dims,
+ trans_permute, fft_plan->matmul_addrs.input_pad_addr,
+ fft_plan->matmul_addrs.input_transed_addr, in_r_dtype,
+ fft_plan->matmul_addrs.internal_workspace_addr,
+ fft_plan->matmul_addrs.internal_workspace_size, api);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ // stridedslice: [a, b, c, d] --> [d, c, b]
+ // stridedslice: [a, b, c, d, e] --> [d, c, b]
+ VLOG(5) << "launch mluOpStridedSlice for input";
+ mluOpTensorDescriptor_t ss_input_desc, ss_output_desc;
+ status = mluOpCreateTensorDescriptor(&ss_input_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ status = mluOpCreateTensorDescriptor(&ss_output_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ const int ss_dim_num = 2;
+ int64_t ss_in_dims[ss_dim_num] = {COMPLEX * batch, FFT_HALF(n)};
+ status = mluOpSetTensorDescriptor_v2(ss_input_desc, MLUOP_LAYOUT_ARRAY,
+ in_r_dtype, ss_dim_num, ss_in_dims);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ int64_t ss_out_dims[ss_dim_num] = {COMPLEX * batch, (n - FFT_HALF(n))};
+ status = mluOpSetTensorDescriptor_v2(ss_output_desc, MLUOP_LAYOUT_ARRAY,
+ in_r_dtype, ss_dim_num, ss_out_dims);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ int dim1_begin = (n % 2) ? -1 : -2;
+ int dim1_end = -FFT_HALF(n);
+ int begin[ss_dim_num] = {0, dim1_begin};
+ int end[ss_dim_num] = {COMPLEX * batch, dim1_end};
+ int stride[ss_dim_num] = {1, -1};
+
+ void *ss_input_addr = fft_plan->matmul_addrs.input_transed_addr;
+ void *ss_output_addr = fft_plan->matmul_addrs.input_reversed_addr;
+
+ DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle,
+ cnnl_handle); // convert to cnnl_handle
+
+ // convert to cnnl_tensor_descriptor
+ DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(ss_input_desc,
+ cnnl_ss_input_desc);
+ DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(ss_output_desc,
+ cnnl_ss_output_desc);
+ CALL_CNNL(cnnlStridedSlice(cnnl_handle, cnnl_ss_input_desc, ss_input_addr,
+ begin, end, stride, cnnl_ss_output_desc,
+ ss_output_addr));
+
+ // reversed input imag part mul -1
+ int reversed_input_num = batch * (n - FFT_HALF(n));
+ void *input_reversed_re_addr =
+ (uint8_t *)fft_plan->matmul_addrs.input_reversed_addr;
+ void *input_reversed_im_addr =
+ (uint8_t *)fft_plan->matmul_addrs.input_reversed_addr +
+ in_r_dtype_size * reversed_input_num;
+
+ status = fftOptensor(handle, reversed_input_num, input_reversed_im_addr,
+ input_reversed_re_addr, input_reversed_im_addr, -1.0,
+ 0.0, 0.0, in_r_dtype, CNNL_OP_TENSOR_ADD,
+ fft_plan->matmul_addrs.internal_workspace_addr,
+ fft_plan->matmul_addrs.internal_workspace_size, api);
+
+ // conat: [a, b, c, d] + [d, c, b] --> [a, b, c, d, d, c, b]
+ // conat: [a, b, c, d, e] + [d, c, b] --> [a, b, c, d, e, d, c, b]
+ VLOG(5) << "launch mluOpConcat for input";
+ mluOpTensorDescriptor_t concat_output_desc;
+ status = mluOpCreateTensorDescriptor(&concat_output_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ const int concat_dim_num = 2;
+ int64_t concat_out_dims[concat_dim_num] = {COMPLEX * batch, n};
+ status = mluOpSetTensorDescriptor_v2(concat_output_desc, MLUOP_LAYOUT_ARRAY,
+ in_r_dtype, concat_dim_num,
+ concat_out_dims);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ // convert to cnnl_tensor_descriptor
+ DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(concat_output_desc,
+ cnnl_concat_output_desc);
+ const int concat_num = 2;
+ const int concat_axis = -1;
+ cnnlTensorDescriptor_t concat_in_descs[concat_num] = {cnnl_ss_input_desc,
+ cnnl_ss_output_desc};
+
+ void *concat_in_addrs[concat_num] = {ss_input_addr, ss_output_addr};
+ void *concat_out_addr = fft_plan->matmul_addrs.input_merged_addr;
+ CALL_CNNL(cnnlConcat(cnnl_handle, concat_num, concat_axis, concat_in_descs,
+ concat_in_addrs,
+ fft_plan->matmul_addrs.internal_workspace_addr,
+ fft_plan->matmul_addrs.internal_workspace_size,
+ cnnl_concat_output_desc, concat_out_addr));
+ VLOG(5) << "launch mluOpConcat end";
+
+ DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_ss_input_desc);
+ DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_ss_output_desc);
+ DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_concat_output_desc);
+
+ DESTROY_CNNL_HANDLE(cnnl_handle);
+ }
+ return status;
+}
+
+/* CNFFT_FUNC_MATMUL:
+ -------------------------
+ | input_pad |
+ -------------------------
+ |
+ | input trans: batch * (n / 2 + 1) * 2 --> 2 * batch * (n /
+ 2 + 1)
+ \|/
+ -------------------------
+ | input_re |
+ | input_im |
+ -------------------------
+
+ CNFFT_FUNC_COOLEY_TUKEY:
+ -------------------------
+ | input_merged_re |
+ | input_merged_im |
+ -------------------------
+ |
+ | input trans: 2 * batch * L * 2^m --> 2 * batch * 2^m * L
+ \|/
+ -------------------------
+ | input_re |
+ | input_im |
+ -------------------------
+*/
+static mluOpStatus_t transposeIRFFT1dPaddedInput(mluOpHandle_t handle,
+ mluOpFFTPlan_t fft_plan) {
+ std::string api = "[mluOpExecFFT]";
+ VLOG(5) << "into transposeIRFFT1dPaddedInput";
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+
+ mluOpDataType_t in_c_dtype = fft_plan->input_dtype;
+ mluOpDataType_t in_r_dtype = (in_c_dtype == MLUOP_DTYPE_COMPLEX_HALF)
+ ? MLUOP_DTYPE_HALF
+ : MLUOP_DTYPE_FLOAT;
+ int batch = fft_plan->batch;
+ int n = fft_plan->n[0];
+
+ if (fft_plan->fft_strategy == CNFFT_FUNC_MATMUL) {
+ // transpose: batch * (n / 2 + 1) * 2 --> 2 * batch * (n / 2 + 1)
+ VLOG(5) << "launch mluOpTranspose for input MATMUL";
+ int padded_input_num = batch * FFT_HALF(n);
+ const int trans_dim_num = 2;
+ int trans_input_dims[trans_dim_num] = {padded_input_num, COMPLEX};
+ int trans_output_dims[trans_dim_num] = {COMPLEX, padded_input_num};
+ int trans_permute[trans_dim_num] = {1, 0};
+
+ status =
+ fftTranspose(handle, trans_dim_num, trans_input_dims, trans_output_dims,
+ trans_permute, fft_plan->matmul_addrs.input_pad_addr,
+ fft_plan->matmul_addrs.input_re_addr, in_r_dtype,
+ fft_plan->matmul_addrs.internal_workspace_addr,
+ fft_plan->matmul_addrs.internal_workspace_size, api);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ } else if (fft_plan->fft_strategy == CNFFT_FUNC_COOLEY_TUKEY) {
+ VLOG(5) << "launch mluOpTranspose for input COOLEY_TUKEY";
+ int L = fft_plan->L;
+ int m = (1 << fft_plan->m);
+
+ // 2nd transpose: 2 * batch * L * 2^m --> 2 * batch * 2^m * L
+ const int trans_dim_num = 3;
+ int trans_input_dims[trans_dim_num] = {COMPLEX * batch, L, m};
+ int trans_output_dims[trans_dim_num] = {COMPLEX * batch, m, L};
+ int trans_permute[trans_dim_num] = {0, 2, 1};
+
+ status =
+ fftTranspose(handle, trans_dim_num, trans_input_dims, trans_output_dims,
+ trans_permute, fft_plan->matmul_addrs.input_merged_addr,
+ fft_plan->matmul_addrs.input_re_addr, in_r_dtype,
+ fft_plan->matmul_addrs.internal_workspace_addr,
+ fft_plan->matmul_addrs.internal_workspace_size, api);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ }
+ return status;
+}
+
+// input : in input_pad_addr
+// output : in input_pos_addr and input_scale_addr
+static mluOpStatus_t quantizeIRFFT1dPaddedInput(mluOpHandle_t handle,
+ mluOpFFTPlan_t fft_plan) {
+ std::string api = "[mluOpExecFFT]";
+ VLOG(5) << "into quantizeIRFFT1dPaddedInput";
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+
+ mluOpDataType_t in_c_dtype = fft_plan->input_dtype;
+ mluOpDataType_t in_r_dtype = (in_c_dtype == MLUOP_DTYPE_COMPLEX_HALF)
+ ? MLUOP_DTYPE_HALF
+ : MLUOP_DTYPE_FLOAT;
+ mluOpDataType_t in_e_dtype = fft_plan->execution_dtype;
+ int padded_input_num = fft_plan->batch * FFT_HALF(fft_plan->n[0]);
+
+ status = fftQuantizePositionScale(
+ handle, COMPLEX * padded_input_num, in_r_dtype, in_e_dtype,
+ fft_plan->matmul_addrs.input_pad_addr,
+ fft_plan->matmul_addrs.input_pos_addr,
+ fft_plan->matmul_addrs.input_scale_addr,
+ fft_plan->matmul_addrs.internal_workspace_addr,
+ fft_plan->matmul_addrs.internal_workspace_size, api);
+
+ return status;
+}
+
+/* CNFFT_FUNC_MATMUL:
+ -------------------------
+ | input_re |
+ | input_im |
+ -------------------------
+ |
+ | matmul
+ \|/
+ -------------------------
+ | matmul_re_mul_re |
+ | matmul_im_mul_im |(reuse output_contiguous)
+ -------------------------
+ |
+ | op_tensor
+ \|/
+ -------------------------
+ | output_contiguous |
+ -------------------------
+
+ CNFFT_FUNC_COOLEY_TUKEY:
+ -------------------------
+ | input_re |
+ | input_im |
+ -------------------------
+ |
+ | matmul
+ | optensor(re_mul_re - im_mul_im, re_mul_im + im_mul_re)
+ \|/
+ -------------------------
+ | matmul_re_mul_re | (matmul_re)
+ | matmul_re_mul_im | (matmul_im)
+ | matmul_im_mul_re |
+ | matmul_im_mul_im |
+ -------------------------
+*/
+static mluOpStatus_t computeIRFFT1dMatmulResult(mluOpHandle_t handle,
+ mluOpFFTPlan_t fft_plan,
+ const float scale_factor) {
+ std::string api = "[mluOpExecFFT]";
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+
+ mluOpDataType_t in_c_dtype = fft_plan->input_dtype;
+ mluOpDataType_t in_r_dtype = (in_c_dtype == MLUOP_DTYPE_COMPLEX_HALF)
+ ? MLUOP_DTYPE_HALF
+ : MLUOP_DTYPE_FLOAT;
+ mluOpDataType_t in_e_dtype = fft_plan->execution_dtype;
+ int batch = fft_plan->batch;
+ int n = fft_plan->n[0];
+
+ if (fft_plan->fft_strategy == CNFFT_FUNC_MATMUL) {
+ VLOG(5) << "into computeIRFFT1dMatmulResult CNFFT_FUNC_MATMUL";
+ // input real matmul dft real
+ status = fftQuantMatMul(
+ handle, batch, FFT_HALF(n), n, fft_plan->matmul_addrs.input_re_addr,
+ fft_plan->matmul_addrs.input_pos_addr,
+ fft_plan->matmul_addrs.input_scale_addr,
+ fft_plan->matmul_addrs.dft_re_matrix_addr,
+ fft_plan->matmul_addrs.dft_pos_addr,
+ fft_plan->matmul_addrs.dft_scale_addr,
+ fft_plan->matmul_addrs.matmul_re_mul_re_addr, false, true, scale_factor,
+ 0.0, in_e_dtype, in_e_dtype, in_r_dtype,
+ fft_plan->matmul_addrs.internal_workspace_addr,
+ fft_plan->matmul_addrs.internal_workspace_size, api);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ // input imag matmul dft imag
+ status = fftQuantMatMul(
+ handle, batch, FFT_HALF(n), n, fft_plan->matmul_addrs.input_im_addr,
+ fft_plan->matmul_addrs.input_pos_addr,
+ fft_plan->matmul_addrs.input_scale_addr,
+ fft_plan->matmul_addrs.dft_im_matrix_addr,
+ fft_plan->matmul_addrs.dft_pos_addr,
+ fft_plan->matmul_addrs.dft_scale_addr,
+ fft_plan->matmul_addrs.matmul_im_mul_im_addr, false, true, scale_factor,
+ 0.0, in_e_dtype, in_e_dtype, in_r_dtype,
+ fft_plan->matmul_addrs.internal_workspace_addr,
+ fft_plan->matmul_addrs.internal_workspace_size, api);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ // real mul real add imag mul imag
+ int per_matmul_output_num = batch * n;
+ status = fftOptensor(handle, per_matmul_output_num,
+ fft_plan->matmul_addrs.matmul_re_mul_re_addr,
+ fft_plan->matmul_addrs.matmul_im_mul_im_addr,
+ fft_plan->matmul_addrs.output_contiguous_addr, 1.0,
+ 1.0, 0.0, in_r_dtype, CNNL_OP_TENSOR_ADD,
+ fft_plan->matmul_addrs.internal_workspace_addr,
+ fft_plan->matmul_addrs.internal_workspace_size, api);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ } else if (fft_plan->fft_strategy == CNFFT_FUNC_COOLEY_TUKEY) {
+ VLOG(5) << "into computeIRFFT1dMatmulResult CNFFT_FUNC_COOLEY_TUKEY";
+ int L = fft_plan->L;
+ int m = (1 << fft_plan->m);
+
+ // input real matmul dft real
+ status = fftQuantMatMul(
+ handle, batch * m, L, L, fft_plan->matmul_addrs.input_re_addr,
+ fft_plan->matmul_addrs.input_pos_addr,
+ fft_plan->matmul_addrs.input_scale_addr,
+ fft_plan->matmul_addrs.dft_re_matrix_addr,
+ fft_plan->matmul_addrs.dft_pos_addr,
+ fft_plan->matmul_addrs.dft_scale_addr,
+ fft_plan->matmul_addrs.matmul_re_mul_re_addr, false, true, scale_factor,
+ 0.0, in_e_dtype, in_e_dtype, in_r_dtype,
+ fft_plan->matmul_addrs.internal_workspace_addr,
+ fft_plan->matmul_addrs.internal_workspace_size, api);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ // input imag matmul dft imag
+ status = fftQuantMatMul(
+ handle, batch * m, L, L, fft_plan->matmul_addrs.input_im_addr,
+ fft_plan->matmul_addrs.input_pos_addr,
+ fft_plan->matmul_addrs.input_scale_addr,
+ fft_plan->matmul_addrs.dft_im_matrix_addr,
+ fft_plan->matmul_addrs.dft_pos_addr,
+ fft_plan->matmul_addrs.dft_scale_addr,
+ fft_plan->matmul_addrs.matmul_im_mul_im_addr, false, true, scale_factor,
+ 0.0, in_e_dtype, in_e_dtype, in_r_dtype,
+ fft_plan->matmul_addrs.internal_workspace_addr,
+ fft_plan->matmul_addrs.internal_workspace_size, api);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ // input real matmul dft imag
+ status = fftQuantMatMul(
+ handle, batch * m, L, L, fft_plan->matmul_addrs.input_re_addr,
+ fft_plan->matmul_addrs.input_pos_addr,
+ fft_plan->matmul_addrs.input_scale_addr,
+ fft_plan->matmul_addrs.dft_im_matrix_addr,
+ fft_plan->matmul_addrs.dft_pos_addr,
+ fft_plan->matmul_addrs.dft_scale_addr,
+ fft_plan->matmul_addrs.matmul_re_mul_im_addr, false, true, scale_factor,
+ 0.0, in_e_dtype, in_e_dtype, in_r_dtype,
+ fft_plan->matmul_addrs.internal_workspace_addr,
+ fft_plan->matmul_addrs.internal_workspace_size, api);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ // input imag matmul dft real
+ status = fftQuantMatMul(
+ handle, batch * m, L, L, fft_plan->matmul_addrs.input_im_addr,
+ fft_plan->matmul_addrs.input_pos_addr,
+ fft_plan->matmul_addrs.input_scale_addr,
+ fft_plan->matmul_addrs.dft_re_matrix_addr,
+ fft_plan->matmul_addrs.dft_pos_addr,
+ fft_plan->matmul_addrs.dft_scale_addr,
+ fft_plan->matmul_addrs.matmul_im_mul_re_addr, false, true, scale_factor,
+ 0.0, in_e_dtype, in_e_dtype, in_r_dtype,
+ fft_plan->matmul_addrs.internal_workspace_addr,
+ fft_plan->matmul_addrs.internal_workspace_size, api);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ } else if (fft_plan->fft_strategy == CNFFT_FUNC_STOCKHAM) {
+ int L = fft_plan->L;
+ int m = (1 << fft_plan->m);
+
+ // W[2 * L, L] * in[batch * 2, L, 2^m] -> out[batch, 2, 2, L, 2^m]
+ status = fftBatchMatMulBcast(
+ handle, 2 * L, L, m, batch * 2,
+ fft_plan->matmul_addrs.dft_re_matrix_addr,
+ fft_plan->matmul_addrs.dft_pos_addr,
+ fft_plan->matmul_addrs.dft_scale_addr,
+ fft_plan->matmul_addrs.input_merged_addr,
+ fft_plan->matmul_addrs.input_pos_addr,
+ fft_plan->matmul_addrs.input_scale_addr,
+ fft_plan->matmul_addrs.matmul_re_mul_re_addr, false, false,
+ scale_factor, 0.0, in_e_dtype, in_e_dtype, in_r_dtype,
+ fft_plan->matmul_addrs.internal_workspace_addr,
+ fft_plan->matmul_addrs.internal_workspace_size, api);
+ }
+
+ return status;
+}
+
+static mluOpStatus_t policyFunc(mluOpHandle_t handle, cnrtDim3_t *k_dim,
+ cnrtFunctionType_t *k_type) {
+ *k_type = CNRT_FUNC_TYPE_UNION1;
+ k_dim->x = handle->core_num_per_cluster;
+ k_dim->y = mluop::runtime::getClusterLimitCapability(handle);
+ k_dim->z = 1;
+ return MLUOP_STATUS_SUCCESS;
+}
+
+// only for CNFFT_FUNC_COOLEY_TUKEY and CNFFT_FUNC_STOCKHAM
+// input : matmul real result in matmul_re_mul_re_addr
+// matmul imag result in matmul_re_mul_im_addr
+// workspace: internal_workspace_addr
+// output : output real result in output_contiguous_addr
+mluOpStatus_t mergeIRFFT1dOutput(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan,
+ const float scale_factor) {
+ std::string api = "[mluOpExecFFT]";
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+ VLOG(5) << "launch merge irfft1d output";
+ if (fft_plan->fft_strategy == CNFFT_FUNC_COOLEY_TUKEY) {
+ int core_num = handle->core_num_per_cluster;
+ cnrtFunctionType_t k_type = CNRT_FUNC_TYPE_UNION1;
+ int task_type = mluop::runtime::getJobLimitCapability(handle);
+ int task_num = 1;
+
+ switch (task_type) {
+ default:
+ task_num = core_num;
+ break;
+ case (int)CNRT_FUNC_TYPE_UNION2:
+ task_num = core_num * 2;
+ break;
+ case (int)CNRT_FUNC_TYPE_UNION4:
+ task_num = core_num * 4;
+ break;
+ case (int)CNRT_FUNC_TYPE_UNION8:
+ task_num = core_num * 8;
+ break;
+ case (int)CNRT_FUNC_TYPE_UNION16:
+ task_num = core_num * 16;
+ break;
+ }
+
+ unsigned int dimx = task_num;
+ cnrtDim3_t k_dim = {dimx, 1, 1};
+ k_type = (cnrtFunctionType_t)dimx;
+ kernelFFTCooleyTukey(k_dim, k_type, handle->queue, fft_plan, -1, IRFFT);
+ // direction, -1 means invalid(only FFT_IFFT use)
+ } else if (fft_plan->fft_strategy == CNFFT_FUNC_STOCKHAM) {
+ cnrtDim3_t k_dim;
+ cnrtFunctionType_t k_type;
+ policyFunc(handle, &k_dim, &k_type);
+ kernelFFTStockham(k_dim, k_type, handle->queue, fft_plan, -1, scale_factor,
+ IRFFT);
+ // direction, -1 means invalid(only FFT_IFFT use).
+ }
+ return status;
+}
+
+// input : in output_contiguous_addr
+// output : in output
+static mluOpStatus_t makeIRFFT1dContiguousOutput(mluOpHandle_t handle,
+ mluOpFFTPlan_t fft_plan,
+ void *output) {
+ std::string api = "[mluOpExecFFT]";
+ VLOG(5) << "into makeIRFFT1dContiguousOutput";
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+ if (!fft_plan->is_output_contiguous) {
+ VLOG(5) << "launch copy with stride";
+ mluOpDataType_t out_r_dtype = fft_plan->output_dtype;
+ // create tensor desc
+ mluOpTensorDescriptor_t copy_src_desc, copy_dst_desc;
+ status = mluOpCreateTensorDescriptor(©_src_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ status = mluOpCreateTensorDescriptor(©_dst_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ // set up tensor desc
+ const int out_dim_num = 2;
+ int64_t dims[out_dim_num] = {fft_plan->batch, fft_plan->onembed[0]};
+ int64_t strides[out_dim_num] = {fft_plan->odist, fft_plan->ostride};
+ status = mluOpSetTensorDescriptor_v2(copy_src_desc, MLUOP_LAYOUT_ARRAY,
+ out_r_dtype, out_dim_num, dims);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ status =
+ mluOpSetTensorDescriptorEx_v2(copy_dst_desc, MLUOP_LAYOUT_ARRAY,
+ out_r_dtype, out_dim_num, dims, strides);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ // copy
+ void *copy_src_addr = fft_plan->matmul_addrs.output_contiguous_addr;
+ DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle,
+ cnnl_handle); // convert to cnnl_handle
+ // convert to cnnl_tensor_descriptor
+ DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(copy_src_desc,
+ cnnl_copy_src_desc);
+ DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(copy_dst_desc,
+ cnnl_copy_dst_desc);
+
+ CALL_CNNL(cnnlCopy(cnnl_handle, cnnl_copy_src_desc, copy_src_addr,
+ cnnl_copy_dst_desc, output));
+
+ DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_copy_src_desc);
+ DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_copy_dst_desc);
+ DESTROY_CNNL_HANDLE(cnnl_handle);
+ }
+ return status;
+}
+
+mluOpStatus_t execIRFFT1d(mluOpHandle_t handle, const mluOpFFTPlan_t fft_plan,
+ const void *input, const float scale_factor,
+ void *workspace, void *output) {
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+
+ std::string api = "[mluOpExecFFT]";
+ configureIRFFT1dMatmulWorkspaceAddrs(handle, fft_plan, (void *)input,
+ workspace, output);
+
+ status = makeIRFFT1dContiguousInput(handle, fft_plan, input);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ status = padIRFFT1dContiguousInput(handle, fft_plan);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ status = mergeIRFFT1dInput(handle, fft_plan);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ status = transposeIRFFT1dPaddedInput(handle, fft_plan);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ status = quantizeIRFFT1dPaddedInput(handle, fft_plan);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ status = computeIRFFT1dMatmulResult(handle, fft_plan, scale_factor);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ status = mergeIRFFT1dOutput(handle, fft_plan, scale_factor);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ status = makeIRFFT1dContiguousOutput(handle, fft_plan, output);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ return status;
+}
diff --git a/kernels/fft/rfft/rfft.h b/kernels/fft/rfft/rfft.h
new file mode 100644
index 000000000..7754e2c1e
--- /dev/null
+++ b/kernels/fft/rfft/rfft.h
@@ -0,0 +1,39 @@
+/*************************************************************************
+ * Copyright (C) [2024] by Cambricon, Inc.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a
+ * copy of this software and associated documentation files (the
+ * "Software"), to deal in the Software without restriction, including
+ * without limitation the rights to use, copy, modify, merge, publish,
+ * distribute, sublicense, and/or sell copies of the Software, and to
+ * permit persons to whom the Software is furnished to do so, subject to
+ * the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included
+ * in all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+ * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+ * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+ * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
+ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
+ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+ *************************************************************************/
+#ifndef KERNELS_FFT_RFFT_RFFT_H_
+#define KERNELS_FFT_RFFT_RFFT_H_
+
+#include
+#include "kernels/fft/fft.h"
+
+mluOpStatus_t makeRFFT1dPolicy(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan);
+
+mluOpStatus_t setRFFT1dReserveArea(mluOpHandle_t handle,
+ mluOpFFTPlan_t fft_plan,
+ const std::string api);
+
+mluOpStatus_t execRFFT1d(mluOpHandle_t handle, const mluOpFFTPlan_t fft_plan,
+ const void *input, const float scale_factor,
+ void *workspace, void *output);
+
+#endif // KERNELS_FFT_RFFT_RFFT_H_
diff --git a/kernels/fft/rfft/rfft_host.cpp b/kernels/fft/rfft/rfft_host.cpp
new file mode 100644
index 000000000..afe42e7f7
--- /dev/null
+++ b/kernels/fft/rfft/rfft_host.cpp
@@ -0,0 +1,917 @@
+/*************************************************************************
+ * Copyright (C) [2024] by Cambricon, Inc.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a
+ * copy of this software and associated documentation files (the
+ * "Software"), to deal in the Software without restriction, including
+ * without limitation the rights to use, copy, modify, merge, publish,
+ * distribute, sublicense, and/or sell copies of the Software, and to
+ * permit persons to whom the Software is furnished to do so, subject to
+ * the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included
+ * in all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+ * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+ * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+ * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
+ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
+ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+ *************************************************************************/
+
+#include "kernels/fft/rfft/rfft.h"
+#include
+#include
+
+static mluOpStatus_t selectRFFT1dStrategy(mluOpHandle_t handle,
+ mluOpFFTPlan_t fft_plan) {
+ const std::string make_plan_api = "[selectRFFT1dStrategy]";
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+ /* there are plenty of algorithms for FFT, depending on the fft length.
+ * Iterative FFT:
+ * Stockham FFT, Cooley-Tukey FFT, peaseFFT, Kron-Lambiotte FFT
+ * Recursive FFT:
+ * Recursive Cooley-Tukey FFT, Four-step FFT, Six-step FFT, Multicore FFT,
+ * SIMD short vector FFT. General FFT: chirp-Z Bluestein FFT.
+ */
+ // select Four-Step FFT or MATMUL strategy logic
+ fft_plan->fft_strategy = CNFFT_FUNC_MATMUL;
+ status = selectFFTStrategy(handle, fft_plan, make_plan_api);
+ return status;
+}
+
+/*
+ * Make the policy of RFFT1d.
+ */
+mluOpStatus_t makeRFFT1dPolicy(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan) {
+ std::string api = "[mluOpMakeFFTPlanMany]";
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+ INTERNAL_CHECK(
+ api, selectRFFT1dStrategy(handle, fft_plan) == MLUOP_STATUS_SUCCESS);
+
+ mluOpDataType_t in_r_dtype = fft_plan->input_dtype;
+ mluOpDataType_t in_e_dtype = fft_plan->execution_dtype;
+ size_t in_r_dtype_size = mluOpDataTypeBytes(in_r_dtype);
+ int batch = fft_plan->batch;
+ int n = fft_plan->n[0];
+
+ switch (fft_plan->fft_strategy) {
+ case CNFFT_FUNC_MATMUL: {
+ if (n > FFT_L_LIMIT) {
+ LOG(ERROR) << "[mluOpMakeFFTPlanMany]: RFFT1d CNFFT_FUNC_MATMUL "
+ << "length > 4096 is not supported currently.";
+ return MLUOP_STATUS_NOT_SUPPORTED;
+ }
+
+ // Matmul Input : [batch, n]
+ // Matmul Matrix : [(n / 2 + 1), 2, n]
+ // Matmul Result : [batch, (n / 2 + 1), 2]
+ int dim0 = FFT_HALF(n);
+ int dim1 = COMPLEX; // complex
+ int dim2 = n;
+ int dft_mat_num = dim0 * dim1 * dim2;
+
+ // reservespace size allocation
+ fft_plan->reservespace_size = 0;
+ fft_plan->reservespace_size += dft_mat_num * in_r_dtype_size;
+ if (fftIsIntDtype(in_e_dtype)) {
+ fft_plan->reservespace_size += sizeof(int32_t) + sizeof(float);
+ size_t required_size = 0;
+ status = fftGetQuantizeParamWorkspaceSize(
+ handle, required_size, dft_mat_num, in_r_dtype, in_e_dtype, api);
+ fft_plan->reservespace_size += required_size;
+ }
+
+ /* CNFFT_FUNC_MATMUL :
+ -------------------------
+ | input |
+ -------------------------
+ |
+ | input contiguous
+ \|/
+ -------------------------
+ | input_contiguous |
+ -------------------------
+ |
+ | input pad
+ \|/
+ -------------------------
+ | input_pad |
+ -------------------------
+ |
+ | matmul
+ \|/
+ -------------------------
+ | output_contiguous |
+ -------------------------
+ |
+ | output contiguous
+ \|/
+ -------------------------
+ | output |
+ -------------------------
+ */
+ // worksapce size allocation
+ fft_plan->matmul_addrs.internal_workspace_size = 0;
+ fft_plan->workspace_size = 0;
+
+ // input contiguous
+ size_t input_size = in_r_dtype_size * fft_plan->inum;
+ fft_plan->workspace_size +=
+ fft_plan->is_input_contiguous ? 0 : input_size;
+
+ // input pad
+ bool need_pad = (fft_plan->inembed[0] != n);
+ int padded_input_num = batch * n;
+ size_t padded_input_size = in_r_dtype_size * padded_input_num;
+ fft_plan->workspace_size += need_pad ? padded_input_size : 0;
+
+ // input quantize param and workspace
+ if (fftIsIntDtype(in_e_dtype)) {
+ fft_plan->workspace_size += sizeof(int32_t) + sizeof(float);
+ size_t input_quant_workspace_size = 0;
+ status = fftGetQuantizeParamWorkspaceSize(
+ handle, input_quant_workspace_size, padded_input_num, in_r_dtype,
+ in_e_dtype, api);
+ fft_plan->matmul_addrs.internal_workspace_size =
+ std::max(fft_plan->matmul_addrs.internal_workspace_size,
+ input_quant_workspace_size);
+ }
+
+ // matmul workspace
+ size_t matmul_workspace_size = 0;
+ status = fftGetQuantizeMatMulWorkspaceSize(
+ handle, matmul_workspace_size, batch, dim2, dim0 * dim1, false, true,
+ in_e_dtype, in_e_dtype, in_r_dtype, api);
+ fft_plan->matmul_addrs.internal_workspace_size =
+ std::max(fft_plan->matmul_addrs.internal_workspace_size,
+ matmul_workspace_size);
+
+ // output contiguous
+ int padded_output_num = batch * FFT_HALF(n);
+ size_t padded_output_size =
+ mluOpDataTypeBytes(fft_plan->output_dtype) * padded_output_num;
+ fft_plan->workspace_size +=
+ fft_plan->is_output_contiguous ? 0 : padded_output_size;
+
+ // internal_workspace
+ fft_plan->workspace_size +=
+ fft_plan->matmul_addrs.internal_workspace_size;
+ VLOG(5) << "internal workspace size: "
+ << fft_plan->matmul_addrs.internal_workspace_size;
+ VLOG(5) << "total workspace size: " << fft_plan->workspace_size;
+ }; break;
+ case CNFFT_FUNC_COOLEY_TUKEY:
+ case CNFFT_FUNC_STOCKHAM: {
+ int L = fft_plan->L;
+ int m = (1 << fft_plan->m);
+ if (L > FFT_L_LIMIT) {
+ LOG(ERROR) << "[mluOpMakeFFTPlanMany]: RFFT1d CNFFT_FUNC_COOLEY_TUKEY "
+ << "n = L * 2^m and L > 4096 is not supported currently.";
+ return MLUOP_STATUS_NOT_SUPPORTED;
+ }
+
+ // Matmul Input : [batch, 2^m, L]
+ // Matmul Matrix : 2 * [L, L]
+ // Matmul Result : 2 * [batch, 2^m, L]
+ int dft_mat_times = COMPLEX;
+ int dim0 = L;
+ int dim1 = L;
+ int dft_mat_num = dft_mat_times * dim0 * dim1;
+
+ // reservespace size allocation
+ fft_plan->reservespace_size = 0;
+ fft_plan->reservespace_size += dft_mat_num * in_r_dtype_size;
+ if (fftIsIntDtype(in_e_dtype)) {
+ fft_plan->reservespace_size += sizeof(int32_t) + sizeof(float);
+ size_t required_size = 0;
+ status = fftGetQuantizeParamWorkspaceSize(
+ handle, required_size, dft_mat_num, in_r_dtype, in_e_dtype, api);
+ fft_plan->reservespace_size += required_size;
+ }
+
+ /* CNFFT_FUNC_COOLEY_TUKEY :
+ -------------------------
+ | input |
+ -------------------------
+ |
+ | input contiguous
+ \|/
+ -------------------------
+ | input_contiguous |
+ -------------------------
+ |
+ | input pad
+ \|/
+ -------------------------
+ | input_pad |
+ -------------------------
+ |
+ | input trans: batch * L * 2^m --> batch * 2^m * L
+ \|/
+ -------------------------
+ | input_transed |
+ -------------------------
+ |
+ | matmul
+ \|/
+ -------------------------
+ | matmul_re_mul_re |
+ | matmul_re_mul_im |
+ -------------------------
+ |
+ | output merge
+ \|/
+ -------------------------
+ | output_contiguous |
+ -------------------------
+ |
+ | output contiguous
+ \|/
+ -------------------------
+ | output |
+ -------------------------
+ */
+ // worksapce size allocation
+ fft_plan->matmul_addrs.internal_workspace_size = 0;
+ fft_plan->workspace_size = 0;
+
+ // input contiguous
+ size_t input_size = in_r_dtype_size * fft_plan->inum;
+ fft_plan->workspace_size +=
+ fft_plan->is_input_contiguous ? 0 : input_size;
+
+ // input pad
+ bool need_pad = (fft_plan->inembed[0] != n);
+ int padded_input_num = batch * n;
+ size_t padded_input_size = in_r_dtype_size * padded_input_num;
+ fft_plan->workspace_size += need_pad ? padded_input_size : 0;
+
+ // input trans
+ size_t transed_input_size = padded_input_size;
+ fft_plan->workspace_size += transed_input_size;
+ // input trans workspace: batch * L * 2^m --> batch * 2^m * L
+ const int trans_dim_num = 3;
+ int trans_input_dims[trans_dim_num] = {batch, L, m};
+ int trans_permute[trans_dim_num] = {0, 2, 1};
+ size_t trans_workspace_size = 0;
+ status = fftGetTransposeWorkspaceSize(handle, trans_workspace_size,
+ trans_dim_num, trans_input_dims,
+ trans_permute, in_r_dtype, api);
+ fft_plan->matmul_addrs.internal_workspace_size = std::max(
+ fft_plan->matmul_addrs.internal_workspace_size, trans_workspace_size);
+
+ // input quantize param and workspace
+ if (fftIsIntDtype(in_e_dtype)) {
+ fft_plan->workspace_size += sizeof(int32_t) + sizeof(float);
+ size_t input_quant_workspace_size = 0;
+ status = fftGetQuantizeParamWorkspaceSize(
+ handle, input_quant_workspace_size, padded_input_num, in_r_dtype,
+ in_e_dtype, api);
+ fft_plan->matmul_addrs.internal_workspace_size =
+ std::max(fft_plan->matmul_addrs.internal_workspace_size,
+ input_quant_workspace_size);
+ }
+
+ // matmul output
+ int matmul_times = COMPLEX; // real and imag
+ int per_matmul_output_num = batch * n;
+ size_t matmul_output_size =
+ matmul_times * in_r_dtype_size * per_matmul_output_num;
+ fft_plan->workspace_size += matmul_output_size;
+ // matmul workspace
+ size_t matmul_workspace_size = 0;
+ status = fftGetQuantizeMatMulWorkspaceSize(
+ handle, matmul_workspace_size, batch * m, L, L, false, true,
+ in_e_dtype, in_e_dtype, in_r_dtype, api);
+ fft_plan->matmul_addrs.internal_workspace_size =
+ std::max(fft_plan->matmul_addrs.internal_workspace_size,
+ matmul_workspace_size);
+
+ // output merge workspace
+ size_t merge_workspace_size = matmul_output_size;
+ fft_plan->matmul_addrs.internal_workspace_size = std::max(
+ fft_plan->matmul_addrs.internal_workspace_size, merge_workspace_size);
+
+ // output contiguous
+ size_t output_size =
+ mluOpDataTypeBytes(fft_plan->output_dtype) * fft_plan->onum;
+ fft_plan->workspace_size +=
+ fft_plan->is_output_contiguous ? 0 : output_size;
+
+ // internal_workspace
+ fft_plan->workspace_size +=
+ fft_plan->matmul_addrs.internal_workspace_size;
+ VLOG(5) << "internal workspace size: "
+ << fft_plan->matmul_addrs.internal_workspace_size;
+ VLOG(5) << "total workspace size: " << fft_plan->workspace_size;
+ }; break;
+
+ default: {
+ status = MLUOP_STATUS_NOT_SUPPORTED;
+ return status;
+ }
+ }
+ return status;
+}
+
+static void configureRFFT1dMatmulReserveAddrs(mluOpHandle_t handle,
+ mluOpFFTPlan_t fft_plan) {
+ size_t dft_mat_size = 0;
+ mluOpDataType_t in_r_dtype = fft_plan->input_dtype;
+ mluOpDataType_t in_e_dtype = fft_plan->execution_dtype;
+ size_t in_r_dtype_size = mluOpDataTypeBytes(in_r_dtype);
+ int n = fft_plan->n[0];
+
+ switch (fft_plan->fft_strategy) {
+ case CNFFT_FUNC_MATMUL: {
+ // Matmul Matrix : [(n / 2 + 1), 2, n]
+ int dim0 = FFT_HALF(n);
+ int dim1 = COMPLEX;
+ int dim2 = n;
+ dft_mat_size = dim0 * dim1 * dim2 * in_r_dtype_size;
+ fft_plan->matmul_addrs.dft_matrix_addr = fft_plan->reservespace_addr;
+ }; break;
+ case CNFFT_FUNC_COOLEY_TUKEY:
+ case CNFFT_FUNC_STOCKHAM: {
+ // Matmul Matrix : 2 * [L, L]
+ int L = fft_plan->L;
+ int dft_mat_times = COMPLEX;
+ size_t per_dft_mat_size = L * L * in_r_dtype_size;
+ dft_mat_size = dft_mat_times * per_dft_mat_size;
+ fft_plan->matmul_addrs.dft_matrix_addr = fft_plan->reservespace_addr;
+ fft_plan->matmul_addrs.dft_re_matrix_addr = fft_plan->reservespace_addr;
+ fft_plan->matmul_addrs.dft_im_matrix_addr =
+ (uint8_t *)fft_plan->reservespace_addr + per_dft_mat_size;
+ }; break;
+ default: {
+ break;
+ }
+ }
+ if (fftIsIntDtype(in_e_dtype)) {
+ fft_plan->matmul_addrs.dft_pos_addr =
+ (uint8_t *)fft_plan->reservespace_addr + dft_mat_size;
+ fft_plan->matmul_addrs.dft_scale_addr =
+ (uint8_t *)fft_plan->matmul_addrs.dft_pos_addr + sizeof(int32_t);
+ fft_plan->matmul_addrs.dft_quantize_workspace_addr =
+ (uint8_t *)fft_plan->matmul_addrs.dft_scale_addr + sizeof(float);
+ fft_plan->matmul_addrs.dft_quantize_workspace_size =
+ fft_plan->reservespace_size - dft_mat_size - sizeof(int32_t) -
+ sizeof(float);
+ } else {
+ fft_plan->matmul_addrs.dft_pos_addr = nullptr;
+ fft_plan->matmul_addrs.dft_scale_addr = nullptr;
+ fft_plan->matmul_addrs.dft_quantize_workspace_addr = nullptr;
+ fft_plan->matmul_addrs.dft_quantize_workspace_size = 0;
+ }
+}
+
+mluOpStatus_t setRFFT1dReserveArea(mluOpHandle_t handle,
+ mluOpFFTPlan_t fft_plan,
+ const std::string api) {
+ VLOG(5) << "setRFFT1dReserveArea";
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+ configureRFFT1dMatmulReserveAddrs(handle, fft_plan);
+
+ mluOpDataType_t in_r_dtype = fft_plan->input_dtype;
+ mluOpDataType_t in_e_dtype = fft_plan->execution_dtype;
+ int n = fft_plan->n[0];
+
+ const unsigned int cluster_number =
+ mluop::runtime::getClusterLimitCapability(handle);
+ const unsigned int core_dim = handle->core_num_per_cluster;
+ cnrtDim3_t k_dim = {core_dim, cluster_number, 1};
+ cnrtFunctionType_t k_type = CNRT_FUNC_TYPE_BLOCK;
+
+ switch (fft_plan->fft_strategy) {
+ case CNFFT_FUNC_MATMUL: {
+ // Matmul Matrix : [(n / 2 + 1), 2, n]
+ int dim0 = FFT_HALF(n);
+ int dim1 = COMPLEX;
+ int dim2 = n;
+ int dft_mat_num = dim0 * dim1 * dim2;
+ kernelGenerateRFFTHalfDFTMatrix(k_dim, k_type, handle->queue, fft_plan,
+ in_r_dtype, n);
+ status = fftQuantizePositionScale(
+ handle, dft_mat_num, in_r_dtype, in_e_dtype,
+ fft_plan->matmul_addrs.dft_matrix_addr,
+ fft_plan->matmul_addrs.dft_pos_addr,
+ fft_plan->matmul_addrs.dft_scale_addr,
+ fft_plan->matmul_addrs.dft_quantize_workspace_addr,
+ fft_plan->matmul_addrs.dft_quantize_workspace_size, api);
+ INTERNAL_CHECK("[mluOpSetFFTReserveArea]",
+ status == MLUOP_STATUS_SUCCESS);
+ }; break;
+ case CNFFT_FUNC_COOLEY_TUKEY: {
+ // Matmul Matrix : 2 * [L, L]
+ int L = fft_plan->L;
+ int dft_mat_times = COMPLEX;
+ int dft_mat_num = dft_mat_times * L * L;
+ kernelGenerateRFFTFullDFTMatrix(k_dim, k_type, handle->queue, fft_plan,
+ in_r_dtype, L, L);
+ status = fftQuantizePositionScale(
+ handle, dft_mat_num, in_r_dtype, in_e_dtype,
+ fft_plan->matmul_addrs.dft_matrix_addr,
+ fft_plan->matmul_addrs.dft_pos_addr,
+ fft_plan->matmul_addrs.dft_scale_addr,
+ fft_plan->matmul_addrs.dft_quantize_workspace_addr,
+ fft_plan->matmul_addrs.dft_quantize_workspace_size, api);
+ INTERNAL_CHECK("[mluOpSetFFTReserveArea]",
+ status == MLUOP_STATUS_SUCCESS);
+ }; break;
+ case CNFFT_FUNC_STOCKHAM: {
+ // Matmul Matrix : 2 * [L, L]
+ int L = fft_plan->L;
+ int row = L <= fft_plan->L_sub ? L : (PAD_UP(L / 2, fft_plan->L_sub) + 1);
+ int dft_mat_times = COMPLEX;
+ int dft_mat_num = dft_mat_times * L * L;
+ VLOG(5) << "CNFFT_FUNC_STOCKHAM generateRFFTFullDFTMatrix";
+ kernelGenerateRFFTFullDFTMatrix(k_dim, k_type, handle->queue, fft_plan,
+ in_r_dtype, row, L);
+ status = fftQuantizePositionScale(
+ handle, dft_mat_num, in_r_dtype, in_e_dtype,
+ fft_plan->matmul_addrs.dft_matrix_addr,
+ fft_plan->matmul_addrs.dft_pos_addr,
+ fft_plan->matmul_addrs.dft_scale_addr,
+ fft_plan->matmul_addrs.dft_quantize_workspace_addr,
+ fft_plan->matmul_addrs.dft_quantize_workspace_size, api);
+ INTERNAL_CHECK("[mluOpSetFFTReserveArea]",
+ status == MLUOP_STATUS_SUCCESS);
+ }; break;
+ default: {
+ status = MLUOP_STATUS_NOT_SUPPORTED;
+ }
+ }
+ return status;
+}
+
+static void configureRFFT1dMatmulWorkspaceAddrs(mluOpHandle_t handle,
+ mluOpFFTPlan_t fft_plan,
+ void *input, void *workspace,
+ void *output) {
+ VLOG(5) << "Into configure RFFT1d Matmul Workspace Addrs";
+ size_t workspace_cur_offset = 0;
+ size_t workspace_cur_offset_to_end = 0;
+ size_t workspace_total_size = fft_plan->workspace_size;
+ void *workspace_end = (uint8_t *)workspace + workspace_total_size;
+
+ mluOpDataType_t in_r_dtype = fft_plan->input_dtype;
+ mluOpDataType_t in_e_dtype = fft_plan->execution_dtype;
+ size_t in_r_dtype_size = mluOpDataTypeBytes(in_r_dtype);
+ int batch = fft_plan->batch;
+ int n = fft_plan->n[0];
+
+ // input contiguous
+ size_t input_size = in_r_dtype_size * fft_plan->inum;
+ if (!fft_plan->is_input_contiguous) {
+ fft_plan->matmul_addrs.input_contiguous_addr =
+ (uint8_t *)workspace + workspace_cur_offset;
+ workspace_cur_offset += input_size;
+ } else {
+ fft_plan->matmul_addrs.input_contiguous_addr = input;
+ }
+
+ // input pad
+ bool need_pad = (fft_plan->inembed[0] != n);
+ int padded_input_num = batch * n;
+ size_t padded_input_size = in_r_dtype_size * padded_input_num;
+ if (need_pad) {
+ fft_plan->matmul_addrs.input_pad_addr =
+ (uint8_t *)workspace + workspace_cur_offset;
+ workspace_cur_offset += padded_input_size;
+ } else {
+ fft_plan->matmul_addrs.input_pad_addr =
+ fft_plan->matmul_addrs.input_contiguous_addr;
+ }
+
+ // input trans
+ if (fft_plan->fft_strategy == CNFFT_FUNC_COOLEY_TUKEY ||
+ fft_plan->fft_strategy == CNFFT_FUNC_STOCKHAM) {
+ fft_plan->matmul_addrs.input_transed_addr =
+ (uint8_t *)workspace + workspace_cur_offset;
+ workspace_cur_offset += padded_input_size;
+ } else {
+ fft_plan->matmul_addrs.input_transed_addr =
+ fft_plan->matmul_addrs.input_pad_addr;
+ }
+
+ // input quantize
+ if (fftIsIntDtype(in_e_dtype)) {
+ fft_plan->matmul_addrs.input_pos_addr =
+ (uint8_t *)workspace + workspace_cur_offset;
+ workspace_cur_offset += sizeof(int32_t);
+ fft_plan->matmul_addrs.input_scale_addr =
+ (uint8_t *)workspace + workspace_cur_offset;
+ workspace_cur_offset += sizeof(float);
+ } else {
+ fft_plan->matmul_addrs.input_pos_addr = nullptr;
+ fft_plan->matmul_addrs.input_scale_addr = nullptr;
+ }
+
+ // internal workspace
+ workspace_cur_offset_to_end += fft_plan->matmul_addrs.internal_workspace_size;
+ fft_plan->matmul_addrs.internal_workspace_addr =
+ (uint8_t *)workspace_end - workspace_cur_offset_to_end;
+
+ // output contiguous
+ size_t output_size =
+ mluOpDataTypeBytes(fft_plan->output_dtype) * fft_plan->onum;
+ if (!fft_plan->is_output_contiguous) {
+ workspace_cur_offset_to_end += output_size;
+ fft_plan->matmul_addrs.output_contiguous_addr =
+ (uint8_t *)workspace_end - workspace_cur_offset_to_end;
+ } else {
+ fft_plan->matmul_addrs.output_contiguous_addr = output;
+ }
+
+ // matmul output
+ if (fft_plan->fft_strategy == CNFFT_FUNC_COOLEY_TUKEY ||
+ fft_plan->fft_strategy == CNFFT_FUNC_STOCKHAM) {
+ int per_matmul_output_num = batch * n;
+ size_t per_matmul_output_size = in_r_dtype_size * per_matmul_output_num;
+ workspace_cur_offset_to_end += per_matmul_output_size;
+ fft_plan->matmul_addrs.matmul_re_mul_im_addr =
+ (uint8_t *)workspace_end - workspace_cur_offset_to_end;
+ workspace_cur_offset_to_end += per_matmul_output_size;
+ fft_plan->matmul_addrs.matmul_re_mul_re_addr =
+ (uint8_t *)workspace_end - workspace_cur_offset_to_end;
+ } else {
+ fft_plan->matmul_addrs.matmul_re_mul_im_addr = nullptr;
+ fft_plan->matmul_addrs.matmul_re_mul_re_addr = nullptr;
+ }
+}
+
+// input : in input
+// output : in input_contiguous_addr
+static mluOpStatus_t makeRFFT1dContiguousInput(mluOpHandle_t handle,
+ mluOpFFTPlan_t fft_plan,
+ const void *input) {
+ std::string api = "[mluOpExecFFT]";
+ VLOG(5) << "into makeRFFT1dContiguousInput";
+ auto status = MLUOP_STATUS_SUCCESS;
+ if (!fft_plan->is_input_contiguous) {
+ VLOG(5) << "launch mluOpContiguous";
+ mluOpTensorDescriptor_t input_desc;
+ status = mluOpCreateTensorDescriptor(&input_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ const int in_dim_num = 2;
+ int64_t dims[in_dim_num] = {fft_plan->batch, fft_plan->inembed[0]};
+ int64_t strides[in_dim_num] = {fft_plan->idist, fft_plan->istride};
+ status = mluOpSetTensorDescriptorEx_v2(input_desc, MLUOP_LAYOUT_ARRAY,
+ fft_plan->input_dtype, in_dim_num,
+ dims, strides);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ status = mluOpContiguous(handle, input_desc, input,
+ fft_plan->matmul_addrs.input_contiguous_addr);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ status = mluOpDestroyTensorDescriptor(input_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ }
+ return status;
+}
+
+// input : in input_contiguous_addr
+// output : in input_pad_addr
+static mluOpStatus_t padRFFT1dContiguousInput(mluOpHandle_t handle,
+ mluOpFFTPlan_t fft_plan) {
+ std::string api = "[mluOpExecFFT]";
+ VLOG(5) << "into padRFFT1dContiguousInput";
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+
+ mluOpDataType_t in_r_dtype = fft_plan->input_dtype;
+ int batch = fft_plan->batch;
+ int n = fft_plan->n[0];
+ bool need_pad = (fft_plan->inembed[0] != n);
+ if (need_pad) {
+ mluOpTensorDescriptor_t input_desc, padded_input_desc;
+ status = mluOpCreateTensorDescriptor(&input_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ status = mluOpCreateTensorDescriptor(&padded_input_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ const int in_dim_num = 2;
+ int64_t dims[in_dim_num] = {batch, fft_plan->inembed[0]};
+ status = mluOpSetTensorDescriptor_v2(input_desc, MLUOP_LAYOUT_ARRAY,
+ in_r_dtype, in_dim_num, dims);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ int64_t padded_dims[in_dim_num] = {batch, n};
+ status = mluOpSetTensorDescriptor_v2(padded_input_desc, MLUOP_LAYOUT_ARRAY,
+ in_r_dtype, in_dim_num, padded_dims);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ const int pad_dim_num = 4;
+ int paddings[pad_dim_num] = {0, 0, 0, n - fft_plan->inembed[0]};
+ uint64_t padding_value = 0x00000000;
+ DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle,
+ cnnl_handle); // convert to cnnl_handle
+
+ // convert to cnnl_tensor_descriptor
+ DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(input_desc, cnnl_input_desc);
+ DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(padded_input_desc,
+ cnnl_padded_input_desc);
+ CALL_CNNL(cnnlPad(cnnl_handle, cnnl_input_desc,
+ fft_plan->matmul_addrs.input_contiguous_addr, paddings,
+ &padding_value, cnnl_padded_input_desc,
+ fft_plan->matmul_addrs.input_pad_addr));
+
+ // destroy cnnl descriptor
+ DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_input_desc);
+ DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_padded_input_desc);
+
+ DESTROY_CNNL_HANDLE(cnnl_handle);
+ }
+ return status;
+}
+
+// only for CNFFT_FUNC_COOLEY_TUKEY
+// batch * L * 2^m --> batch * 2^m * L
+// input : in input_pad_addr
+// output : in input_transed_addr
+static mluOpStatus_t transposeRFFT1dPaddedInput(mluOpHandle_t handle,
+ mluOpFFTPlan_t fft_plan) {
+ std::string api = "[mluOpExecFFT]";
+ VLOG(5) << "into transposeRFFT1dPaddedInput";
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+ if (fft_plan->fft_strategy == CNFFT_FUNC_COOLEY_TUKEY) {
+ VLOG(5) << "launch mluOpTranspose";
+
+ mluOpDataType_t in_r_dtype = fft_plan->input_dtype;
+ int batch = fft_plan->batch;
+ int L = fft_plan->L;
+ int m = (1 << fft_plan->m);
+
+ const int trans_dim_num = 3;
+ int trans_input_dims[trans_dim_num] = {batch, L, m};
+ int trans_output_dims[trans_dim_num] = {batch, m, L};
+ int trans_permute[trans_dim_num] = {0, 2, 1};
+
+ status =
+ fftTranspose(handle, trans_dim_num, trans_input_dims, trans_output_dims,
+ trans_permute, fft_plan->matmul_addrs.input_pad_addr,
+ fft_plan->matmul_addrs.input_transed_addr, in_r_dtype,
+ fft_plan->matmul_addrs.internal_workspace_addr,
+ fft_plan->matmul_addrs.internal_workspace_size, api);
+ }
+ return status;
+}
+
+// input : in input_pad_addr
+// output : in input_pos_addr and input_scale_addr
+static mluOpStatus_t quantizeRFFT1dPaddedInput(mluOpHandle_t handle,
+ mluOpFFTPlan_t fft_plan) {
+ std::string api = "[mluOpExecFFT]";
+ VLOG(5) << "into quantizeRFFT1dPaddedInput";
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+
+ mluOpDataType_t in_r_dtype = fft_plan->input_dtype;
+ mluOpDataType_t in_e_dtype = fft_plan->execution_dtype;
+ int padded_num = fft_plan->batch * fft_plan->n[0];
+
+ status = fftQuantizePositionScale(
+ handle, padded_num, in_r_dtype, in_e_dtype,
+ fft_plan->matmul_addrs.input_pad_addr,
+ fft_plan->matmul_addrs.input_pos_addr,
+ fft_plan->matmul_addrs.input_scale_addr,
+ fft_plan->matmul_addrs.internal_workspace_addr,
+ fft_plan->matmul_addrs.internal_workspace_size, api);
+
+ return status;
+}
+
+// CNFFT_FUNC_MATMUL
+// input : in input_pad_addr
+// output : in output_contiguous_addr
+// CNFFT_FUNC_COOLEY_TUKEY
+// input : in input_transed_addr
+// output : input real matmul dft real result in matmul_re_mul_re_addr
+// input real matmul dft imag result in matmul_re_mul_im_addr
+static mluOpStatus_t computeRFFT1dMatmulResult(mluOpHandle_t handle,
+ mluOpFFTPlan_t fft_plan,
+ const float scale_factor) {
+ std::string api = "[mluOpExecFFT]";
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+
+ mluOpDataType_t in_r_dtype = fft_plan->input_dtype;
+ mluOpDataType_t in_e_dtype = fft_plan->execution_dtype;
+ int batch = fft_plan->batch;
+ int n = fft_plan->n[0];
+
+ if (fft_plan->fft_strategy == CNFFT_FUNC_MATMUL) {
+ VLOG(5) << "into CNFFT_FUNC_MATMUL";
+ status = fftQuantMatMul(
+ handle, batch, n, FFT_HALF(n) * COMPLEX,
+ fft_plan->matmul_addrs.input_pad_addr,
+ fft_plan->matmul_addrs.input_pos_addr,
+ fft_plan->matmul_addrs.input_scale_addr,
+ fft_plan->matmul_addrs.dft_matrix_addr,
+ fft_plan->matmul_addrs.dft_pos_addr,
+ fft_plan->matmul_addrs.dft_scale_addr,
+ fft_plan->matmul_addrs.output_contiguous_addr, false, true,
+ scale_factor, 0.0, in_e_dtype, in_e_dtype, in_r_dtype,
+ fft_plan->matmul_addrs.internal_workspace_addr,
+ fft_plan->matmul_addrs.internal_workspace_size, api);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ } else if (fft_plan->fft_strategy == CNFFT_FUNC_COOLEY_TUKEY) {
+ VLOG(5) << "into CNFFT_FUNC_COOLEY_TUKEY";
+ int L = fft_plan->L;
+ int m = (1 << fft_plan->m);
+ // input real matmul dft real
+ status = fftQuantMatMul(
+ handle, batch * m, L, L, fft_plan->matmul_addrs.input_transed_addr,
+ fft_plan->matmul_addrs.input_pos_addr,
+ fft_plan->matmul_addrs.input_scale_addr,
+ fft_plan->matmul_addrs.dft_re_matrix_addr,
+ fft_plan->matmul_addrs.dft_pos_addr,
+ fft_plan->matmul_addrs.dft_scale_addr,
+ fft_plan->matmul_addrs.matmul_re_mul_re_addr, false, true, scale_factor,
+ 0.0, in_e_dtype, in_e_dtype, in_r_dtype,
+ fft_plan->matmul_addrs.internal_workspace_addr,
+ fft_plan->matmul_addrs.internal_workspace_size, api);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ // input re matmul dft imag
+ status = fftQuantMatMul(
+ handle, batch * m, L, L, fft_plan->matmul_addrs.input_transed_addr,
+ fft_plan->matmul_addrs.input_pos_addr,
+ fft_plan->matmul_addrs.input_scale_addr,
+ fft_plan->matmul_addrs.dft_im_matrix_addr,
+ fft_plan->matmul_addrs.dft_pos_addr,
+ fft_plan->matmul_addrs.dft_scale_addr,
+ fft_plan->matmul_addrs.matmul_re_mul_im_addr, false, true, scale_factor,
+ 0.0, in_e_dtype, in_e_dtype, in_r_dtype,
+ fft_plan->matmul_addrs.internal_workspace_addr,
+ fft_plan->matmul_addrs.internal_workspace_size, api);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ } else if (fft_plan->fft_strategy == CNFFT_FUNC_STOCKHAM) {
+ VLOG(5) << "into CNFFT_FUNC_STOCKHAM";
+ int L = fft_plan->L;
+ int m = (1 << fft_plan->m);
+
+ // origin: in_trans[batch, 2^m, L] * W_real[L, L] -> IN_real[batch, 2^m, L]
+ // in_trans[batch, 2^m, L] * W_imag[L, L] -> IN_imag[batch, 2^m, L]
+ // update: W[c*L, L] * in[batch, L, 2^m] -> out[batch, c*L, 2^m]
+ status = fftBatchMatMulBcast(
+ handle,
+ L <= fft_plan->L_sub ? (2 * L)
+ : (2 * (PAD_UP(L / 2, fft_plan->L_sub) + 1)),
+ L, m, batch, fft_plan->matmul_addrs.dft_re_matrix_addr,
+ fft_plan->matmul_addrs.dft_pos_addr,
+ fft_plan->matmul_addrs.dft_scale_addr,
+ fft_plan->matmul_addrs.input_pad_addr,
+ fft_plan->matmul_addrs.input_pos_addr,
+ fft_plan->matmul_addrs.input_scale_addr,
+ fft_plan->matmul_addrs.matmul_re_mul_re_addr, false, false,
+ scale_factor, 0.0, in_e_dtype, in_e_dtype, in_r_dtype,
+ fft_plan->matmul_addrs.internal_workspace_addr,
+ fft_plan->matmul_addrs.internal_workspace_size, api);
+ }
+
+ return status;
+}
+
+static mluOpStatus_t policyFunc(mluOpHandle_t handle, cnrtDim3_t *k_dim,
+ cnrtFunctionType_t *k_type) {
+ *k_type = CNRT_FUNC_TYPE_UNION1;
+ k_dim->x = handle->core_num_per_cluster;
+ k_dim->y = mluop::runtime::getClusterLimitCapability(handle);
+ k_dim->z = 1;
+ return MLUOP_STATUS_SUCCESS;
+}
+
+// only for CNFFT_FUNC_COOLEY_TUKEY and CNFFT_FUNC_STOCKHAM
+// input : input real matmul dft real result in matmul_re_mul_re_addr
+// input real matmul dft imag result in matmul_re_mul_im_addr
+// output : output complex result in output_contiguous_addr
+static mluOpStatus_t mergeRFFT1dOutput(mluOpHandle_t handle,
+ mluOpFFTPlan_t fft_plan,
+ const float scale_factor) {
+ std::string api = "[mluOpExecFFT]";
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+ if (fft_plan->fft_strategy == CNFFT_FUNC_COOLEY_TUKEY) {
+ VLOG(5) << "launch merge rfft1d output";
+ int core_num = handle->core_num_per_cluster;
+ cnrtFunctionType_t k_type = CNRT_FUNC_TYPE_UNION1;
+ int task_type = mluop::runtime::getJobLimitCapability(handle);
+ int task_num = 1;
+
+ switch (task_type) {
+ default:
+ task_num = core_num;
+ break;
+ case (int)CNRT_FUNC_TYPE_UNION2:
+ task_num = core_num * 2;
+ break;
+ case (int)CNRT_FUNC_TYPE_UNION4:
+ task_num = core_num * 4;
+ break;
+ case (int)CNRT_FUNC_TYPE_UNION8:
+ task_num = core_num * 8;
+ break;
+ case (int)CNRT_FUNC_TYPE_UNION16:
+ task_num = core_num * 16;
+ break;
+ }
+ unsigned int dimx = task_num;
+ cnrtDim3_t k_dim = {dimx, 1, 1};
+ k_type = (cnrtFunctionType_t)dimx;
+ kernelFFTCooleyTukey(k_dim, k_type, handle->queue, fft_plan, -1, RFFT);
+ // direction, -1 means invalid(only FFT_IFFT use)
+ } else if (fft_plan->fft_strategy == CNFFT_FUNC_STOCKHAM) {
+ VLOG(5) << "launch mrege four-step rfft1d output";
+ cnrtDim3_t k_dim;
+ cnrtFunctionType_t k_type;
+ policyFunc(handle, &k_dim, &k_type);
+ kernelFFTStockham(k_dim, k_type, handle->queue, fft_plan, -1, scale_factor,
+ RFFT);
+ // direction, -1 means invalid(only FFT_IFFT use).
+ }
+ return status;
+}
+
+// input : in output_contiguous_addr
+// output : in output
+static mluOpStatus_t makeRFFT1dContiguousOutput(mluOpHandle_t handle,
+ mluOpFFTPlan_t fft_plan,
+ void *output) {
+ std::string api = "[mluOpExecFFT]";
+ VLOG(5) << "into makeRFFT1dContiguousOutput";
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+ if (!fft_plan->is_output_contiguous) {
+ VLOG(5) << "launch copy with stride";
+ mluOpDataType_t out_c_dtype = fft_plan->output_dtype;
+ // create tensor desc
+ mluOpTensorDescriptor_t copy_src_desc, copy_dst_desc;
+ status = mluOpCreateTensorDescriptor(©_src_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ status = mluOpCreateTensorDescriptor(©_dst_desc);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ // set up tensor desc
+ const int out_dim_num = 2;
+ int64_t dims[out_dim_num] = {fft_plan->batch, fft_plan->onembed[0]};
+ int64_t strides[out_dim_num] = {fft_plan->odist, fft_plan->ostride};
+ status = mluOpSetTensorDescriptor_v2(copy_src_desc, MLUOP_LAYOUT_ARRAY,
+ out_c_dtype, out_dim_num, dims);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ status =
+ mluOpSetTensorDescriptorEx_v2(copy_dst_desc, MLUOP_LAYOUT_ARRAY,
+ out_c_dtype, out_dim_num, dims, strides);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ // copy
+ void *copy_src_addr = fft_plan->matmul_addrs.output_contiguous_addr;
+
+ DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle,
+ cnnl_handle); // convert to cnnl_handle
+ // convert to cnnl_tensor_descriptor
+ DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(copy_src_desc,
+ cnnl_copy_src_desc);
+ DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(copy_dst_desc,
+ cnnl_copy_dst_desc);
+
+ CALL_CNNL(cnnlCopy(cnnl_handle, cnnl_copy_src_desc, copy_src_addr,
+ cnnl_copy_dst_desc, output));
+
+ DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_copy_src_desc);
+ DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_copy_dst_desc);
+ DESTROY_CNNL_HANDLE(cnnl_handle);
+ }
+ return status;
+}
+
+mluOpStatus_t execRFFT1d(mluOpHandle_t handle, const mluOpFFTPlan_t fft_plan,
+ const void *input, const float scale_factor,
+ void *workspace, void *output) {
+ mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
+ std::string api = "[mluOpExecFFT]";
+ configureRFFT1dMatmulWorkspaceAddrs(handle, fft_plan, (void *)input,
+ workspace, output);
+
+ status = makeRFFT1dContiguousInput(handle, fft_plan, input);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ status = padRFFT1dContiguousInput(handle, fft_plan);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ status = transposeRFFT1dPaddedInput(handle, fft_plan);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ status = quantizeRFFT1dPaddedInput(handle, fft_plan);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ status = computeRFFT1dMatmulResult(handle, fft_plan, scale_factor);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ status = mergeRFFT1dOutput(handle, fft_plan, scale_factor);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+
+ status = makeRFFT1dContiguousOutput(handle, fft_plan, output);
+ INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
+ return status;
+}
diff --git a/kernels/tensor_stride_process/tensor_stride_process_host.cpp b/kernels/tensor_stride_process/tensor_stride_process_host.cpp
index 222ac5829..0099a111a 100644
--- a/kernels/tensor_stride_process/tensor_stride_process_host.cpp
+++ b/kernels/tensor_stride_process/tensor_stride_process_host.cpp
@@ -474,13 +474,20 @@ static vector getDefaultStride(int64_t *dims, int dim) {
mluOpStatus_t MLUOP_WIN_API
mluOpContiguous(mluOpHandle_t handle, const mluOpTensorDescriptor_t input_desc,
const void *input, void *output) {
+ auto default_stride = getDefaultStride(input_desc->dims, input_desc->dim);
+ mluOpTensorDescriptor_t temp_desc = nullptr;
+ mluOpCreateTensorDescriptor(&temp_desc);
+ mluOpSetTensorDescriptorEx_v2(temp_desc, input_desc->layout,
+ input_desc->dtype, input_desc->dim,
+ input_desc->dims, default_stride.data());
DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle, cnnl_handle);
DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(input_desc, cnnl_input_desc);
- DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(input_desc, cnnl_output_desc);
+ DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(temp_desc, cnnl_temp_desc);
CALL_CNNL(
- cnnlCopy(cnnl_handle, cnnl_input_desc, input, cnnl_output_desc, output));
+ cnnlCopy(cnnl_handle, cnnl_input_desc, input, cnnl_temp_desc, output));
DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_input_desc);
- DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_output_desc);
+ DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_temp_desc);
DESTROY_CNNL_HANDLE(cnnl_handle);
+ mluOpDestroyTensorDescriptor(temp_desc);
return MLUOP_STATUS_SUCCESS;
-}
\ No newline at end of file
+}
diff --git a/kernels/utils/common.h b/kernels/utils/common.h
index 3532b0c2b..c6fbce664 100644
--- a/kernels/utils/common.h
+++ b/kernels/utils/common.h
@@ -685,4 +685,31 @@ __mlu_func__ void __mluop_get_stage_indices_tfuse(int *dst_nram, int length) {
#endif
}
+/***************************************************************************
+ * MLUOPS FUNC: __mluop_get_indices.
+ * param "dst" is needed for holding the final result.
+ * param "start_index" is the smallest integer to be generated.
+ * param "len" is the total number of integers to be generated.
+ * Note:
+ * Get [start_index, len-1] stage indices in nram on mlu590 mlu300
+ * and other platform which support necessary instruction.
+ * len not need to be aligned any number.
+ * dst only support nram.
+ * This funciton currently only supports float type indices.
+ * *************************************************************************/
+__mlu_vector__ void __mluop_get_indices(float *dst, float start_index,
+ uint32_t len) {
+ vv_int16 r_out, r_dim;
+ unsigned BlockDim = __vv_get_length() / sizeof(int16_t);
+ __asm__ volatile("index.vvr.f32 %[dst], %[base], 1;\n\t"
+ : [ dst ] "+r"(r_out)
+ : [ base ] "r"(start_index));
+ __vv_move(r_dim, BlockDim);
+ int repeat = DIV_UP(len, BlockDim);
+ for (int iter = 0; iter < repeat; iter++) {
+ __vv_store(dst + iter * BlockDim, r_out);
+ __vv_add(r_out, r_out, r_dim);
+ }
+}
+
#endif // KERNELS_UTILS_COMMON_H_
diff --git a/mlu_op.h b/mlu_op.h
index ff26f52e9..f4d90acad 100644
--- a/mlu_op.h
+++ b/mlu_op.h
@@ -13253,6 +13253,315 @@ mluOpDCNBackwardData(mluOpHandle_t handle,
const mluOpTensorDescriptor_t grad_mask_desc,
void *grad_mask);
+/*!
+ * @brief The descriptor of FFT (Fast Fourier Transform) operation that holds FFT information including
+
+ * the tensor descriptor of input tensor and output tensor, the rank of FFT, the FFT size on each
+ * dimension, the size of reserved space and the size of workspace.
+ *
+ * You need to call the ::mluOpCreateFFTPlan function to create a descriptor for the FFT operation, and call
+ * the ::mluOpMakeFFTPlanMany function to set the information of the FFT operation to the descriptor.
+ * Then, you need to allocate the reserved space and set the space to the FFT descriptor by ::mluOpSetFFTReserveArea.
+ * AT the end you need to destroy the Cambricon MLUOP context with the ::mluOpDestroyFFTPlan.
+ */
+typedef struct mluOpFFTStruct *mluOpFFTPlan_t;
+
+// Group: FFT
+/*!
+ * @brief Creates a descriptor pointed by \p fft_plan for the FFT operation, and allocates memory
+ * for holding the information about the FFT operation. The information is defined in ::mluOpFFTPlan_t.
+ *
+ * @param[out] fft_plan
+ * Pointer to the FFT descriptor that holds information about the FFT operation.
+ *
+ * @par Return
+ * - ::MLUOP_STATUS_SUCCESS, ::MLUOP_STATUS_ALLOC_FAILED
+ *
+ * @par Data Type
+ * - None.
+ *
+ * @par Data Layout
+ * - None.
+ *
+ * @par Scale Limitation
+ * - None.
+ *
+ * @par API Dependency
+ * - After calling this function, you can call the ::mluOpMakeFFTPlanMany function to initialize and set the
+ * information to the created descriptor.
+ * - You need to call the ::mluOpDestroyFFTPlan to destroy the descriptor.
+ * Otherwise, the memory leak may occur.
+ *
+ * @par Note
+ * - This function only supports 1D FFT currently. 2D FFT and 3D FFT
+ * will be supported in the future.
+ * - When the data type of input is float or complex_float, the 1D FFT length should be equal to:
+ * length = \f$base * 2^ {m}\f$, and the base should be less than or equal to 4096.
+ * - When the data type of input is half or complex_half, the 1D FFT length should be equal to:
+ * length = \f$2^{m}\f$.
+ *
+ * @par Example.
+ * - None.
+ *
+ * @par Reference.
+ * - None.
+ */
+mluOpStatus_t MLUOP_WIN_API
+mluOpCreateFFTPlan(mluOpFFTPlan_t *fft_plan);
+
+// Group: FFT
+/*!
+ * @brief Initializes the FFT descriptor pointed by \p fft_plan that is previously created
+ * with the ::mluOpCreateFFTPlan function, and sets the information about the
+ * tensor descriptors of input tensor and output tensor, the rank of FFT, and the FFT size on each
+ * dimension.
+ * This function also gets the size of MLU memory buffers for FFT execution, including \p reservespace_size and
+ * \p workspace_size. The size of extra workspace is based on the given information of the
+ * \p fft_plan.
+ *
+ * @param[in] handle
+ * Handle to a Cambricon MLUOP context that is used to manage MLU devices and queues
+ * in the FFT operation. For detailed information, see ::mluOpHandle_t.
+ * @param[in,out] fft_plan
+ * The descriptor of FFT. For detailed information, see ::mluOpFFTPlan_t.
+ * @param[in] input_desc
+ * The descriptor of input signals. For detailed information,
+ * see ::mluOpTensorDescriptor_t.
+ * @param[in] output_desc
+ * The descriptor of output signals. For detailed information,
+ * see ::mluOpTensorDescriptor_t.
+ * @param[in] rank
+ * The dimensionality of the FFT operation. It can be 1D, 2D or 3D.
+ * @param[in] n
+ * An array of size \p rank describing the FFT size of each dimension. n[0]
+ * is the size of the outermost dimension and n[rank - 1] is the innermost dimension
+ * of FFT operation. If n[i] is greater than the size of input on dimension i, the input
+ * signal will be zero-padded on that dimension. Otherwise, input signal is trimmed
+ * on the dimension i.
+ * @param[out] reservespace_size
+ * The size of the extra reserved space in bytes that needs to be used in FFT operation.
+ * @param[out] workspace_size
+ * The size of the extra workspace in bytes that needs to be used in FFT operation.
+ *
+ * @par Return
+ * - ::MLUOP_STATUS_SUCCESS, ::MLUOP_STATUS_BAD_PARAM, ::MLUOP_STATUS_NOT_SUPPORTED, ::MLUOP_STATUS_NOT_INITIALIZED
+ *
+ * @par Data Type
+ * - The supported data types of \p input and \p output tensors are as follows:
+ * - real-to-complex FFT:
+ * - half(input offchip)-complex_half(output offchip)-int16(input onchip)
+ * - half(input offchip)-complex_half(output offchip)-half(input onchip)
+ * - float(input offchip)-complex_float(output offchip)-float(input onchip)
+ * - complex-to-real FFT:
+ * - complex_half(input offchip)-half(output offchip)-int16(input onchip)
+ * - complex_half(input offchip)-half(output offchip)-half(input onchip)
+ * - complex_float(input offchip)-float(output offchip)-float(input onchip)
+ * - complex-to-complex FFT:
+ * - complex_half(input offchip)-complex_half(output offchip)-int16(input onchip)
+ * - complex_half(input offchip)-complex_half(output offchip)-half(input onchip)
+ * - complex_float(input offchip)-complex_float(output offchip)-float(input onchip)
+ *
+ * @par Data Layout
+ * - None.
+ *
+ * @par Scale Limitation
+ * - None.
+ *
+ * @par API Dependency
+ * - Before calling this function, you need to call the ::mluOpCreateFFTPlan function to
+ * create an FFT descriptor, call the ::mluOpSetTensorDescriptor or
+ * ::mluOpSetTensorDescriptorEx function to set the input and output tensor descriptor,
+ * and then call the ::mluOpSetTensorDescriptorOnchipDataType to set the onchip data type
+ * of input tensor descriptor.
+ *
+ * @par Note
+ * - The advanced data layout parameters including (i/o)nembed, (i/o)istride and (i/o)idist, are set through
+ * ::mluOpSetTensorDescriptorEx. If stride information is not needed, you can set the simple data layout
+ * through ::mluOpSetTensorDescriptor.
+ * - The dimension size of input or output should be equal to \p rank or \p rank + 1. In the former case,
+ * the batch size is considered as 1. Otherwise, the outermost dimension is the batch size.
+ * - For real-to-complex FFTs, the innermost dimension of FFT length and output arrays are not the same.
+ * For a x-length 1D real-to-complex FFT, the output is x/2 + 1 complex numbers (the non-redundant outputs).
+ * For a N-D real-to-complex FFT with n=[z, y, x], the output shape will be [z, y, x/2+1].
+ * - For complex-to-real FFTs, the input tensor only holds the non-redundant part of the Fourier coefficients.
+ * And the output tensor stores the real output values.
+ * - When n[0] is greater than 4096, the data type of input only supports float or complex_float.
+ *
+ * @par Example.
+ * - None.
+ *
+ * @par Reference.
+ * - None.
+ */
+mluOpStatus_t MLUOP_WIN_API
+mluOpMakeFFTPlanMany(mluOpHandle_t handle,
+ mluOpFFTPlan_t fft_plan,
+ const mluOpTensorDescriptor_t input_desc,
+ const mluOpTensorDescriptor_t output_desc,
+ const int rank,
+ const int n[],
+ size_t *reservespace_size,
+ size_t *workspace_size);
+
+// Group:FFT
+/*!
+ * @brief Bonds the \p reservespace to the \p fft_plan. The size of reserved space can be derived
+ * through ::mluOpMakeFFTPlanMany.
+ *
+ * @param[in] handle
+ * Handle to a Cambricon MLUOP context that is used to manage MLU devices and queues in the
+ * ::mluOpExecFFT. For detailed information, see ::mluOpHandle_t.
+ * @param[in, out] fft_plan
+ * The descriptor of FFT. For detailed information, see ::mluOpFFTPlan_t.
+ * @param[in] reservespace
+ * Pointer to the MLU memory that is used as an extra memory space for saving
+ * intermediate results of FFT operation.
+ *
+ * @par Return
+ * - ::MLUOP_STATUS_SUCCESS, ::MLUOP_STATUS_BAD_PARAM, ::MLUOP_STATUS_INTERNAL_ERROR
+ *
+ * @par Data Type
+ * - None.
+ *
+ * @par Data Layout
+ * - None.
+ *
+ * @par Scale Limitation
+ * - None.
+ *
+ * @par API Dependency
+ * - Before calling this function, you need to call the ::mluOpCreateFFTPlan function
+ * to create an FFT descriptor, call the ::mluOpMakeFFTPlanMany function to set the
+ * FFT descriptor and get the size of reserved space, and then call the
+ * cnrtMalloc function to create MLU memory according to the rservespace_size given.
+ *
+ * @par Note
+ * - None.
+ *
+ * @par Example.
+ * - None.
+ *
+ * @par Reference.
+ * - None.
+ */
+mluOpStatus_t MLUOP_WIN_API
+mluOpSetFFTReserveArea(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan, void *reservespace);
+
+// Group:FFT
+/*!
+ * @brief Executes any FFT. In case of complex-to-real and real-to-complex
+ * transforms, \p direction parameter is ignored. This function stores the Fourier coefficients
+ * in the output array. If the address of input and output are the same, an in-place FFT
+ * is adopted.
+ *
+ * @param[in] handle
+ * Handle to a Cambricon MLUOP context that is used to manage MLU devices and queues
+ * in the FFT execution. For detailed information, see ::mluOpHandle_t.
+ * @param[in] fft_plan
+ * The plan for FFT execution. For detailed information, see ::mluOpFFTPlan_t.
+ * @param[in] input
+ * Pointer to the MLU memory that stores the input tensor.
+ * @param[in] scale_factor
+ * Input. A float-point scalar used to multiply the FFT output.
+ * @param[in, out] workspace
+ * Pointer to the MLU memory that is used as an extra workspace for the
+ * ::mluOpExecFFT.
+ * @param[out] output
+ * Pointer to the MLU memory that stores the output tensor.
+ * @param[in] direction
+ * The transform direction: 0 means FFT forward and 1 means FFT inverse.
+ * Direction is ignored for real-to-complex and complex-to-real transforms.
+ *
+ * @par Note
+ * - For in-place 1D real-to-complex FFTs, the input is a batch of n real numbers, and the
+ * output is n/2 + 1 non-redundant complex numbers. This requires a padding of input array.
+ * - For in-place N-D real-to-complex FFTs, extra padding of the real-data array on the innermost
+ * dimension is necessary to accommodate the size of the complex-data output.
+ * - When \p input contains NaN or infinity and the input onchip data type of FFT is not quantized
+ * data type, the output is computed through the FFT formula with computation rules of NaN or
+ * infinity based on IEEE 754.
+ * - When \p input contains NaN or infinity and the input onchip data type of FFT is quantized
+ * data type such as int16, the output will be unpredictable.
+ * - \p Input is recommended to be in range of [-10, 10] with uniform
+ * distribution for higher precision.
+ * - \p Scale_factor is recommended to be in range of [-1, 1] to avoid exceeding
+ * the data representation range.
+ * - Half data type of \p input is not recommended due to low precision. The first element of the
+ * FFT result is the sum of all input elements, and it is likely to overflow.
+ * - This operation is not supported on the 1V platforms.
+ *
+ * @par Return
+ * - ::MLUOP_STATUS_SUCCESS, ::MLUOP_STATUS_BAD_PARAM, ::MLUOP_STATUS_INTERNAL_ERROR
+ *
+ * @par Data Type
+ * - None.
+ *
+ * @par Data Layout
+ * - None.
+ *
+ * @par Scale Limitation
+ * - None.
+ *
+ * @par API Dependency
+ * - Before calling this function, you need to call the ::mluOpCreateFFTPlan
+ * function to create an FFT descriptor, call the ::mluOpMakeFFTPlanMany
+ * function to set the FFT descriptor and the size of reserved space and work space,
+ * and then call the ::mluOpSetFFTReserveArea to bond the reservespace area to the descriptor.
+ *
+ * @par Note
+ * - None.
+ *
+ * @par Example.
+ * - None.
+ *
+ * @par Reference.
+ * - None.
+ */
+mluOpStatus_t MLUOP_WIN_API
+mluOpExecFFT(mluOpHandle_t handle,
+ const mluOpFFTPlan_t fft_plan,
+ const void *input,
+ const float scale_factor,
+ void *workspace,
+ void *output,
+ int direction);
+
+// Group:FFT
+/*!
+ * @brief Destroys an FFT plan \p fft_plan that is created with the
+ * ::mluOpCreateFFTPlan function. The FFT plan is defined in ::mluOpFFTPlan_t and
+ * holds the information about the FFT operation.
+ *
+ * @param[in] fft_plan
+ * The fft plan to be destroyed.
+ *
+ * @par Return
+ * - ::MLUOP_STATUS_SUCCESS, ::MLUOP_STATUS_EXECUTION_FAILED
+ *
+ * @par Data Type
+ * - None.
+ *
+ * @par Data Layout
+ * - None.
+ *
+ * @par Scale Limitation
+ * - None.
+ *
+ * @par API Dependency
+ * - None.
+ *
+ * @par Note
+ * - You need to call this function after calling the ::mluOpExecFFT.
+ * Otherwise, memory leak may occur.
+ *
+ * @par Example.
+ * - None.
+ *
+ * @par Reference.
+ * - None.
+ */
+mluOpStatus_t MLUOP_WIN_API
+mluOpDestroyFFTPlan(mluOpFFTPlan_t fft_plan);
#if defined(__cplusplus)
}
#endif
diff --git a/test/mlu_op_gtest/api_gtest/src/gtest/fft/fft_ExecFFT.cpp b/test/mlu_op_gtest/api_gtest/src/gtest/fft/fft_ExecFFT.cpp
new file mode 100644
index 000000000..3833221d5
--- /dev/null
+++ b/test/mlu_op_gtest/api_gtest/src/gtest/fft/fft_ExecFFT.cpp
@@ -0,0 +1,219 @@
+/*************************************************************************
+ * Copyright (C) [2024] by Cambricon, Inc.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a
+ * copy of this software and associated documentation files (the
+ * "Software"), to deal in the Software without restriction, including
+ * without limitation the rights to use, copy, modify, merge, publish,
+ * distribute, sublicense, and/or sell copies of the Software, and to
+ * permit persons to whom the Software is furnished to do so, subject to
+ * the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included
+ * in all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+ * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+ * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+ * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
+ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
+ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+ *************************************************************************/
+#include
+#include
+#include
+#include
+
+#include "gtest/gtest.h"
+#include "mlu_op.h"
+#include "api_test_tools.h"
+#include "core/context.h"
+#include "core/logging.h"
+#include "core/tensor.h"
+
+namespace mluopapitest {
+class fft_ExecFFT : public testing::Test {
+ public:
+ void setParam(bool handle, bool fft_plan, bool input, bool workspace,
+ bool output) {
+ if (handle) {
+ MLUOP_CHECK(mluOpCreate(&handle_));
+ }
+
+ if (fft_plan) {
+ MLUOP_CHECK(mluOpCreateFFTPlan(&fft_plan_));
+ }
+
+ if (input) {
+ size_t i_bytes = mluOpDataTypeBytes(MLUOP_DTYPE_FLOAT);
+ GTEST_CHECK(CNRT_RET_SUCCESS == cnrtMalloc(&input_, i_bytes));
+ }
+
+ if (workspace) {
+ GTEST_CHECK(CNRT_RET_SUCCESS == cnrtMalloc(&workspace_, workspace_size_));
+ }
+
+ if (output) {
+ size_t o_bytes = mluOpDataTypeBytes(MLUOP_DTYPE_COMPLEX_FLOAT);
+ GTEST_CHECK(CNRT_RET_SUCCESS == cnrtMalloc(&output_, o_bytes));
+ }
+ }
+
+ mluOpStatus_t compute() {
+ mluOpDataType_t input_data_type = MLUOP_DTYPE_FLOAT;
+ mluOpDataType_t output_data_type = MLUOP_DTYPE_COMPLEX_FLOAT;
+ mluOpDataType_t execution_dtype = MLUOP_DTYPE_FLOAT;
+ const int rank = 1;
+ const int batch = 2000;
+ const int n[rank] = {400};
+ const int ndim = rank + 1;
+ const int input_dim_size[ndim] = {batch, n[0] / 2 + 1};
+ const int input_dim_stride[ndim] = {n[0] / 2 + 1, 1};
+
+ const int output_dim_size[ndim] = {batch, n[0] / 2 + 1};
+ const int output_dim_stride[ndim] = {n[0] / 2 + 1, 1};
+
+ mluOpCreateTensorDescriptor(&input_desc_);
+ mluOpCreateTensorDescriptor(&output_desc_);
+ mluOpSetTensorDescriptorEx(input_desc_, MLUOP_LAYOUT_ARRAY, input_data_type,
+ ndim, input_dim_size, input_dim_stride);
+ mluOpSetTensorDescriptorOnchipDataType(input_desc_, execution_dtype);
+ mluOpSetTensorDescriptorEx(output_desc_, MLUOP_LAYOUT_ARRAY,
+ output_data_type, ndim, output_dim_size,
+ output_dim_stride);
+ size_t reservespaceSizeInBytes_ = 64;
+ size_t workspaceSizeInBytes_ = 64;
+ size_t *reservespace_size = &reservespaceSizeInBytes_;
+ size_t *workspace_size = &workspaceSizeInBytes_;
+
+ mluOpStatus_t status;
+ if (handle_ != nullptr && fft_plan_ != nullptr) {
+ status =
+ mluOpMakeFFTPlanMany(handle_, fft_plan_, input_desc_, output_desc_,
+ rank, n, reservespace_size, workspace_size);
+
+ if (status != MLUOP_STATUS_SUCCESS) {
+ destroy();
+ return status;
+ }
+ }
+
+ status = mluOpExecFFT(handle_, fft_plan_, input_, scale_factor_, workspace_,
+ output_, direction_);
+ destroy();
+
+ return status;
+ }
+
+ protected:
+ virtual void SetUp() {
+ handle_ = nullptr;
+ fft_plan_ = nullptr;
+ input_ = nullptr;
+ workspace_ = nullptr;
+ output_ = nullptr;
+ }
+
+ void destroy() {
+ try {
+ if (handle_) {
+ CNRT_CHECK(cnrtQueueSync(handle_->queue));
+ VLOG(4) << "Destroy handle_";
+ MLUOP_CHECK(mluOpDestroy(handle_));
+ handle_ = nullptr;
+ }
+ if (input_desc_) {
+ VLOG(4) << "Destroy input_desc_";
+ MLUOP_CHECK(mluOpDestroyTensorDescriptor(input_desc_));
+ input_desc_ = nullptr;
+ }
+ if (output_desc_) {
+ VLOG(4) << "Destroy output_desc_";
+ MLUOP_CHECK(mluOpDestroyTensorDescriptor(output_desc_));
+ output_desc_ = nullptr;
+ }
+ if (input_) {
+ VLOG(4) << "Destroy input_";
+ GTEST_CHECK(CNRT_RET_SUCCESS == cnrtFree(input_));
+ input_ = nullptr;
+ }
+ if (output_) {
+ VLOG(4) << "Destroy output_";
+ GTEST_CHECK(CNRT_RET_SUCCESS == cnrtFree(output_));
+ output_ = nullptr;
+ }
+ if (fft_plan_) {
+ VLOG(4) << "Destroy fft_plan_";
+ MLUOP_CHECK(mluOpDestroyFFTPlan(fft_plan_));
+ fft_plan_ = nullptr;
+ }
+ if (workspace_) {
+ VLOG(4) << "Destroy workspace_";
+ GTEST_CHECK(CNRT_RET_SUCCESS == cnrtFree(workspace_));
+ workspace_ = nullptr;
+ }
+ } catch (const std::exception &e) {
+ FAIL() << "MLUOPAPIGTEST: catched " << e.what() << " in fft_ExecFFT";
+ }
+ }
+
+ private:
+ mluOpHandle_t handle_ = nullptr;
+ mluOpFFTPlan_t fft_plan_ = nullptr;
+ mluOpTensorDescriptor_t input_desc_ = nullptr;
+ mluOpTensorDescriptor_t output_desc_ = nullptr;
+ void *input_ = nullptr;
+ float scale_factor_ = 0.1;
+ void *workspace_ = nullptr;
+ void *output_ = nullptr;
+ int direction_ = 0;
+ size_t workspace_size_ = 64;
+};
+
+TEST_F(fft_ExecFFT, BAD_PARAM_handle_null) {
+ try {
+ setParam(false, true, true, true, true);
+ EXPECT_TRUE(MLUOP_STATUS_BAD_PARAM == compute());
+ } catch (const std::exception &e) {
+ FAIL() << "MLUOPAPIGTEST: catched " << e.what() << " in fft_ExecFFT";
+ }
+}
+
+TEST_F(fft_ExecFFT, BAD_PARAM_fft_plan_null) {
+ try {
+ setParam(true, false, true, true, true);
+ EXPECT_TRUE(MLUOP_STATUS_BAD_PARAM == compute());
+ } catch (const std::exception &e) {
+ FAIL() << "MLUOPAPIGTEST: catched " << e.what() << " in fft_ExecFFT";
+ }
+}
+
+TEST_F(fft_ExecFFT, BAD_PARAM_input_null) {
+ try {
+ setParam(true, true, false, true, true);
+ EXPECT_TRUE(MLUOP_STATUS_BAD_PARAM == compute());
+ } catch (const std::exception &e) {
+ FAIL() << "MLUOPAPIGTEST: catched " << e.what() << " in fft_ExecFFT";
+ }
+}
+
+TEST_F(fft_ExecFFT, BAD_PARAM_workspace_null) {
+ try {
+ setParam(true, true, true, false, true);
+ EXPECT_TRUE(MLUOP_STATUS_BAD_PARAM == compute());
+ } catch (const std::exception &e) {
+ FAIL() << "MLUOPAPIGTEST: catched " << e.what() << " in fft_ExecFFT";
+ }
+}
+
+TEST_F(fft_ExecFFT, BAD_PARAM_output_null) {
+ try {
+ setParam(true, true, true, true, false);
+ EXPECT_TRUE(MLUOP_STATUS_BAD_PARAM == compute());
+ } catch (const std::exception &e) {
+ FAIL() << "MLUOPAPIGTEST: catched " << e.what() << " in fft_ExecFFT";
+ }
+}
+
+} // namespace mluopapitest
diff --git a/test/mlu_op_gtest/api_gtest/src/gtest/fft/fft_MakeFFTPlanMany.cpp b/test/mlu_op_gtest/api_gtest/src/gtest/fft/fft_MakeFFTPlanMany.cpp
new file mode 100644
index 000000000..479097664
--- /dev/null
+++ b/test/mlu_op_gtest/api_gtest/src/gtest/fft/fft_MakeFFTPlanMany.cpp
@@ -0,0 +1,196 @@
+/*************************************************************************
+ * Copyright (C) [2024] by Cambricon, Inc.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a
+ * copy of this software and associated documentation files (the
+ * "Software"), to deal in the Software without restriction, including
+ * without limitation the rights to use, copy, modify, merge, publish,
+ * distribute, sublicense, and/or sell copies of the Software, and to
+ * permit persons to whom the Software is furnished to do so, subject to
+ * the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included
+ * in all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+ * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+ * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+ * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
+ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
+ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+ *************************************************************************/
+#include
+#include
+#include
+#include
+
+#include "gtest/gtest.h"
+#include "mlu_op.h"
+#include "api_test_tools.h"
+#include "core/context.h"
+#include "core/logging.h"
+
+namespace mluopapitest {
+class fft_MakeFFTPlanMany : public testing::Test {
+ public:
+ void setParam(bool handle, bool fft_plan, bool input_desc, bool output_desc,
+ bool reservespace_size, bool workspace_size) {
+ if (handle) {
+ MLUOP_CHECK(mluOpCreate(&handle_));
+ }
+
+ if (fft_plan) {
+ MLUOP_CHECK(mluOpCreateFFTPlan(&fft_plan_));
+ }
+
+ if (input_desc) {
+ MLUOP_CHECK(mluOpCreateTensorDescriptor(&input_desc_));
+ std::vector input_dims{1, 400};
+ const int input_dim_stride[2] = {400, 1};
+ MLUOP_CHECK(mluOpSetTensorDescriptorEx(
+ input_desc_, MLUOP_LAYOUT_ARRAY, MLUOP_DTYPE_FLOAT, input_dims.size(),
+ input_dims.data(), input_dim_stride));
+ MLUOP_CHECK(mluOpSetTensorDescriptorOnchipDataType(input_desc_,
+ MLUOP_DTYPE_FLOAT));
+ }
+
+ if (output_desc) {
+ MLUOP_CHECK(mluOpCreateTensorDescriptor(&output_desc_));
+ std::vector output_dims{1, 201};
+ const int output_dim_stride[2] = {201, 1};
+ MLUOP_CHECK(mluOpSetTensorDescriptorEx(
+ output_desc_, MLUOP_LAYOUT_ARRAY, MLUOP_DTYPE_COMPLEX_FLOAT,
+ output_dims.size(), output_dims.data(), output_dim_stride));
+ }
+
+ if (reservespace_size) {
+ reservespace_size_ = &reservespaceSizeInBytes_;
+ }
+
+ if (workspace_size) {
+ workspace_size_ = &workspaceSizeInBytes_;
+ }
+ }
+
+ mluOpStatus_t compute() {
+ mluOpStatus_t status =
+ mluOpMakeFFTPlanMany(handle_, fft_plan_, input_desc_, output_desc_,
+ rank, n, reservespace_size_, workspace_size_);
+ destroy();
+ return status;
+ }
+
+ protected:
+ virtual void SetUp() {
+ handle_ = nullptr;
+ fft_plan_ = nullptr;
+ input_desc_ = nullptr;
+ output_desc_ = nullptr;
+ reservespace_size_ = nullptr;
+ workspace_size_ = nullptr;
+ }
+
+ void destroy() {
+ try {
+ if (handle_) {
+ CNRT_CHECK(cnrtQueueSync(handle_->queue));
+ VLOG(4) << "Destroy handle_";
+ MLUOP_CHECK(mluOpDestroy(handle_));
+ handle_ = nullptr;
+ }
+ if (fft_plan_) {
+ VLOG(4) << "Destroy fft_plan_";
+ MLUOP_CHECK(mluOpDestroyFFTPlan(fft_plan_));
+ fft_plan_ = nullptr;
+ }
+ if (input_desc_) {
+ VLOG(4) << "Destroy input_desc_";
+ MLUOP_CHECK(mluOpDestroyTensorDescriptor(input_desc_));
+ input_desc_ = nullptr;
+ }
+ if (output_desc_) {
+ VLOG(4) << "Destroy output_desc_";
+ MLUOP_CHECK(mluOpDestroyTensorDescriptor(output_desc_));
+ output_desc_ = nullptr;
+ }
+ } catch (const std::exception &e) {
+ FAIL() << "MLUOPAPIGTEST: catched " << e.what()
+ << " in fft_MakeFFTPlanMany";
+ }
+ }
+
+ private:
+ mluOpHandle_t handle_ = nullptr;
+ mluOpFFTPlan_t fft_plan_ = nullptr;
+ mluOpTensorDescriptor_t input_desc_ = nullptr;
+ mluOpTensorDescriptor_t output_desc_ = nullptr;
+ int rank = 1;
+ int n[1] = {400};
+ size_t *reservespace_size_ = nullptr;
+ size_t *workspace_size_ = nullptr;
+ size_t reservespaceSizeInBytes_ = 64;
+ size_t workspaceSizeInBytes_ = 64;
+};
+
+TEST_F(fft_MakeFFTPlanMany, BAD_PARAM_handle_null) {
+ try {
+ setParam(false, true, true, true, true, true);
+ EXPECT_EQ(MLUOP_STATUS_BAD_PARAM, compute());
+ } catch (const std::exception &e) {
+ FAIL() << "MLUOPAPIGTEST: catched " << e.what()
+ << " in fft_MakeFFTPlanMany";
+ }
+}
+
+TEST_F(fft_MakeFFTPlanMany, BAD_PARAM_fft_plan_null) {
+ try {
+ setParam(true, false, true, true, true, true);
+ EXPECT_EQ(MLUOP_STATUS_NOT_INITIALIZED, compute());
+ } catch (const std::exception &e) {
+ FAIL() << "MLUOPAPIGTEST: catched " << e.what()
+ << " in fft_MakeFFTPlanMany";
+ }
+}
+
+TEST_F(fft_MakeFFTPlanMany, BAD_PARAM_input_desc_null) {
+ try {
+ setParam(true, true, false, true, true, true);
+ EXPECT_EQ(MLUOP_STATUS_BAD_PARAM, compute());
+ } catch (const std::exception &e) {
+ FAIL() << "MLUOPAPIGTEST: catched " << e.what()
+ << " in fft_MakeFFTPlanMany";
+ }
+}
+
+TEST_F(fft_MakeFFTPlanMany, BAD_PARAM_output_desc_null) {
+ try {
+ setParam(true, true, true, false, true, true);
+ EXPECT_EQ(MLUOP_STATUS_BAD_PARAM, compute());
+ } catch (const std::exception &e) {
+ FAIL() << "MLUOPAPIGTEST: catched " << e.what()
+ << " in fft_MakeFFTPlanMany";
+ }
+}
+
+TEST_F(fft_MakeFFTPlanMany, BAD_PARAM_reservespace_size_null) {
+ try {
+ setParam(true, true, true, true, false, true);
+ EXPECT_EQ(MLUOP_STATUS_BAD_PARAM, compute());
+ } catch (const std::exception &e) {
+ FAIL() << "MLUOPAPIGTEST: catched " << e.what()
+ << " in fft_MakeFFTPlanMany";
+ }
+}
+
+TEST_F(fft_MakeFFTPlanMany, BAD_PARAM_workspace_size_null) {
+ try {
+ setParam(true, true, true, true, true, false);
+ EXPECT_EQ(MLUOP_STATUS_BAD_PARAM, compute());
+ } catch (const std::exception &e) {
+ FAIL() << "MLUOPAPIGTEST: catched " << e.what()
+ << " in fft_MakeFFTPlanMany";
+ }
+}
+
+} // namespace mluopapitest
diff --git a/test/mlu_op_gtest/api_gtest/src/gtest/fft/fft_SetFFTReserveArea.cpp b/test/mlu_op_gtest/api_gtest/src/gtest/fft/fft_SetFFTReserveArea.cpp
new file mode 100644
index 000000000..a1610d0a1
--- /dev/null
+++ b/test/mlu_op_gtest/api_gtest/src/gtest/fft/fft_SetFFTReserveArea.cpp
@@ -0,0 +1,128 @@
+/*************************************************************************
+ * Copyright (C) [2024] by Cambricon, Inc.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a
+ * copy of this software and associated documentation files (the
+ * "Software"), to deal in the Software without restriction, including
+ * without limitation the rights to use, copy, modify, merge, publish,
+ * distribute, sublicense, and/or sell copies of the Software, and to
+ * permit persons to whom the Software is furnished to do so, subject to
+ * the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included
+ * in all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+ * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+ * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+ * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
+ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
+ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+ *************************************************************************/
+#include
+#include
+#include
+#include
+
+#include "gtest/gtest.h"
+#include "mlu_op.h"
+#include "core/context.h"
+#include "core/logging.h"
+#include "api_test_tools.h"
+
+namespace mluopapitest {
+class fft_SetFFTReserveArea : public testing::Test {
+ public:
+ void setParam(bool handle, bool fft_plan, bool reservespace) {
+ if (handle) {
+ MLUOP_CHECK(mluOpCreate(&handle_));
+ }
+
+ if (fft_plan) {
+ MLUOP_CHECK(mluOpCreateFFTPlan(&fft_plan_));
+ }
+
+ if (reservespace) {
+ GTEST_CHECK(CNRT_RET_SUCCESS ==
+ cnrtMalloc(&reservespace_, reservespace_size));
+ }
+ }
+
+ mluOpStatus_t compute() {
+ mluOpStatus_t status =
+ mluOpSetFFTReserveArea(handle_, fft_plan_, reservespace_);
+
+ destroy();
+ return status;
+ }
+
+ protected:
+ virtual void SetUp() {
+ handle_ = nullptr;
+ fft_plan_ = nullptr;
+ reservespace_ = nullptr;
+ }
+
+ void destroy() {
+ try {
+ if (handle_) {
+ CNRT_CHECK(cnrtQueueSync(handle_->queue));
+ VLOG(4) << "Destroy handle_";
+ MLUOP_CHECK(mluOpDestroy(handle_));
+ handle_ = nullptr;
+ }
+ if (fft_plan_) {
+ VLOG(4) << "Destroy fft_plan_";
+ MLUOP_CHECK(mluOpDestroyFFTPlan(fft_plan_));
+ fft_plan_ = nullptr;
+ }
+ if (reservespace_) {
+ VLOG(4) << "Destroy reservespace_";
+ GTEST_CHECK(CNRT_RET_SUCCESS == cnrtFree(reservespace_));
+ reservespace_ = nullptr;
+ }
+ } catch (const std::exception &e) {
+ FAIL() << "MLUOPAPIGTEST: catched " << e.what()
+ << " in fft_SetFFTReserveArea";
+ }
+ }
+
+ private:
+ mluOpHandle_t handle_ = nullptr;
+ mluOpFFTPlan_t fft_plan_ = nullptr;
+ void *reservespace_ = nullptr;
+ size_t reservespace_size = 64;
+};
+
+TEST_F(fft_SetFFTReserveArea, BAD_PARAM_handle_null) {
+ try {
+ setParam(false, true, true);
+ EXPECT_EQ(MLUOP_STATUS_BAD_PARAM, compute());
+ } catch (const std::exception &e) {
+ FAIL() << "MLUOPAPIGTEST: catched " << e.what()
+ << " in fft_SetFFTReserveArea";
+ }
+}
+
+TEST_F(fft_SetFFTReserveArea, BAD_PARAM_fft_plan_null) {
+ try {
+ setParam(true, false, true);
+ EXPECT_EQ(MLUOP_STATUS_BAD_PARAM, compute());
+ } catch (const std::exception &e) {
+ FAIL() << "MLUOPAPIGTEST: catched " << e.what()
+ << " in fft_SetFFTReserveArea";
+ }
+}
+
+TEST_F(fft_SetFFTReserveArea, BAD_PARAM_reservespace_null) {
+ try {
+ setParam(true, true, false);
+ EXPECT_EQ(MLUOP_STATUS_BAD_PARAM, compute());
+ } catch (const std::exception &e) {
+ FAIL() << "MLUOPAPIGTEST: catched " << e.what()
+ << " in fft_SetFFTReserveArea";
+ }
+}
+
+} // namespace mluopapitest
diff --git a/test/mlu_op_gtest/api_gtest/src/gtest/fft/fft_general.cpp b/test/mlu_op_gtest/api_gtest/src/gtest/fft/fft_general.cpp
new file mode 100644
index 000000000..0b2ac99cc
--- /dev/null
+++ b/test/mlu_op_gtest/api_gtest/src/gtest/fft/fft_general.cpp
@@ -0,0 +1,314 @@
+/*************************************************************************
+ * Copyright (C) [2024] by Cambricon, Inc.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a
+ * copy of this software and associated documentation files (the
+ * "Software"), to deal in the Software without restriction, including
+ * without limitation the rights to use, copy, modify, merge, publish,
+ * distribute, sublicense, and/or sell copies of the Software, and to
+ * permit persons to whom the Software is furnished to do so, subject to
+ * the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included
+ * in all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+ * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+ * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+ * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
+ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
+ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+ *************************************************************************/
+#include
+#include
+#include
+#include
+
+#include "gtest/gtest.h"
+#include "mlu_op.h"
+#include "core/context.h"
+#include "core/logging.h"
+#include "api_test_tools.h"
+
+namespace mluopapitest {
+typedef std::tuple
+ FFTParams;
+
+class fft_general : public testing::TestWithParam {
+ public:
+ void SetUp() {
+ target_device_ = std::get<4>(GetParam());
+ expected_status_ = std::get<5>(GetParam());
+ MLUOP_CHECK(mluOpCreate(&handle_));
+ MLUOP_CHECK(mluOpCreateFFTPlan(&fft_plan_));
+
+ MLUOpTensorParam input_params = std::get<0>(GetParam());
+ MLUOP_CHECK(mluOpCreateTensorDescriptor(&input_desc_));
+ MLUOP_CHECK(mluOpSetTensorDescriptorEx(
+ input_desc_, input_params.get_layout(), input_params.get_dtype(),
+ input_params.get_dim_nb(), input_params.get_dim_size().data(),
+ input_params.get_dim_stride().data()));
+
+ MLUOP_CHECK(mluOpSetTensorDescriptorOnchipDataType(
+ input_desc_, input_params.get_onchip_dtype()));
+
+ MLUOpTensorParam output_params = std::get<1>(GetParam());
+ MLUOP_CHECK(mluOpCreateTensorDescriptor(&output_desc_));
+
+ MLUOP_CHECK(mluOpSetTensorDescriptorEx(
+ output_desc_, output_params.get_layout(), output_params.get_dtype(),
+ output_params.get_dim_nb(), output_params.get_dim_size().data(),
+ output_params.get_dim_stride().data()));
+ n_[0] = std::get<3>(GetParam());
+ rank_ = std::get<2>(GetParam());
+ }
+
+ bool compute() {
+ if (!(target_device_ == MLUOP_UNKNOWN_DEVICE ||
+ target_device_ == handle_->arch)) {
+ destroy();
+ return true;
+ }
+
+ mluOpStatus_t status;
+ status = mluOpMakeFFTPlanMany(handle_, fft_plan_, input_desc_, output_desc_,
+ rank_, n_, &reserveSpaceSizeInBytes_,
+ &workSpaceSizeInBytes_);
+ destroy();
+
+ return expected_status_ == status;
+ }
+
+ void destroy() {
+ try {
+ if (handle_) {
+ CNRT_CHECK(cnrtQueueSync(handle_->queue));
+ VLOG(4) << "Destroy handle_";
+ MLUOP_CHECK(mluOpDestroy(handle_));
+ handle_ = nullptr;
+ }
+ if (input_desc_) {
+ VLOG(4) << "Destroy input_desc_";
+ MLUOP_CHECK(mluOpDestroyTensorDescriptor(input_desc_));
+ input_desc_ = nullptr;
+ }
+ if (output_desc_) {
+ VLOG(4) << "Destroy output_desc_";
+ MLUOP_CHECK(mluOpDestroyTensorDescriptor(output_desc_));
+ output_desc_ = nullptr;
+ }
+ if (fft_plan_) {
+ VLOG(4) << "Destroy fft_plan_";
+ MLUOP_CHECK(mluOpDestroyFFTPlan(fft_plan_));
+ fft_plan_ = nullptr;
+ }
+ if (workspace_size_) {
+ VLOG(4) << "Destroy workspace_size_";
+ GTEST_CHECK(CNRT_RET_SUCCESS == cnrtFree(workspace_size_));
+ workspace_size_ = nullptr;
+ }
+ if (reservespace_size_) {
+ VLOG(4) << "Destroy reservespace_size_";
+ GTEST_CHECK(CNRT_RET_SUCCESS == cnrtFree(reservespace_size_));
+ reservespace_size_ = nullptr;
+ }
+ } catch (const std::exception &e) {
+ FAIL() << "MLUOPAPIGTEST: catched " << e.what() << " in fft_general";
+ }
+ }
+
+ private:
+ mluOpHandle_t handle_ = nullptr;
+ mluOpFFTPlan_t fft_plan_ = nullptr;
+ mluOpTensorDescriptor_t input_desc_ = nullptr;
+ mluOpTensorDescriptor_t output_desc_ = nullptr;
+ int rank_ = 1;
+ int n_[1] = {1};
+ size_t *reservespace_size_ = nullptr;
+ size_t *workspace_size_ = nullptr;
+ size_t reserveSpaceSizeInBytes_ = 64;
+ size_t workSpaceSizeInBytes_ = 64;
+ mluOpDevType_t target_device_ = MLUOP_UNKNOWN_DEVICE;
+ mluOpStatus_t expected_status_ = MLUOP_STATUS_BAD_PARAM;
+};
+
+TEST_P(fft_general, negative) { EXPECT_TRUE(compute()); }
+
+INSTANTIATE_TEST_CASE_P(
+ zero_element, fft_general,
+ testing::Combine(testing::Values(MLUOpTensorParam{
+ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 1,
+ std::vector({0, 1}), std::vector({1, 1}),
+ MLUOP_DTYPE_FLOAT}),
+ testing::Values(MLUOpTensorParam{
+ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 1,
+ std::vector({1, 1}), std::vector({1, 1})}),
+ testing::Values(1), testing::Values(1),
+ testing::Values(MLUOP_UNKNOWN_DEVICE),
+ testing::Values(MLUOP_STATUS_SUCCESS)));
+
+INSTANTIATE_TEST_CASE_P(
+ negative_2_n, // half,complex_half,fft length can be broken down into 2^m
+ fft_general,
+ testing::Combine(testing::Values(MLUOpTensorParam{
+ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_HALF, 2,
+ std::vector({1, 7}), std::vector({1, 1}),
+ MLUOP_DTYPE_HALF}),
+ testing::Values(MLUOpTensorParam{
+ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_HALF, 2,
+ std::vector({1, 7}), std::vector({1, 1})}),
+ testing::Values(1), testing::Values(7),
+ testing::Values(MLUOP_UNKNOWN_DEVICE),
+ testing::Values(MLUOP_STATUS_NOT_SUPPORTED)));
+
+INSTANTIATE_TEST_CASE_P(
+ negative_2_m_l, // float/complex_float,n>4096, fft length can be broken
+ // down into 2^m*l
+ fft_general,
+ testing::Combine(testing::Values(MLUOpTensorParam{
+ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 2,
+ std::vector({1, 4097}), std::vector({1, 1}),
+ MLUOP_DTYPE_FLOAT}),
+ testing::Values(MLUOpTensorParam{
+ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 2,
+ std::vector({1, 4097}),
+ std::vector({1, 1})}),
+ testing::Values(1), testing::Values(4097),
+ testing::Values(MLUOP_UNKNOWN_DEVICE),
+ testing::Values(MLUOP_STATUS_NOT_SUPPORTED)));
+
+INSTANTIATE_TEST_CASE_P(
+ negative_rank_1, fft_general,
+ testing::Combine(testing::Values(MLUOpTensorParam{
+ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 1,
+ std::vector({1}), std::vector({1}),
+ MLUOP_DTYPE_FLOAT}),
+ testing::Values(MLUOpTensorParam{
+ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 1,
+ std::vector({1}), std::vector({1})}),
+ testing::Values(4), testing::Values(1),
+ testing::Values(MLUOP_UNKNOWN_DEVICE),
+ testing::Values(MLUOP_STATUS_BAD_PARAM)));
+
+INSTANTIATE_TEST_CASE_P(
+ negative_N_le_0, fft_general,
+ testing::Combine(testing::Values(MLUOpTensorParam{
+ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 1,
+ std::vector({1}), std::vector({1}),
+ MLUOP_DTYPE_FLOAT}),
+ testing::Values(MLUOpTensorParam{
+ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 1,
+ std::vector({1}), std::vector({1})}),
+ testing::Values(1), testing::Values(0, -1),
+ testing::Values(MLUOP_UNKNOWN_DEVICE),
+ testing::Values(MLUOP_STATUS_BAD_PARAM)));
+
+INSTANTIATE_TEST_CASE_P(
+ negative_batch, fft_general,
+ testing::Combine(testing::Values(MLUOpTensorParam{
+ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 2,
+ std::vector({1, 1}), std::vector({1, 1}),
+ MLUOP_DTYPE_FLOAT}),
+ testing::Values(MLUOpTensorParam{
+ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 2,
+ std::vector({2, 1}), std::vector({1, 1})}),
+ testing::Values(1), testing::Values(1),
+ testing::Values(MLUOP_UNKNOWN_DEVICE),
+ testing::Values(MLUOP_STATUS_BAD_PARAM)));
+
+INSTANTIATE_TEST_CASE_P(
+ negative_input_stride, fft_general,
+ testing::Combine(testing::Values(MLUOpTensorParam{
+ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 1,
+ std::vector({1}), std::vector({-1}),
+ MLUOP_DTYPE_FLOAT}),
+ testing::Values(MLUOpTensorParam{
+ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 1,
+ std::vector({1}), std::vector({1})}),
+ testing::Values(1), testing::Values(1),
+ testing::Values(MLUOP_UNKNOWN_DEVICE),
+ testing::Values(MLUOP_STATUS_BAD_PARAM)));
+
+INSTANTIATE_TEST_CASE_P(
+ negative_output_stride, fft_general,
+ testing::Combine(testing::Values(MLUOpTensorParam{
+ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 1,
+ std::vector({1}), std::vector({1}),
+ MLUOP_DTYPE_FLOAT}),
+ testing::Values(MLUOpTensorParam{
+ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 1,
+ std::vector({1}), std::vector({-1})}),
+ testing::Values(1), testing::Values(1),
+ testing::Values(MLUOP_UNKNOWN_DEVICE),
+ testing::Values(MLUOP_STATUS_BAD_PARAM)));
+
+INSTANTIATE_TEST_CASE_P(
+ negative_unsupported_dtype_combination, fft_general,
+ testing::Combine(testing::Values(MLUOpTensorParam{
+ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_HALF, 1,
+ std::vector({1}), std::vector({1}),
+ MLUOP_DTYPE_FLOAT}),
+ testing::Values(MLUOpTensorParam{
+ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 1,
+ std::vector({1}), std::vector({1})}),
+ testing::Values(1), testing::Values(1),
+ testing::Values(MLUOP_UNKNOWN_DEVICE),
+ testing::Values(MLUOP_STATUS_BAD_PARAM)));
+
+INSTANTIATE_TEST_CASE_P(
+ negative_onchip_dtype, fft_general,
+ testing::Combine(testing::Values(MLUOpTensorParam{
+ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 1,
+ std::vector({1}), std::vector({1}),
+ MLUOP_DTYPE_HALF}),
+ testing::Values(MLUOpTensorParam{
+ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 1,
+ std::vector({1}), std::vector({1})}),
+ testing::Values(1), testing::Values(1),
+ testing::Values(MLUOP_UNKNOWN_DEVICE),
+ testing::Values(MLUOP_STATUS_BAD_PARAM)));
+
+// r2c,output!=n/2+1
+INSTANTIATE_TEST_CASE_P(
+ negative_r2c_length, fft_general,
+ testing::Combine(testing::Values(MLUOpTensorParam{
+ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_FLOAT, 1,
+ std::vector({4}), std::vector({1}),
+ MLUOP_DTYPE_FLOAT}),
+ testing::Values(MLUOpTensorParam{
+ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 1,
+ std::vector({1}), std::vector({1})}),
+ testing::Values(1), testing::Values(4),
+ testing::Values(MLUOP_UNKNOWN_DEVICE),
+ testing::Values(MLUOP_STATUS_BAD_PARAM)));
+
+// c2c,output != n
+INSTANTIATE_TEST_CASE_P(
+ negative_c2c_length, fft_general,
+ testing::Combine(testing::Values(MLUOpTensorParam{
+ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 1,
+ std::vector({1}), std::vector({1}),
+ MLUOP_DTYPE_FLOAT}),
+ testing::Values(MLUOpTensorParam{
+ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 1,
+ std::vector({0}), std::vector({1})}),
+ testing::Values(1), testing::Values(1),
+ testing::Values(MLUOP_UNKNOWN_DEVICE),
+ testing::Values(MLUOP_STATUS_BAD_PARAM)));
+
+// c2r,output!=n
+INSTANTIATE_TEST_CASE_P(
+ negative_c2r_length, fft_general,
+ testing::Combine(testing::Values(MLUOpTensorParam{
+ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 1,
+ std::vector({4}), std::vector({1}),
+ MLUOP_DTYPE_FLOAT}),
+ testing::Values(MLUOpTensorParam{
+ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_FLOAT, 1,
+ std::vector({3}), std::vector({1})}),
+ testing::Values(1), testing::Values(4),
+ testing::Values(MLUOP_UNKNOWN_DEVICE),
+ testing::Values(MLUOP_STATUS_BAD_PARAM)));
+} // namespace mluopapitest
diff --git a/test/mlu_op_gtest/api_gtest/src/gtest/fft/fft_plan_descriptor.cpp b/test/mlu_op_gtest/api_gtest/src/gtest/fft/fft_plan_descriptor.cpp
new file mode 100644
index 000000000..3bc894c85
--- /dev/null
+++ b/test/mlu_op_gtest/api_gtest/src/gtest/fft/fft_plan_descriptor.cpp
@@ -0,0 +1,44 @@
+/*************************************************************************
+ * Copyright (C) [2024] by Cambricon, Inc.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a
+ * copy of this software and associated documentation files (the
+ * "Software"), to deal in the Software without restriction, including
+ * without limitation the rights to use, copy, modify, merge, publish,
+ * distribute, sublicense, and/or sell copies of the Software, and to
+ * permit persons to whom the Software is furnished to do so, subject to
+ * the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included
+ * in all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+ * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+ * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+ * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
+ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
+ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+ *************************************************************************/
+#include
+#include
+#include
+#include
+
+#include "gtest/gtest.h"
+#include "mlu_op.h"
+#include "core/logging.h"
+#include "api_test_tools.h"
+
+namespace mluopapitest {
+TEST(fft_plan_descriptor, BAD_PARAM_DestroyDesc_null) {
+ try {
+ mluOpFFTPlan_t fft_plan = nullptr;
+ mluOpStatus_t status = mluOpDestroyFFTPlan(fft_plan);
+ EXPECT_TRUE(status == MLUOP_STATUS_BAD_PARAM);
+ } catch (const std::exception &e) {
+ FAIL() << "MLUOPAPIGTEST: catched " << e.what() << " in fft_plan_descriptor"
+ << ")";
+ }
+}
+} // namespace mluopapitest
diff --git a/test/mlu_op_gtest/pb_gtest/src/zoo/fft/fft.cpp b/test/mlu_op_gtest/pb_gtest/src/zoo/fft/fft.cpp
new file mode 100644
index 000000000..d623ce2c5
--- /dev/null
+++ b/test/mlu_op_gtest/pb_gtest/src/zoo/fft/fft.cpp
@@ -0,0 +1,170 @@
+/*************************************************************************
+ * Copyright (C) [2023] by Cambricon, Inc.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a
+ * copy of this software and associated documentation files (the
+ * "Software"), to deal in the Software without restriction, including
+ * without limitation the rights to use, copy, modify, merge, publish,
+ * distribute, sublicense, and/or sell copies of the Software, and to
+ * permit persons to whom the Software is furnished to do so, subject to
+ * the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included
+ * in all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+ * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+ * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+ * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
+ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
+ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+ *************************************************************************/
+#include "fft.h"
+
+namespace mluoptest {
+
+void FftExecutor::paramCheck() {
+ GTEST_CHECK(parser_->getInputNum() == 1, "fft input number is wrong.");
+ GTEST_CHECK(parser_->getOutputNum() == 1, "fft input number is wrong.");
+}
+
+void FftExecutor::workspaceMalloc() {
+ auto input_tensor = tensor_desc_[0].tensor;
+ auto output_tensor = tensor_desc_[1].tensor;
+
+ auto fft_param = parser_->getProtoNode()->fft_param();
+ int rank = fft_param.rank();
+ std::vector n;
+ for (int i = 0; i < rank; i++) {
+ n.push_back(fft_param.n(i));
+ }
+
+ MLUOP_CHECK(mluOpCreateFFTPlan(&fft_plan_));
+ MLUOP_CHECK(mluOpMakeFFTPlanMany(handle_, fft_plan_, input_tensor,
+ output_tensor, rank, n.data(),
+ &reservespace_size_, &workspace_size_));
+
+ VLOG(4) << "reserve space size: " << reservespace_size_;
+ VLOG(4) << "workspace size: " << workspace_size_;
+
+ if (reservespace_size_ > 0) {
+ GTEST_CHECK(reservespace_addr_ = mlu_runtime_.allocate(reservespace_size_));
+ workspace_.push_back(reservespace_addr_);
+ }
+ // interface_timer_.start();
+ /* reserve space is the compiling time process before FFT execution */
+ MLUOP_CHECK(mluOpSetFFTReserveArea(handle_, fft_plan_, reservespace_addr_));
+ // interface_timer_.stop();
+ if (workspace_size_ > 0) {
+ GTEST_CHECK(workspace_addr_ = mlu_runtime_.allocate(workspace_size_));
+ workspace_.push_back(workspace_addr_);
+ }
+}
+
+void FftExecutor::compute() {
+ VLOG(4) << "FftExecutor compute ";
+ auto input_dev = data_vector_[0].device_ptr;
+ auto output_dev = data_vector_[1].device_ptr;
+
+ auto fft_param = parser_->getProtoNode()->fft_param();
+ int direction = fft_param.direction();
+ float scale_factor = fft_param.scale_factor();
+
+ VLOG(4) << "call mluOpFFT";
+
+ interface_timer_.start();
+ MLUOP_CHECK(mluOpExecFFT(handle_, fft_plan_, input_dev, scale_factor,
+ workspace_addr_, output_dev, direction));
+ interface_timer_.stop();
+}
+
+void FftExecutor::workspaceFree() {
+ MLUOP_CHECK(mluOpDestroyFFTPlan(fft_plan_));
+ for (auto &addr : workspace_) {
+ mlu_runtime_.deallocate(addr);
+ }
+ workspace_.clear();
+}
+
+void FftExecutor::cpuCompute() {
+ // TODO(sunhui): use fftw? librosa? OTFFT? other thrid-party library.
+}
+
+int64_t FftExecutor::getTheoryOps() {
+ auto input_tensor = tensor_desc_[0].tensor;
+ auto fft_param = parser_->getProtoNode()->fft_param();
+ int rank = fft_param.rank();
+ int bc = 1;
+ if (input_tensor->dim != rank) {
+ bc = input_tensor->dims[0];
+ }
+ int n = fft_param.n(0);
+
+ int64_t ops_each_batch;
+ // Convert LT and CT computing power. The computing power of a single LT is
+ // 4096 * 2, the computing power of a single CT is 128.
+ int cp_ratio = 4096 * 2 / 128;
+ if (n <= 4096) {
+ // fft_plan->fft_strategy = CNFFT_FUNC_MATMUL. Mainly use LT.
+ ops_each_batch = n * n * 2 / cp_ratio;
+ } else {
+ ops_each_batch = n * int(std::log(n)) * 2;
+ // fft_plan->fft_strategy = CNFFT_FUNC_COOLEY_TUKEY or CNFFT_FUNC_STOCKHAM.
+ // Half use LT and half use CT.
+ ops_each_batch = ops_each_batch * (0.5 / cp_ratio + 0.5);
+ }
+ int64_t theory_ops = bc * ops_each_batch;
+ VLOG(4) << "getTheoryOps: " << theory_ops << " ops";
+ return theory_ops;
+}
+
+int64_t FftExecutor::getTheoryIoSize() {
+ // dtype check
+ auto input_tensor = tensor_desc_[0].tensor;
+ auto output_tensor = tensor_desc_[1].tensor;
+ mluOpDataType_t input_dtype = input_tensor->dtype;
+ mluOpDataType_t output_dtype = output_tensor->dtype;
+
+ auto fft_param = parser_->getProtoNode()->fft_param();
+ int rank = fft_param.rank();
+ int bc = 1;
+ if (input_tensor->dim != rank) {
+ bc = input_tensor->dims[0];
+ }
+ int n = fft_param.n(0);
+
+ int64_t theory_ios = 0;
+ if (n <= 4096) {
+ if (input_dtype == output_dtype) {
+ theory_ios += bc * n * 4; // matmul io
+ } else { // r2c or c2r
+ theory_ios += bc * n * 2; // matmul io
+ }
+ theory_ios += n * n * 2; // W io
+ } else {
+ if (input_dtype == output_dtype) {
+ theory_ios += bc * n * 4; // matmul io
+ theory_ios += bc * n * 4; // stockham or cooley_tukey io
+ } else { // r2c or c2r
+ theory_ios += bc * n * 2; // matmul
+ theory_ios += bc * n * 2; // stockham or cooley_tukey io
+ }
+
+ // W io
+ int n_temp = n;
+ while (n_temp >= 128 && n_temp % 2 == 0) {
+ n_temp = n_temp / 2;
+ }
+ theory_ios += n_temp * 2;
+ }
+ VLOG(4) << "getTheoryIoSize: " << theory_ios << " ops";
+ return theory_ios;
+}
+
+std::set FftExecutor::getCriterionsUse() const {
+ return {Evaluator::DIFF1, Evaluator::DIFF2, Evaluator::DIFF3,
+ Evaluator::DIFF4};
+}
+
+} // namespace mluoptest
diff --git a/test/mlu_op_gtest/pb_gtest/src/zoo/fft/fft.h b/test/mlu_op_gtest/pb_gtest/src/zoo/fft/fft.h
new file mode 100644
index 000000000..a7af31090
--- /dev/null
+++ b/test/mlu_op_gtest/pb_gtest/src/zoo/fft/fft.h
@@ -0,0 +1,53 @@
+/*************************************************************************
+ * Copyright (C) [2023] by Cambricon, Inc.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a
+ * copy of this software and associated documentation files (the
+ * "Software"), to deal in the Software without restriction, including
+ * without limitation the rights to use, copy, modify, merge, publish,
+ * distribute, sublicense, and/or sell copies of the Software, and to
+ * permit persons to whom the Software is furnished to do so, subject to
+ * the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included
+ * in all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+ * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+ * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+ * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
+ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
+ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+ *************************************************************************/
+#ifndef TEST_MLU_OP_GTEST_SRC_ZOO_FFT_FFT_H_
+#define TEST_MLU_OP_GTEST_SRC_ZOO_FFT_FFT_H_
+#include
+#include
+#include "executor.h"
+
+namespace mluoptest {
+
+class FftExecutor : public Executor {
+ public:
+ FftExecutor() {}
+ ~FftExecutor() {}
+
+ void paramCheck() override;
+ void workspaceMalloc() override;
+ void compute() override;
+ void cpuCompute() override;
+ void workspaceFree() override;
+ int64_t getTheoryOps() override;
+ int64_t getTheoryIoSize() override;
+ std::set getCriterionsUse() const override;
+
+ private:
+ mluOpFFTPlan_t fft_plan_;
+ size_t reservespace_size_ = 0, workspace_size_ = 0;
+ void *reservespace_addr_ = nullptr;
+ void *workspace_addr_ = nullptr;
+};
+
+} // namespace mluoptest
+#endif // TEST_MLU_OP_GTEST_SRC_ZOO_FFT_FFT_H_
diff --git a/test/mlu_op_gtest/pb_gtest/src/zoo/fft/test_case/fft_0.prototxt b/test/mlu_op_gtest/pb_gtest/src/zoo/fft/test_case/fft_0.prototxt
new file mode 100644
index 000000000..25181705b
--- /dev/null
+++ b/test/mlu_op_gtest/pb_gtest/src/zoo/fft/test_case/fft_0.prototxt
@@ -0,0 +1,81 @@
+device: GPU
+op_name: "fft"
+input {
+ id: "input1"
+ shape {
+ dims: 2
+ dims: 6
+ dim_stride: 6
+ dim_stride: 1
+ }
+ layout: LAYOUT_ARRAY
+ dtype: DTYPE_FLOAT
+ value_h: "c0c2dbda"
+ value_h: "4071d388"
+ value_h: "bf69fea5"
+ value_h: "c11f2dfd"
+ value_h: "c0c5e5e2"
+ value_h: "c091e8c5"
+ value_h: "406d762c"
+ value_h: "3d8a27b5"
+ value_h: "bffaa8cc"
+ value_h: "bf981f7e"
+ value_h: "c01e21f5"
+ value_h: "3fae1fc8"
+ random_data {
+ distribution: UNIFORM
+ lower_bound_double: -10
+ upper_bound_double: 10
+ }
+ onchip_dtype: DTYPE_FLOAT
+}
+output {
+ id: "output1"
+ shape {
+ dims: 2
+ dims: 4
+ dim_stride: 4
+ dim_stride: 1
+ }
+ layout: LAYOUT_ARRAY
+ dtype: DTYPE_COMPLEX_FLOAT
+ value_h: "c1bf5723"
+ value_h: "0"
+ value_h: "40e0937c"
+ value_h: "c13c9082"
+ value_h: "c14192bc"
+ value_h: "c02a0abc"
+ value_h: "c01d4d28"
+ value_h: "0"
+ value_h: "bef5766c"
+ value_h: "0"
+ value_h: "40fa78d5"
+ value_h: "3f2d00db"
+ value_h: "4080b885"
+ value_h: "3fc8226c"
+ value_h: "bf75464a"
+ value_h: "0"
+ random_data {
+ distribution: UNIFORM
+ lower_bound_double: -1
+ upper_bound_double: 1
+ }
+ thresholds {
+ evaluation_threshold: 1e-05
+ evaluation_threshold: 1e-05
+ evaluation_threshold_imag: 1e-05
+ evaluation_threshold_imag: 1e-05
+ }
+}
+evaluation_criterion: DIFF1
+evaluation_criterion: DIFF2
+supported_mlu_platform: MLU370
+handle_param {
+ round_mode: ROUND_OFF_ZERO
+}
+fft_param {
+ rank: 1
+ n: 6
+ direction: 0
+ scale_factor: 1
+}