From e6b8599d034b638e6b4dd090db6bc21c572bd428 Mon Sep 17 00:00:00 2001 From: Aleksandar Zecevic Date: Fri, 10 Jan 2025 14:57:21 +0000 Subject: [PATCH 1/9] ttir.matmul transpose input - changed names of inputs to lhs and rhs - refactoring of verification --- include/ttmlir/Dialect/TTIR/IR/TTIROps.td | 4 +- lib/Dialect/TTIR/IR/TTIROps.cpp | 87 ++++++++------- .../TTIR/matmul/matmul_tests_negative.mlir | 101 ++++++++++++------ .../TTIR/matmul/matmul_tests_positive.mlir | 87 --------------- 4 files changed, 117 insertions(+), 162 deletions(-) delete mode 100644 test/ttmlir/Dialect/TTIR/matmul/matmul_tests_positive.mlir diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index 839bd81d9..eec3d4a38 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -1325,7 +1325,9 @@ def TTIR_MatmulOp : TTIR_DPSOp<"matmul"> { let arguments = (ins AnyRankedTensor:$a, AnyRankedTensor:$b, - AnyRankedTensor:$output); + AnyRankedTensor:$output, + DefaultValuedAttr:$transpose_a, + DefaultValuedAttr:$transpose_b); let results = (outs AnyRankedTensor:$result); diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index 73daad713..d4c08b964 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -1187,56 +1187,58 @@ ::mlir::LogicalResult mlir::tt::ttir::MatmulOp::verify() { llvm::SmallVector inputAShape(inputAType.getShape()); llvm::SmallVector inputBShape(inputBType.getShape()); - // Verify that the input A is at least 1D tensor + // Verify that the input A is at least 1D tensor. if (inputAType.getRank() < 1) { return emitOpError("Input A must be at least a 1D tensor"); } - // Verify that the input B is at least 1D tensor + // Verify that the input B is at least 1D tensor. if (inputBType.getRank() < 1) { return emitOpError("Input B must be at least a 1D tensor"); } - // If input A is a vector (1D tensor), 1 is prepended to its dimension for the - // purpose of the matrix multiply. After the matrix multiply, the prepended - // dimension is removed. + // If input A is a vector (1D tensor), 1 is prepended to its dimensions for + // the purpose of the matrix multiplication. After the matrix multiplication, + // the prepended dimension is removed. Otherwise, check if the LHS needs to be + // transposed. if (inputAType.getRank() == 1) { inputAShape.insert(inputAShape.begin(), 1); + } else if (getTransposeA()) { + std::swap(inputAShape[inputAShape.size() - 1], + inputAShape[inputAShape.size() - 2]); } - // If input B is a vector (1D tensor), a 1 is appended to its dimension for - // the purpose of the matrix-vector product and removed after. + // If input B is a vector (1D tensor), a 1 is appended to its dimensions for + // the purpose of the matrix-vector product and removed afterwards. Otherwise, + // check if the RHS needs to be transposed. if (inputBType.getRank() == 1) { inputBShape.push_back(1); + } else if (getTransposeB()) { + std::swap(inputBShape[inputBShape.size() - 1], + inputBShape[inputBShape.size() - 2]); } - // Verify that the input A and input B has matching inner dimensions + // Verify that the input A and input B has matching inner dimensions. if (inputAShape[inputAShape.size() - 1] != inputBShape[inputBShape.size() - 2]) { - return emitOpError( - "Input A[-1](" + std::to_string(inputAShape[inputAShape.size() - 1]) + - ") and B[-2](" + std::to_string(inputBShape[inputBShape.size() - 2]) + - ") must have matching inner dimensions"); + return emitOpError("Input A[-1](") + << inputAShape[inputAShape.size() - 1] << ") and B[-2](" + << inputBShape[inputBShape.size() - 2] + << ") must have matching inner dimensions"; } llvm::SmallVector expectedOutputShape; // Verify that the batch dimensions are broadcast compatible and construct the - // expected output shape + // expected output shape. If either of input A or input B is at most 2D + // tensors, the batch dimensions are trivially broadcast compatible. if (inputAShape.size() > 2 || inputBShape.size() > 2) { - llvm::SmallVector inputABatchDims, inputBBatchDims; - - if (inputAShape.size() > 2) { - inputABatchDims.insert(inputABatchDims.begin(), inputAShape.begin(), - inputAShape.end() - 2); - } - - if (inputBShape.size() > 2) { - inputBBatchDims.insert(inputBBatchDims.begin(), inputBShape.begin(), - inputBShape.end() - 2); - } + llvm::SmallVector inputABatchDims(inputAShape.begin(), + inputAShape.end() - 2); + llvm::SmallVector inputBBatchDims(inputBShape.begin(), + inputBShape.end() - 2); // Verify that the batch dimensions of input A and B are broadcast - // compatible + // compatible. llvm::SmallVector broadcastedShape; if (!OpTrait::util::getBroadcastedShape(inputABatchDims, inputBBatchDims, broadcastedShape)) { @@ -1248,10 +1250,8 @@ ::mlir::LogicalResult mlir::tt::ttir::MatmulOp::verify() { ") are not broadcast compatible"); } - // Insert the broadcasted batch dimensions in the expected output shape - expectedOutputShape.insert(expectedOutputShape.begin(), - broadcastedShape.begin(), - broadcastedShape.end()); + // Insert the broadcasted batch dimensions in the expected output shape. + expectedOutputShape = std::move(broadcastedShape); } // Insert the input A and B inner dimensions in expected output shape @@ -1277,26 +1277,25 @@ ::mlir::LogicalResult mlir::tt::ttir::MatmulOp::verify() { return emitOpError("Scalar output must be a 1D tensor of size 1"); } - return llvm::success(); + return success(); } - // Verify that the output shape is correct + // Verify that the output shape is correct. if (outputShape.size() != expectedOutputShape.size()) { - return emitOpError("Output shape rank(" + - std::to_string(outputShape.size()) + - ") must match the expected output shape rank(" + - std::to_string(expectedOutputShape.size()) + ")"); + return emitOpError("Output shape rank(") + << outputShape.size() + << ") must match the expected output shape rank(" + << expectedOutputShape.size() << ")"; } - // Verify each dim of the output shape - for (size_t i = 0; i < outputShape.size(); i++) { - if (outputShape[i] != expectedOutputShape[i]) { - return emitOpError( - "Output shape dimension[" + std::to_string(i) + "](" + - std::to_string(outputShape[i]) + - ") doesn't match the expected output shape dimension[" + - std::to_string(i) + "](" + std::to_string(expectedOutputShape[i]) + - ")"); + // Verify each dim of the output shape. + for (auto [index, outputDim, expectedDim] : llvm::zip( + llvm::seq(outputShape.size()), outputShape, expectedOutputShape)) { + if (outputDim != expectedDim) { + return emitOpError("Output shape dimension[") + << index << "](" << outputDim + << ") doesn't match the expected output shape dimension[" << index + << "](" << expectedDim << ")"; } } diff --git a/test/ttmlir/Dialect/TTIR/matmul/matmul_tests_negative.mlir b/test/ttmlir/Dialect/TTIR/matmul/matmul_tests_negative.mlir index d6c20b0ac..cd4c3a08d 100644 --- a/test/ttmlir/Dialect/TTIR/matmul/matmul_tests_negative.mlir +++ b/test/ttmlir/Dialect/TTIR/matmul/matmul_tests_negative.mlir @@ -1,9 +1,9 @@ // RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s -// Negative tests for matmul operation +// Negative tests for matmul operation. -// Verify that the parsing fails if either of operands is a scalar -module attributes {} { - func.func @matmul_negative_1d_1d_inner_dimension_missmatch(%arg0: tensor, %arg1: tensor<64xbf16>) -> tensor<1xbf16> { +// Verify that the parsing fails if either of operands is a scalar. +module { + func.func @matmul_negative_0d_1d_input_scalar(%arg0: tensor, %arg1: tensor<64xbf16>) -> tensor<1xbf16> { // CHECK: error: 'ttir.matmul' op Input A must be at least a 1D tensor %0 = tensor.empty() : tensor<1xbf16> %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor, tensor<64xbf16>, tensor<1xbf16>) -> tensor<1xbf16> @@ -12,8 +12,8 @@ module attributes {} { } // ----- -module attributes {} { - func.func @matmul_negative_1d_1d_inner_dimension_missmatch(%arg0: tensor<128xbf16>, %arg1: tensor) -> tensor<1xbf16> { +module { + func.func @matmul_negative_1d_0d_input_scalar(%arg0: tensor<128xbf16>, %arg1: tensor) -> tensor<1xbf16> { // CHECK: error: 'ttir.matmul' op Input B must be at least a 1D tensor %0 = tensor.empty() : tensor<1xbf16> %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<128xbf16>, tensor, tensor<1xbf16>) -> tensor<1xbf16> @@ -21,10 +21,10 @@ module attributes {} { } } -// Verify that the parsing fails if the output is a scalar +// Verify that the parsing fails if the output is a scalar. // ----- -module attributes {} { - func.func @matmul_negative_1d_1d_inner_dimension_missmatch(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>) -> tensor { +module { + func.func @matmul_negative_1d_1d_output_scalar(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>) -> tensor { // CHECK: error: 'ttir.matmul' op Scalar output is not supported, output must be at least a 1D tensor %0 = tensor.empty() : tensor %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<128xbf16>, tensor<128xbf16>, tensor) -> tensor @@ -33,8 +33,8 @@ module attributes {} { } // ----- -module attributes {} { - func.func @matmul_negative_1d_1d_inner_dimension_missmatch(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>) -> tensor<2xbf16> { +module { + func.func @matmul_negative_1d_1d_nonone_output(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>) -> tensor<2xbf16> { // CHECK: error: 'ttir.matmul' op Scalar output must be a 1D tensor of size 1 %0 = tensor.empty() : tensor<2xbf16> %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<128xbf16>, tensor<128xbf16>, tensor<2xbf16>) -> tensor<2xbf16> @@ -42,10 +42,10 @@ module attributes {} { } } -// Inner dimension mismatch tests +// Inner dimension mismatch tests. // ----- -module attributes {} { - func.func @matmul_negative_1d_1d_inner_dimension_missmatch(%arg0: tensor<128xbf16>, %arg1: tensor<64xbf16>) -> tensor<1xbf16> { +module { + func.func @matmul_negative_1d_1d_inner_dimension_mismatch(%arg0: tensor<128xbf16>, %arg1: tensor<64xbf16>) -> tensor<1xbf16> { // CHECK: error: 'ttir.matmul' op Input A[-1](128) and B[-2](64) must have matching inner dimensions %0 = tensor.empty() : tensor<1xbf16> %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<128xbf16>, tensor<64xbf16>, tensor<1xbf16>) -> tensor<1xbf16> @@ -54,8 +54,8 @@ module attributes {} { } // ----- -module attributes {} { -func.func @matmul_negative_1d_2d_inner_dimension_missmatch(%arg0: tensor<64xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64xbf16> { +module { +func.func @matmul_negative_1d_2d_inner_dimension_mismatch(%arg0: tensor<64xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64xbf16> { // CHECK: error: 'ttir.matmul' op Input A[-1](64) and B[-2](128) must have matching inner dimensions %0 = tensor.empty() : tensor<64xbf16> %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<64xbf16>, tensor<128x64xbf16>, tensor<64xbf16>) -> tensor<64xbf16> @@ -64,8 +64,8 @@ func.func @matmul_negative_1d_2d_inner_dimension_missmatch(%arg0: tensor<64xbf16 } // ----- -module attributes {} { - func.func @matmul_negative_2d_1d_inner_dimension_missmatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<64xbf16>) -> tensor<64xbf16> { +module { + func.func @matmul_negative_2d_1d_inner_dimension_mismatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<64xbf16>) -> tensor<64xbf16> { // CHECK: error: 'ttir.matmul' op Input A[-1](128) and B[-2](64) must have matching inner dimensions %0 = tensor.empty() : tensor<64xbf16> %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<64xbf16>, tensor<64xbf16>) -> tensor<64xbf16> @@ -74,8 +74,8 @@ module attributes {} { } // ----- -module attributes {} { - func.func @matmul_negative_2d_2d_inner_dimension_missmatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<64x64xbf16> { +module { + func.func @matmul_negative_2d_2d_inner_dimension_mismatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<64x64xbf16> { // CHECK: error: 'ttir.matmul' op Input A[-1](128) and B[-2](64) must have matching inner dimensions %0 = tensor.empty() : tensor<64x64xbf16> %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> @@ -84,8 +84,28 @@ module attributes {} { } // ----- -module attributes {} { - func.func @matmul_negative_nd_nd_inner_dimension_missmatch(%arg0: tensor<7x64x128xbf16>, %arg1: tensor<1x64x128xbf16>) -> tensor<7x64x64xbf16> { +module { + func.func @matmul_negative_2d_transpose_2d_inner_dimension_mismatch(%arg0: tensor<128x64xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<128x128xbf16> { + // CHECK: error: 'ttir.matmul' op Input A[-1](128) and B[-2](64) must have matching inner dimensions + %0 = tensor.empty() : tensor<128x128xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) <{transpose_a = true}> : (tensor<128x64xbf16>, tensor<64x128xbf16>, tensor<128x128xbf16>) -> tensor<128x128xbf16> + return %1 : tensor<128x128xbf16> + } +} + +// ----- +module { + func.func @matmul_negative_2d_2d_transpose_inner_dimension_mismatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64x64xbf16> { + // CHECK: error: 'ttir.matmul' op Input A[-1](128) and B[-2](64) must have matching inner dimensions + %0 = tensor.empty() : tensor<64x64xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) <{transpose_b = true}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %1 : tensor<64x64xbf16> + } +} + +// ----- +module { + func.func @matmul_negative_nd_nd_inner_dimension_mismatch(%arg0: tensor<7x64x128xbf16>, %arg1: tensor<1x64x128xbf16>) -> tensor<7x64x64xbf16> { // CHECK: error: 'ttir.matmul' op Input A[-1](128) and B[-2](64) must have matching inner dimensions %0 = tensor.empty() : tensor<7x64x64xbf16> %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<7x64x128xbf16>, tensor<1x64x128xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> @@ -93,9 +113,9 @@ module attributes {} { } } -// Batch dimension mismatch tests +// Batch dimension mismatch tests. // ----- -module attributes {} { +module { func.func @matmul_negative_nd_nd_same_rank_batch_broadcast_incompatible_1(%arg0: tensor<7x64x128xbf16>, %arg1: tensor<2x128x64xbf16>) -> tensor<7x64x64xbf16> { // CHECK: error: 'ttir.matmul' op Batch dimensions of input A(7) and B(2) are not broadcast compatible %0 = tensor.empty() : tensor<7x64x64xbf16> @@ -105,7 +125,7 @@ module attributes {} { } // ----- -module attributes {} { +module { func.func @matmul_negative_nd_nd_same_rank_batch_broadcast_incompatible_2(%arg0: tensor<2x7x64x128xbf16>, %arg1: tensor<7x1x128x64xbf16>) -> tensor<7x7x64x64xbf16> { // CHECK: error: 'ttir.matmul' op Batch dimensions of input A(2,7) and B(7,1) are not broadcast compatible %0 = tensor.empty() : tensor<7x64x64xbf16> @@ -115,7 +135,7 @@ module attributes {} { } // ----- -module attributes {} { +module { func.func @matmul_negative_nd_nd_different_rank_batch_broadcast_incompatible(%arg0: tensor<12x2x7x64x128xbf16>, %arg1: tensor<7x1x128x64xbf16>) -> tensor<12x7x7x64x64xbf16> { // CHECK: error: 'ttir.matmul' op Batch dimensions of input A(12,2,7) and B(7,1) are not broadcast compatible %0 = tensor.empty() : tensor<12x7x7x64x64xbf16> @@ -124,10 +144,10 @@ module attributes {} { } } -// Output shape mismatch tests +// Output shape mismatch tests. // ----- -module attributes {} { - func.func @matmul_negative_2d_2d_inner_dimension_missmatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64xbf16> { +module { + func.func @matmul_negative_2d_2d_inner_dimension_mismatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64xbf16> { // CHECK: error: 'ttir.matmul' op Output shape rank(1) must match the expected output shape rank(2) %0 = tensor.empty() : tensor<64xbf16> %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64xbf16>) -> tensor<64xbf16> @@ -136,7 +156,7 @@ module attributes {} { } // ----- -module attributes {} { +module { func.func @matmul_negative_2d_2d_inner_dimension_missmatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64x128xbf16> { // CHECK: error: 'ttir.matmul' op Output shape dimension[1](128) doesn't match the expected output shape dimension[1](64) %0 = tensor.empty() : tensor<64x128xbf16> @@ -144,3 +164,24 @@ module attributes {} { return %1 : tensor<64x128xbf16> } } + + +// ----- +module { + func.func @matmul_negative_2d_transpose_2d_output_shape_mismatch(%arg0: tensor<128x64xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<128x128xbf16> { + // CHECK: error: 'ttir.matmul' op Output shape dimension[0](128) doesn't match the expected output shape dimension[0](64) + %0 = tensor.empty() : tensor<128x128xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) <{transpose_a = true}> : (tensor<128x64xbf16>, tensor<128x64xbf16>, tensor<128x128xbf16>) -> tensor<128x128xbf16> + return %1 : tensor<128x128xbf16> + } +} + +// ----- +module { + func.func @matmul_negative_2d_2d_transpose_output_shape_mismatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<128x128xbf16> { + // CHECK: error: 'ttir.matmul' op Output shape dimension[0](128) doesn't match the expected output shape dimension[0](64) + %0 = tensor.empty() : tensor<128x128xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) <{transpose_b = true}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<128x128xbf16>) -> tensor<128x128xbf16> + return %1 : tensor<128x128xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTIR/matmul/matmul_tests_positive.mlir b/test/ttmlir/Dialect/TTIR/matmul/matmul_tests_positive.mlir deleted file mode 100644 index cfc77c0fb..000000000 --- a/test/ttmlir/Dialect/TTIR/matmul/matmul_tests_positive.mlir +++ /dev/null @@ -1,87 +0,0 @@ -// RUN: ttmlir-opt %s | FileCheck %s -module attributes {} { - func.func @matmul_1d_1d(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>) -> tensor<1xbf16> { - %0 = tensor.empty() : tensor<1xbf16> - // CHECK: %[[C:.*]] = "ttir.matmul"[[C:.*]] - %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<128xbf16>, tensor<128xbf16>, tensor<1xbf16>) -> tensor<1xbf16> - return %1 : tensor<1xbf16> - } - - func.func @matmul_1d_2d(%arg0: tensor<128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64xbf16> { - %0 = tensor.empty() : tensor<64xbf16> - // CHECK: %[[C:.*]] = "ttir.matmul"[[C:.*]] - %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<128xbf16>, tensor<128x64xbf16>, tensor<64xbf16>) -> tensor<64xbf16> - return %1 : tensor<64xbf16> - } - - func.func @matmul_2d_1d(%arg0: tensor<64x128xbf16>, %arg1: tensor<128xbf16>) -> tensor<64xbf16> { - %0 = tensor.empty() : tensor<64xbf16> - // CHECK: %[[C:.*]] = "ttir.matmul"[[C:.*]] - %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128xbf16>, tensor<64xbf16>) -> tensor<64xbf16> - return %1 : tensor<64xbf16> - } - - func.func @matmul_2d_2d(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64x64xbf16> { - %0 = tensor.empty() : tensor<64x64xbf16> - // CHECK: %[[C:.*]] = "ttir.matmul"[[C:.*]] - %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> - return %1 : tensor<64x64xbf16> - } - - func.func @matmul_1d_nd(%arg0: tensor<128xbf16>, %arg1: tensor<12x7x128x64xbf16>) -> tensor<12x7x64xbf16> { - %0 = tensor.empty() : tensor<12x7x64xbf16> - // CHECK: %[[C:.*]] = "ttir.matmul"[[C:.*]] - %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<128xbf16>, tensor<12x7x128x64xbf16>, tensor<12x7x64xbf16>) -> tensor<12x7x64xbf16> - return %1 : tensor<12x7x64xbf16> - } - - func.func @matmul_nd_1d(%arg0: tensor<12x7x128x64xbf16>, %arg1: tensor<64xbf16>) -> tensor<12x7x128xbf16> { - %0 = tensor.empty() : tensor<12x7x128xbf16> - // CHECK: %[[C:.*]] = "ttir.matmul"[[C:.*]] - %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<12x7x128x64xbf16>, tensor<64xbf16>, tensor<12x7x128xbf16>) -> tensor<12x7x128xbf16> - return %1 : tensor<12x7x128xbf16> - } - - func.func @matmul_2d_nd(%arg0: tensor<64x128xbf16>, %arg1: tensor<12x7x128x64xbf16>) -> tensor<12x7x64x64xbf16> { - %0 = tensor.empty() : tensor<12x7x64x64xbf16> - // CHECK: %[[C:.*]] = "ttir.matmul"[[C:.*]] - %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<12x7x128x64xbf16>, tensor<12x7x64x64xbf16>) -> tensor<12x7x64x64xbf16> - return %1 : tensor<12x7x64x64xbf16> - } - - func.func @matmul_nd_2d(%arg0: tensor<12x7x128x64xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<12x7x128x128xbf16> { - %0 = tensor.empty() : tensor<12x7x128x128xbf16> - // CHECK: %[[C:.*]] = "ttir.matmul"[[C:.*]] - %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<12x7x128x64xbf16>, tensor<64x128xbf16>, tensor<12x7x128x128xbf16>) -> tensor<12x7x128x128xbf16> - return %1 : tensor<12x7x128x128xbf16> - } - - // matmul nd - nd tests - func.func @matmul_nd_nd_same_rank_same_dims(%arg0: tensor<7x64x128xbf16>, %arg1: tensor<7x128x64xbf16>) -> tensor<7x64x64xbf16> { - %0 = tensor.empty() : tensor<7x64x64xbf16> - // CHECK: %[[C:.*]] = "ttir.matmul"[[C:.*]] - %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<7x64x128xbf16>, tensor<7x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> - return %1 : tensor<7x64x64xbf16> - } - - func.func @matmul_nd_nd_same_rank_broadcastable_dims_1(%arg0: tensor<7x64x128xbf16>, %arg1: tensor<1x128x64xbf16>) -> tensor<7x64x64xbf16> { - %0 = tensor.empty() : tensor<7x64x64xbf16> - // CHECK: %[[C:.*]] = "ttir.matmul"[[C:.*]] - %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<7x64x128xbf16>, tensor<1x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> - return %1 : tensor<7x64x64xbf16> - } - - func.func @matmul_nd_nd_same_rank_broadcastable_dims_2(%arg0: tensor<1x7x64x128xbf16>, %arg1: tensor<7x1x128x64xbf16>) -> tensor<7x7x64x64xbf16> { - %0 = tensor.empty() : tensor<7x7x64x64xbf16> - // CHECK: %[[C:.*]] = "ttir.matmul"[[C:.*]] - %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<1x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<7x7x64x64xbf16>) -> tensor<7x7x64x64xbf16> - return %1 : tensor<7x7x64x64xbf16> - } - - func.func @matmul_nd_nd_different_rank_broadcastable_dims_2(%arg0: tensor<12x1x7x64x128xbf16>, %arg1: tensor<7x1x128x64xbf16>) -> tensor<12x7x7x64x64xbf16> { - %0 = tensor.empty() : tensor<12x7x7x64x64xbf16> - // CHECK: %[[C:.*]] = "ttir.matmul"[[C:.*]] - %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<12x1x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<12x7x7x64x64xbf16>) -> tensor<12x7x7x64x64xbf16> - return %1 : tensor<12x7x7x64x64xbf16> - } -} From a265c1565b9b25659429da643c6cc824231d47c6 Mon Sep 17 00:00:00 2001 From: Aleksandar Zecevic Date: Fri, 10 Jan 2025 15:05:29 +0000 Subject: [PATCH 2/9] ttnn.matmul transpose attrs - matmul DPS runtime fix - TODO: add tests with transpose --- include/ttmlir/Dialect/TTNN/IR/TTNNOps.td | 5 +- include/ttmlir/Target/TTNN/program.fbs | 7 +- lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp | 3 +- lib/Dialect/TTNN/IR/TTNNOps.cpp | 87 +++++++++---------- lib/Target/TTNN/TTNNToFlatbuffer.cpp | 7 +- runtime/lib/ttnn/operations/matmul/matmul.cpp | 28 +++--- 6 files changed, 69 insertions(+), 68 deletions(-) diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index ba2484ac5..490aa9d0d 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -815,7 +815,10 @@ def TTNN_MatmulOp : TTNN_NamedDPSOp<"matmul", > { let arguments = (ins AnyRankedTensor:$a, AnyRankedTensor:$b, - AnyRankedTensor:$output); + AnyRankedTensor:$output, + DefaultValuedAttr:$transpose_a, + DefaultValuedAttr:$transpose_b); + let results = (outs AnyRankedTensor:$result); let extraClassDeclaration = [{ diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index b56cdb39a..0f1ce94bb 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -241,9 +241,12 @@ table LinearOp { // ANCHOR: adding_an_op_matmul_fbs table MatmulOp { - in0: tt.target.TensorRef; - in1: tt.target.TensorRef; + a: tt.target.TensorRef; + b: tt.target.TensorRef; out: tt.target.TensorRef; + transpose_a: bool; + transpose_b: bool; + } // ANCHOR_END: adding_an_op_matmul_fbs diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 2e84eb347..fa4bb2083 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -785,7 +785,8 @@ class MatmulOpConversionPattern : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( op, this->getTypeConverter()->convertType(op.getType()), adaptor.getA(), - adaptor.getB(), adaptor.getOutput()); + adaptor.getB(), adaptor.getOutput(), adaptor.getTransposeA(), + adaptor.getTransposeB()); return success(); } }; diff --git a/lib/Dialect/TTNN/IR/TTNNOps.cpp b/lib/Dialect/TTNN/IR/TTNNOps.cpp index eccb1e9ba..56cd8c8c5 100644 --- a/lib/Dialect/TTNN/IR/TTNNOps.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOps.cpp @@ -881,56 +881,58 @@ ::mlir::LogicalResult mlir::tt::ttnn::MatmulOp::verify() { llvm::SmallVector inputAShape(inputAType.getShape()); llvm::SmallVector inputBShape(inputBType.getShape()); - // Verify that the input A is at least 1D tensor + // Verify that the input A is at least 1D tensor. if (inputAType.getRank() < 1) { return emitOpError("Input A must be at least a 1D tensor"); } - // Verify that the input B is at least 1D tensor + // Verify that the input B is at least 1D tensor. if (inputBType.getRank() < 1) { return emitOpError("Input B must be at least a 1D tensor"); } - // If input A is a vector (1D tensor), 1 is prepended to its dimension for the - // purpose of the matrix multiply. After the matrix multiply, the prepended - // dimension is removed. + // If input A is a vector (1D tensor), 1 is prepended to its dimensions for + // the purpose of the matrix multiplication. After the matrix multiplication, + // the prepended dimension is removed. Otherwise, check if the LHS needs to be + // transposed. if (inputAType.getRank() == 1) { inputAShape.insert(inputAShape.begin(), 1); + } else if (getTransposeA()) { + std::swap(inputAShape[inputAShape.size() - 1], + inputAShape[inputAShape.size() - 2]); } - // If input B is a vector (1D tensor), a 1 is appended to its dimension for - // the purpose of the matrix-vector product and removed after. + // If input B is a vector (1D tensor), a 1 is appended to its dimensions for + // the purpose of the matrix-vector product and removed afterwards. Otherwise, + // check if the RHS needs to be transposed. if (inputBType.getRank() == 1) { inputBShape.push_back(1); + } else if (getTransposeB()) { + std::swap(inputBShape[inputBShape.size() - 1], + inputBShape[inputBShape.size() - 2]); } - // Verify that the input A and input B has matching inner dimensions + // Verify that the input A and input B has matching inner dimensions. if (inputAShape[inputAShape.size() - 1] != inputBShape[inputBShape.size() - 2]) { - return emitOpError( - "Input A[-1](" + std::to_string(inputAShape[inputAShape.size() - 1]) + - ") and B[-2](" + std::to_string(inputBShape[inputBShape.size() - 2]) + - ") must have matching inner dimensions"); + return emitOpError("Input A[-1](") + << inputAShape[inputAShape.size() - 1] << ") and B[-2](" + << inputBShape[inputBShape.size() - 2] + << ") must have matching inner dimensions"; } llvm::SmallVector expectedOutputShape; // Verify that the batch dimensions are broadcast compatible and construct the - // expected output shape + // expected output shape. If either of input A or input B is at most 2D + // tensors, the batch dimensions are trivially broadcast compatible. if (inputAShape.size() > 2 || inputBShape.size() > 2) { - llvm::SmallVector inputABatchDims, inputBBatchDims; - - if (inputAShape.size() > 2) { - inputABatchDims.insert(inputABatchDims.begin(), inputAShape.begin(), - inputAShape.end() - 2); - } - - if (inputBShape.size() > 2) { - inputBBatchDims.insert(inputBBatchDims.begin(), inputBShape.begin(), - inputBShape.end() - 2); - } + llvm::SmallVector inputABatchDims(inputAShape.begin(), + inputAShape.end() - 2); + llvm::SmallVector inputBBatchDims(inputBShape.begin(), + inputBShape.end() - 2); // Verify that the batch dimensions of input A and B are broadcast - // compatible + // compatible. llvm::SmallVector broadcastedShape; if (!OpTrait::util::getBroadcastedShape(inputABatchDims, inputBBatchDims, broadcastedShape)) { @@ -942,10 +944,8 @@ ::mlir::LogicalResult mlir::tt::ttnn::MatmulOp::verify() { ") are not broadcast compatible"); } - // Insert the broadcasted batch dimensions in the expected output shape - expectedOutputShape.insert(expectedOutputShape.begin(), - broadcastedShape.begin(), - broadcastedShape.end()); + // Insert the broadcasted batch dimensions in the expected output shape. + expectedOutputShape = std::move(broadcastedShape); } // Insert the input A and B inner dimensions in expected output shape @@ -971,26 +971,25 @@ ::mlir::LogicalResult mlir::tt::ttnn::MatmulOp::verify() { return emitOpError("Scalar output must be a 1D tensor of size 1"); } - return llvm::success(); + return success(); } - // Verify that the output shape is correct + // Verify that the output shape is correct. if (outputShape.size() != expectedOutputShape.size()) { - return emitOpError("Output shape rank(" + - std::to_string(outputShape.size()) + - ") must match the expected output shape rank(" + - std::to_string(expectedOutputShape.size()) + ")"); + return emitOpError("Output shape rank(") + << outputShape.size() + << ") must match the expected output shape rank(" + << expectedOutputShape.size() << ")"; } - // Verify each dim of the output shape - for (size_t i = 0; i < outputShape.size(); i++) { - if (outputShape[i] != expectedOutputShape[i]) { - return emitOpError( - "Output shape dimension[" + std::to_string(i) + "](" + - std::to_string(outputShape[i]) + - ") doesn't match the expected output shape dimension[" + - std::to_string(i) + "](" + std::to_string(expectedOutputShape[i]) + - ")"); + // Verify each dim of the output shape. + for (auto [index, outputDim, expectedDim] : llvm::zip( + llvm::seq(outputShape.size()), outputShape, expectedOutputShape)) { + if (outputDim != expectedDim) { + return emitOpError("Output shape dimension[") + << index << "](" << outputDim + << ") doesn't match the expected output shape dimension[" << index + << "](" << expectedDim << ")"; } } diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index 055566c24..94a6a2e19 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -433,13 +433,14 @@ createOp(FlatbufferObjectCache &cache, LinearOp op) { // ANCHOR: adding_an_op_matmul_serialize_to_binary ::flatbuffers::Offset<::tt::target::ttnn::MatmulOp> createOp(FlatbufferObjectCache &cache, MatmulOp op) { - auto in0 = + auto a = cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getA())); - auto in1 = + auto b = cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getB())); auto output = cache.at<::tt::target::TensorRef>( getOperandThroughDPSOps(op.getResult())); - return ::tt::target::ttnn::CreateMatmulOp(*cache.fbb, in0, in1, output); + return ::tt::target::ttnn::CreateMatmulOp( + *cache.fbb, a, b, output, op.getTransposeA(), op.getTransposeB()); } // ANCHOR_END: adding_an_op_matmul_serialize_to_binary diff --git a/runtime/lib/ttnn/operations/matmul/matmul.cpp b/runtime/lib/ttnn/operations/matmul/matmul.cpp index 5cfcc6b8d..73d18aff8 100644 --- a/runtime/lib/ttnn/operations/matmul/matmul.cpp +++ b/runtime/lib/ttnn/operations/matmul/matmul.cpp @@ -3,36 +3,30 @@ // SPDX-License-Identifier: Apache-2.0 #include "operations/matmul/matmul.h" + #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/utils.h" #include "tt/runtime/ttnn/utils.h" + #include namespace tt::runtime::ttnn::operations::matmul { // ANCHOR: adding_an_op_matmul_runtime_operations void run(const ::tt::target::ttnn::MatmulOp *op, ProgramContext &context) { ProgramTensorPool &tensorPool = context.getTensorPool(); - const ::ttnn::Tensor &lhs = tensorPool.at(op->in0()->global_id()); - const ::ttnn::Tensor &rhs = tensorPool.at(op->in1()->global_id()); + const ::ttnn::Tensor &lhs = tensorPool.at(op->a()->global_id()); + const ::ttnn::Tensor &rhs = tensorPool.at(op->b()->global_id()); + const ::ttnn::Tensor &out = tensorPool.at(op->out()->global_id()); DEBUG_ASSERT(lhs.is_allocated()); DEBUG_ASSERT(rhs.is_allocated()); - ::ttnn::DataType outputDataType = utils::getDataType(op->out()); - ::tt::tt_metal::MemoryConfig outputMemoryConfig = - ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); - - const std::optional memoryConfig = - std::make_optional(outputMemoryConfig); - - const std::optional dtype = - std::make_optional(outputDataType); - - ::ttnn::Tensor out = ::ttnn::matmul( - lhs, rhs, /*transposeA*/ false, /*transposeB*/ false, memoryConfig, dtype, - /*programConfig*/ std::nullopt, /*activation*/ std::nullopt, - /*computeKernelConfig*/ std::nullopt, /*coreGrid*/ std::nullopt); + DEBUG_ASSERT(out.is_allocated()); - tensorPool.insert_or_assign(op->out()->global_id(), out); + ::ttnn::matmul(lhs, rhs, op->transpose_a(), op->transpose_b(), + /*memory_config=*/std::nullopt, /*dtype=*/std::nullopt, + /*program_config=*/std::nullopt, /*activation=*/std::nullopt, + /*compute_kernel_config=*/std::nullopt, + /*core_grid=*/std::nullopt, /*output_tile=*/std::nullopt, out); } // ANCHOR_END: adding_an_op_matmul_runtime_operations From 0a70432428ea0d7d388521594eb6cf278966c753 Mon Sep 17 00:00:00 2001 From: Aleksandar Zecevic Date: Fri, 10 Jan 2025 15:15:21 +0000 Subject: [PATCH 3/9] Updated matmul tests --- .../TTNN/matmul/matmul_tests_negative.mlir | 101 ++++++++++++------ .../TTNN/matmul/matmul_tests_positive.mlir | 50 ++++++--- .../Dialect/TTNN/matmul/simple_matmul.mlir | 5 +- .../Silicon/TTNN/matmul/llama_matmul.mlir | 4 +- .../Silicon/TTNN/matmul/simple_matmul.mlir | 19 +++- .../TTNN/perf_unit/test_perf_matmul.mlir | 5 +- 6 files changed, 129 insertions(+), 55 deletions(-) diff --git a/test/ttmlir/Dialect/TTNN/matmul/matmul_tests_negative.mlir b/test/ttmlir/Dialect/TTNN/matmul/matmul_tests_negative.mlir index 7ca7efeec..576777cec 100644 --- a/test/ttmlir/Dialect/TTNN/matmul/matmul_tests_negative.mlir +++ b/test/ttmlir/Dialect/TTNN/matmul/matmul_tests_negative.mlir @@ -1,9 +1,9 @@ // RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s -// Negative tests for matmul operation +// Negative tests for matmul operation. -// Verify that the parsing fails if either of operands is a scalar -module attributes {} { - func.func @matmul_negative_1d_1d_inner_dimension_missmatch(%arg0: tensor, %arg1: tensor<64xbf16>) -> tensor<1xbf16> { +// Verify that the parsing fails if either of operands is a scalar. +module { + func.func @matmul_negative_0d_1d_input_scalar(%arg0: tensor, %arg1: tensor<64xbf16>) -> tensor<1xbf16> { // CHECK: error: 'ttnn.matmul' op Input A must be at least a 1D tensor %0 = tensor.empty() : tensor<1xbf16> %1 = "ttnn.matmul"(%arg0, %arg1, %0) : (tensor, tensor<64xbf16>, tensor<1xbf16>) -> tensor<1xbf16> @@ -12,8 +12,8 @@ module attributes {} { } // ----- -module attributes {} { - func.func @matmul_negative_1d_1d_inner_dimension_missmatch(%arg0: tensor<128xbf16>, %arg1: tensor) -> tensor<1xbf16> { +module { + func.func @matmul_negative_1d_0d_input_scalar(%arg0: tensor<128xbf16>, %arg1: tensor) -> tensor<1xbf16> { // CHECK: error: 'ttnn.matmul' op Input B must be at least a 1D tensor %0 = tensor.empty() : tensor<1xbf16> %1 = "ttnn.matmul"(%arg0, %arg1, %0) : (tensor<128xbf16>, tensor, tensor<1xbf16>) -> tensor<1xbf16> @@ -21,10 +21,10 @@ module attributes {} { } } -// Verify that the parsing fails if the output is a scalar +// Verify that the parsing fails if the output is a scalar. // ----- -module attributes {} { - func.func @matmul_negative_1d_1d_inner_dimension_missmatch(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>) -> tensor { +module { + func.func @matmul_negative_1d_1d_output_scalar(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>) -> tensor { // CHECK: error: 'ttnn.matmul' op Scalar output is not supported, output must be at least a 1D tensor %0 = tensor.empty() : tensor %1 = "ttnn.matmul"(%arg0, %arg1, %0) : (tensor<128xbf16>, tensor<128xbf16>, tensor) -> tensor @@ -33,8 +33,8 @@ module attributes {} { } // ----- -module attributes {} { - func.func @matmul_negative_1d_1d_inner_dimension_missmatch(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>) -> tensor<2xbf16> { +module { + func.func @matmul_negative_1d_1d_nonone_output(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>) -> tensor<2xbf16> { // CHECK: error: 'ttnn.matmul' op Scalar output must be a 1D tensor of size 1 %0 = tensor.empty() : tensor<2xbf16> %1 = "ttnn.matmul"(%arg0, %arg1, %0) : (tensor<128xbf16>, tensor<128xbf16>, tensor<2xbf16>) -> tensor<2xbf16> @@ -42,10 +42,10 @@ module attributes {} { } } -// Inner dimension mismatch tests +// Inner dimension mismatch tests. // ----- -module attributes {} { - func.func @matmul_negative_1d_1d_inner_dimension_missmatch(%arg0: tensor<128xbf16>, %arg1: tensor<64xbf16>) -> tensor<1xbf16> { +module { + func.func @matmul_negative_1d_1d_inner_dimension_mismatch(%arg0: tensor<128xbf16>, %arg1: tensor<64xbf16>) -> tensor<1xbf16> { // CHECK: error: 'ttnn.matmul' op Input A[-1](128) and B[-2](64) must have matching inner dimensions %0 = tensor.empty() : tensor<1xbf16> %1 = "ttnn.matmul"(%arg0, %arg1, %0) : (tensor<128xbf16>, tensor<64xbf16>, tensor<1xbf16>) -> tensor<1xbf16> @@ -54,8 +54,8 @@ module attributes {} { } // ----- -module attributes {} { -func.func @matmul_negative_1d_2d_inner_dimension_missmatch(%arg0: tensor<64xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64xbf16> { +module { +func.func @matmul_negative_1d_2d_inner_dimension_mismatch(%arg0: tensor<64xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64xbf16> { // CHECK: error: 'ttnn.matmul' op Input A[-1](64) and B[-2](128) must have matching inner dimensions %0 = tensor.empty() : tensor<64xbf16> %1 = "ttnn.matmul"(%arg0, %arg1, %0) : (tensor<64xbf16>, tensor<128x64xbf16>, tensor<64xbf16>) -> tensor<64xbf16> @@ -64,8 +64,8 @@ func.func @matmul_negative_1d_2d_inner_dimension_missmatch(%arg0: tensor<64xbf16 } // ----- -module attributes {} { - func.func @matmul_negative_2d_1d_inner_dimension_missmatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<64xbf16>) -> tensor<64xbf16> { +module { + func.func @matmul_negative_2d_1d_inner_dimension_mismatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<64xbf16>) -> tensor<64xbf16> { // CHECK: error: 'ttnn.matmul' op Input A[-1](128) and B[-2](64) must have matching inner dimensions %0 = tensor.empty() : tensor<64xbf16> %1 = "ttnn.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<64xbf16>, tensor<64xbf16>) -> tensor<64xbf16> @@ -74,8 +74,8 @@ module attributes {} { } // ----- -module attributes {} { - func.func @matmul_negative_2d_2d_inner_dimension_missmatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<64x64xbf16> { +module { + func.func @matmul_negative_2d_2d_inner_dimension_mismatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<64x64xbf16> { // CHECK: error: 'ttnn.matmul' op Input A[-1](128) and B[-2](64) must have matching inner dimensions %0 = tensor.empty() : tensor<64x64xbf16> %1 = "ttnn.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> @@ -84,8 +84,28 @@ module attributes {} { } // ----- -module attributes {} { - func.func @matmul_negative_nd_nd_inner_dimension_missmatch(%arg0: tensor<7x64x128xbf16>, %arg1: tensor<1x64x128xbf16>) -> tensor<7x64x64xbf16> { +module { + func.func @matmul_negative_2d_transpose_2d_inner_dimension_mismatch(%arg0: tensor<128x64xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<128x128xbf16> { + // CHECK: error: 'ttnn.matmul' op Input A[-1](128) and B[-2](64) must have matching inner dimensions + %0 = tensor.empty() : tensor<128x128xbf16> + %1 = "ttnn.matmul"(%arg0, %arg1, %0) <{transpose_a = true}> : (tensor<128x64xbf16>, tensor<64x128xbf16>, tensor<128x128xbf16>) -> tensor<128x128xbf16> + return %1 : tensor<128x128xbf16> + } +} + +// ----- +module { + func.func @matmul_negative_2d_2d_transpose_inner_dimension_mismatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64x64xbf16> { + // CHECK: error: 'ttnn.matmul' op Input A[-1](128) and B[-2](64) must have matching inner dimensions + %0 = tensor.empty() : tensor<64x64xbf16> + %1 = "ttnn.matmul"(%arg0, %arg1, %0) <{transpose_b = true}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %1 : tensor<64x64xbf16> + } +} + +// ----- +module { + func.func @matmul_negative_nd_nd_inner_dimension_mismatch(%arg0: tensor<7x64x128xbf16>, %arg1: tensor<1x64x128xbf16>) -> tensor<7x64x64xbf16> { // CHECK: error: 'ttnn.matmul' op Input A[-1](128) and B[-2](64) must have matching inner dimensions %0 = tensor.empty() : tensor<7x64x64xbf16> %1 = "ttnn.matmul"(%arg0, %arg1, %0) : (tensor<7x64x128xbf16>, tensor<1x64x128xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> @@ -93,9 +113,9 @@ module attributes {} { } } -// Batch dimension mismatch tests +// Batch dimension mismatch tests. // ----- -module attributes {} { +module { func.func @matmul_negative_nd_nd_same_rank_batch_broadcast_incompatible_1(%arg0: tensor<7x64x128xbf16>, %arg1: tensor<2x128x64xbf16>) -> tensor<7x64x64xbf16> { // CHECK: error: 'ttnn.matmul' op Batch dimensions of input A(7) and B(2) are not broadcast compatible %0 = tensor.empty() : tensor<7x64x64xbf16> @@ -105,7 +125,7 @@ module attributes {} { } // ----- -module attributes {} { +module { func.func @matmul_negative_nd_nd_same_rank_batch_broadcast_incompatible_2(%arg0: tensor<2x7x64x128xbf16>, %arg1: tensor<7x1x128x64xbf16>) -> tensor<7x7x64x64xbf16> { // CHECK: error: 'ttnn.matmul' op Batch dimensions of input A(2,7) and B(7,1) are not broadcast compatible %0 = tensor.empty() : tensor<7x64x64xbf16> @@ -115,7 +135,7 @@ module attributes {} { } // ----- -module attributes {} { +module { func.func @matmul_negative_nd_nd_different_rank_batch_broadcast_incompatible(%arg0: tensor<12x2x7x64x128xbf16>, %arg1: tensor<7x1x128x64xbf16>) -> tensor<12x7x7x64x64xbf16> { // CHECK: error: 'ttnn.matmul' op Batch dimensions of input A(12,2,7) and B(7,1) are not broadcast compatible %0 = tensor.empty() : tensor<12x7x7x64x64xbf16> @@ -124,10 +144,10 @@ module attributes {} { } } -// Output shape mismatch tests +// Output shape mismatch tests. // ----- -module attributes {} { - func.func @matmul_negative_2d_2d_inner_dimension_missmatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64xbf16> { +module { + func.func @matmul_negative_2d_2d_inner_dimension_mismatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64xbf16> { // CHECK: error: 'ttnn.matmul' op Output shape rank(1) must match the expected output shape rank(2) %0 = tensor.empty() : tensor<64xbf16> %1 = "ttnn.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64xbf16>) -> tensor<64xbf16> @@ -136,7 +156,7 @@ module attributes {} { } // ----- -module attributes {} { +module { func.func @matmul_negative_2d_2d_inner_dimension_missmatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64x128xbf16> { // CHECK: error: 'ttnn.matmul' op Output shape dimension[1](128) doesn't match the expected output shape dimension[1](64) %0 = tensor.empty() : tensor<64x128xbf16> @@ -144,3 +164,24 @@ module attributes {} { return %1 : tensor<64x128xbf16> } } + + +// ----- +module { + func.func @matmul_negative_2d_transpose_2d_output_shape_mismatch(%arg0: tensor<128x64xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<128x128xbf16> { + // CHECK: error: 'ttnn.matmul' op Output shape dimension[0](128) doesn't match the expected output shape dimension[0](64) + %0 = tensor.empty() : tensor<128x128xbf16> + %1 = "ttnn.matmul"(%arg0, %arg1, %0) <{transpose_a = true}> : (tensor<128x64xbf16>, tensor<128x64xbf16>, tensor<128x128xbf16>) -> tensor<128x128xbf16> + return %1 : tensor<128x128xbf16> + } +} + +// ----- +module { + func.func @matmul_negative_2d_2d_transpose_output_shape_mismatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<128x128xbf16> { + // CHECK: error: 'ttnn.matmul' op Output shape dimension[0](128) doesn't match the expected output shape dimension[0](64) + %0 = tensor.empty() : tensor<128x128xbf16> + %1 = "ttnn.matmul"(%arg0, %arg1, %0) <{transpose_b = true}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<128x128xbf16>) -> tensor<128x128xbf16> + return %1 : tensor<128x128xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTNN/matmul/matmul_tests_positive.mlir b/test/ttmlir/Dialect/TTNN/matmul/matmul_tests_positive.mlir index a62e53211..aa18d0cbb 100644 --- a/test/ttmlir/Dialect/TTNN/matmul/matmul_tests_positive.mlir +++ b/test/ttmlir/Dialect/TTNN/matmul/matmul_tests_positive.mlir @@ -1,87 +1,109 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -module attributes {} { +module { func.func @matmul_1d_1d(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>) -> tensor<1xbf16> { %0 = tensor.empty() : tensor<1xbf16> - // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] + // CHECK: "ttnn.matmul" %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<128xbf16>, tensor<128xbf16>, tensor<1xbf16>) -> tensor<1xbf16> return %1 : tensor<1xbf16> } func.func @matmul_1d_2d(%arg0: tensor<128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64xbf16> { %0 = tensor.empty() : tensor<64xbf16> - // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] + // CHECK: "ttnn.matmul" %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<128xbf16>, tensor<128x64xbf16>, tensor<64xbf16>) -> tensor<64xbf16> return %1 : tensor<64xbf16> } func.func @matmul_2d_1d(%arg0: tensor<64x128xbf16>, %arg1: tensor<128xbf16>) -> tensor<64xbf16> { %0 = tensor.empty() : tensor<64xbf16> - // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] + // CHECK: "ttnn.matmul" %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128xbf16>, tensor<64xbf16>) -> tensor<64xbf16> return %1 : tensor<64xbf16> } func.func @matmul_2d_2d(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64x64xbf16> { %0 = tensor.empty() : tensor<64x64xbf16> - // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] + // CHECK: "ttnn.matmul" %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> return %1 : tensor<64x64xbf16> } func.func @matmul_1d_nd(%arg0: tensor<128xbf16>, %arg1: tensor<12x7x128x64xbf16>) -> tensor<12x7x64xbf16> { %0 = tensor.empty() : tensor<12x7x64xbf16> - // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] + // CHECK: "ttnn.matmul" %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<128xbf16>, tensor<12x7x128x64xbf16>, tensor<12x7x64xbf16>) -> tensor<12x7x64xbf16> return %1 : tensor<12x7x64xbf16> } func.func @matmul_nd_1d(%arg0: tensor<12x7x128x64xbf16>, %arg1: tensor<64xbf16>) -> tensor<12x7x128xbf16> { %0 = tensor.empty() : tensor<12x7x128xbf16> - // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] + // CHECK: "ttnn.matmul" %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<12x7x128x64xbf16>, tensor<64xbf16>, tensor<12x7x128xbf16>) -> tensor<12x7x128xbf16> return %1 : tensor<12x7x128xbf16> } func.func @matmul_2d_nd(%arg0: tensor<64x128xbf16>, %arg1: tensor<12x7x128x64xbf16>) -> tensor<12x7x64x64xbf16> { %0 = tensor.empty() : tensor<12x7x64x64xbf16> - // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] + // CHECK: "ttnn.matmul" %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<12x7x128x64xbf16>, tensor<12x7x64x64xbf16>) -> tensor<12x7x64x64xbf16> return %1 : tensor<12x7x64x64xbf16> } func.func @matmul_nd_2d(%arg0: tensor<12x7x128x64xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<12x7x128x128xbf16> { %0 = tensor.empty() : tensor<12x7x128x128xbf16> - // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] + // CHECK: "ttnn.matmul" %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<12x7x128x64xbf16>, tensor<64x128xbf16>, tensor<12x7x128x128xbf16>) -> tensor<12x7x128x128xbf16> return %1 : tensor<12x7x128x128xbf16> } - // matmul nd - nd tests + // Matmul nd - nd tests. func.func @matmul_nd_nd_same_rank_same_dims(%arg0: tensor<7x64x128xbf16>, %arg1: tensor<7x128x64xbf16>) -> tensor<7x64x64xbf16> { %0 = tensor.empty() : tensor<7x64x64xbf16> - // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] + // CHECK: "ttnn.matmul" %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<7x64x128xbf16>, tensor<7x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> return %1 : tensor<7x64x64xbf16> } func.func @matmul_nd_nd_same_rank_broadcastable_dims_1(%arg0: tensor<7x64x128xbf16>, %arg1: tensor<1x128x64xbf16>) -> tensor<7x64x64xbf16> { %0 = tensor.empty() : tensor<7x64x64xbf16> - // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] + // CHECK: "ttnn.matmul" %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<7x64x128xbf16>, tensor<1x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> return %1 : tensor<7x64x64xbf16> } func.func @matmul_nd_nd_same_rank_broadcastable_dims_2(%arg0: tensor<1x7x64x128xbf16>, %arg1: tensor<7x1x128x64xbf16>) -> tensor<7x7x64x64xbf16> { %0 = tensor.empty() : tensor<7x7x64x64xbf16> - // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] + // CHECK: "ttnn.matmul" %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<1x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<7x7x64x64xbf16>) -> tensor<7x7x64x64xbf16> return %1 : tensor<7x7x64x64xbf16> } func.func @matmul_nd_nd_different_rank_broadcastable_dims_2(%arg0: tensor<12x1x7x64x128xbf16>, %arg1: tensor<7x1x128x64xbf16>) -> tensor<12x7x7x64x64xbf16> { %0 = tensor.empty() : tensor<12x7x7x64x64xbf16> - // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] + // CHECK: "ttnn.matmul" %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<12x1x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<12x7x7x64x64xbf16>) -> tensor<12x7x7x64x64xbf16> return %1 : tensor<12x7x7x64x64xbf16> } + + // Matmul with transposed inputs tests. + func.func @matmul_2d_tranpose_2d(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<128x128xbf16> { + %0 = tensor.empty() : tensor<128x128xbf16> + // CHECK: "ttnn.matmul" + %1 = "ttir.matmul"(%arg0, %arg1, %0) <{transpose_a = true}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<128x128xbf16>) -> tensor<128x128xbf16> + return %1 : tensor<128x128xbf16> + } + + func.func @matmul_2d_2d_transpose(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<64x64xbf16> { + %0 = tensor.empty() : tensor<64x64xbf16> + // CHECK: "ttnn.matmul" + %1 = "ttir.matmul"(%arg0, %arg1, %0) <{transpose_b = true}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %1 : tensor<64x64xbf16> + } + + func.func @matmul_2d_tranpose_2d_transpose(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<128x128xbf16> { + %0 = tensor.empty() : tensor<128x128xbf16> + // CHECK: "ttnn.matmul" + %1 = "ttir.matmul"(%arg0, %arg1, %0) <{transpose_a = true, transpose_b = true}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<128x128xbf16>) -> tensor<128x128xbf16> + return %1 : tensor<128x128xbf16> + } } diff --git a/test/ttmlir/Dialect/TTNN/matmul/simple_matmul.mlir b/test/ttmlir/Dialect/TTNN/matmul/simple_matmul.mlir index 87db65078..c8285f0a1 100644 --- a/test/ttmlir/Dialect/TTNN/matmul/simple_matmul.mlir +++ b/test/ttmlir/Dialect/TTNN/matmul/simple_matmul.mlir @@ -1,9 +1,8 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -// CHECK: #[[TILED_LAYOUT:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<2x4x!tt.tile<32x32, bf16>, #dram>, > -module attributes {} { +module { func.func @forward(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x96xbf16>) -> tensor<64x96xbf16> { %0 = tensor.empty() : tensor<64x96xbf16> - // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] + // CHECK: "ttnn.matmul" %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16> return %1 : tensor<64x96xbf16> } diff --git a/test/ttmlir/Silicon/TTNN/matmul/llama_matmul.mlir b/test/ttmlir/Silicon/TTNN/matmul/llama_matmul.mlir index e777cd55a..f3bcba8f1 100644 --- a/test/ttmlir/Silicon/TTNN/matmul/llama_matmul.mlir +++ b/test/ttmlir/Silicon/TTNN/matmul/llama_matmul.mlir @@ -1,10 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -module attributes {} { +module { func.func @forward(%arg0: tensor<1x11x2048xf32>, %arg1: tensor<2048x128256xf32>) -> tensor<1x11x128256xf32> { %0 = tensor.empty() : tensor<1x11x128256xf32> - // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] + // CHECK: "ttnn.matmul" %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<1x11x2048xf32>, tensor<2048x128256xf32>, tensor<1x11x128256xf32>) -> tensor<1x11x128256xf32> return %1 : tensor<1x11x128256xf32> } diff --git a/test/ttmlir/Silicon/TTNN/matmul/simple_matmul.mlir b/test/ttmlir/Silicon/TTNN/matmul/simple_matmul.mlir index f221001bb..e76b0b3d2 100644 --- a/test/ttmlir/Silicon/TTNN/matmul/simple_matmul.mlir +++ b/test/ttmlir/Silicon/TTNN/matmul/simple_matmul.mlir @@ -1,12 +1,25 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -// CHECK: #[[TILED_LAYOUT:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<2x4x!tt.tile<32x32, bf16>, #dram>, > -module attributes {} { +module { func.func @forward(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x96xbf16>) -> tensor<64x96xbf16> { %0 = tensor.empty() : tensor<64x96xbf16> - // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] + // CHECK: "ttnn.matmul" %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16> return %1 : tensor<64x96xbf16> } + + func.func @matmul_transpose_first(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<128x128xbf16> { + %0 = tensor.empty() : tensor<128x128xbf16> + // CHECK: "ttnn.matmul" + %1 = "ttir.matmul"(%arg0, %arg1, %0) <{transpose_a = true}>: (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<128x128xbf16>) -> tensor<128x128xbf16> + return %1 : tensor<128x128xbf16> + } + + func.func @matmul_transpose_second(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<64x64xbf16> { + %0 = tensor.empty() : tensor<64x64xbf16> + // CHECK: "ttnn.matmul" + %1 = "ttir.matmul"(%arg0, %arg1, %0) <{transpose_b = true}>: (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %1 : tensor<64x64xbf16> + } } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_matmul.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_matmul.mlir index f221001bb..7ce93f024 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_matmul.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_matmul.mlir @@ -1,11 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -// CHECK: #[[TILED_LAYOUT:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<2x4x!tt.tile<32x32, bf16>, #dram>, > -module attributes {} { +module { func.func @forward(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x96xbf16>) -> tensor<64x96xbf16> { %0 = tensor.empty() : tensor<64x96xbf16> - // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] + // CHECK: "ttnn.matmul" %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16> return %1 : tensor<64x96xbf16> } From daa5ed0f78e6c6b93255791286887b5e89d03d10 Mon Sep 17 00:00:00 2001 From: Aleksandar Zecevic Date: Fri, 10 Jan 2025 15:28:08 +0000 Subject: [PATCH 4/9] ttir.linear tranpose attrs - TODO: add ttir negative tests with transpose --- include/ttmlir/Dialect/TTIR/IR/TTIROps.td | 4 +- lib/Dialect/TTIR/IR/TTIROps.cpp | 87 +++++++++---------- .../TTIR/linear/linear_tests_negative.mlir | 14 +-- 3 files changed, 53 insertions(+), 52 deletions(-) diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index eec3d4a38..04448c4f3 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -1305,7 +1305,9 @@ def TTIR_LinearOp : TTIR_DPSOp<"linear"> { let arguments = (ins AnyRankedTensor:$a, AnyRankedTensor:$b, Optional:$bias, - AnyRankedTensor:$output); + AnyRankedTensor:$output, + DefaultValuedAttr:$transpose_a, + DefaultValuedAttr:$transpose_b); let results = (outs AnyRankedTensor:$result); diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index d4c08b964..0a2d792ec 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -1046,43 +1046,45 @@ ::mlir::LogicalResult mlir::tt::ttir::LinearOp::verify() { return emitOpError("Input B must be at least a 1D tensor"); } - // If input A is a vector (1D tensor), 1 is prepended to its dimension for the - // purpose of the matrix multiplication. After the matrix multiplication, the - // prepended dimension is removed. + // If input A is a vector (1D tensor), 1 is prepended to its dimensions for + // the purpose of the matrix multiplication. After the matrix multiplication, + // the prepended dimension is removed. Otherwise, check if the LHS needs to be + // transposed. if (inputAType.getRank() == 1) { inputAShape.insert(inputAShape.begin(), 1); + } else if (getTransposeA()) { + std::swap(inputAShape[inputAShape.size() - 1], + inputAShape[inputAShape.size() - 2]); } - // If input B is a vector (1D tensor), a 1 is appended to its dimension for - // the purpose of the matrix-vector product and removed afterwards. + // If input B is a vector (1D tensor), a 1 is appended to its dimensions for + // the purpose of the matrix-vector product and removed afterwards. Otherwise, + // check if the RHS needs to be transposed. if (inputBType.getRank() == 1) { inputBShape.push_back(1); + } else if (getTransposeB()) { + std::swap(inputBShape[inputBShape.size() - 1], + inputBShape[inputBShape.size() - 2]); } // Verify that the input A and input B has matching inner dimensions. if (inputAShape[inputAShape.size() - 1] != inputBShape[inputBShape.size() - 2]) { - return emitOpError( - "Input A[-1](" + std::to_string(inputAShape[inputAShape.size() - 1]) + - ") and B[-2](" + std::to_string(inputBShape[inputBShape.size() - 2]) + - ") must have matching inner dimensions"); + return emitOpError("Input A[-1](") + << inputAShape[inputAShape.size() - 1] << ") and B[-2](" + << inputBShape[inputBShape.size() - 2] + << ") must have matching inner dimensions"; } llvm::SmallVector expectedOutputShape; // Verify that the batch dimensions are broadcast compatible and construct the - // expected output shape. + // expected output shape. If either of input A or input B is at most 2D + // tensors, the batch dimensions are trivially broadcast compatible. if (inputAShape.size() > 2 || inputBShape.size() > 2) { - llvm::SmallVector inputABatchDims, inputBBatchDims; - - if (inputAShape.size() > 2) { - inputABatchDims.insert(inputABatchDims.begin(), inputAShape.begin(), - inputAShape.end() - 2); - } - - if (inputBShape.size() > 2) { - inputBBatchDims.insert(inputBBatchDims.begin(), inputBShape.begin(), - inputBShape.end() - 2); - } + llvm::SmallVector inputABatchDims(inputAShape.begin(), + inputAShape.end() - 2); + llvm::SmallVector inputBBatchDims(inputBShape.begin(), + inputBShape.end() - 2); // Verify that the batch dimensions of input A and B are broadcast // compatible. @@ -1098,12 +1100,10 @@ ::mlir::LogicalResult mlir::tt::ttir::LinearOp::verify() { } // Insert the broadcasted batch dimensions in the expected output shape. - expectedOutputShape.insert(expectedOutputShape.begin(), - broadcastedShape.begin(), - broadcastedShape.end()); + expectedOutputShape = std::move(broadcastedShape); } - // Insert the input A and B inner dimensions in expected output shape. + // Insert the input A and B inner dimensions in expected output shape // Consider the case where input A and B are vectors. In that case, // the dimension 1 is ommited from the output shape. if (inputAType.getRank() > 1) { @@ -1116,21 +1116,21 @@ ::mlir::LogicalResult mlir::tt::ttir::LinearOp::verify() { if (biasType) { // Verify that the input bias is at least 1D tensor. - if (biasType.value().getRank() < 1) { + if (biasType->getRank() < 1) { return emitOpError("Bias must be at least a 1D tensor"); } - llvm::SmallVector biasShape(biasType.value().getShape()); + llvm::SmallVector biasShape(biasType->getShape()); // Verify that the dimensions of the matmul of A and B are broadcast // compatible with input bias. llvm::SmallVector matmulShape = expectedOutputShape; if (!OpTrait::util::getBroadcastedShape(matmulShape, biasShape, expectedOutputShape)) { - return emitOpError("Bias shape(" + ttmlir::utils::join(biasShape, ",") + - ") is not broadcast compatible with the matmul output " - "shape(" + - ttmlir::utils::join(matmulShape, ",") + ")"); + return emitOpError("Bias shape(") + << ttmlir::utils::join(biasShape, ",") + << ") is not broadcast compatible with the matmul output shape(" + << ttmlir::utils::join(matmulShape, ",") << ")"; } } @@ -1149,23 +1149,22 @@ ::mlir::LogicalResult mlir::tt::ttir::LinearOp::verify() { return success(); } - // Verify that the output shape dimension count is correct. + // Verify that the output shape is correct. if (outputShape.size() != expectedOutputShape.size()) { - return emitOpError("Output shape rank(" + - std::to_string(outputShape.size()) + - ") must match the expected output shape rank(" + - std::to_string(expectedOutputShape.size()) + ")"); + return emitOpError("Output shape rank(") + << outputShape.size() + << ") must match the expected output shape rank(" + << expectedOutputShape.size() << ")"; } // Verify each dim of the output shape. - for (size_t i = 0; i < outputShape.size(); i++) { - if (outputShape[i] != expectedOutputShape[i]) { - return emitOpError( - "Output shape dimension[" + std::to_string(i) + "](" + - std::to_string(outputShape[i]) + - ") doesn't match the expected output shape dimension[" + - std::to_string(i) + "](" + std::to_string(expectedOutputShape[i]) + - ")"); + for (auto [index, outputDim, expectedDim] : llvm::zip( + llvm::seq(outputShape.size()), outputShape, expectedOutputShape)) { + if (outputDim != expectedDim) { + return emitOpError("Output shape dimension[") + << index << "](" << outputDim + << ") doesn't match the expected output shape dimension[" << index + << "](" << expectedDim << ")"; } } diff --git a/test/ttmlir/Dialect/TTIR/linear/linear_tests_negative.mlir b/test/ttmlir/Dialect/TTIR/linear/linear_tests_negative.mlir index 0154deff9..a67d9bef0 100644 --- a/test/ttmlir/Dialect/TTIR/linear/linear_tests_negative.mlir +++ b/test/ttmlir/Dialect/TTIR/linear/linear_tests_negative.mlir @@ -1,7 +1,7 @@ // RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s -// Negative tests for linear operation +// Negative tests for linear operation. -// Verify that the parsing fails if either of operands is a scalar +// Verify that the parsing fails if either of operands is a scalar. module { func.func @linear_negative_1d_1d_scalar_a(%arg0: tensor, %arg1: tensor<64xbf16>) -> tensor<1xbf16> { // CHECK: error: 'ttir.linear' op Input A must be at least a 1D tensor @@ -31,7 +31,7 @@ module { } } -// Verifty that the parsing fails if the output is a scalar +// Verifty that the parsing fails if the output is a scalar. // ----- module { func.func @linear_negative_1d_1d_scalar_output(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>) -> tensor { @@ -52,7 +52,7 @@ module { } } -// Inner dimension mismatch tests +// Inner dimension mismatch tests. // ----- module { func.func @linear_negative_1d_1d_inner_dimension_mismatch(%arg0: tensor<128xbf16>, %arg1: tensor<64xbf16>) -> tensor<1xbf16> { @@ -103,7 +103,7 @@ module { } } -// Batch dimension mismatch tests +// Batch dimension mismatch tests. // ----- module { func.func @linear_negative_nd_nd_same_rank_batch_broadcast_incompatible_1(%arg0: tensor<7x64x128xbf16>, %arg1: tensor<2x128x64xbf16>) -> tensor<7x64x64xbf16> { @@ -134,7 +134,7 @@ module { } } -// Bias shape mismatch tests +// Bias shape mismatch tests. // ----- module { func.func @linear_negative_matmul_bias_broadcast_incompatible(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<2x64xbf16>) -> tensor<64x64xbf16> { @@ -155,7 +155,7 @@ module { } } -// Output shape mismatch tests +// Output shape mismatch tests. // ----- module { func.func @linear_negative_2d_2d_output_shape_mismatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64xbf16> { From 3a1bc296bafd327a6f63f477b11ed8fea3504830 Mon Sep 17 00:00:00 2001 From: Aleksandar Zecevic Date: Fri, 10 Jan 2025 15:36:32 +0000 Subject: [PATCH 5/9] ttnn.linear transpose attrs - TODO: add linear tests with transpose --- include/ttmlir/Dialect/TTNN/IR/TTNNOps.td | 5 +- include/ttmlir/Target/TTNN/program.fbs | 6 +- lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp | 3 +- lib/Dialect/TTNN/IR/TTNNOps.cpp | 87 +++++++++---------- lib/Target/TTNN/TTNNToFlatbuffer.cpp | 7 +- runtime/lib/ttnn/operations/matmul/matmul.cpp | 28 ++---- 6 files changed, 66 insertions(+), 70 deletions(-) diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 490aa9d0d..40eb46568 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -798,7 +798,10 @@ def TTNN_LinearOp : TTNN_NamedDPSOp<"linear"> { let arguments = (ins AnyRankedTensor:$a, AnyRankedTensor:$b, Optional:$bias, - AnyRankedTensor:$output); + AnyRankedTensor:$output, + DefaultValuedAttr:$transpose_a, + DefaultValuedAttr:$transpose_b); + let results = (outs AnyRankedTensor:$result); let extraClassDeclaration = [{ diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index 0f1ce94bb..71d4597f1 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -233,10 +233,12 @@ table SliceOp { } table LinearOp { - in0: tt.target.TensorRef; - in1: tt.target.TensorRef; + a: tt.target.TensorRef; + b: tt.target.TensorRef; bias: tt.target.TensorRef; out: tt.target.TensorRef; + transpose_a: bool; + transpose_b: bool; } // ANCHOR: adding_an_op_matmul_fbs diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index fa4bb2083..6789a6edb 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -770,7 +770,8 @@ class LinearOpConversionPattern : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( op, this->getTypeConverter()->convertType(op.getType()), adaptor.getA(), - adaptor.getB(), adaptor.getBias(), adaptor.getOutput()); + adaptor.getB(), adaptor.getBias(), adaptor.getOutput(), + adaptor.getTransposeA(), adaptor.getTransposeB()); return success(); } }; diff --git a/lib/Dialect/TTNN/IR/TTNNOps.cpp b/lib/Dialect/TTNN/IR/TTNNOps.cpp index 56cd8c8c5..8d2bb98d4 100644 --- a/lib/Dialect/TTNN/IR/TTNNOps.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOps.cpp @@ -740,43 +740,45 @@ ::mlir::LogicalResult mlir::tt::ttnn::LinearOp::verify() { return emitOpError("Input B must be at least a 1D tensor"); } - // If input A is a vector (1D tensor), 1 is prepended to its dimension for the - // purpose of the matrix multiplication. After the matrix multiplication, the - // prepended dimension is removed. + // If input A is a vector (1D tensor), 1 is prepended to its dimensions for + // the purpose of the matrix multiplication. After the matrix multiplication, + // the prepended dimension is removed. Otherwise, check if the LHS needs to be + // transposed. if (inputAType.getRank() == 1) { inputAShape.insert(inputAShape.begin(), 1); + } else if (getTransposeA()) { + std::swap(inputAShape[inputAShape.size() - 1], + inputAShape[inputAShape.size() - 2]); } - // If input B is a vector (1D tensor), a 1 is appended to its dimension for - // the purpose of the matrix-vector product and removed afterwards. + // If input B is a vector (1D tensor), a 1 is appended to its dimensions for + // the purpose of the matrix-vector product and removed afterwards. Otherwise, + // check if the RHS needs to be transposed. if (inputBType.getRank() == 1) { inputBShape.push_back(1); + } else if (getTransposeB()) { + std::swap(inputBShape[inputBShape.size() - 1], + inputBShape[inputBShape.size() - 2]); } // Verify that the input A and input B has matching inner dimensions. if (inputAShape[inputAShape.size() - 1] != inputBShape[inputBShape.size() - 2]) { - return emitOpError( - "Input A[-1](" + std::to_string(inputAShape[inputAShape.size() - 1]) + - ") and B[-2](" + std::to_string(inputBShape[inputBShape.size() - 2]) + - ") must have matching inner dimensions"); + return emitOpError("Input A[-1](") + << inputAShape[inputAShape.size() - 1] << ") and B[-2](" + << inputBShape[inputBShape.size() - 2] + << ") must have matching inner dimensions"; } llvm::SmallVector expectedOutputShape; // Verify that the batch dimensions are broadcast compatible and construct the - // expected output shape. + // expected output shape. If either of input A or input B is at most 2D + // tensors, the batch dimensions are trivially broadcast compatible. if (inputAShape.size() > 2 || inputBShape.size() > 2) { - llvm::SmallVector inputABatchDims, inputBBatchDims; - - if (inputAShape.size() > 2) { - inputABatchDims.insert(inputABatchDims.begin(), inputAShape.begin(), - inputAShape.end() - 2); - } - - if (inputBShape.size() > 2) { - inputBBatchDims.insert(inputBBatchDims.begin(), inputBShape.begin(), - inputBShape.end() - 2); - } + llvm::SmallVector inputABatchDims(inputAShape.begin(), + inputAShape.end() - 2); + llvm::SmallVector inputBBatchDims(inputBShape.begin(), + inputBShape.end() - 2); // Verify that the batch dimensions of input A and B are broadcast // compatible. @@ -792,12 +794,10 @@ ::mlir::LogicalResult mlir::tt::ttnn::LinearOp::verify() { } // Insert the broadcasted batch dimensions in the expected output shape. - expectedOutputShape.insert(expectedOutputShape.begin(), - broadcastedShape.begin(), - broadcastedShape.end()); + expectedOutputShape = std::move(broadcastedShape); } - // Insert the input A and B inner dimensions in expected output shape. + // Insert the input A and B inner dimensions in expected output shape // Consider the case where input A and B are vectors. In that case, // the dimension 1 is ommited from the output shape. if (inputAType.getRank() > 1) { @@ -810,21 +810,21 @@ ::mlir::LogicalResult mlir::tt::ttnn::LinearOp::verify() { if (biasType) { // Verify that the input bias is at least 1D tensor. - if (biasType.value().getRank() < 1) { + if (biasType->getRank() < 1) { return emitOpError("Bias must be at least a 1D tensor"); } - llvm::SmallVector biasShape(biasType.value().getShape()); + llvm::SmallVector biasShape(biasType->getShape()); // Verify that the dimensions of the matmul of A and B are broadcast // compatible with input bias. llvm::SmallVector matmulShape = expectedOutputShape; if (!OpTrait::util::getBroadcastedShape(matmulShape, biasShape, expectedOutputShape)) { - return emitOpError("Bias shape(" + ttmlir::utils::join(biasShape, ",") + - ") is not broadcast compatible with the matmul output " - "shape(" + - ttmlir::utils::join(matmulShape, ",") + ")"); + return emitOpError("Bias shape(") + << ttmlir::utils::join(biasShape, ",") + << ") is not broadcast compatible with the matmul output shape(" + << ttmlir::utils::join(matmulShape, ",") << ")"; } } @@ -843,23 +843,22 @@ ::mlir::LogicalResult mlir::tt::ttnn::LinearOp::verify() { return success(); } - // Verify that the output shape dimension count is correct. + // Verify that the output shape is correct. if (outputShape.size() != expectedOutputShape.size()) { - return emitOpError("Output shape rank(" + - std::to_string(outputShape.size()) + - ") must match the expected output shape rank(" + - std::to_string(expectedOutputShape.size()) + ")"); + return emitOpError("Output shape rank(") + << outputShape.size() + << ") must match the expected output shape rank(" + << expectedOutputShape.size() << ")"; } // Verify each dim of the output shape. - for (size_t i = 0; i < outputShape.size(); i++) { - if (outputShape[i] != expectedOutputShape[i]) { - return emitOpError( - "Output shape dimension[" + std::to_string(i) + "](" + - std::to_string(outputShape[i]) + - ") doesn't match the expected output shape dimension[" + - std::to_string(i) + "](" + std::to_string(expectedOutputShape[i]) + - ")"); + for (auto [index, outputDim, expectedDim] : llvm::zip( + llvm::seq(outputShape.size()), outputShape, expectedOutputShape)) { + if (outputDim != expectedDim) { + return emitOpError("Output shape dimension[") + << index << "](" << outputDim + << ") doesn't match the expected output shape dimension[" << index + << "](" << expectedDim << ")"; } } diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index 94a6a2e19..d35cbd5fb 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -417,9 +417,9 @@ createOp(FlatbufferObjectCache &cache, OnesOp op) { ::flatbuffers::Offset<::tt::target::ttnn::LinearOp> createOp(FlatbufferObjectCache &cache, LinearOp op) { - auto in0 = + auto a = cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getA())); - auto in1 = + auto b = cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getB())); auto bias = op.getODSOperands(2).empty() ? flatbuffers::Offset<::tt::target::TensorRef>() @@ -427,7 +427,8 @@ createOp(FlatbufferObjectCache &cache, LinearOp op) { getOperandThroughDPSOps(op.getBias())); auto output = cache.at<::tt::target::TensorRef>( getOperandThroughDPSOps(op.getResult())); - return ::tt::target::ttnn::CreateLinearOp(*cache.fbb, in0, in1, bias, output); + return ::tt::target::ttnn::CreateLinearOp( + *cache.fbb, a, b, bias, output, op.getTransposeA(), op.getTransposeB()); } // ANCHOR: adding_an_op_matmul_serialize_to_binary diff --git a/runtime/lib/ttnn/operations/matmul/matmul.cpp b/runtime/lib/ttnn/operations/matmul/matmul.cpp index 73d18aff8..37e6bd671 100644 --- a/runtime/lib/ttnn/operations/matmul/matmul.cpp +++ b/runtime/lib/ttnn/operations/matmul/matmul.cpp @@ -32,31 +32,21 @@ void run(const ::tt::target::ttnn::MatmulOp *op, ProgramContext &context) { void run(const ::tt::target::ttnn::LinearOp *op, ProgramContext &context) { ProgramTensorPool &tensorPool = context.getTensorPool(); - const ::ttnn::Tensor &lhs = tensorPool.at(op->in0()->global_id()); - const ::ttnn::Tensor &rhs = tensorPool.at(op->in1()->global_id()); + const ::ttnn::Tensor &lhs = tensorPool.at(op->a()->global_id()); + const ::ttnn::Tensor &rhs = tensorPool.at(op->b()->global_id()); std::optional<::ttnn::Tensor> bias = op->bias() ? std::make_optional(tensorPool.at(op->bias()->global_id())) : std::nullopt; - + const ::ttnn::Tensor &out = tensorPool.at(op->out()->global_id()); DEBUG_ASSERT(lhs.is_allocated()); DEBUG_ASSERT(rhs.is_allocated()); DEBUG_ASSERT(!bias || bias->is_allocated()); + DEBUG_ASSERT(out.is_allocated()); - ::ttnn::DataType outputDataType = utils::getDataType(op->out()); - ::tt::tt_metal::MemoryConfig outputMemoryConfig = - ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); - - const std::optional memoryConfig = - std::make_optional(outputMemoryConfig); - - const std::optional dtype = - std::make_optional(outputDataType); - - ::ttnn::Tensor out = ::ttnn::linear( - lhs, rhs, bias, /*transposeA*/ false, /*transposeB*/ false, memoryConfig, - dtype, /*programConfig*/ std::nullopt, /*activation*/ std::nullopt, - /*computeKernelConfig*/ std::nullopt, /*coreGrid*/ std::nullopt); - - tensorPool.insert_or_assign(op->out()->global_id(), out); + ::ttnn::linear(lhs, rhs, bias, op->transpose_a(), op->transpose_b(), + /*memory_config=*/std::nullopt, /*dtype=*/std::nullopt, + /*program_config=*/std::nullopt, /*activation=*/std::nullopt, + /*compute_kernel_config=*/std::nullopt, + /*core_grid=*/std::nullopt, /*output_tile=*/std::nullopt, out); } } // namespace tt::runtime::ttnn::operations::matmul From 4c2965baea2e59ea6718e652e198f59cd071e642 Mon Sep 17 00:00:00 2001 From: Aleksandar Zecevic Date: Fri, 10 Jan 2025 15:57:42 +0000 Subject: [PATCH 6/9] Linear tests with transpose --- .../TTIR/linear/linear_tests_negative.mlir | 50 +++++++++++++++++++ .../TTNN/linear/linear_tests_positive.mlir | 44 ++++++++++++++++ 2 files changed, 94 insertions(+) diff --git a/test/ttmlir/Dialect/TTIR/linear/linear_tests_negative.mlir b/test/ttmlir/Dialect/TTIR/linear/linear_tests_negative.mlir index a67d9bef0..499ab8325 100644 --- a/test/ttmlir/Dialect/TTIR/linear/linear_tests_negative.mlir +++ b/test/ttmlir/Dialect/TTIR/linear/linear_tests_negative.mlir @@ -93,6 +93,26 @@ module { } } +// ----- +module { + func.func @linear_negative_2d_transpose_2d_inner_dimension_mismatch(%arg0: tensor<128x64xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<128x128xbf16> { + // CHECK: error: 'ttir.linear' op Input A[-1](128) and B[-2](64) must have matching inner dimensions + %0 = tensor.empty() : tensor<128x128xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{transpose_a = true}> : (tensor<128x64xbf16>, tensor<64x128xbf16>, tensor<128x128xbf16>) -> tensor<128x128xbf16> + return %1 : tensor<128x128xbf16> + } +} + +// ----- +module { + func.func @linear_negative_2d_2d_transpose_inner_dimension_mismatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64x64xbf16> { + // CHECK: error: 'ttir.linear' op Input A[-1](128) and B[-2](64) must have matching inner dimensions + %0 = tensor.empty() : tensor<64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{transpose_b = true}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %1 : tensor<64x64xbf16> + } +} + // ----- module { func.func @linear_negative_nd_nd_inner_dimension_mismatch(%arg0: tensor<7x64x128xbf16>, %arg1: tensor<1x64x128xbf16>) -> tensor<7x64x64xbf16> { @@ -134,6 +154,16 @@ module { } } +// ----- +module { + func.func @linear_negative_nd_nd_transpose_bias_broadcast_incomatible(%arg0: tensor<3x64x128xbf16>, %arg1: tensor<64x128xbf16>, %bias: tensor<2x64x64xbf16>) -> tensor<3x64x64xbf16> { + // CHECK: error: 'ttir.linear' op Bias shape(2,64,64) is not broadcast compatible with the matmul output shape(3,64,64) + %0 = tensor.empty() : tensor<3x64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{transpose_b = true}> : (tensor<3x64x128xbf16>, tensor<64x128xbf16>, tensor<2x64x64xbf16>, tensor<3x64x64xbf16>) -> tensor<3x64x64xbf16> + return %1 : tensor<3x64x64xbf16> + } +} + // Bias shape mismatch tests. // ----- module { @@ -175,3 +205,23 @@ module { return %1 : tensor<64x128xbf16> } } + +// ----- +module { + func.func @linear_negative_2d_transpose_2d_output_shape_mismatch(%arg0: tensor<128x64xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<128x128xbf16> { + // CHECK: error: 'ttir.linear' op Output shape dimension[0](128) doesn't match the expected output shape dimension[0](64) + %0 = tensor.empty() : tensor<128x128xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{transpose_a = true}> : (tensor<128x64xbf16>, tensor<128x64xbf16>, tensor<128x128xbf16>) -> tensor<128x128xbf16> + return %1 : tensor<128x128xbf16> + } +} + +// ----- +module { + func.func @linear_negative_2d_2d_transpose_output_shape_mismatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<128x128xbf16> { + // CHECK: error: 'ttir.linear' op Output shape dimension[0](128) doesn't match the expected output shape dimension[0](64) + %0 = tensor.empty() : tensor<128x128xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{transpose_b = true}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<128x128xbf16>) -> tensor<128x128xbf16> + return %1 : tensor<128x128xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTNN/linear/linear_tests_positive.mlir b/test/ttmlir/Dialect/TTNN/linear/linear_tests_positive.mlir index ef0a6729e..1cbf118f6 100644 --- a/test/ttmlir/Dialect/TTNN/linear/linear_tests_positive.mlir +++ b/test/ttmlir/Dialect/TTNN/linear/linear_tests_positive.mlir @@ -212,4 +212,48 @@ module { %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) : (tensor<3x64x128xbf16>, tensor<4x3x128x32xbf16>, tensor<14x4x3x64x32xbf16>, tensor<14x4x3x64x32xbf16>) -> tensor<14x4x3x64x32xbf16> return %1 : tensor<14x4x3x64x32xbf16> } + + // Linear with transposed inputs tests. + func.func @linear_2d_tranpose_2d_bias(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>, %bias: tensor<128x128xbf16>) -> tensor<128x128xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<128x128xbf16 + %0 = tensor.empty() : tensor<128x128xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: transpose_a = true + // CHECK-SAME: transpose_b = false + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<128x128xbf16 + // CHECK-SAME: tensor<128x128xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{transpose_a = true}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<128x128xbf16>, tensor<128x128xbf16>) -> tensor<128x128xbf16> + return %1 : tensor<128x128xbf16> + } + func.func @linear_2d_2d_transpose_bias(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>, %bias: tensor<64x64xbf16>) -> tensor<64x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<64x64xbf16 + %0 = tensor.empty() : tensor<64x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: transpose_a = false + // CHECK-SAME: transpose_b = true + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<64x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{transpose_b = true}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %1 : tensor<64x64xbf16> + } + func.func @linear_2d_tranpose_2d_transpose(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<128x128xbf16>) -> tensor<128x128xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<128x128xbf16 + %0 = tensor.empty() : tensor<128x128xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: transpose_a = true + // CHECK-SAME: transpose_b = true + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<128x64xbf16 + // CHECK-SAME: tensor<128x128xbf16 + // CHECK-SAME: tensor<128x128xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{transpose_a = true, transpose_b = true}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<128x128xbf16>, tensor<128x128xbf16>) -> tensor<128x128xbf16> + return %1 : tensor<128x128xbf16> + } } From 462eb40afa27eff736de9e9675cbf3f69535e6c2 Mon Sep 17 00:00:00 2001 From: Aleksandar Zecevic Date: Fri, 10 Jan 2025 16:01:21 +0000 Subject: [PATCH 7/9] Fusing transpose into matmul transpose_input parameter --- include/ttmlir/Dialect/TTIR/IR/TTIROps.td | 2 + lib/Dialect/TTIR/IR/TTIROps.cpp | 57 +++++++++++++++++++ .../canonicalize/matmul_op_canonicalize.mlir | 36 ++++++++++++ 3 files changed, 95 insertions(+) create mode 100644 test/ttmlir/Dialect/TTIR/canonicalize/matmul_op_canonicalize.mlir diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index 04448c4f3..6b56a995b 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -1338,6 +1338,8 @@ def TTIR_MatmulOp : TTIR_DPSOp<"matmul"> { }]; let hasVerifier = 1; + + let hasCanonicalizer = 1; } // ANCHOR_END: adding_an_op_matmul_ttir diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index 0a2d792ec..9c9d2823b 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Location.h" +#include "mlir/IR/PatternMatch.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" @@ -1302,6 +1303,62 @@ ::mlir::LogicalResult mlir::tt::ttir::MatmulOp::verify() { } // ANCHOR_END: adding_an_op_matmul_ttir_verify +// If value is defined by TransposeOp with transpose dimensions +// (rank - 2, rank - 1), return the input of the TransposeOp, otherwise return +// std::nullopt. This is used for canonicalization of MatmulOp and LinearOp. +static std::optional> +getTransposeOpOperand(mlir::TypedValue value) { + auto producerOp = value.getDefiningOp(); + if (!producerOp) { + return std::nullopt; + } + + int64_t rank = value.getType().getRank(); + // TODO (azecevic): Change llvm::SmallSet comparison to direct comparison when + // TransposeOp canonicalization is merged. + if (rank < 2 || llvm::SmallSet{rank - 2, rank - 1} != + llvm::SmallSet{producerOp.getDim0(), + producerOp.getDim1()}) { + return std::nullopt; + } + + return producerOp.getInput(); +} + +// MatmulOp canonicalization +void mlir::tt::ttir::MatmulOp::getCanonicalizationPatterns( + mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) { + // matmul(transpose(a), b, transpose_a, transpose_b) -> matmul(a, b, + // !transpose_a, transpose_b) + patterns.add(+[](ttir::MatmulOp op, mlir::PatternRewriter &rewriter) { + auto inputACanonical = getTransposeOpOperand(op.getA()); + if (!inputACanonical) { + return mlir::failure(); + } + + rewriter.replaceOpWithNewOp( + op, op.getType(), *inputACanonical, op.getB(), op.getOutput(), + !op.getTransposeA(), op.getTransposeB()); + + return mlir::success(); + }); + + // matmul(a, transpose(b), transpose_a, transpose_b) -> matmul(a, b, + // transpose_a, !transpose_b) + patterns.add(+[](ttir::MatmulOp op, mlir::PatternRewriter &rewriter) { + auto inputBCanonical = getTransposeOpOperand(op.getB()); + if (!inputBCanonical) { + return mlir::failure(); + } + + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getA(), *inputBCanonical, op.getOutput(), + op.getTransposeA(), !op.getTransposeB()); + + return mlir::success(); + }); +} + //===----------------------------------------------------------------------===// // AllocOp //===----------------------------------------------------------------------===// diff --git a/test/ttmlir/Dialect/TTIR/canonicalize/matmul_op_canonicalize.mlir b/test/ttmlir/Dialect/TTIR/canonicalize/matmul_op_canonicalize.mlir new file mode 100644 index 000000000..1a480848d --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/canonicalize/matmul_op_canonicalize.mlir @@ -0,0 +1,36 @@ +// RUN: ttmlir-opt -canonicalize %s | FileCheck %s +module { + func.func @matmul_canonicalize_lhs(%arg0: tensor<64x128xbf16>) -> tensor<128x128xbf16> { + %0 = tensor.empty() : tensor<128x64xbf16> + %1 = "ttir.transpose"(%arg0, %0) <{dim0 = 0 : si32, dim1 = 1 : si32}> : (tensor<64x128xbf16>, tensor<128x64xbf16>) -> tensor<128x64xbf16> + %2 = tensor.empty() : tensor<128x128xbf16> + // CHECK-NOT: "ttir.transpose" + // CHECK: "ttir.matmul" + // CHECK-SAME: transpose_a = true + // CHECK-SAME: transpose_b = false + %3 = "ttir.matmul"(%1, %arg0, %2) : (tensor<128x64xbf16>, tensor<64x128xbf16>, tensor<128x128xbf16>) -> tensor<128x128xbf16> + return %3 : tensor<128x128xbf16> + } + func.func @matmul_canonicalize_rhs(%arg0: tensor<64x128xbf16>) -> tensor<64x64xbf16> { + %0 = tensor.empty() : tensor<128x64xbf16> + %1 = "ttir.transpose"(%arg0, %0) <{dim0 = 0 : si32, dim1 = 1 : si32}> : (tensor<64x128xbf16>, tensor<128x64xbf16>) -> tensor<128x64xbf16> + %2 = tensor.empty() : tensor<64x64xbf16> + // CHECK-NOT: "ttir.transpose" + // CHECK: "ttir.matmul" + // CHECK-SAME: transpose_a = false + // CHECK-SAME: transpose_b = true + %3 = "ttir.matmul"(%arg0, %1, %2) : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %3 : tensor<64x64xbf16> + } + func.func @matmul_double_transpose_lhs(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64x64xbf16> { + %0 = tensor.empty() : tensor<128x64xbf16> + %1 = "ttir.transpose"(%arg0, %0) <{dim0 = 0 : si32, dim1 = 1 : si32}> : (tensor<64x128xbf16>, tensor<128x64xbf16>) -> tensor<128x64xbf16> + %2 = tensor.empty() : tensor<64x64xbf16> + // CHECK-NOT: "ttir.transpose" + // CHECK: "ttir.matmul" + // CHECK-SAME: transpose_a = false + // CHECK-SAME: transpose_b = false + %3 = "ttir.matmul"(%1, %arg1, %2) <{transpose_a = true}> : (tensor<128x64xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %3 : tensor<64x64xbf16> + } +} From 8ac3bbd7b722de665d37e61b2936171566c66fdc Mon Sep 17 00:00:00 2001 From: Aleksandar Zecevic Date: Fri, 10 Jan 2025 16:15:27 +0000 Subject: [PATCH 8/9] Fusing transpose into linear transpose_input parameter --- include/ttmlir/Dialect/TTIR/IR/TTIROps.td | 2 + lib/Dialect/TTIR/IR/TTIROps.cpp | 78 +++++++++++++------ .../canonicalize/linear_op_canonicalize.mlir | 38 +++++++++ 3 files changed, 96 insertions(+), 22 deletions(-) create mode 100644 test/ttmlir/Dialect/TTIR/canonicalize/linear_op_canonicalize.mlir diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index 6b56a995b..fd74e2556 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -1316,6 +1316,8 @@ def TTIR_LinearOp : TTIR_DPSOp<"linear"> { }]; let hasVerifier = 1; + + let hasCanonicalizer = 1; } // ANCHOR: adding_an_op_matmul_ttir diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index 9c9d2823b..750fe05cb 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -1172,6 +1172,62 @@ ::mlir::LogicalResult mlir::tt::ttir::LinearOp::verify() { return success(); } +// If value is defined by TransposeOp with transpose dimensions +// (rank - 2, rank - 1), return the input of the TransposeOp, otherwise return +// std::nullopt. This is used for canonicalization of MatmulOp and LinearOp. +static std::optional> +getTransposeOpOperand(mlir::TypedValue value) { + auto producerOp = value.getDefiningOp(); + if (!producerOp) { + return std::nullopt; + } + + int64_t rank = value.getType().getRank(); + // TODO (azecevic): Change llvm::SmallSet comparison to direct comparison when + // TransposeOp canonicalization is merged. + if (rank < 2 || llvm::SmallSet{rank - 2, rank - 1} != + llvm::SmallSet{producerOp.getDim0(), + producerOp.getDim1()}) { + return std::nullopt; + } + + return producerOp.getInput(); +} + +// LinearOp canonicalization +void mlir::tt::ttir::LinearOp::getCanonicalizationPatterns( + mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) { + // linear(transpose(a), b, bias transpose_a, transpose_b) -> linear(a, b, + // bias, !transpose_a, transpose_b) + patterns.add(+[](ttir::LinearOp op, mlir::PatternRewriter &rewriter) { + auto inputACanonical = getTransposeOpOperand(op.getA()); + if (!inputACanonical) { + return mlir::failure(); + } + + rewriter.replaceOpWithNewOp( + op, op.getType(), *inputACanonical, op.getB(), op.getBias(), + op.getOutput(), !op.getTransposeA(), op.getTransposeB()); + + return mlir::success(); + }); + + // linear(a, transpose(b), bias transpose_a, transpose_b) -> linear(a, b, + // bias, transpose_a, !transpose_b) + patterns.add(+[](ttir::LinearOp op, mlir::PatternRewriter &rewriter) { + auto inputBCanonical = getTransposeOpOperand(op.getB()); + if (!inputBCanonical) { + return mlir::failure(); + } + + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getA(), *inputBCanonical, op.getBias(), + op.getOutput(), op.getTransposeA(), !op.getTransposeB()); + + return mlir::success(); + }); +} + //===----------------------------------------------------------------------===// // MatmulOp //===----------------------------------------------------------------------===// @@ -1303,28 +1359,6 @@ ::mlir::LogicalResult mlir::tt::ttir::MatmulOp::verify() { } // ANCHOR_END: adding_an_op_matmul_ttir_verify -// If value is defined by TransposeOp with transpose dimensions -// (rank - 2, rank - 1), return the input of the TransposeOp, otherwise return -// std::nullopt. This is used for canonicalization of MatmulOp and LinearOp. -static std::optional> -getTransposeOpOperand(mlir::TypedValue value) { - auto producerOp = value.getDefiningOp(); - if (!producerOp) { - return std::nullopt; - } - - int64_t rank = value.getType().getRank(); - // TODO (azecevic): Change llvm::SmallSet comparison to direct comparison when - // TransposeOp canonicalization is merged. - if (rank < 2 || llvm::SmallSet{rank - 2, rank - 1} != - llvm::SmallSet{producerOp.getDim0(), - producerOp.getDim1()}) { - return std::nullopt; - } - - return producerOp.getInput(); -} - // MatmulOp canonicalization void mlir::tt::ttir::MatmulOp::getCanonicalizationPatterns( mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) { diff --git a/test/ttmlir/Dialect/TTIR/canonicalize/linear_op_canonicalize.mlir b/test/ttmlir/Dialect/TTIR/canonicalize/linear_op_canonicalize.mlir new file mode 100644 index 000000000..da0ec8952 --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/canonicalize/linear_op_canonicalize.mlir @@ -0,0 +1,38 @@ +// RUN: ttmlir-opt -canonicalize %s | FileCheck %s +module { + func.func @linear_canonicalize_lhs(%arg0: tensor<64x128xbf16>, %bias: tensor<128x128xbf16>) -> tensor<128x128xbf16> { + %0 = tensor.empty() : tensor<128x64xbf16> + %1 = "ttir.transpose"(%arg0, %0) <{dim0 = 0 : si32, dim1 = 1 : si32}> : (tensor<64x128xbf16>, tensor<128x64xbf16>) -> tensor<128x64xbf16> + %2 = tensor.empty() : tensor<128x128xbf16> + // CHECK-NOT: "ttir.transpose" + // CHECK: "ttir.linear" + // CHECK-SAME: transpose_a = true + // CHECK-SAME: transpose_b = false + %3 = "ttir.linear"(%1, %arg0, %bias, %2) : (tensor<128x64xbf16>, tensor<64x128xbf16>, tensor<128x128xbf16>, tensor<128x128xbf16>) -> tensor<128x128xbf16> + return %3 : tensor<128x128xbf16> + } + + func.func @linear_canonicalize_rhs(%arg0: tensor<64x128xbf16>, %bias: tensor<64x64xbf16>) -> tensor<64x64xbf16> { + %0 = tensor.empty() : tensor<128x64xbf16> + %1 = "ttir.transpose"(%arg0, %0) <{dim0 = 0 : si32, dim1 = 1 : si32}> : (tensor<64x128xbf16>, tensor<128x64xbf16>) -> tensor<128x64xbf16> + %2 = tensor.empty() : tensor<64x64xbf16> + // CHECK-NOT: "ttir.transpose" + // CHECK: "ttir.linear" + // CHECK-SAME: transpose_a = false + // CHECK-SAME: transpose_b = true + %3 = "ttir.linear"(%arg0, %1, %bias, %2) : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %3 : tensor<64x64xbf16> + } + + func.func @linear_double_transpose_lhs(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<64x64xbf16>) -> tensor<64x64xbf16> { + %0 = tensor.empty() : tensor<128x64xbf16> + %1 = "ttir.transpose"(%arg0, %0) <{dim0 = 0 : si32, dim1 = 1 : si32}> : (tensor<64x128xbf16>, tensor<128x64xbf16>) -> tensor<128x64xbf16> + %2 = tensor.empty() : tensor<64x64xbf16> + // CHECK-NOT: "ttir.transpose" + // CHECK: "ttir.linear" + // CHECK-SAME: transpose_a = false + // CHECK-SAME: transpose_b = false + %3 = "ttir.linear"(%1, %arg1, %bias, %2) <{transpose_a = true}> : (tensor<128x64xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %3 : tensor<64x64xbf16> + } +} From d6937d96e87ac16f936dbcf00462482081b3b179 Mon Sep 17 00:00:00 2001 From: Aleksandar Zecevic Date: Fri, 10 Jan 2025 16:29:34 +0000 Subject: [PATCH 9/9] LinearOp silicon tests with transpose --- .../Silicon/TTNN/matmul/simple_matmul.mlir | 4 +-- test/ttmlir/Silicon/TTNN/simple_linear.mlir | 31 +++++++++++++++++++ 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/test/ttmlir/Silicon/TTNN/matmul/simple_matmul.mlir b/test/ttmlir/Silicon/TTNN/matmul/simple_matmul.mlir index e76b0b3d2..28942a48d 100644 --- a/test/ttmlir/Silicon/TTNN/matmul/simple_matmul.mlir +++ b/test/ttmlir/Silicon/TTNN/matmul/simple_matmul.mlir @@ -9,14 +9,14 @@ module { return %1 : tensor<64x96xbf16> } - func.func @matmul_transpose_first(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<128x128xbf16> { + func.func @matmul_transpose_lhs(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<128x128xbf16> { %0 = tensor.empty() : tensor<128x128xbf16> // CHECK: "ttnn.matmul" %1 = "ttir.matmul"(%arg0, %arg1, %0) <{transpose_a = true}>: (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<128x128xbf16>) -> tensor<128x128xbf16> return %1 : tensor<128x128xbf16> } - func.func @matmul_transpose_second(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<64x64xbf16> { + func.func @matmul_transpose_rhs(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<64x64xbf16> { %0 = tensor.empty() : tensor<64x64xbf16> // CHECK: "ttnn.matmul" %1 = "ttir.matmul"(%arg0, %arg1, %0) <{transpose_b = true}>: (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> diff --git a/test/ttmlir/Silicon/TTNN/simple_linear.mlir b/test/ttmlir/Silicon/TTNN/simple_linear.mlir index b65bf99db..3385432d9 100644 --- a/test/ttmlir/Silicon/TTNN/simple_linear.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_linear.mlir @@ -29,4 +29,35 @@ module { %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> return %1 : tensor<64x64xbf16> } + + + func.func @linear_transpose_lhs(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>, %bias: tensor<128x128xbf16>) -> tensor<128x128xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<128x128xbf16 + %0 = tensor.empty() : tensor<128x128xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: transpose_a = true + // CHECK-SAME: transpose_b = false + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<128x128xbf16 + // CHECK-SAME: tensor<128x128xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{transpose_a = true}>: (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<128x128xbf16>, tensor<128x128xbf16>) -> tensor<128x128xbf16> + return %1 : tensor<128x128xbf16> + } + + func.func @linear_transpose_second(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>, %bias: tensor<64x64xbf16>) -> tensor<64x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<64x64xbf16 + %0 = tensor.empty() : tensor<64x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: transpose_a = false + // CHECK-SAME: transpose_b = true + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<64x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{transpose_b = true}>: (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %1 : tensor<64x64xbf16> + } }