From db260048f575491ac8c9f68f273876c2dfc0b5c4 Mon Sep 17 00:00:00 2001 From: Aleksandar Zecevic Date: Fri, 20 Dec 2024 19:27:18 +0100 Subject: [PATCH] Added E2E PermuteOp support (#1505) Added support for PermuteOp, since it's already supported in [ttnn](https://docs.tenstorrent.com/tt-metal/latest/ttnn/ttnn/api/ttnn.permute.html#ttnn.permute). `stablehlo.transpose` and `ttir.permute` have the same semantics, so decomposition of `stablehlo.transpose` into series of `ttir.transpose`es is not needed anymore. Closes https://github.com/tenstorrent/tt-mlir/issues/652 --- include/ttmlir/Dialect/TTIR/IR/TTIROps.td | 27 ++ include/ttmlir/Dialect/TTNN/IR/TTNNOps.td | 23 ++ include/ttmlir/Target/TTNN/program.fbs | 9 + include/ttmlir/Utils.h | 80 ++++- .../StableHLOToTTIRPatterns.cpp | 74 +---- .../TTIRToTTIRDecomposition.cpp | 309 +++++++----------- lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp | 21 +- lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp | 3 +- lib/Dialect/TTIR/IR/TTIROps.cpp | 37 +++ lib/Dialect/TTNN/IR/TTNNOps.cpp | 38 +++ lib/Target/TTNN/TTNNToFlatbuffer.cpp | 23 +- runtime/lib/ttnn/operations/CMakeLists.txt | 1 + .../ttnn/operations/data_movement/permute.cpp | 30 ++ .../ttnn/operations/data_movement/permute.h | 15 + runtime/lib/ttnn/program.cpp | 4 + .../unary/permute_transpose_op.mlir | 4 +- .../StableHLOToTTIR/unary/transpose_op.mlir | 5 +- .../TTIR/permute/permute_tests_negative.mlir | 33 ++ .../complex_conv_channel_first.mlir | 10 +- .../TTNN/convolution/simple_conv1d.mlir | 20 +- .../TTNN/permute/permute_tests_negative.mlir | 30 ++ .../TTNN/permute/permute_tests_positive.mlir | 32 ++ .../Dialect/TTNN/permute/simple_permute.mlir | 12 + .../Dialect/TTNN/pooling/simple_pooling.mlir | 6 +- test/ttmlir/Dialect/TTNN/simple_permute.mlir | 10 + .../StableHLO/Unary/permute_transpose_op.mlir | 8 +- .../Silicon/StableHLO/Unary/tranpose_op.mlir | 4 +- .../TTNN/complex_conv_channel_first.mlir | 10 +- .../TTNN/perf_unit/test_perf_permute.mlir | 14 + .../Silicon/TTNN/pooling/simple_pooling.mlir | 6 +- test/ttmlir/Silicon/TTNN/simple_permute.mlir | 14 + 31 files changed, 621 insertions(+), 291 deletions(-) create mode 100644 runtime/lib/ttnn/operations/data_movement/permute.cpp create mode 100644 runtime/lib/ttnn/operations/data_movement/permute.h create mode 100644 test/ttmlir/Dialect/TTIR/permute/permute_tests_negative.mlir create mode 100644 test/ttmlir/Dialect/TTNN/permute/permute_tests_negative.mlir create mode 100644 test/ttmlir/Dialect/TTNN/permute/permute_tests_positive.mlir create mode 100644 test/ttmlir/Dialect/TTNN/permute/simple_permute.mlir create mode 100644 test/ttmlir/Dialect/TTNN/simple_permute.mlir create mode 100644 test/ttmlir/Silicon/TTNN/perf_unit/test_perf_permute.mlir create mode 100644 test/ttmlir/Silicon/TTNN/simple_permute.mlir diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index c0e5a78b8..842d35362 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -1308,6 +1308,33 @@ def TTIR_MatmulOp : TTIR_DPSOp<"matmul"> { } // ANCHOR_END: adding_an_op_matmul_ttir +def TTIR_PermuteOp : TTIR_DPSOp<"permute"> { + let summary = "Permute operation."; + let description = [{ + Permute input tensor dimensions. + + Attributes: + - `permutation` array: The permutation of the input tensor dimensions. + + Example: + %a = tensor.empty() : () -> tensor<2x3x4xi32> + %output = tensor.empty() : () -> tensor<3x4x2xi32> + %0 = "ttir.permute"(%a, %output) {permutation = array} : (tensor<2x3x4xi32>, tensor<3x4x2xi32>) -> tensor<3x4x2xi32> + }]; + + let arguments = (ins AnyRankedTensor:$input, + AnyRankedTensor:$output, + DenseI64ArrayAttr:$permutation); + + let results = (outs AnyRankedTensor:$result); + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } + }]; + + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // TTIR top level generic ops //===----------------------------------------------------------------------===// diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 7567364d1..b91595947 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -1062,4 +1062,27 @@ def TTNN_MeshShardOp: TTNN_Op<"mesh_shard"> { let hasVerifier = 1; } +def TTNN_PermuteOp : TTNN_Op<"permute"> { + let summary = "Permute operation."; + let description = [{ + Permute input tensor dimensions. + + Attributes: + - `permutation` array: The permutation of the input tensor dimensions. + + Example: + %a = tensor.empty() : () -> tensor<2x3x4xi32> + %0 = "ttir.permute"(%a) {permutation = array} : (tensor<2x3x4xi32>) -> tensor<3x4x2xi32> + }]; + + let arguments = (ins AnyRankedTensor:$input, + DenseI64ArrayAttr:$permutation, + OptionalAttr:$memory_config, + DefaultValuedOptionalAttr:$pad_value); + + let results = (outs AnyRankedTensor:$result); + + let hasVerifier = 1; +} + #endif diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index 4ba5443ad..c8c0f1ad0 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -295,6 +295,14 @@ table AllGatherOp { num_links: uint32; } +table PermuteOp { + in: tt.target.TensorRef; + permutation: [int64]; + memory_config: MemoryConfigDesc; + pad_value: float; + out: tt.target.TensorRef; +} + table ReduceScatterOp { in: tt.target.TensorRef; out: tt.target.TensorRef; @@ -343,6 +351,7 @@ union OpType { ArangeOp, UpdateCacheOp, FillCacheOp, + PermuteOp, } table Operation { diff --git a/include/ttmlir/Utils.h b/include/ttmlir/Utils.h index ec7838b2f..8d70bef11 100644 --- a/include/ttmlir/Utils.h +++ b/include/ttmlir/Utils.h @@ -5,14 +5,14 @@ #ifndef TTMLIR_UTILS_H #define TTMLIR_UTILS_H -#include -#include - #include "mlir-c/IR.h" #include "mlir/CAPI/IR.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinAttributes.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" + +#include namespace ttmlir::utils { template @@ -72,17 +72,14 @@ constexpr std::underlying_type_t enum_as_int(Enum e) { return static_cast>(e); } -template -std::string join(const llvm::SmallVector &vec, - const std::string &delimiter) { - std::ostringstream result; - for (size_t i = 0; i < vec.size(); ++i) { - result << vec[i]; - if (i != vec.size() - 1) { - result << delimiter; - } - } - return result.str(); +// Returns a string that is the concatenation of the string representations of +// Range R elements interleaved with separator. Example: join({1, 2, 3}, ", ") +// -> "1, 2, 3" +template +std::string join(Range &&R, llvm::StringRef separator) { + return llvm::join( + llvm::map_range(R, [](auto &v) { return llvm::Twine(v).str(); }), + separator); } // Prepacks `MlirAttribute`s stored in input array into a vector of @@ -131,6 +128,61 @@ inline bool isRankedTensor(mlir::Value v) { return mlir::isa(v.getType()); } +// Returns the element received as a parameter. Useful as a callback for +// higher-order functions. +template +inline T identity(T x) { + return x; +} + +// Returns a vector of indices `permutation` such that input[permutation[i]] == +// output[i], for all i. Assumes that input and output have the same elements. +// Example: input = [1, 2, 3], output = [3, 1, 2] -> [2, 0, 1] +template +inline llvm::SmallVector +generatePermutation(llvm::ArrayRef input, llvm::ArrayRef output) { + assert(input.size() == output.size()); + + llvm::DenseMap indices; + for (const auto [index, value] : llvm::enumerate(input)) { + indices[value] = index; + } + llvm::SmallVector permutation; + for (const T &dim : output) { + permutation.push_back(indices[dim]); + } + return permutation; +} + +// Returns a vector `output`, such that output[i] = input[permutation[i]], for +// all i. Assumes that permutation is a valid permutation of the indices of +// input. Example: input = [1, 2, 3], permutation = [2, 0, 1] -> [3, 1, 2] +template +inline llvm::SmallVector +applyPermutation(llvm::ArrayRef input, llvm::ArrayRef permutation) { + assert(input.size() == permutation.size()); + + llvm::SmallVector output(input.size()); + + llvm::transform(permutation, output.begin(), + [&](const int64_t i) { return input[i]; }); + + return output; +} + +// Returns a vector `inversePermutation`, such that +// inversePermutation[permutation[i]] = i, for all i. Assumes that permutation +// is a valid permutation of a range(0, permutation.size()). Example: +// permutation = [2, 0, 1] -> [1, 2, 0] +inline llvm::SmallVector +inversePermutation(llvm::ArrayRef permutation) { + llvm::SmallVector inversePermutation(permutation.size()); + for (size_t i = 0; i < permutation.size(); ++i) { + inversePermutation[permutation[i]] = i; + } + return inversePermutation; +} + } // namespace ttmlir::utils #endif diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index b541d0a3e..4eeec92dc 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -138,46 +138,16 @@ class StableHLOToTTIRTransposeOpConversionPattern matchAndRewrite(mlir::stablehlo::TransposeOp srcOp, mlir::stablehlo::TransposeOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - - auto input = Value(adaptor.getOperand()); - auto transposes = getPermutationTransposes(adaptor.getPermutation().vec()); - - for (auto transposeDims : transposes) { - auto dim0 = std::get<0>(transposeDims); - auto dim1 = std::get<1>(transposeDims); - - auto inputType = mlir::cast(input.getType()); - auto outputShape = inputType.getShape().vec(); - std::swap(outputShape[dim0], outputShape[dim1]); - - auto outputType = RankedTensorType::get( - outputShape, inputType.getElementType(), inputType.getEncoding()); - - auto outputTensor = rewriter.create( - srcOp.getLoc(), outputShape, outputType.getElementType()); - - input = rewriter.create( - srcOp.getLoc(), outputType, input, outputTensor, - rewriter.getSI32IntegerAttr(dim0), rewriter.getSI32IntegerAttr(dim1)); - } - rewriter.replaceOp(srcOp, input); + ::mlir::RankedTensorType outputType = mlir::cast( + this->getTypeConverter()->convertType(srcOp.getResult().getType())); + tensor::EmptyOp outputTensor = rewriter.create( + srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); + // stablehlo.transpose and ttir.permute have the same semantics. + rewriter.replaceOpWithNewOp( + srcOp, getTypeConverter()->convertType(srcOp.getResult().getType()), + adaptor.getOperand(), outputTensor, adaptor.getPermutation()); return success(); } - -private: - std::vector> - getPermutationTransposes(std::vector permutation) const { - std::vector> transposes; - for (uint32_t i = 0; i < permutation.size(); i++) { - while (i != permutation[i]) { - transposes.push_back( - std::make_tuple(permutation[i], permutation[permutation[i]])); - std::swap(permutation[i], permutation[permutation[i]]); - } - } - - return transposes; - } }; class StableHLOToTTIRReshapeOpConversionPattern @@ -204,19 +174,6 @@ class StableHLOToTTIRReshapeOpConversionPattern adaptor.getOperand(), outputTensor, new_shape_attr); return success(); } - - LogicalResult - checkBasicLegality(mlir::stablehlo::TransposeOp &srcOp, - mlir::stablehlo::TransposeOp::Adaptor &adaptor, - ConversionPatternRewriter &rewriter) const { - - if (adaptor.getPermutation().size() != 2) { - return rewriter.notifyMatchFailure( - srcOp, "TTIR supports only two dimensional transposeOp."); - } - - return success(); - } }; class StableHLOToTTIRDotGeneralOpConversionPattern @@ -1831,13 +1788,6 @@ void addReduceOpsConversionPatterns(MLIRContext *ctx, patterns.add(typeConverter, ctx); } -void addTransposeOpsConversionPatterns(MLIRContext *ctx, - RewritePatternSet &patterns, - TypeConverter &typeConverter) { - - patterns.add(typeConverter, ctx); -} - void addMatmulOpsConversionPatterns(MLIRContext *ctx, RewritePatternSet &patterns, TypeConverter &typeConverter) { @@ -1891,6 +1841,12 @@ void addConcatOpsConversionPatterns(MLIRContext *ctx, patterns.add(typeConverter, ctx); } +void addTransposeOpConversionPattern(MLIRContext *ctx, + RewritePatternSet &patterns, + TypeConverter &typeConverter) { + patterns.add(typeConverter, ctx); +} + void addReshapeOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns, TypeConverter &typeConverter) { @@ -1973,7 +1929,6 @@ void populateStableHLOToTTIRPatterns(MLIRContext *ctx, addElementwiseUnaryOpsConversionPatterns(ctx, patterns, typeConverter); addElementwiseBinaryOpsConversionPatterns(ctx, patterns, typeConverter); addReduceOpsConversionPatterns(ctx, patterns, typeConverter); - addTransposeOpsConversionPatterns(ctx, patterns, typeConverter); addMatmulOpsConversionPatterns(ctx, patterns, typeConverter); addGetDimensionSizeOpsConversionPatterns(ctx, patterns, typeConverter); addTensorCreationOpsConversionPatterns(ctx, patterns, typeConverter); @@ -1982,6 +1937,7 @@ void populateStableHLOToTTIRPatterns(MLIRContext *ctx, addReduceWindowOpConversionPattern(ctx, patterns, typeConverter); addCompareOpsConversionPatterns(ctx, patterns, typeConverter); addConcatOpsConversionPatterns(ctx, patterns, typeConverter); + addTransposeOpConversionPattern(ctx, patterns, typeConverter); addReshapeOpConversionPattern(ctx, patterns, typeConverter); addCCLOpsConversionPattern(ctx, patterns, typeConverter); addLogicalAndBitwiseOpsConversionPatterns(ctx, patterns, typeConverter); diff --git a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp index 25140b9bc..953ea2835 100644 --- a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp +++ b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp @@ -6,6 +6,7 @@ #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" #include "ttmlir/Dialect/TTIR/IR/TTIROps.h" +#include "ttmlir/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinTypes.h" @@ -17,7 +18,6 @@ #include "mlir/Transforms/DialectConversion.h" #include -#include using namespace mlir; using namespace mlir::tt; @@ -77,168 +77,49 @@ struct IndexToSliceConversionPattern // Convolution passes //===----------------------------------------------------------------------===// -using TransposeDims = std::tuple; - template using PaddingMatrix = std::array, NDims>; template static PaddingMatrix getPaddingMatrix(ArrayRef padding) { + assert(padding.size() >= 2 * NDims && + "padding must be at least 2 * NDims sized array"); + PaddingMatrix paddingMatrix; - std::vector paddingFlattened = padding.vec(); for (uint32_t i = 0; i < 2 * NDims; i += 2) { - paddingMatrix[i / 2] = {paddingFlattened[i], paddingFlattened[i + 1]}; + paddingMatrix[i / 2] = {padding[i], padding[i + 1]}; } return paddingMatrix; } -/* - * The following functions are used to generate the transpose operations needed - * to convert a convolution operation to the specific op definitions for a - * ConvNdOp for any N spatial dimensions. - * - * All convolutions will have a batch and feature dimension, and the kernel will - * have an input and output feature dimension. The spatial dimensions can be - * represented by non-negative integers. - */ -enum ConvolutionDimension { BATCH = -1, FEATURE = -2, INVALID_DIM = -3 }; - -enum ConvolutionKernelDimension { - INPUT_FEATURES = -1, - OUTPUT_FEATURES = -2, - INVALID_KERNEL_DIM = -3 -}; - -/* - * Generates a sequence of dims in which to transpose to make currentLayout - * match desiredLayout - * - * Ex: if currentLayout = [0, 1, 2, 3] and desiredLayout = [0, 2, 3, 1] - * then the function will return [(1, 2), (2, 3)] because when we swap - * currentLayout[1] with currentLayout[2] we get [0, 2, 1, 3], and then when - * we swap currentLayout[2] with currentLayout[3] we get [0, 2, 3, 1], which - * is the desired layout - */ -static std::vector -generateTransposeIndices(std::vector currentLayout, - const std::vector desiredLayout) { - std::vector transposeIndices; - for (int64_t i = 0; i < static_cast(currentLayout.size()); i++) { - if (currentLayout[i] != desiredLayout[i]) { - int64_t dim0 = i; - int64_t dim1 = std::find(currentLayout.begin(), currentLayout.end(), - desiredLayout[i]) - - currentLayout.begin(); - transposeIndices.push_back(std::make_tuple(dim0, dim1)); - std::swap(currentLayout[dim0], currentLayout[dim1]); - } - } - - return transposeIndices; -} - -/* - * This function will use a sequence of transpose indices to - * generate the actual transpose operations descrbibed by them. - * - * It takes an input to apply these transposes to and returns the - * result at the end of the sequence - */ -static Value generateTransposeOps(Value input, PatternRewriter &rewriter, - std::vector transposeIndices) { - for (auto [dim0, dim1] : transposeIndices) { - - auto inputType = mlir::cast(input.getType()); - auto outputShape = inputType.getShape().vec(); - std::swap(outputShape[dim0], outputShape[dim1]); - - auto dim0Attr = rewriter.getSI32IntegerAttr(dim0); - auto dim1Attr = rewriter.getSI32IntegerAttr(dim1); - - auto outputType = RankedTensorType::get( - outputShape, inputType.getElementType(), inputType.getEncoding()); - - auto dpsOutput = rewriter.create( - input.getLoc(), outputShape, outputType.getElementType()); - input = rewriter - .create(input.getLoc(), outputType, input, - dpsOutput, dim0Attr, dim1Attr) - .getResult(); - } - - return input; -} - -/* - * This function will generate the transpose indices needed to convert a - * convolution input to a desired layout. The reason for the separate - * function is to encapsulate the logic for constructuring the inputLayout - */ -static std::vector -generateConvTransposeIndices(ttir::ConvolutionOp op, - const std::vector ttnnConvolutionLayout) { - - std::vector inputLayout(ttnnConvolutionLayout.size(), - ConvolutionDimension::INVALID_DIM); - inputLayout[op.getConvolutionLayout().getInputBatchDimension()] = - ConvolutionDimension::BATCH; - inputLayout[op.getConvolutionLayout().getInputFeatureDimension()] = - ConvolutionDimension::FEATURE; - - int64_t spatialCount = 0; - for (int64_t spatialDim : - op.getConvolutionLayout().getInputSpatialDimensions()) { - inputLayout[spatialDim] = spatialCount; - spatialCount++; - } - - return generateTransposeIndices(inputLayout, ttnnConvolutionLayout); -} - -/* - * This function will generate the transpose indices needed to convert a - * convolution input to a desired layout. The reason for the separate - * function is to encapsulate the logic for constructuring the kernelLayout - */ -static std::vector generateConvKernelTransposeIndices( - ttir::ConvolutionOp op, - const std::vector ttnnConvolutionKernelLayout) { - std::vector transposeIndices; - - std::vector kernelLayout( - ttnnConvolutionKernelLayout.size(), - ConvolutionKernelDimension::INVALID_KERNEL_DIM); - kernelLayout[op.getConvolutionLayout().getKernelOutputFeatureDimension()] = - ConvolutionKernelDimension::OUTPUT_FEATURES; - kernelLayout[op.getConvolutionLayout().getKernelInputFeatureDimension()] = - ConvolutionKernelDimension::INPUT_FEATURES; - - int64_t spatialCount = 0; - for (int64_t spatialDim : - op.getConvolutionLayout().getKernelSpatialDimensions()) { - kernelLayout[spatialDim] = spatialCount; - spatialCount++; - } - - return generateTransposeIndices(kernelLayout, ttnnConvolutionKernelLayout); -} struct ConvolutionDecompositionPattern : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; + // All convolutions will have a batch and feature dimension, and the kernel + // will have an input and output feature dimension. The spatial dimensions + // can be + // represented by non-negative integers. + enum ConvolutionDimension { BATCH = -1, FEATURE = -2, INVALID_DIM = -3 }; + enum ConvolutionKernelDimension { + INPUT_FEATURES = -1, + OUTPUT_FEATURES = -2, + INVALID_KERNEL_DIM = -3 + }; + LogicalResult matchAndRewrite(ttir::ConvolutionOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override = 0; protected: - bool isNDimensional(ttir::ConvolutionOp op, uint32_t numSpatialDims) const { + static bool isNDimensional(ttir::ConvolutionOp op, uint32_t numSpatialDims) { return op.getConvolutionLayout().getInputSpatialDimensions().size() == numSpatialDims; } - bool isSupportedConv(ttir::ConvolutionOp op) const { + static bool isSupportedConv(ttir::ConvolutionOp op) { assert(op.getConvolutionLayout().getInputSpatialDimensions().size() == op.getConvolutionLayout().getOutputSpatialDimensions().size() && "Convolution input, output, and kernel must have the same number of " @@ -249,12 +130,8 @@ struct ConvolutionDecompositionPattern "spatial dimensions"); // Not currently supporting window reversal - std::vector windowReversal(op.getWindowReversal().begin(), - op.getWindowReversal().end()); - for (bool reversed : windowReversal) { - if (reversed) { - return false; - } + if (llvm::any_of(op.getWindowReversal(), ttmlir::utils::identity)) { + return false; } // Not currently support batch groups @@ -264,6 +141,53 @@ struct ConvolutionDecompositionPattern return true; } + + // This function will generate the transpose indices needed to convert a + // convolution input to a desired layout. The reason for the separate + // function is to encapsulate the logic for constructuring the inputLayout. + static llvm::SmallVector + generateConvPermutation(ttir::ConvolutionOp op, + llvm::ArrayRef ttnnConvolutionLayout) { + + llvm::SmallVector inputLayout(ttnnConvolutionLayout.size(), + ConvolutionDimension::INVALID_DIM); + inputLayout[op.getConvolutionLayout().getInputBatchDimension()] = + ConvolutionDimension::BATCH; + inputLayout[op.getConvolutionLayout().getInputFeatureDimension()] = + ConvolutionDimension::FEATURE; + + for (const auto [spatialCount, spatialDim] : llvm::enumerate( + op.getConvolutionLayout().getInputSpatialDimensions())) { + inputLayout[spatialDim] = spatialCount; + } + + return ttmlir::utils::generatePermutation(llvm::ArrayRef(inputLayout), + ttnnConvolutionLayout); + } + + // This function will generate the transpose indices needed to convert a + // convolution input to a desired layout. The reason for the separate + // function is to encapsulate the logic for constructuring the kernelLayout. + static llvm::SmallVector generateConvKernelPermutation( + ttir::ConvolutionOp op, + llvm::ArrayRef ttnnConvolutionKernelLayout) { + + llvm::SmallVector kernelLayout( + ttnnConvolutionKernelLayout.size(), + ConvolutionKernelDimension::INVALID_KERNEL_DIM); + kernelLayout[op.getConvolutionLayout().getKernelOutputFeatureDimension()] = + ConvolutionKernelDimension::OUTPUT_FEATURES; + kernelLayout[op.getConvolutionLayout().getKernelInputFeatureDimension()] = + ConvolutionKernelDimension::INPUT_FEATURES; + + for (const auto [spatialCount, spatialDim] : llvm::enumerate( + op.getConvolutionLayout().getKernelSpatialDimensions())) { + kernelLayout[spatialDim] = spatialCount; + } + + return ttmlir::utils::generatePermutation(llvm::ArrayRef(kernelLayout), + ttnnConvolutionKernelLayout); + } }; // A decomposition pattern that matches to a ttir.convolution op that does 1D @@ -274,12 +198,12 @@ struct ConvolutionDecompositionPattern struct Legalize1DConvolutionPattern : public ConvolutionDecompositionPattern { public: using ConvolutionDecompositionPattern::ConvolutionDecompositionPattern; - constexpr static uint32_t numSpatialDims = 1; + constexpr static uint32_t NUM_SPATIAL_DIMS = 1; LogicalResult matchAndRewrite(ttir::ConvolutionOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (!(isSupportedConv(op) && isNDimensional(op, numSpatialDims))) { + if (!(isSupportedConv(op) && isNDimensional(op, NUM_SPATIAL_DIMS))) { return failure(); } @@ -299,8 +223,7 @@ struct Legalize1DConvolutionPattern : public ConvolutionDecompositionPattern { conv2dOutputShape.push_back(1); auto DPSConv2dOutput = rewriter.create( op->getLoc(), conv2dOutputShape, outputType.getElementType()); - auto conv2dOutputType = - mlir::cast(DPSConv2dOutput.getType()); + RankedTensorType conv2dOutputType = DPSConv2dOutput.getType(); auto inputType = mlir::cast(adaptor.getInput().getType()); llvm::ArrayRef inputShape = inputType.getShape(); @@ -418,11 +341,12 @@ struct Legalize1DConvolutionPattern : public ConvolutionDecompositionPattern { return rewriter.getDenseBoolArrayAttr(newDenseArray); } }; + struct ConvolutionToConv2dPattern : public ConvolutionDecompositionPattern { public: using ConvolutionDecompositionPattern::ConvolutionDecompositionPattern; - constexpr static uint32_t numSpatialDims = 2; + constexpr static uint32_t NUM_SPATIAL_DIMS = 2; constexpr static uint32_t SPATIAL_DIM_HEIGHT = 0; constexpr static uint32_t SPATIAL_DIM_WIDTH = 1; @@ -444,7 +368,7 @@ struct ConvolutionToConv2dPattern : public ConvolutionDecompositionPattern { LogicalResult matchAndRewrite(ttir::ConvolutionOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (!(isSupportedConv(op) && isNDimensional(op, numSpatialDims))) { + if (!(isSupportedConv(op) && isNDimensional(op, NUM_SPATIAL_DIMS))) { return failure(); } @@ -460,7 +384,8 @@ struct ConvolutionToConv2dPattern : public ConvolutionDecompositionPattern { // Padding is a list of 2-tuples, the order of the 2-tuples is in // most-significant spatial dimension first order For Conv2d the most // significant spatial dimension is the height, followed by the width. - auto paddingMatrix = getPaddingMatrix(adaptor.getPadding()); + auto paddingMatrix = + getPaddingMatrix(adaptor.getPadding()); auto paddingTopAttr = rewriter.getSI32IntegerAttr(paddingMatrix[SPATIAL_DIM_HEIGHT][0]); auto paddingBottomAttr = @@ -473,8 +398,8 @@ struct ConvolutionToConv2dPattern : public ConvolutionDecompositionPattern { auto groupsAttr = rewriter.getSI32IntegerAttr(adaptor.getFeatureGroupCount()); - auto outputShape = op.getResult().getType().getShape().vec(); - std::vector newOutputShape = { + llvm::ArrayRef outputShape = op.getResult().getType().getShape(); + llvm::SmallVector newOutputShape{ outputShape[adaptor.getConvolutionLayout().getOutputBatchDimension()], outputShape[adaptor.getConvolutionLayout() .getOutputSpatialDimensions()[SPATIAL_DIM_HEIGHT]], @@ -488,30 +413,41 @@ struct ConvolutionToConv2dPattern : public ConvolutionDecompositionPattern { inputType.cloneWith(newOutputShape, inputType.getElementType()); auto convDPSOutput = rewriter.create( - adaptor.getInput().getLoc(), newOutputShape, - outputType.getElementType()); + op.getLoc(), newOutputShape, outputType.getElementType()); - auto transposeIndices = generateConvTransposeIndices(op, conv2dLayout); - Value input = - generateTransposeOps(adaptor.getInput(), rewriter, transposeIndices); + auto permutation = generateConvPermutation(op, conv2dLayout); + auto permuteOutputShape = + ::ttmlir::utils::applyPermutation(inputType.getShape(), permutation); + auto permuteDPSOutput = rewriter.create( + op.getLoc(), permuteOutputShape, inputType.getElementType()); + auto input = rewriter.create( + op.getLoc(), permuteDPSOutput.getType(), adaptor.getInput(), + permuteDPSOutput, permutation); - auto kernelTransposeIndices = - generateConvKernelTransposeIndices(op, conv2dKernelLayout); - Value weight = generateTransposeOps(adaptor.getWeight(), rewriter, - kernelTransposeIndices); + auto weightType = + mlir::cast(adaptor.getWeight().getType()); + auto kernelPermutation = + generateConvKernelPermutation(op, conv2dKernelLayout); + auto weightOutputShape = ::ttmlir::utils::applyPermutation( + mlir::cast(adaptor.getWeight().getType()).getShape(), + kernelPermutation); + auto weightDPSOutput = rewriter.create( + op.getLoc(), weightOutputShape, weightType.getElementType()); + auto weight = rewriter.create( + op.getLoc(), weightDPSOutput.getType(), adaptor.getWeight(), + weightDPSOutput, kernelPermutation); ttir::Conv2dOp newConv = rewriter.create( op.getLoc(), outputType, input, weight, adaptor.getBias(), convDPSOutput, strideHeightAttr, strideWidthAttr, dilationHeightAttr, dilationWidthAttr, groupsAttr, paddingLeftAttr, paddingRightAttr, paddingTopAttr, paddingBottomAttr); - // Applying the transposes in reverse order to the output will restore the - // tensor to the original layout - std::reverse(transposeIndices.begin(), transposeIndices.end()); - Value output = - generateTransposeOps(newConv.getResult(), rewriter, transposeIndices); + // Applying the inverse of permutation to the output will restore the + // tensor to the original layout. + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), newConv, adaptor.getOutput(), + ttmlir::utils::inversePermutation(permutation)); - rewriter.replaceOp(op, output); return success(); } }; @@ -795,8 +731,9 @@ struct PoolingToPool2dPattern : public OpConversionPattern { } } - auto transposeIndices = - generateTransposeIndices(currentLayout, desiredLayout); + auto permutation = ttmlir::utils::generatePermutation( + llvm::ArrayRef(currentLayout), llvm::ArrayRef(desiredLayout)); + auto inverseOfPermutation = ttmlir::utils::inversePermutation(permutation); auto kernelHeightAttr = rewriter.getSI32IntegerAttr( static_cast(op.getWindowDimensions()[spatialDims[0]])); @@ -824,16 +761,21 @@ struct PoolingToPool2dPattern : public OpConversionPattern { auto paddingRightAttr = rewriter.getSI32IntegerAttr(op.getPadding()[2 * spatialDims[1] + 1]); - std::vector outputs; + llvm::SmallVector outputs; for (Value input : adaptor.getInputs()) { - input = generateTransposeOps(input, rewriter, transposeIndices); + RankedTensorType inputTy = mlir::cast(input.getType()); + + auto inputPermuteShape = + ::ttmlir::utils::applyPermutation(inputTy.getShape(), permutation); + auto inputDPSOutput = rewriter.create( + op.getLoc(), inputPermuteShape, inputTy.getElementType()); + input = rewriter.create(op.getLoc(), + inputDPSOutput.getType(), input, + inputDPSOutput, permutation); auto outputType = mlir::cast(op.getResult(0).getType()); - auto newOutputShape = outputType.getShape().vec(); - for (TransposeDims dims : transposeIndices) { - std::swap(newOutputShape[std::get<0>(dims)], - newOutputShape[std::get<1>(dims)]); - } + auto newOutputShape = + ::ttmlir::utils::applyPermutation(outputType.getShape(), permutation); auto newOutputType = outputType.cloneWith(newOutputShape, outputType.getElementType()); auto outputTensor = rewriter.create( @@ -846,15 +788,14 @@ struct PoolingToPool2dPattern : public OpConversionPattern { dilationHeightAttr, dilationWidthAttr, ceilModeAttr, paddingTopAttr, paddingBottomAttr, paddingLeftAttr, paddingRightAttr); - // Applying the transposes in reverse order to the output will restore the - // tensor to the original layout - std::reverse(transposeIndices.begin(), transposeIndices.end()); - Value output = - generateTransposeOps(newPool.getResult(), rewriter, transposeIndices); + // Applying the inverse of permutation to the output will restore the + // tensor to the original layout. + auto reversePoolDPSOuput = rewriter.create( + op.getLoc(), outputType.getShape(), outputType.getElementType()); + Value output = rewriter.create( + op.getLoc(), reversePoolDPSOuput.getType(), newPool, + reversePoolDPSOuput, inverseOfPermutation); - // Reverse back so the proper input transposes are generated for the next - // pool - std::reverse(transposeIndices.begin(), transposeIndices.end()); outputs.push_back(output); } diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index a2b63a1bc..6461d97eb 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -26,6 +26,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" + #include using namespace mlir; @@ -1120,6 +1121,23 @@ class ScatterOpConversionPattern : public OpConversionPattern { return success(); } }; + +class PermuteOpConversionPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ttir::PermuteOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, this->getTypeConverter()->convertType(op.getType()), + adaptor.getInput(), adaptor.getPermutationAttr(), + ttnn::MemoryConfigAttr(), mlir::FloatAttr()); + + return success(); + } +}; + } // namespace namespace mlir::tt { @@ -1201,7 +1219,8 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns, ArangeOpConversionPattern, UpdateCacheOpConversionPattern, FillCacheOpConversionPattern, - ScatterOpConversionPattern + ScatterOpConversionPattern, + PermuteOpConversionPattern >(typeConverter, ctx); // ANCHOR_END: op_rewriter_pattern_set // clang-format on diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp index 8aa60e002..fda64f925 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp @@ -755,7 +755,8 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, patterns.add, DefaultOpConversionPattern, DefaultOpConversionPattern, - DefaultOpConversionPattern>(typeConverter, ctx); + DefaultOpConversionPattern, + DefaultOpConversionPattern>(typeConverter, ctx); // Matmul ops // diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index ef9fd29f4..52e68b811 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -22,6 +22,7 @@ #include "llvm/Support/LogicalResult.h" #include +#include #include #define GET_OP_CLASSES @@ -1579,6 +1580,42 @@ ::mlir::LogicalResult mlir::tt::ttir::ReverseOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// PermuteOp +//===----------------------------------------------------------------------===// + +// PermuteOp verification +::mlir::LogicalResult mlir::tt::ttir::PermuteOp::verify() { + llvm::ArrayRef inputShape = getInput().getType().getShape(); + const size_t inputRank = inputShape.size(); + llvm::ArrayRef resultShape = getResult().getType().getShape(); + + // Check that given attribute `permutation` is a valid permutation of the + // dimensions. + llvm::ArrayRef permutation = getPermutation(); + llvm::SmallVector dimensions(inputRank); + std::iota(dimensions.begin(), dimensions.end(), 0); + if (inputRank != permutation.size() || + !std::is_permutation(permutation.begin(), permutation.end(), + dimensions.begin())) { + return emitOpError("Expected a permutation of (") + << ttmlir::utils::join(dimensions, ", ") + << "), got (" + ttmlir::utils::join(permutation, ", ") << ")"; + } + + // Check that the result shape matches the shape of input tensor after + // permutation is applied. + llvm::SmallVector expectedResultShape = + ttmlir::utils::applyPermutation(inputShape, permutation); + if (!llvm::equal(expectedResultShape, resultShape)) { + return emitOpError("Expected result shape (") + << ttmlir::utils::join(expectedResultShape, ", ") << "), got (" + << ttmlir::utils::join(resultShape, ", ") << ")"; + } + + return success(); +} + //===----------------------------------------------------------------------===// // GenericOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TTNN/IR/TTNNOps.cpp b/lib/Dialect/TTNN/IR/TTNNOps.cpp index 6a0fed3e8..e3fc5a33c 100644 --- a/lib/Dialect/TTNN/IR/TTNNOps.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOps.cpp @@ -10,6 +10,7 @@ #include "ttmlir/Dialect/TTNN/Utils/Utils.h" #include "ttmlir/Utils.h" +#include #include #include "mlir/Dialect/Traits.h" @@ -1272,4 +1273,41 @@ ::mlir::LogicalResult FillCacheOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// PermuteOp +//===----------------------------------------------------------------------===// + +// PermuteOp verification +::mlir::LogicalResult mlir::tt::ttnn::PermuteOp::verify() { + llvm::ArrayRef inputShape = getInput().getType().getShape(); + const size_t inputRank = inputShape.size(); + llvm::ArrayRef resultShape = getResult().getType().getShape(); + + // Check that given attribute `permutation` is a valid permutation of the + // dimensions. + llvm::ArrayRef permutation = getPermutation(); + llvm::SmallVector dimensions(inputRank); + std::iota(dimensions.begin(), dimensions.end(), 0); + if (inputRank != permutation.size() || + !std::is_permutation(permutation.begin(), permutation.end(), + dimensions.begin())) { + return emitOpError("Expected a permutation of (") + << ttmlir::utils::join(dimensions, ", ") + << "), got (" + ttmlir::utils::join(permutation, ", ") << ")"; + } + + // Check that the result shape matches the shape of input tensor after + // permutation is applied. + llvm::SmallVector expectedResultShape = + ttmlir::utils::applyPermutation(inputShape, permutation); + if (!llvm::equal(expectedResultShape, resultShape)) { + return emitOpError("Expected result shape (" + + ttmlir::utils::join(expectedResultShape, ", ") + + "), got (" + ttmlir::utils::join(resultShape, ", ") + + ")"); + } + + return success(); +} + } // namespace mlir::tt::ttnn diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index db35c13b1..1e8492317 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -34,7 +34,6 @@ #include "llvm/Support/raw_ostream.h" #include -#include #include namespace mlir::tt { @@ -517,6 +516,24 @@ createOp(FlatbufferObjectCache &cache, MeshShardOp op) { cache.fbb->CreateVector(shardShape)); } +::flatbuffers::Offset<::tt::target::ttnn::PermuteOp> +createOp(FlatbufferObjectCache &cache, PermuteOp op) { + flatbuffers::Offset<::tt::target::TensorRef> input = + cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput())); + flatbuffers::Offset> permutation = + toFlatbuffer(cache, op.getPermutation()); + std::optional memoryConfig = + op.getMemoryConfig(); + float padValue = op.getPadValue().convertToFloat(); + auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer, + kHostAllocatedAddress, kHostAllocatedSize); + return ::tt::target::ttnn::CreatePermuteOp( + *cache.fbb, input, permutation, + memoryConfig ? cache.getOrCreate(*memoryConfig, memoryConfigToFlatbuffer) + : 0, + padValue, output); +} + template ::flatbuffers::Offset createEltwiseOpParams(FlatbufferObjectCache &cache, EltwiseOp op) { @@ -1170,6 +1187,10 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op, return createOperation(cache, createOp(cache, fillCacheOp), debugString, locInfo); } + if (auto permuteOp = dyn_cast(op); permuteOp) { + return createOperation(cache, createOp(cache, permuteOp), debugString, + locInfo); + } llvm_unreachable("unhandled op in emitTTNNOperation"); } diff --git a/runtime/lib/ttnn/operations/CMakeLists.txt b/runtime/lib/ttnn/operations/CMakeLists.txt index 953310193..3a7f2b3ef 100644 --- a/runtime/lib/ttnn/operations/CMakeLists.txt +++ b/runtime/lib/ttnn/operations/CMakeLists.txt @@ -10,6 +10,7 @@ set(TTNN_OPS_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/creation/ones.cpp ${CMAKE_CURRENT_SOURCE_DIR}/creation/full.cpp ${CMAKE_CURRENT_SOURCE_DIR}/data_movement/concat.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/data_movement/permute.cpp ${CMAKE_CURRENT_SOURCE_DIR}/data_movement/reshape.cpp ${CMAKE_CURRENT_SOURCE_DIR}/data_movement/slice.cpp ${CMAKE_CURRENT_SOURCE_DIR}/data_movement/transpose.cpp diff --git a/runtime/lib/ttnn/operations/data_movement/permute.cpp b/runtime/lib/ttnn/operations/data_movement/permute.cpp new file mode 100644 index 000000000..70589faf0 --- /dev/null +++ b/runtime/lib/ttnn/operations/data_movement/permute.cpp @@ -0,0 +1,30 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "operations/data_movement/permute.h" + +#include "tt/runtime/detail/logger.h" +#include "tt/runtime/ttnn/operations/utils.h" + +#include + +namespace tt::runtime::ttnn::operations::data_movement { +void run(const ::tt::target::ttnn::PermuteOp *op, ProgramContext &context) { + ProgramTensorPool &tensorPool = context.getTensorPool(); + + const ::ttnn::Tensor &in = tensorPool.at(op->in()->global_id()); + DEBUG_ASSERT(in.is_allocated()); + + std::vector permutation(op->permutation()->begin(), + op->permutation()->end()); + std::optional memoryConfig = + op->memory_config() ? std::make_optional(utils::createMemoryConfig( + op->memory_config(), op->out())) + : std::nullopt; + float padValue = op->pad_value(); + + ::ttnn::Tensor out = ::ttnn::permute(in, permutation, memoryConfig, padValue); + tensorPool.insert_or_assign(op->out()->global_id(), out); +} +} // namespace tt::runtime::ttnn::operations::data_movement diff --git a/runtime/lib/ttnn/operations/data_movement/permute.h b/runtime/lib/ttnn/operations/data_movement/permute.h new file mode 100644 index 000000000..d583aec33 --- /dev/null +++ b/runtime/lib/ttnn/operations/data_movement/permute.h @@ -0,0 +1,15 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef RUNTIME_LIB_TTNN_OPERATIONS_DATA_MOVEMENT_PERMUTE_H +#define RUNTIME_LIB_TTNN_OPERATIONS_DATA_MOVEMENT_PERMUTE_H + +#include "tt/runtime/ttnn/types.h" +#include "ttmlir/Target/TTNN/program_generated.h" + +namespace tt::runtime::ttnn::operations::data_movement { +void run(const ::tt::target::ttnn::PermuteOp *op, ProgramContext &context); +} // namespace tt::runtime::ttnn::operations::data_movement + +#endif diff --git a/runtime/lib/ttnn/program.cpp b/runtime/lib/ttnn/program.cpp index 5d1bf6fd0..b12a41b23 100644 --- a/runtime/lib/ttnn/program.cpp +++ b/runtime/lib/ttnn/program.cpp @@ -9,6 +9,7 @@ #include "operations/creation/full.h" #include "operations/creation/ones.h" #include "operations/data_movement/concat.h" +#include "operations/data_movement/permute.h" #include "operations/data_movement/reshape.h" #include "operations/data_movement/slice.h" #include "operations/data_movement/transpose.h" @@ -202,6 +203,9 @@ void ProgramExecutor::runOperation(const ::tt::target::ttnn::Operation *op) { case ::tt::target::ttnn::OpType::ConcatOp: { return operations::data_movement::run(op->type_as_ConcatOp(), context); } + case ::tt::target::ttnn::OpType::PermuteOp: { + return operations::data_movement::run(op->type_as_PermuteOp(), context); + } case ::tt::target::ttnn::OpType::ReshapeOp: { return operations::data_movement::run(op->type_as_ReshapeOp(), context); } diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/unary/permute_transpose_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/unary/permute_transpose_op.mlir index f12db0ae7..8440146da 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/unary/permute_transpose_op.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/unary/permute_transpose_op.mlir @@ -2,8 +2,8 @@ // RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s module { func.func @main(%arg0: tensor<1x32x64x128xf32>) -> tensor<1x128x32x64xf32> { - // CHECK: %[[C:.*]] = "ttir.transpose"[[C:.*]] - // CHECK: %[[C:.*]] = "ttir.transpose"[[C:.*]] + // CHECK: "ttir.permute" + // CHECK-SAME: permutation = array %0 = stablehlo.transpose %arg0, dims = [0, 3, 1, 2] : (tensor<1x32x64x128xf32>) -> tensor<1x128x32x64xf32> return %0 : tensor<1x128x32x64xf32> } diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/unary/transpose_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/unary/transpose_op.mlir index 398746dda..7bdcf9129 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/unary/transpose_op.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/unary/transpose_op.mlir @@ -3,8 +3,9 @@ module @jit_transpose attributes {} { func.func public @test_transpose(%arg0: tensor<64x128xf32>) -> tensor<128x64xf32> { %0 = stablehlo.transpose %arg0, dims = [1,0] : (tensor<64x128xf32>) -> tensor<128x64xf32> - // CHECK: %[[C:.*]] = tensor.empty[[C:.*]] - // CHECK: %[[C:.*]] = "ttir.transpose"[[C:.*]] + // CHECK: tensor.empty + // CHECK: "ttir.permute" + // CHECK-SAME: permutation = array return %0 : tensor<128x64xf32> } } diff --git a/test/ttmlir/Dialect/TTIR/permute/permute_tests_negative.mlir b/test/ttmlir/Dialect/TTIR/permute/permute_tests_negative.mlir new file mode 100644 index 000000000..a3c64e31b --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/permute/permute_tests_negative.mlir @@ -0,0 +1,33 @@ +// RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s +// Negative tests for permute operation + +// Verfiy that given attribute `permutation` is a valid permutation of the dimensions. +module { + func.func @permute_non_valid_permutation(%arg0: tensor<16x32x64xbf16>) -> tensor<16x32x64xbf16> { + // CHECK: error: 'ttir.permute' op Expected a permutation of (0, 1, 2), got (0, 1, 0) + %0 = tensor.empty() : tensor<16x32x64xbf16> + %1 = "ttir.permute"(%arg0, %0) <{permutation = array}> : (tensor<16x32x64xbf16>, tensor<16x32x64xbf16>) -> tensor<16x32x64xbf16> + return %1 : tensor<16x32x64xbf16> + } +} + +// ----- +module { + func.func @permute_subset_permutation(%arg0: tensor<16x32x64xbf16>) -> tensor<16x32x64xbf16> { + // CHECK: error: 'ttir.permute' op Expected a permutation of (0, 1, 2), got (0, 1) + %0 = tensor.empty() : tensor<16x32x64xbf16> + %1 = "ttir.permute"(%arg0, %0) <{permutation = array}> : (tensor<16x32x64xbf16>, tensor<16x32x64xbf16>) -> tensor<16x32x64xbf16> + return %1 : tensor<16x32x64xbf16> + } +} + +// Verify that the result shape matches the shape of the input tensor after permutation is applied. +// ----- +module { + func.func @permute_non_valid_shape(%arg0: tensor<16x32x64xbf16>) -> tensor<16x32x64xbf16> { + // CHECK: error: 'ttir.permute' op Expected result shape (16, 64, 32), got (16, 32, 64) + %0 = tensor.empty() : tensor<16x32x64xbf16> + %1 = "ttir.permute"(%arg0, %0) <{permutation = array}> : (tensor<16x32x64xbf16>, tensor<16x32x64xbf16>) -> tensor<16x32x64xbf16> + return %1 : tensor<16x32x64xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTNN/convolution/complex_conv_channel_first.mlir b/test/ttmlir/Dialect/TTNN/convolution/complex_conv_channel_first.mlir index f4633b7a8..76dd7ce76 100644 --- a/test/ttmlir/Dialect/TTNN/convolution/complex_conv_channel_first.mlir +++ b/test/ttmlir/Dialect/TTNN/convolution/complex_conv_channel_first.mlir @@ -2,9 +2,11 @@ module @jit_convolution { func.func public @test_NCHW_IOHW_to_NHWC_OIHW_conv2d(%arg0: tensor<1x3x100x100xbf16>, %arg1: tensor<7x3x3x3xbf16>) -> tensor<1x7x100x100xbf16> { %0 = tensor.empty() : tensor<1x7x100x100xbf16> - // CHECK: %[[C:.*]] = "ttnn.transpose"[[C:.*]] - // CHECK: %[[C:.*]] = "ttnn.transpose"[[C:.*]] - // CHECK: %[[C:.*]] = "ttnn.conv2d"[[C:.*]] + // CHECK: "ttnn.permute" + // CHECK-SAME: permutation = array + // CHECK: "ttnn.conv2d" + // CHECK: "ttnn.permute" + // CHECK-SAME: permutation = array %1 = "ttir.convolution"(%arg0, %arg1, %0) <{ batch_group_count = 1 : i64, convolution_layout = #ttir, window_strides = array }> : (tensor<1x3x100x100xbf16>, tensor<7x3x3x3xbf16>, tensor<1x7x100x100xbf16>) -> tensor<1x7x100x100xbf16> - // CHECK: %[[C:.*]] = "ttnn.transpose"[[C:.*]] - // CHECK: %[[C:.*]] = "ttnn.transpose"[[C:.*]] return %1 : tensor<1x7x100x100xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/convolution/simple_conv1d.mlir b/test/ttmlir/Dialect/TTNN/convolution/simple_conv1d.mlir index bd86d0218..8719ab3cc 100644 --- a/test/ttmlir/Dialect/TTNN/convolution/simple_conv1d.mlir +++ b/test/ttmlir/Dialect/TTNN/convolution/simple_conv1d.mlir @@ -2,13 +2,19 @@ module { func.func @main(%arg0: tensor<1x256x512xf32>, %arg1: tensor<1024x256x1xf32>, %arg2: tensor<1024xf32>) -> tensor<1x1024x512xf32> { %0 = tensor.empty() : tensor<1x1024x512xf32> - // CHECK: [[VAL0:%[0-9]+]] = "ttnn.reshape"(%{{.*}}) <{shape = [1 : i32, 256 : i32, 512 : i32, 1 : i32]}> : (tensor<[[TENSOR_SHAPE0:[0-9]+x[0-9]+x[0-9]+xf32]], #{{.*}}) -> tensor<[[TENSOR_SHAPE1:[0-9]+x[0-9]+x[0-9]+x1xf32]], #{{.*}}> - // CHECK: [[VAL1:%[0-9]+]] = "ttnn.reshape"(%{{.*}}) <{shape = [1024 : i32, 256 : i32, 1 : i32, 1 : i32]}> : (tensor<[[TENSOR_SHAPE2:[0-9]+x[0-9]+x[0-9]+xf32]], #{{.*}}>) -> tensor<[[TENSOR_SHAPE3:[0-9]+x[0-9]+x[0-9]+x1xf32]], #{{.*}}> - // CHECK: [[VAL2:%[0-9]+]] = "ttnn.transpose"([[VAL0]]) <{dim0 = 1 : si32, dim1 = 2 : si32}> : (tensor<[[TENSOR_SHAPE1]], #{{.*}}>) -> tensor<[[TENSOR_SHAPE4:[0-9]+x[0-9]+x[0-9]+x1xf32]], #{{.*}}> - // CHECK: [[VAL3:%[0-9]+]] = "ttnn.transpose"([[VAL2]]) <{dim0 = 2 : si32, dim1 = 3 : si32}> : (tensor<[[TENSOR_SHAPE4]], #{{.*}}>) -> tensor<[[TENSOR_SHAPE5:[0-9]+x[0-9]+x[0-9]+x[0-9]+xf32]], #{{.*}}> - // CHECK: [[VAL4:%[0-9]+]] = "ttnn.reshape"([[VAL3]]) <{shape = [1 : i32, 1 : i32, 512 : i32, 256 : i32]}> : (tensor<[[TENSOR_SHAPE5]], #{{.*}}>) -> tensor<[[TENSOR_SHAPE6:[0-9]+x[0-9]+x[0-9]+x[0-9]+xf32]], #{{.*}}> - // CHECK: [[VAL5:%[0-9]+]] = "ttnn.conv2d"([[VAL4]], %10, %{{[0-9]+}}, %{{[0-9]+}}) - // CHECK: (tensor<[[TENSOR_SHAPE6]], #{{.*}}>, tensor<1024x256x1x1xf32, #{{.*}}>, tensor<1x1x512x1024xf32, #{{.*}}>, !tt.device<#device>) -> tensor<1x1x512x1024xf32, #{{.*}}> + // CHECK: "ttnn.reshape" + // CHECK-SAME: shape = [1 : i32, 256 : i32, 512 : i32, 1 : i32] + // CHECK: "ttnn.reshape" + // CHECK-SAME: shape = [1024 : i32, 256 : i32, 1 : i32, 1 : i32] + // CHECK: "ttnn.permute" + // CHECK-SAME: permutation = array + // CHECK: "ttnn.permute" + // CHECK-SAME: permutation = array + // CHECK: "ttnn.conv2d" + // CHECK: "ttnn.permute" + // CHECK-SAME: permutation = array + // CHECK: "ttnn.reshape" + // CHECK-SAME: shape = [1 : i32, 1024 : i32, 512 : i32] %1 = "ttir.convolution"(%arg0, %arg1, %0) <{batch_group_count = 1 : i64, convolution_layout = #ttir, feature_group_count = 1 : i64, input_dilation = array, padding = array, weight_dilation = array, window_reversal = array, window_strides = array}> : (tensor<1x256x512xf32>, tensor<1024x256x1xf32>, tensor<1x1024x512xf32>) -> tensor<1x1024x512xf32> // CHECK: return %{{.*}} : tensor<1x1024x512xf32, #ttnn_layout3> return %1 : tensor<1x1024x512xf32> diff --git a/test/ttmlir/Dialect/TTNN/permute/permute_tests_negative.mlir b/test/ttmlir/Dialect/TTNN/permute/permute_tests_negative.mlir new file mode 100644 index 000000000..f7a83c65e --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/permute/permute_tests_negative.mlir @@ -0,0 +1,30 @@ +// RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s +// Negative tests for permute operation + +// Verfiy that given attribute `permutation` is a valid permutation of the dimensions. +module { + func.func @permute_non_valid_permutation(%arg0: tensor<16x32x64xbf16>) -> tensor<16x32x64xbf16> { + // CHECK: error: 'ttnn.permute' op Expected a permutation of (0, 1, 2), got (0, 1, 0) + %0 = "ttnn.permute"(%arg0) <{permutation = array}> : (tensor<16x32x64xbf16>) -> tensor<16x32x64xbf16> + return %0 : tensor<16x32x64xbf16> + } +} + +// ----- +module { + func.func @permute_subset_permutation(%arg0: tensor<16x32x64xbf16>) -> tensor<16x32x64xbf16> { + // CHECK: error: 'ttnn.permute' op Expected a permutation of (0, 1, 2), got (0, 1) + %0 = "ttnn.permute"(%arg0) <{permutation = array}> : (tensor<16x32x64xbf16>) -> tensor<16x32x64xbf16> + return %0 : tensor<16x32x64xbf16> + } +} + +// Verify that the result shape matches the shape of the input tensor after permutation is applied. +// ----- +module { + func.func @permute_non_valid_shape(%arg0: tensor<16x32x64xbf16>) -> tensor<16x32x64xbf16> { + // CHECK: error: 'ttnn.permute' op Expected result shape (16, 64, 32), got (16, 32, 64) + %0 = "ttnn.permute"(%arg0) <{permutation = array}> : (tensor<16x32x64xbf16>) -> tensor<16x32x64xbf16> + return %0 : tensor<16x32x64xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTNN/permute/permute_tests_positive.mlir b/test/ttmlir/Dialect/TTNN/permute/permute_tests_positive.mlir new file mode 100644 index 000000000..54988e880 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/permute/permute_tests_positive.mlir @@ -0,0 +1,32 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +module { + func.func @permute_identity(%arg0: tensor<8x32x64x128xf32>) -> tensor<8x32x64x128xf32> { + %0 = tensor.empty() : tensor<8x32x64x128xf32> + // CHECK: "ttnn.permute" + // CHECK-SAME: permutation = array + // CHECK-SAME: tensor<8x32x64x128xf32 + // CHECK-SAME: tensor<8x32x64x128xf32 + %1 = "ttir.permute"(%arg0, %0) <{permutation = array}> : (tensor<8x32x64x128xf32>, tensor<8x32x64x128xf32>) -> tensor<8x32x64x128xf32> + return %1 : tensor<8x32x64x128xf32> + } + + func.func @permute_general(%arg0: tensor<8x32x64x128xf32>) -> tensor<64x8x128x32xf32> { + %0 = tensor.empty() : tensor<64x8x128x32xf32> + // CHECK: "ttnn.permute" + // CHECK-SAME: permutation = array + // CHECK-SAME: tensor<8x32x64x128xf32 + // CHECK-SAME: tensor<64x8x128x32xf32 + %1 = "ttir.permute"(%arg0, %0) <{permutation = array}> : (tensor<8x32x64x128xf32>, tensor<64x8x128x32xf32>) -> tensor<64x8x128x32xf32> + return %1 : tensor<64x8x128x32xf32> + } + + func.func @permute_1d(%arg0: tensor<32xf32>) -> tensor<32xf32> { + %0 = tensor.empty() : tensor<32xf32> + // CHECK: "ttnn.permute" + // CHECK-SAME: permutation = array + // CHECK-SAME: tensor<32xf32 + // CHECK-SAME: tensor<32xf32 + %1 = "ttir.permute"(%arg0, %0) <{permutation = array}> : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32> + return %1 : tensor<32xf32> + } +} diff --git a/test/ttmlir/Dialect/TTNN/permute/simple_permute.mlir b/test/ttmlir/Dialect/TTNN/permute/simple_permute.mlir new file mode 100644 index 000000000..f657d55d1 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/permute/simple_permute.mlir @@ -0,0 +1,12 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +module { + func.func @permute(%arg0: tensor<1x4x32x64xf32>) -> tensor<4x32x64x1xf32> { + %0 = tensor.empty() : tensor<4x32x64x1xf32> + // CHECK: "ttnn.permute" + // CHECK-SAME: permutation = array + // CHECK-SAME: tensor<1x4x32x64xf32 + // CHECK-SAME: tensor<4x32x64x1xf32 + %1 = "ttir.permute"(%arg0, %0) <{permutation = array}> : (tensor<1x4x32x64xf32>, tensor<4x32x64x1xf32>) -> tensor<4x32x64x1xf32> + return %1 : tensor<4x32x64x1xf32> + } +} diff --git a/test/ttmlir/Dialect/TTNN/pooling/simple_pooling.mlir b/test/ttmlir/Dialect/TTNN/pooling/simple_pooling.mlir index 2f4b65ce6..626d56c83 100644 --- a/test/ttmlir/Dialect/TTNN/pooling/simple_pooling.mlir +++ b/test/ttmlir/Dialect/TTNN/pooling/simple_pooling.mlir @@ -2,7 +2,11 @@ module attributes {} { func.func @forward(%arg0: tensor<1x32x128x128xbf16>) -> tensor<1x32x64x64xbf16> { %0 = tensor.empty() : tensor<1x32x64x64xbf16> - // CHECK: %[[C:.*]] = "ttnn.max_pool2d"[[C:.*]] + // CHECK: "ttnn.permute" + // CHECK-SAME: permutation = array + // CHECK: "ttnn.max_pool2d" + // CHECK: "ttnn.permute" + // CHECK-SAME: permutation = array %1 = "ttir.pooling"(%arg0, %0) <{ operandSegmentSizes = array, pooling_method = #ttir, diff --git a/test/ttmlir/Dialect/TTNN/simple_permute.mlir b/test/ttmlir/Dialect/TTNN/simple_permute.mlir new file mode 100644 index 000000000..753c2ab49 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/simple_permute.mlir @@ -0,0 +1,10 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +module { + func.func @forward(%arg0: tensor<1x4x32x64xf32>) -> tensor<4x32x64x1xf32> { + %0 = tensor.empty() : tensor<4x32x64x1xf32> + // CHECK: "ttnn.permute" + // CHECK-SAME: permutation = array + %1 = "ttir.permute"(%arg0, %0) <{permutation = array}> : (tensor<1x4x32x64xf32>, tensor<4x32x64x1xf32>) -> tensor<4x32x64x1xf32> + return %1 : tensor<4x32x64x1xf32> + } +} diff --git a/test/ttmlir/Silicon/StableHLO/Unary/permute_transpose_op.mlir b/test/ttmlir/Silicon/StableHLO/Unary/permute_transpose_op.mlir index c61fc05aa..b0017de94 100644 --- a/test/ttmlir/Silicon/StableHLO/Unary/permute_transpose_op.mlir +++ b/test/ttmlir/Silicon/StableHLO/Unary/permute_transpose_op.mlir @@ -9,13 +9,9 @@ module { func.func public @test_permute_transpose(%arg0: tensor<1x32x64x128xf32>) -> tensor<1x128x32x64xf32> { // CHECK-LABEL: func.func public @test_permute_transpose - // CHECK: %[[VAL:[0-9]+]] = "ttnn.transpose" - // CHECK-SAME: {dim0 = 3 : si32, dim1 = 2 : si32} + // CHECK: %[[VAL:[0-9]+]] = "ttnn.permute" + // CHECK-SAME: permutation = array // CHECK-SAME: tensor<1x32x64x128xf32, - // CHECK-SAME: -> tensor<1x32x128x64xf32 - // CHECK: "ttnn.transpose"(%[[VAL]]) - // CHECK-SAME: {dim0 = 2 : si32, dim1 = 1 : si32} - // CHECK-SAME: tensor<1x32x128x64xf32, // CHECK-SAME: -> tensor<1x128x32x64xf32, %0 = stablehlo.transpose %arg0, dims = [0, 3, 1, 2] : (tensor<1x32x64x128xf32>) -> tensor<1x128x32x64xf32> return %0 : tensor<1x128x32x64xf32> diff --git a/test/ttmlir/Silicon/StableHLO/Unary/tranpose_op.mlir b/test/ttmlir/Silicon/StableHLO/Unary/tranpose_op.mlir index 54795a063..b14601c05 100644 --- a/test/ttmlir/Silicon/StableHLO/Unary/tranpose_op.mlir +++ b/test/ttmlir/Silicon/StableHLO/Unary/tranpose_op.mlir @@ -9,8 +9,8 @@ module @jit_transpose attributes {} { func.func public @test_transpose(%arg0: tensor<64x128xf32>) -> tensor<128x64xf32> { // CHECK-LABEL: func.func public @test_transpose - // CHECK: ttnn.transpose - // CHECK-SAME: {dim0 = 1 : si32, dim1 = 0 : si32} + // CHECK: "ttnn.permute" + // CHECK-SAME: permutation = array // CHECK-SAME: tensor<64x128xf32, // CHECK-SAME: -> tensor<128x64xf32, %0 = stablehlo.transpose %arg0, dims = [1,0] : (tensor<64x128xf32>) -> tensor<128x64xf32> diff --git a/test/ttmlir/Silicon/TTNN/complex_conv_channel_first.mlir b/test/ttmlir/Silicon/TTNN/complex_conv_channel_first.mlir index ca773e978..b51fa140d 100644 --- a/test/ttmlir/Silicon/TTNN/complex_conv_channel_first.mlir +++ b/test/ttmlir/Silicon/TTNN/complex_conv_channel_first.mlir @@ -4,9 +4,11 @@ module @jit_convolution { func.func public @test_NCHW_IOHW_to_NHWC_OIHW_conv2d(%arg0: tensor<1x3x100x100xbf16>, %arg1: tensor<7x3x3x3xbf16>) -> tensor<1x7x100x100xbf16> { %0 = tensor.empty() : tensor<1x7x100x100xbf16> - // CHECK: %[[C:.*]] = "ttnn.transpose"[[C:.*]] - // CHECK: %[[C:.*]] = "ttnn.transpose"[[C:.*]] - // CHECK: %[[C:.*]] = "ttnn.conv2d"[[C:.*]] + // CHECK: "ttnn.permute" + // CHECK-SAME: permutation = array + // CHECK: "ttnn.conv2d" + // CHECK: "ttnn.permute" + // CHECK-SAME: permutation = array %1 = "ttir.convolution"(%arg0, %arg1, %0) <{ batch_group_count = 1 : i64, convolution_layout = #ttir, window_strides = array }> : (tensor<1x3x100x100xbf16>, tensor<7x3x3x3xbf16>, tensor<1x7x100x100xbf16>) -> tensor<1x7x100x100xbf16> - // CHECK: %[[C:.*]] = "ttnn.transpose"[[C:.*]] - // CHECK: %[[C:.*]] = "ttnn.transpose"[[C:.*]] return %1 : tensor<1x7x100x100xbf16> } } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_permute.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_permute.mlir new file mode 100644 index 000000000..6d7898cd8 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_permute.mlir @@ -0,0 +1,14 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +module { + func.func @permute(%arg0: tensor<1x4x32x64xf32>) -> tensor<4x32x64x1xf32> { + %0 = tensor.empty() : tensor<4x32x64x1xf32> + // CHECK: "ttnn.permute" + // CHECK-SAME: permutation = array + // CHECK-SAME: tensor<1x4x32x64xf32 + // CHECK-SAME: tensor<4x32x64x1xf32 + %1 = "ttir.permute"(%arg0, %0) <{permutation = array}> : (tensor<1x4x32x64xf32>, tensor<4x32x64x1xf32>) -> tensor<4x32x64x1xf32> + return %1 : tensor<4x32x64x1xf32> + } +} diff --git a/test/ttmlir/Silicon/TTNN/pooling/simple_pooling.mlir b/test/ttmlir/Silicon/TTNN/pooling/simple_pooling.mlir index 7c4d62660..8b68ad5ec 100644 --- a/test/ttmlir/Silicon/TTNN/pooling/simple_pooling.mlir +++ b/test/ttmlir/Silicon/TTNN/pooling/simple_pooling.mlir @@ -4,7 +4,11 @@ module attributes {} { func.func @forward(%arg0: tensor<1x32x128x128xbf16>) -> tensor<1x32x64x64xbf16> { %0 = tensor.empty() : tensor<1x32x64x64xbf16> - // CHECK: %[[C:.*]] = "ttnn.max_pool2d"[[C:.*]] + // CHECK: "ttnn.permute" + // CHECK-SAME: permutation = array + // CHECK: "ttnn.max_pool2d" + // CHECK: "ttnn.permute" + // CHECK-SAME: permutation = array %1 = "ttir.pooling"(%arg0, %0) <{ operandSegmentSizes = array, pooling_method = #ttir, diff --git a/test/ttmlir/Silicon/TTNN/simple_permute.mlir b/test/ttmlir/Silicon/TTNN/simple_permute.mlir new file mode 100644 index 000000000..6d7898cd8 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/simple_permute.mlir @@ -0,0 +1,14 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +module { + func.func @permute(%arg0: tensor<1x4x32x64xf32>) -> tensor<4x32x64x1xf32> { + %0 = tensor.empty() : tensor<4x32x64x1xf32> + // CHECK: "ttnn.permute" + // CHECK-SAME: permutation = array + // CHECK-SAME: tensor<1x4x32x64xf32 + // CHECK-SAME: tensor<4x32x64x1xf32 + %1 = "ttir.permute"(%arg0, %0) <{permutation = array}> : (tensor<1x4x32x64xf32>, tensor<4x32x64x1xf32>) -> tensor<4x32x64x1xf32> + return %1 : tensor<4x32x64x1xf32> + } +}