From 4639695f87d0adaf7c44de2e28e610656a951d3c Mon Sep 17 00:00:00 2001 From: Nikola Obradovic Date: Thu, 31 Oct 2024 13:46:14 +0000 Subject: [PATCH] [Optimizer] Fix mem reconfig/reshard. --- .../ttmlir/Dialect/TT/Utils/OverrideParams.h | 2 +- .../Dialect/TTNN/Pipelines/TTNNPipelines.h | 10 +- .../Dialect/TTNN/Transforms/Optimizer.h | 4 +- .../TTNN/Analysis/DFShardingPolicy.cpp | 5 +- lib/Dialect/TTNN/Transforms/Optimizer.cpp | 179 ++++++++---------- ...le_add_with_loc_input_layout_override.mlir | 10 +- .../TTNN/test_override_reshard_edges.mlir | 32 ++-- 7 files changed, 111 insertions(+), 131 deletions(-) diff --git a/include/ttmlir/Dialect/TT/Utils/OverrideParams.h b/include/ttmlir/Dialect/TT/Utils/OverrideParams.h index b80f73940..ed7967c07 100644 --- a/include/ttmlir/Dialect/TT/Utils/OverrideParams.h +++ b/include/ttmlir/Dialect/TT/Utils/OverrideParams.h @@ -63,7 +63,7 @@ struct InputLayoutOverrideParser static void print(llvm::raw_ostream &os, const llvm::StringMap &value) { - os << "insert-reshard="; + os << "insert-memreconfig="; size_t count = 0; for (const auto &entry : value) { os << entry.getKey() << "="; diff --git a/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h b/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h index 4eb64291a..9988bbcc1 100644 --- a/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h +++ b/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h @@ -31,17 +31,17 @@ struct TTIRToTTNNBackendPipelineOptions // // Full Example: "op1=0,op2=0:1" // - // This will insert one TTIR_ToLayoutOps responsible for resharding the op1's - // first operand and two TTIR_ToLayoutOps responsible for resharding the op2's - // first and second operand. + // This will insert one memory reconfig op responsible for resharding the + // op1's first operand and two memory reconfig ops responsible for resharding + // the op2's first and second operand. // // Note: This option is only valid if optimizerPassEnabled is true. // Option, InputLayoutOverrideParser> overrideInputLayout{ - *this, "insert-reshard", + *this, "insert-memreconfig", llvm::cl::desc( - "Manually insert TTIR_ToLayoutOp for specific op's operand."), + "Manually insert memory reconfig op for specific op's operand."), llvm::cl::init(llvm::StringMap())}; // Option to override output layout for specific ops. diff --git a/include/ttmlir/Dialect/TTNN/Transforms/Optimizer.h b/include/ttmlir/Dialect/TTNN/Transforms/Optimizer.h index 013b83960..064495f31 100644 --- a/include/ttmlir/Dialect/TTNN/Transforms/Optimizer.h +++ b/include/ttmlir/Dialect/TTNN/Transforms/Optimizer.h @@ -102,9 +102,9 @@ class TTNNOptimizerBase : public ::mlir::OperationPass<::mlir::ModuleOp> { ::mlir::Pass::Option, mlir::tt::InputLayoutOverrideParser> overrideInputLayout{ - *this, "insert-reshard", + *this, "insert-memreconfig", ::llvm::cl::desc( - "Manually insert reshard for specific op's operand."), + "Manually insert memory reconfig op for specific op's operand."), ::llvm::cl::init(llvm::StringMap())}; ::mlir::Pass::Option, mlir::tt::OutputLayoutOverrideParser> diff --git a/lib/Dialect/TTNN/Analysis/DFShardingPolicy.cpp b/lib/Dialect/TTNN/Analysis/DFShardingPolicy.cpp index 6136ce277..7a7470ad3 100644 --- a/lib/Dialect/TTNN/Analysis/DFShardingPolicy.cpp +++ b/lib/Dialect/TTNN/Analysis/DFShardingPolicy.cpp @@ -35,7 +35,7 @@ void DFShardingPolicy::run( // if (l1ChainConfigs->back().isEmpty()) { for (auto *op : scheduleableOps) { - if (isa(op)) { + if (isa(op)) { currentOp = op; break; } @@ -52,8 +52,7 @@ void DFShardingPolicy::run( // Skip starting sharding chain if currentOp is a memory management op. // - if (l1ChainConfigs->back().isEmpty() && - isa(currentOp)) { + if (l1ChainConfigs->back().isEmpty() && isa(currentOp)) { currentOp = nullptr; continue; } diff --git a/lib/Dialect/TTNN/Transforms/Optimizer.cpp b/lib/Dialect/TTNN/Transforms/Optimizer.cpp index 6c30a98ee..c7499964f 100644 --- a/lib/Dialect/TTNN/Transforms/Optimizer.cpp +++ b/lib/Dialect/TTNN/Transforms/Optimizer.cpp @@ -115,6 +115,12 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase { return; } + // Skip empty ops. Handled via DPS op output operand update. + // + if (isa(op)) { + return; + } + if (!isa(op->getResult(0).getType())) { return; } @@ -149,7 +155,7 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase { EmptyOp emptyOp = mlir::cast(op->getOperands().back().getDefiningOp()); - emptyOp.setMemoryConfigAttr(ttnn::MemoryConfigAttr::get( + emptyOp.setMemoryConfigAttr(MemoryConfigAttr::get( op->getContext(), TensorMemoryLayoutAttr::get(op->getContext(), tensorMemoryLayout), @@ -159,29 +165,6 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase { ShapeAttr::get(op->getContext(), ttLayoutAttr.getMemref().getShape())))); } - // TODO (nobradovic): Other memory management ops after lowering to - // TTNN will need to be special handled as well. Depends on ttnn - // layout attr refactor and lowering. - // - else if (isa(op)) { - BufferType bufferType = - utils::toTTNNBufferType(ttLayoutAttr.getMemorySpace()); - TensorMemoryLayout tensorMemoryLayout = - utils::toTTNNTensorMemoryLayout(ttLayoutAttr.getMemLayout()); - // Update the device op with the new tensor type. - // - ttnn::ToLayoutOp toLayoutOp = llvm::cast(op); - toLayoutOp.setMemoryConfigAttr(ttnn::MemoryConfigAttr::get( - op->getContext(), - ttnn::TensorMemoryLayoutAttr::get(op->getContext(), - tensorMemoryLayout), - ttnn::BufferTypeAttr::get(op->getContext(), bufferType), - ttnn::ShardSpecAttr::get( - op->getContext(), - ttnn::ShapeAttr::get( - op->getContext(), - ttLayoutAttr.getMemref().getShape())))); - } } }); @@ -233,6 +216,19 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase { assert(overrideInputLayout.size() == overrideReshardEdges.size()); } + mlir::TypedValue + getDeviceOpValue(Operation *contextOp) { + Block *block = contextOp->getBlock(); + mlir::TypedValue deviceOpResult; + for (auto &op : block->getOperations()) { + if (GetDeviceOp deviceOp = dyn_cast(op)) { + deviceOpResult = deviceOp.getResult(); + break; + } + } + return deviceOpResult; + } + void processMemReconfigEdges(const std::unordered_set &memReconfigEdges) { // Insert memory reconfig ops here based on results of memory layout @@ -242,86 +238,75 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase { Operation *producerOp = edge.producerOp; Operation *consumerOp = edge.consumerOp; + tt::LayoutAttr consumerOpOutputLayout = mlir::cast( + mlir::cast(consumerOp->getResult(0).getType()) + .getEncoding()); + + RankedTensorType producerOpTensorType = + mlir::cast(producerOp->getResult(0).getType()); + llvm::ArrayRef producerOpTensorShape = + producerOpTensorType.getShape(); + tt::LayoutAttr producerOpLayout = + mlir::cast(producerOpTensorType.getEncoding()); + + // TODO(nobradovic): Match memory space and layout of consumer op. + // This actually needs to be properly resolved based on op type, output + // layout and other inputs. + // + RankedTensorType newTensorType = RankedTensorType::get( + producerOpTensorShape, producerOpTensorType.getElementType(), + producerOpLayout + .withElementType(consumerOp->getContext(), + consumerOpOutputLayout.getElementType()) + .withMemorySpace(consumerOp->getContext(), + consumerOpOutputLayout.getMemorySpace()) + .withMemoryLayout(consumerOp->getContext(), + consumerOpOutputLayout.getMemLayout()) + .withGrid(consumerOp->getContext(), producerOpTensorType, + consumerOpOutputLayout.getGrid())); + + BufferType outputBufferType = + utils::toTTNNBufferType(consumerOpOutputLayout.getMemorySpace()); + TensorMemoryLayout outputTensorMemoryLayout = + utils::toTTNNTensorMemoryLayout( + consumerOpOutputLayout.getMemLayout()); + MemRefType outputMemref = consumerOpOutputLayout.getMemref(); + + MemoryConfigAttr outputMemConfigAttr = MemoryConfigAttr::get( + consumerOp->getContext(), + TensorMemoryLayoutAttr::get(consumerOp->getContext(), + outputTensorMemoryLayout), + BufferTypeAttr::get(consumerOp->getContext(), outputBufferType), + ShardSpecAttr::get(consumerOp->getContext(), + ShapeAttr::get(consumerOp->getContext(), + outputMemref.getShape()))); + // If producerOp is a toLayoutOp, adjust its output layout(update // inplace) to reflect consumerOp's output layout. If producerOp is not a // toLayoutOp, insert a toLayoutOp in between producerOp // and consumerOp. // - if (isa(producerOp)) { - ttnn::ToLayoutOp toLayoutOp = llvm::cast(producerOp); - tt::LayoutAttr consumerOpOutputLayout = mlir::cast( - mlir::cast(consumerOp->getResult(0).getType()) - .getEncoding()); - - RankedTensorType toLayoutOpTensorType = - mlir::cast(toLayoutOp.getResult().getType()); - llvm::ArrayRef toLayoutOpTensorShape = - toLayoutOpTensorType.getShape(); - tt::LayoutAttr toLayoutOpLayout = - mlir::cast(toLayoutOpTensorType.getEncoding()); - - // TODO(nobradovic): Match memory space and layout of consumer op. This - // actually needs to be properly resolved based on op type, output - // layout and other inputs. - // - RankedTensorType newTensorType = RankedTensorType::get( - toLayoutOpTensorShape, toLayoutOpTensorType.getElementType(), - toLayoutOpLayout - .withElementType(toLayoutOp->getContext(), - consumerOpOutputLayout.getElementType()) - .withMemorySpace(toLayoutOp.getContext(), - consumerOpOutputLayout.getMemorySpace()) - .withMemoryLayout(toLayoutOp.getContext(), - consumerOpOutputLayout.getMemLayout()) - .withGrid(toLayoutOp.getContext(), toLayoutOpTensorType, - consumerOpOutputLayout.getGrid())); - + if (isa(producerOp)) { + ToLayoutOp toLayoutOp = llvm::cast(producerOp); + toLayoutOp.setMemoryConfigAttr(outputMemConfigAttr); toLayoutOp.getResult().setType(newTensorType); + } else { + OpBuilder builder(consumerOp); + + DataTypeAttr outputDataType = + DataTypeAttr::get(consumerOp->getContext(), + utils::getDataTypeFromMemRef(outputMemref)); + Layout outputLayoutEnum = utils::getLayoutFromMemRef(outputMemref); + LayoutAttr outputLayout = + LayoutAttr::get(consumerOp->getContext(), outputLayoutEnum); + Operation *memoryReconfigOp = builder.create( + consumerOp->getLoc(), newTensorType, producerOp->getResult(0), + outputLayout, outputDataType, outputMemConfigAttr, + getDeviceOpValue(consumerOp)); + + consumerOp->setOperand(edge.operandIndex, + memoryReconfigOp->getResult(0)); } - // TODO (nobradovic): Memory layout reconfig needs to be reimplemented for - // TTNN dialect. - // else { - // tt::LayoutAttr consumerOpOutputLayout = mlir::cast( - // mlir::cast(consumerOp->getResult(0).getType()) - // .getEncoding()); - - // RankedTensorType producerOpTensorType = - // mlir::cast(producerOp->getResult(0).getType()); - // llvm::ArrayRef producerOpTensorShape = - // producerOpTensorType.getShape(); - // tt::LayoutAttr producerOpLayout = - // mlir::cast(producerOpTensorType.getEncoding()); - - // // TODO(nobradovic): Match memory space and layout of consumer op. - // This - // // actually needs to be properly resolved based on op type, output - // // layout and other inputs. - // // - // RankedTensorType newTensorType = RankedTensorType::get( - // producerOpTensorShape, producerOpTensorType.getElementType(), - // producerOpLayout - // .withElementType(consumerOp->getContext(), - // consumerOpOutputLayout.getElementType()) - // .withMemorySpace(consumerOp->getContext(), - // consumerOpOutputLayout.getMemorySpace()) - // .withMemoryLayout(consumerOp->getContext(), - // consumerOpOutputLayout.getMemLayout()) - // .withGrid(consumerOp->getContext(), producerOpTensorType, - // consumerOpOutputLayout.getGrid())); - - // OpBuilder builder(consumerOp); - - // mlir::tensor::EmptyOp emptyOp = builder.create( - // consumerOp->getLoc(), producerOpTensorShape, - // producerOpTensorType.getElementType(), - // mlir::cast(newTensorType.getEncoding())); - - // Operation *toLayoutOp = builder.create( - // consumerOp->getLoc(), newTensorType, producerOp->getResult(0), - // emptyOp); - - // consumerOp->setOperand(edge.operandIndex, toLayoutOp->getResult(0)); - // } } } }; diff --git a/test/ttmlir/Dialect/TTNN/multiple_add_with_loc_input_layout_override.mlir b/test/ttmlir/Dialect/TTNN/multiple_add_with_loc_input_layout_override.mlir index 96d205192..fb2eaa465 100644 --- a/test/ttmlir/Dialect/TTNN/multiple_add_with_loc_input_layout_override.mlir +++ b/test/ttmlir/Dialect/TTNN/multiple_add_with_loc_input_layout_override.mlir @@ -1,16 +1,16 @@ -// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-optimizer=true memory-layout-analysis-enabled=true memreconfig-enabled=true insert-reshard=add_0_1_2=0" %s | FileCheck %s -// UNSUPPORTED: true +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-optimizer=true memory-layout-analysis-enabled=true memreconfig-enabled=true insert-memreconfig=add_0_1_2=0 override-output-layout=add_1_2=1x1:dram:interleaved" %s | FileCheck %s #any_device = #tt.operand_constraint #loc = loc("test_ops.py:17_0_0":0:0) module attributes {} { func.func @main(%arg0: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg1: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg2: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0)) -> tensor<1x32x32xf32> { // CHECK: #[[L1_:.*]] = #tt.memory_space - // CHECK: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <1x1>, memref<32x32xf32, #l1_>, block_sharded> + // CHECK-DAG: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <1x1>, memref<32x32xf32, #l1_>, block_sharded> + // CHECK-DAG: #[[LAYOUT_2:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <1x1>, memref<32x32xf32, #dram>, interleaved> %0 = tensor.empty() : tensor<1x32x32xf32> loc(#loc5) - // CHECK: %[[C:.*]] = "ttnn.add"{{.*}} -> tensor<1x32x32xf32, #[[LAYOUT_1]]> + // CHECK: %[[C:.*]] = "ttnn.add"{{.*}} -> tensor<1x32x32xf32, #[[LAYOUT_2]]> %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc5) %2 = tensor.empty() : tensor<1x32x32xf32> loc(#loc6) - // CHECK: %{{.*}} = "ttnn.to_layout"(%[[C]], %0) {{.*}} -> tensor<1x32x32xf32, #[[LAYOUT_1]]> + // CHECK: %{{.*}} = "ttnn.to_memory_config"(%[[C]]) {{.*}} -> tensor<1x32x32xf32, #[[LAYOUT_1]]> %3 = "ttir.add"(%1, %arg0, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc6) %4 = tensor.empty() : tensor<1x32x32xf32> loc(#loc7) %5 = "ttir.relu"(%3, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc7) diff --git a/test/ttmlir/Dialect/TTNN/test_override_reshard_edges.mlir b/test/ttmlir/Dialect/TTNN/test_override_reshard_edges.mlir index 2004e4d62..34eb9bdc5 100644 --- a/test/ttmlir/Dialect/TTNN/test_override_reshard_edges.mlir +++ b/test/ttmlir/Dialect/TTNN/test_override_reshard_edges.mlir @@ -1,9 +1,8 @@ -// RUN: ttmlir-opt --ttnn-optimizer="memory-layout-analysis-enabled=true memreconfig-enabled=true insert-reshard=add_0_1_2=0" %s | FileCheck %s -// UNSUPPORTED: true +// RUN: ttmlir-opt --ttnn-optimizer="memory-layout-analysis-enabled=true memreconfig-enabled=true insert-memreconfig=add_0_1_2=0 override-output-layout=add_1_2=1x1:dram:interleaved" %s | FileCheck %s #device = #tt.device (0, d0, d1)>, l1Map = (d0, d1)[s0, s1] -> (0, d0 floordiv s0, d1 floordiv s1, (d0 mod s0) * s1 + d1 mod s1), dramMap = (d0, d1)[s0, s1] -> (0, 0, ((((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 8192) mod 12, (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 98304 + (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) mod 8192), meshShape = , chipIds = [0]> #dram = #tt.memory_space #system = #tt.memory_space -#system_desc = #tt.system_desc<[{arch = , grid = 8x8, l1_size = 1499136, num_dram_channels = 12, dram_channel_size = 1073741824, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32, l1_unreserved_base = 1024, erisc_l1_unreserved_base = 1024, dram_unreserved_base = 1024, dram_unreserved_end = 1073741824, physical_cores = {worker = [ 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 1x0, 1x1, 1x2, 1x3, 1x4, 1x5, 1x6, 1x7, 2x0, 2x1, 2x2, 2x3, 2x4, 2x5, 2x6, 2x7, 3x0, 3x1, 3x2, 3x3, 3x4, 3x5, 3x6, 3x7, 4x0, 4x1, 4x2, 4x3, 4x4, 4x5, 4x6, 4x7, 5x0, 5x1, 5x2, 5x3, 5x4, 5x5, 5x6, 5x7, 6x0, 6x1, 6x2, 6x3, 6x4, 6x5, 6x6, 6x7, 7x0, 7x1, 7x2, 7x3, 7x4, 7x5, 7x6, 7x7] dram = [ 8x0, 9x0, 10x0, 8x1, 9x1, 10x1, 8x2, 9x2, 10x2, 8x3, 9x3, 10x3]}, supported_data_types = [, , , , , , , , , , , ], supported_tile_sizes = [ 4x16, 16x16, 32x16, 4x32, 16x32, 32x32]}], [0], [3 : i32], [ 0x0x0x0]> +#system_desc = #tt.system_desc<[{arch = , grid = 8x8, l1_size = 1499136, num_dram_channels = 12, dram_channel_size = 1073741824, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32, l1_unreserved_base = 1024, erisc_l1_unreserved_base = 1024, dram_unreserved_base = 1024, dram_unreserved_end = 1073741824, physical_cores = {worker = [ 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 1x0, 1x1, 1x2, 1x3, 1x4, 1x5, 1x6, 1x7, 2x0, 2x1, 2x2, 2x3, 2x4, 2x5, 2x6, 2x7, 3x0, 3x1, 3x2, 3x3, 3x4, 3x5, 3x6, 3x7, 4x0, 4x1, 4x2, 4x3, 4x4, 4x5, 4x6, 4x7, 5x0, 5x1, 5x2, 5x3, 5x4, 5x5, 5x6, 5x7, 6x0, 6x1, 6x2, 6x3, 6x4, 6x5, 6x6, 6x7, 7x0, 7x1, 7x2, 7x3, 7x4, 7x5, 7x6, 7x7] dram = [ 8x0, 9x0, 10x0, 8x1, 9x1, 10x1, 8x2, 9x2, 10x2, 8x3, 9x3, 10x3]}, supported_data_types = [, , , , , , , , , , , ], supported_tile_sizes = [ 4x16, 16x16, 32x16, 4x32, 16x32, 32x32], num_cbs = 32}], [0], [3 : i32], [ 0x0x0x0]> #layout = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <1x1>, memref<32x32xf32, #system>> #layout1 = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <1x1>, memref<32x32xf32, #dram>, interleaved> module attributes {tt.device = #device, tt.system_desc = #system_desc} { @@ -12,22 +11,19 @@ module attributes {tt.device = #device, tt.system_desc = #system_desc} { // CHECK: #[[LAYOUT_2:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <1x1>, memref<32x32xf32, #l1_>, block_sharded> // CHECK: #[[LAYOUT_3:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #dram>, interleaved> %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> - %1 = "ttnn.to_layout"(%arg0, %0) <{layout = #ttnn.layout}> : (tensor<1x32x32xf32, #layout>, !tt.device<#device>) -> tensor<1x32x32xf32, #layout1> - %2 = "ttnn.to_device"(%1, %0) <{memory_config = #ttnn.memory_config<, >}> : (tensor<1x32x32xf32, #layout1>, !tt.device<#device>) -> tensor<1x32x32xf32, #layout1> - %3 = "ttnn.to_layout"(%arg1, %0) <{layout = #ttnn.layout}> : (tensor<1x32x32xf32, #layout>, !tt.device<#device>) -> tensor<1x32x32xf32, #layout1> - %4 = "ttnn.to_device"(%3, %0) <{memory_config = #ttnn.memory_config<, >}> : (tensor<1x32x32xf32, #layout1>, !tt.device<#device>) -> tensor<1x32x32xf32, #layout1> - %5 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, >, shape = #ttnn.shape<1x32x32>}> : (!tt.device<#device>) -> tensor<1x32x32xf32, #layout1> loc(#loc1) - // CHECK: %[[C:.*]] = "ttnn.add"{{.*}} -> tensor<1x32x32xf32, #[[LAYOUT_2]]> - %6 = "ttnn.add"(%2, %4, %5) <{operandSegmentSizes = array}> : (tensor<1x32x32xf32, #layout1>, tensor<1x32x32xf32, #layout1>, tensor<1x32x32xf32, #layout1>) -> tensor<1x32x32xf32, #layout1> loc(#loc1) - %7 = "ttnn.to_layout"(%arg0, %0) <{layout = #ttnn.layout}> : (tensor<1x32x32xf32, #layout>, !tt.device<#device>) -> tensor<1x32x32xf32, #layout1> - %8 = "ttnn.to_device"(%7, %0) <{memory_config = #ttnn.memory_config<, >}> : (tensor<1x32x32xf32, #layout1>, !tt.device<#device>) -> tensor<1x32x32xf32, #layout1> - %9 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, >, shape = #ttnn.shape<1x32x32>}> : (!tt.device<#device>) -> tensor<1x32x32xf32, #layout1> loc(#loc2) + %1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<32x32>>>}> : (tensor<1x32x32xf32, #layout>, !tt.device<#device>) -> tensor<1x32x32xf32, #layout1> + %2 = "ttnn.to_layout"(%arg1, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<32x32>>>}> : (tensor<1x32x32xf32, #layout>, !tt.device<#device>) -> tensor<1x32x32xf32, #layout1> + %3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<32x32>>>, shape = #ttnn.shape<1x32x32>}> : (!tt.device<#device>) -> tensor<1x32x32xf32, #layout1> loc(#loc1) + // CHECK: %[[C:.*]] = "ttnn.add"{{.*}} -> tensor<1x32x32xf32, #[[LAYOUT_1]]> + %4 = "ttnn.add"(%1, %2, %3) <{operandSegmentSizes = array}> : (tensor<1x32x32xf32, #layout1>, tensor<1x32x32xf32, #layout1>, tensor<1x32x32xf32, #layout1>) -> tensor<1x32x32xf32, #layout1> loc(#loc1) + %5 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<32x32>>>}> : (tensor<1x32x32xf32, #layout>, !tt.device<#device>) -> tensor<1x32x32xf32, #layout1> + %6 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<32x32>>>, shape = #ttnn.shape<1x32x32>}> : (!tt.device<#device>) -> tensor<1x32x32xf32, #layout1> loc(#loc2) // CHECK: %{{.*}} = "ttnn.to_layout"(%[[C]], %0) {{.*}} -> tensor<1x32x32xf32, #[[LAYOUT_2]]> - %10 = "ttnn.add"(%6, %8, %9) <{operandSegmentSizes = array}> : (tensor<1x32x32xf32, #layout1>, tensor<1x32x32xf32, #layout1>, tensor<1x32x32xf32, #layout1>) -> tensor<1x32x32xf32, #layout1> loc(#loc2) - %11 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, >, shape = #ttnn.shape<1x32x32>}> : (!tt.device<#device>) -> tensor<1x32x32xf32, #layout1> loc(#loc3) - %12 = "ttnn.relu"(%10, %11) <{operandSegmentSizes = array}> : (tensor<1x32x32xf32, #layout1>, tensor<1x32x32xf32, #layout1>) -> tensor<1x32x32xf32, #layout1> loc(#loc3) - %13 = "ttnn.to_memory_config"(%12, %0) : (tensor<1x32x32xf32, #layout1>, !tt.device<#device>) -> tensor<1x32x32xf32, #layout> - return %13 : tensor<1x32x32xf32, #layout> + %7 = "ttnn.add"(%4, %6, %6) <{operandSegmentSizes = array}> : (tensor<1x32x32xf32, #layout1>, tensor<1x32x32xf32, #layout1>, tensor<1x32x32xf32, #layout1>) -> tensor<1x32x32xf32, #layout1> loc(#loc2) + %8 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<32x32>>>, shape = #ttnn.shape<1x32x32>}> : (!tt.device<#device>) -> tensor<1x32x32xf32, #layout1> loc(#loc3) + %9 = "ttnn.relu"(%7, %8) <{operandSegmentSizes = array}> : (tensor<1x32x32xf32, #layout1>, tensor<1x32x32xf32, #layout1>) -> tensor<1x32x32xf32, #layout1> loc(#loc3) + %10 = "ttnn.to_layout"(%9) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<32x32>>>}> : (tensor<1x32x32xf32, #layout1>) -> tensor<1x32x32xf32, #layout> + return %10 : tensor<1x32x32xf32, #layout> } } #loc1 = loc("add_1_2")