From b595a453a23e088fcadf09f7ebbc951d84da8731 Mon Sep 17 00:00:00 2001 From: Stefan Djordjevic Date: Thu, 12 Dec 2024 13:49:37 +0000 Subject: [PATCH] Adding support for data type workarounds --- include/ttmlir/Dialect/TTNN/IR/TTNNOps.td | 3 ++ .../ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h | 40 ++++++++++++---- .../Dialect/TTNN/Utils/TransformUtils.h | 3 +- include/ttmlir/Dialect/TTNN/Utils/Utils.h | 7 ++- lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp | 1 - lib/Dialect/TTNN/IR/TTNNWorkarounds.cpp | 26 ++++++++++ .../TTNN/Transforms/TTNNWorkarounds.cpp | 39 ++++++++++----- lib/Dialect/TTNN/Utils/TransformUtils.cpp | 17 ++++--- lib/Dialect/TTNN/Utils/Utils.cpp | 10 +++- .../Workarounds/embedding_workaround.mlir | 47 +++++++++++++++++++ 10 files changed, 164 insertions(+), 29 deletions(-) create mode 100644 test/ttmlir/Dialect/TTNN/Transforms/Workarounds/embedding_workaround.mlir diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 7567364d1..da76804b6 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -618,6 +618,9 @@ def TTNN_EmbeddingOp : TTNN_NamedDPSOp<"embedding"> { let extraClassDeclaration = [{ MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } + wa::TTNNOperandsWorkarounds getOperandsWorkarounds() { + return wa::TTNNOperandsWorkaroundsFactory::createEmbeddingOpOperandsWorkarounds(); + } }]; let hasVerifier = 1; diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h b/include/ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h index 9e07a0315..f102c0242 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h @@ -17,6 +17,7 @@ namespace mlir::tt::ttnn::wa { using TensorLayoutWorkaround = std::optional; using TensorBufferTypeWorkaround = std::optional; using TensorMemoryLayoutWorkaround = std::optional; +using TensorDataTypeWorkaround = std::optional; // Struct that encapsulates operand workarounds. // It contains tensor layout, tensor buffer type and tensor memory layout @@ -31,35 +32,47 @@ struct TTNNOperandWorkarounds { // Tensor memory layout workaround. TensorMemoryLayoutWorkaround tensorMemoryLayoutWorkaround; + // Tensor data format workaround. + TensorDataTypeWorkaround tensorDataTypeWorkaround; + + // Default constructor. TTNNOperandWorkarounds() = default; // Constructor that takes tensor layout, tensor buffer type and tensor memory. TTNNOperandWorkarounds( TensorLayoutWorkaround tensorLayoutWorkaround, TensorBufferTypeWorkaround tensorBufferTypeWorkaround, - TensorMemoryLayoutWorkaround tensorMemoryLayoutWorkaround) + TensorMemoryLayoutWorkaround tensorMemoryLayoutWorkaround, + TensorDataTypeWorkaround tensorDataTypeWorkaround) : tensorLayoutWorkaround(tensorLayoutWorkaround), tensorBufferTypeWorkaround(tensorBufferTypeWorkaround), - tensorMemoryLayoutWorkaround(tensorMemoryLayoutWorkaround) {} + tensorMemoryLayoutWorkaround(tensorMemoryLayoutWorkaround), + tensorDataTypeWorkaround(tensorDataTypeWorkaround) {} // Constructor that takes tensor layout workaround and sets the other // workarounds to nullopt. TTNNOperandWorkarounds(TensorLayoutWorkaround tensorLayoutWorkaround) : TTNNOperandWorkarounds(tensorLayoutWorkaround, std::nullopt, - std::nullopt) {} + std::nullopt, std::nullopt) {} // Constructor that takes tensor buffer type workaround and sets the other // workarounds to nullopt. TTNNOperandWorkarounds(TensorBufferTypeWorkaround tensorBufferTypeWorkaround) : TTNNOperandWorkarounds(std::nullopt, tensorBufferTypeWorkaround, - std::nullopt) {} + std::nullopt, std::nullopt) {} // Constructor that takes tensor memory layout workaround and sets the other // workarounds to nullopt. TTNNOperandWorkarounds( TensorMemoryLayoutWorkaround tensorMemoryLayoutWorkaround) : TTNNOperandWorkarounds(std::nullopt, std::nullopt, - tensorMemoryLayoutWorkaround) {} + tensorMemoryLayoutWorkaround, std::nullopt) {} + + // Constructor that takes tensor data type workaround and sets the other + // workarounds to nullopt. + TTNNOperandWorkarounds(TensorDataTypeWorkaround tensorDataTypeWorkaround) + : TTNNOperandWorkarounds(std::nullopt, std::nullopt, std::nullopt, + tensorDataTypeWorkaround) {} // Operand workarounds factory methods. static TTNNOperandWorkarounds createEmptyTTNNOperandWorkarounds(); @@ -68,7 +81,8 @@ struct TTNNOperandWorkarounds { bool operator==(const TTNNOperandWorkarounds &rhs) const { return tensorLayoutWorkaround == rhs.tensorLayoutWorkaround && tensorBufferTypeWorkaround == rhs.tensorBufferTypeWorkaround && - tensorMemoryLayoutWorkaround == rhs.tensorMemoryLayoutWorkaround; + tensorMemoryLayoutWorkaround == rhs.tensorMemoryLayoutWorkaround && + tensorDataTypeWorkaround == rhs.tensorDataTypeWorkaround; } // Inequality operator. @@ -79,7 +93,7 @@ struct TTNNOperandWorkarounds { // Returns true if any of the workarounds is set. bool hasAnyWorkaround() const { return tensorLayoutWorkaround || tensorBufferTypeWorkaround || - tensorMemoryLayoutWorkaround; + tensorMemoryLayoutWorkaround || tensorDataTypeWorkaround; } }; @@ -103,6 +117,9 @@ struct BufferTypeWorkaroundResult : public WorkaroundResult {}; struct MemoryLayoutWorkaroundResult : public WorkaroundResult> {}; +// Data type workaround result struct. +struct DataTypeWorkaroundResult : public WorkaroundResult {}; + // Struct that encapsulates the result of applying the workarounds. // It contains the target tensor layout, buffer type and tensor memory layout // results and a flag indicating whether the workarounds were applied. @@ -116,11 +133,15 @@ struct WorkaroundResults { // Tensor memory layout workaround result. MemoryLayoutWorkaroundResult tensorMemoryLayoutResult; + // Tensor data type workaround result. + DataTypeWorkaroundResult tensorDataTypeResult; + // Returns true if any of the workarounds were applied. bool isModified() const { return tensorLayoutResult.isModified() || tensorBufferTypeResult.isModified() || - tensorMemoryLayoutResult.isModified(); + tensorMemoryLayoutResult.isModified() || + tensorDataTypeResult.isModified(); } }; @@ -194,6 +215,9 @@ class TTNNOperandsWorkaroundsFactory { public: // Create workarounds for max_pool2d op operands. static TTNNOperandsWorkarounds createMaxPool2DOpOperandsWorkarounds(); + + // Create workarounds for embedding op operands. + static TTNNOperandsWorkarounds createEmbeddingOpOperandsWorkarounds(); }; } // namespace mlir::tt::ttnn::wa diff --git a/include/ttmlir/Dialect/TTNN/Utils/TransformUtils.h b/include/ttmlir/Dialect/TTNN/Utils/TransformUtils.h index f491f2ed5..b3df649ff 100644 --- a/include/ttmlir/Dialect/TTNN/Utils/TransformUtils.h +++ b/include/ttmlir/Dialect/TTNN/Utils/TransformUtils.h @@ -23,7 +23,8 @@ createToLayoutOp(mlir::Operation *op, mlir::TypedValue inputValue, PatternRewriter &rewriter, Layout targetTensorLayout, BufferType targetTensorBufferType, - std::optional targetTensorMemoryLayout); + std::optional targetTensorMemoryLayout, + DataType targetTensorDataType); } // namespace mlir::tt::ttnn::utils #endif diff --git a/include/ttmlir/Dialect/TTNN/Utils/Utils.h b/include/ttmlir/Dialect/TTNN/Utils/Utils.h index 71dc98b7f..e4b854263 100644 --- a/include/ttmlir/Dialect/TTNN/Utils/Utils.h +++ b/include/ttmlir/Dialect/TTNN/Utils/Utils.h @@ -39,11 +39,16 @@ toTTMemorySpace(const mlir::tt::ttnn::BufferType bufferType); mlir::Type createRowMajorTypeFromDtype(::mlir::MLIRContext *context, DataType dtype); -// Helper method to create a RankedTensorType with the given encoding +// Helper method to create a RankedTensorType with the given encoding. RankedTensorType createRankedTensorTypeWithEncoding(RankedTensorType tensorType, ttnn::TTNNLayoutAttr encoding); +// Helper method to create a RankedTensorType with the given element type. +RankedTensorType +createRankedTensorTypeWithElementType(RankedTensorType tensorType, + Type elementType); + // Return the L1 memory usage of the output tensor of the given op. // Used within L1 interleaved policies. // diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index a2b63a1bc..7ad62397c 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -250,7 +250,6 @@ class ToLayoutOpConversionPattern // operands. for (mlir::Operation *user : op.getResult().getUsers()) { if (isa(user) || isa(user) || - isa(user) || (isa(user) && (user->getOperand(0) == op || user->getOperand(1) == op))) { return true; diff --git a/lib/Dialect/TTNN/IR/TTNNWorkarounds.cpp b/lib/Dialect/TTNN/IR/TTNNWorkarounds.cpp index 848d80a3e..8a9786ca6 100644 --- a/lib/Dialect/TTNN/IR/TTNNWorkarounds.cpp +++ b/lib/Dialect/TTNN/IR/TTNNWorkarounds.cpp @@ -54,6 +54,11 @@ WorkaroundResults applyWorkarounds(const TTNNOperandWorkarounds &workaround, results.tensorMemoryLayoutResult.previousValue = inputLayoutAttr.getMemLayoutOpt(); + results.tensorDataTypeResult.targetValue = + workaround.tensorDataTypeWorkaround.value_or( + inputLayoutAttr.getDataType()); + results.tensorDataTypeResult.previousValue = inputLayoutAttr.getDataType(); + return results; } @@ -87,4 +92,25 @@ TTNNOperandsWorkaroundsFactory::createMaxPool2DOpOperandsWorkarounds() { .addInputOperandWorkaround(rowMajorLayoutWorkaround) .addOutputOperandWorkaround(rowMajorLayoutWorkaround); } + +// Factory method to create a set of workarounds for embedding operation +// operands. The embedding operation expects the input to be in row-major layout +// and the weight operand to use the bf16 data type. Since the output of the +// embedding operation follows the same format as the weight operand, the same +// workaround is applied to the output operand.Metal issue for input operand +// workaround: https://github.com/tenstorrent/tt-metal/issues/14915 Metal issue +// for weight operand workaround: to be added +TTNNOperandsWorkarounds +TTNNOperandsWorkaroundsFactory::createEmbeddingOpOperandsWorkarounds() { + // Create input and weight workarounds. + TTNNOperandWorkarounds inputWorkaround = + TTNNOperandWorkarounds(Layout::RowMajor); + TTNNOperandWorkarounds weightWorkaround = + TTNNOperandWorkarounds(DataType::BFloat16); + return TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds(0, 0) + .addInputOperandWorkaround(inputWorkaround) + .addInputOperandWorkaround(weightWorkaround) + .addInputOperandWorkaround(weightWorkaround) + .addOutputOperandWorkaround(weightWorkaround); +} } // namespace mlir::tt::ttnn::wa diff --git a/lib/Dialect/TTNN/Transforms/TTNNWorkarounds.cpp b/lib/Dialect/TTNN/Transforms/TTNNWorkarounds.cpp index 2c0c48dbc..6e21df7a1 100644 --- a/lib/Dialect/TTNN/Transforms/TTNNWorkarounds.cpp +++ b/lib/Dialect/TTNN/Transforms/TTNNWorkarounds.cpp @@ -23,9 +23,9 @@ #include "mlir/Support/LLVM.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/APFloat.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Support/raw_ostream.h" #include #include @@ -57,7 +57,8 @@ static void revertOutputLayout(wa::TTNNWorkaroundInterface &op, op.getOperation(), newOpResult, rewriter, workaroundResults.tensorLayoutResult.previousValue, workaroundResults.tensorBufferTypeResult.previousValue, - workaroundResults.tensorMemoryLayoutResult.previousValue); + workaroundResults.tensorMemoryLayoutResult.previousValue, + workaroundResults.tensorDataTypeResult.previousValue); // Replace the new output result with the casted output result. rewriter.replaceUsesWithIf( @@ -93,7 +94,8 @@ static bool workaroundInputOperand( op.getOperation(), inputValue, rewriter, inputWorkaroundResults.tensorLayoutResult.targetValue, inputWorkaroundResults.tensorBufferTypeResult.targetValue, - inputWorkaroundResults.tensorMemoryLayoutResult.targetValue); + inputWorkaroundResults.tensorMemoryLayoutResult.targetValue, + inputWorkaroundResults.tensorDataTypeResult.targetValue); // Insert to layout op between the current op and the input operand // to convert the input operand to the desired tensor layout, buffer type. @@ -136,7 +138,7 @@ workaroundOutputOperand(mlir::TypedValue opResult, Type elementType = utils::getElementType( rewriter.getContext(), outputWorkaroundResults.tensorLayoutResult.targetValue, - opResultLayoutAttr.getDataType()); + outputWorkaroundResults.tensorDataTypeResult.targetValue); // Get the input operand type. RankedTensorType opResultType = @@ -150,16 +152,24 @@ workaroundOutputOperand(mlir::TypedValue opResult, *outputWorkaroundResults.tensorMemoryLayoutResult.targetValue) : nullptr; - // Create the new output result type with the updated tensor layout, buffer - // type and memory layout. + // Create the new output layout attribute with the updated tensor layout, + // buffer type, memory layout and data type. + TTNNLayoutAttr newOutputLayoutAttr = + opResultLayoutAttr.withElementType(rewriter.getContext(), elementType) + .withBufferType( + rewriter.getContext(), + outputWorkaroundResults.tensorBufferTypeResult.targetValue) + .withMemoryLayout(rewriter.getContext(), outputMemLayoutAttr); + + // Create the new output result type with the updated data type and layout. RankedTensorType newOutputResultType = ttnn::utils::createRankedTensorTypeWithEncoding( - opResultType, - opResultLayoutAttr.withElementType(rewriter.getContext(), elementType) - .withBufferType( + ttnn::utils::createRankedTensorTypeWithElementType( + opResultType, + ttnn::utils::createRowMajorTypeFromDtype( rewriter.getContext(), - outputWorkaroundResults.tensorBufferTypeResult.targetValue) - .withMemoryLayout(rewriter.getContext(), outputMemLayoutAttr)); + outputWorkaroundResults.tensorDataTypeResult.targetValue)), + newOutputLayoutAttr); // Update the type of result with applied workarounds. rewriter.modifyOpInPlace(op, [&]() { @@ -175,6 +185,13 @@ workaroundOutputOperand(mlir::TypedValue opResult, op->setAttr("layout", updatedLayoutAttr); } + if (outputWorkaroundResults.tensorDataTypeResult.isModified() && + op->getAttrDictionary().get("dtype")) { + DataTypeAttr updatedDataTypeAttr = rewriter.getAttr( + outputWorkaroundResults.tensorDataTypeResult.targetValue); + op->setAttr("dtype", updatedDataTypeAttr); + } + if ((outputWorkaroundResults.tensorBufferTypeResult.isModified() || outputWorkaroundResults.tensorMemoryLayoutResult.isModified()) && op->getAttrDictionary().get("memory_config")) { diff --git a/lib/Dialect/TTNN/Utils/TransformUtils.cpp b/lib/Dialect/TTNN/Utils/TransformUtils.cpp index ed4b318ec..11f0f68b1 100644 --- a/lib/Dialect/TTNN/Utils/TransformUtils.cpp +++ b/lib/Dialect/TTNN/Utils/TransformUtils.cpp @@ -39,12 +39,13 @@ ToLayoutOp createToLayoutOp(Operation *op, mlir::TypedValue inputValue, PatternRewriter &rewriter, Layout targetTensorLayout, BufferType targetTensorBufferType, - std::optional targetTensorMemoryLayout) { + std::optional targetTensorMemoryLayout, + DataType targetTensorDataType) { TTNNLayoutAttr inputLayoutAttr = getLayoutAttrFromTensor(inputValue); // Create element type based on tensor layout. Type elementType = getElementType(rewriter.getContext(), targetTensorLayout, - inputLayoutAttr.getDataType()); + targetTensorDataType); // Create tensor memory layout attribute. ttnn::TensorMemoryLayoutAttr outputMemLayoutAttr = @@ -63,10 +64,14 @@ createToLayoutOp(Operation *op, mlir::TypedValue inputValue, .withBufferType(rewriter.getContext(), targetTensorBufferType) .withMemoryLayout(rewriter.getContext(), outputMemLayoutAttr); - // Create the output result type with the new encoding. + // Create the output result type with the new data type and encoding. RankedTensorType toLayoutOpResultType = - ttnn::utils::createRankedTensorTypeWithEncoding(inputToLayoutOpType, - toLayoutOpResultEncoding); + ttnn::utils::createRankedTensorTypeWithEncoding( + ttnn::utils::createRankedTensorTypeWithElementType( + inputToLayoutOpType, + utils::createRowMajorTypeFromDtype(rewriter.getContext(), + targetTensorDataType)), + toLayoutOpResultEncoding); // Create the output memory config attribute. ttnn::MemoryConfigAttr outputMemConfigAttr = ttnn::MemoryConfigAttr::get( @@ -88,7 +93,7 @@ createToLayoutOp(Operation *op, mlir::TypedValue inputValue, return rewriter.create( op->getLoc(), toLayoutOpResultType, inputValue, LayoutAttr::get(rewriter.getContext(), targetTensorLayout), - DataTypeAttr::get(rewriter.getContext(), inputLayoutAttr.getDataType()), + DataTypeAttr::get(rewriter.getContext(), targetTensorDataType), outputMemConfigAttr, deviceValue); } } // namespace mlir::tt::ttnn::utils diff --git a/lib/Dialect/TTNN/Utils/Utils.cpp b/lib/Dialect/TTNN/Utils/Utils.cpp index 90091b1ff..cb4646ffa 100644 --- a/lib/Dialect/TTNN/Utils/Utils.cpp +++ b/lib/Dialect/TTNN/Utils/Utils.cpp @@ -112,7 +112,7 @@ Type createRowMajorTypeFromDtype(::mlir::MLIRContext *context, DataType dtype) { } } -// Helper method to create a RankedTensorType with the given encoding +// Helper method to create a RankedTensorType with the given encoding. RankedTensorType createRankedTensorTypeWithEncoding(RankedTensorType tensorType, ttnn::TTNNLayoutAttr encoding) { @@ -120,6 +120,14 @@ createRankedTensorTypeWithEncoding(RankedTensorType tensorType, tensorType.getElementType(), encoding); } +// Helper method to create a RankedTensorType with the given element type. +RankedTensorType +createRankedTensorTypeWithElementType(RankedTensorType tensorType, + Type elementType) { + return RankedTensorType::get(tensorType.getShape(), elementType, + tensorType.getEncoding()); +} + uint64_t getOpOutputL1Usage(TTNNLayoutAttr opLayout) { // In case the opLayout is not in L1 memory space, L1 memory usage is 0. // diff --git a/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/embedding_workaround.mlir b/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/embedding_workaround.mlir new file mode 100644 index 000000000..1aa9f55ac --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/embedding_workaround.mlir @@ -0,0 +1,47 @@ +// RUN: ttmlir-opt --ttnn-workaround --canonicalize %s | FileCheck %s +#device = #tt.device (0, d0, d1)>, l1Map = (d0, d1)[s0, s1] -> (0, d0 floordiv s0, d1 floordiv s1, (d0 mod s0) * s1 + d1 mod s1), dramMap = (d0, d1)[s0, s1] -> (0, 0, ((((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 8192) mod 12, (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 98304 + (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) mod 8192), meshShape = , chipIds = [0]> +#dram = #ttnn.buffer_type +#system_desc = #tt.system_desc<[{role = host, target_triple = "x86_64-pc-linux-gnu"}], [{arch = , grid = 8x8, l1_size = 1499136, num_dram_channels = 12, dram_channel_size = 1073741824, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32, l1_unreserved_base = 1024, erisc_l1_unreserved_base = 1024, dram_unreserved_base = 1024, dram_unreserved_end = 1073741824, physical_cores = {worker = [ 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 1x0, 1x1, 1x2, 1x3, 1x4, 1x5, 1x6, 1x7, 2x0, 2x1, 2x2, 2x3, 2x4, 2x5, 2x6, 2x7, 3x0, 3x1, 3x2, 3x3, 3x4, 3x5, 3x6, 3x7, 4x0, 4x1, 4x2, 4x3, 4x4, 4x5, 4x6, 4x7, 5x0, 5x1, 5x2, 5x3, 5x4, 5x5, 5x6, 5x7, 6x0, 6x1, 6x2, 6x3, 6x4, 6x5, 6x6, 6x7, 7x0, 7x1, 7x2, 7x3, 7x4, 7x5, 7x6, 7x7] dram = [ 8x0, 9x0, 10x0, 8x1, 9x1, 10x1, 8x2, 9x2, 10x2, 8x3, 9x3, 10x3]}, supported_data_types = [, , , , , , , , , , , ], supported_tile_sizes = [ 4x16, 16x16, 32x16, 4x32, 16x32, 32x32], num_cbs = 32}], [0], [3 : i32], [ 0x0x0x0]> +#system_memory = #ttnn.buffer_type +#ttnn_layout = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<32x32xf32, #system_memory>> +#ttnn_layout1 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<512x128xf32, #system_memory>> +#ttnn_layout2 = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <1x1>, memref<1024x128xf32, #system_memory>> +#ttnn_layout3 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<1x1x!tt.tile<32x32, f32>, #dram>, > +#ttnn_layout4 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<16x4x!tt.tile<32x32, f32>, #dram>, > +#ttnn_layout5 = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <1x1>, memref<32x4x!tt.tile<32x32, f32>, #dram>, > +module attributes {tt.device = #device, tt.system_desc = #system_desc} { + func.func @forward(%arg0: tensor<32x32xf32, #ttnn_layout>, %arg1: tensor<512x128xf32, #ttnn_layout1>) -> tensor<32x32x128xf32, #ttnn_layout2> { + %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> + // CHECK: %[[DEVICE_OP:.*]] = "ttnn.get_device" + %1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<1x1>>, >}> : (tensor<32x32xf32, #ttnn_layout>, !tt.device<#device>) -> tensor<32x32xf32, #ttnn_layout3> + %2 = "ttnn.to_layout"(%arg1, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<16x4>>, >}> : (tensor<512x128xf32, #ttnn_layout1>, !tt.device<#device>) -> tensor<512x128xf32, #ttnn_layout4> + %3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<32x4>>, >, shape = #ttnn.shape<32x32x128>}> : (!tt.device<#device>) -> tensor<32x32x128xf32, #ttnn_layout5> + // CHECK: %[[EMPTY_OP:.*]] = "ttnn.empty"(%[[DEVICE_OP]]) + // Check that the input operand is transformed into the row major layout. + // CHECK-NEXT: %[[TO_LAYOUT_INPUT:.*]] = "ttnn.to_layout" + // CHECK-SAME: layout = #ttnn.layout + // CHECK-SAME: memory_config = #ttnn.memory_config<#dram, <<32x32>>, > + // CHECK-SAME: -> tensor<32x32xf32 + // Check that the data type of the weight operand is transformed in bf16. + // CHECK-NEXT: %[[TO_LAYOUT_WEIGHTS:.*]] = "ttnn.to_layout" + // CHECK-SAME: dtype = #tt.supportedDataTypes + // CHECK-SAME: layout = #ttnn.layout + // CHECK-SAME: memory_config = #ttnn.memory_config<#dram, <<16x4>>, > + // CHECK-SAME: -> tensor<512x128xbf16 + // Check that the data type of the output operand is transformed in bf16. + // CHECK-NEXT: %[[TO_LAYOUT_OUTPUT_DPS:.*]] = "ttnn.to_layout"(%[[EMPTY_OP]], %[[DEVICE_OP]]) + // CHECK-SAME: dtype = #tt.supportedDataTypes + // CHECK-SAME: layout = #ttnn.layout + // CHECK-SAME: memory_config = #ttnn.memory_config<#dram, <<32x4>>, > + // CHECK-SAME: -> tensor<32x32x128xbf16 + %4 = "ttnn.embedding"(%1, %2, %3) : (tensor<32x32xf32, #ttnn_layout3>, tensor<512x128xf32, #ttnn_layout4>, tensor<32x32x128xf32, #ttnn_layout5>) -> tensor<32x32x128xf32, #ttnn_layout5> + // CHECK-NEXT: %[[EMBEDDING_OP:.*]] = "ttnn.embedding"(%[[TO_LAYOUT_INPUT]], %[[TO_LAYOUT_WEIGHTS]], %[[TO_LAYOUT_OUTPUT_DPS]]) + // Check that the output operand is transformed back into the f32 data type. + // CHECK-NEXT: "ttnn.to_layout"(%[[EMBEDDING_OP]]) + // CHECK-SAME: dtype = #tt.supportedDataTypes + // CHECK-SAME: layout = #ttnn.layout + // CHECK-SAME: memory_config = #ttnn.memory_config<#system_memory, <<1024x128>>> + %5 = "ttnn.to_layout"(%4) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#system_memory, <<1024x128>>>}> : (tensor<32x32x128xf32, #ttnn_layout5>) -> tensor<32x32x128xf32, #ttnn_layout2> + return %5 : tensor<32x32x128xf32, #ttnn_layout2> + } +}