Skip to content

Commit

Permalink
[Optimizer] Fix ttnn.ToLayout mem config optimizer override. (#1130)
Browse files Browse the repository at this point in the history
  • Loading branch information
nobradovictt authored Nov 4, 2024
1 parent f2ccda0 commit d73456b
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 2 deletions.
24 changes: 24 additions & 0 deletions lib/Dialect/TTNN/Transforms/Optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,30 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase<TTNNOptimizer> {
ShapeAttr::get(op->getContext(),
ttLayoutAttr.getMemref().getShape()))));
}
// TODO(mtopalovic): Temp workaround for generic ToLayoutOp. Allign
// MemoryConfigAttr with layout attribute of its output tensor. This
// redundant info should be removed or made consistent as part of temp
// ToLayoutOp decomposition pass.
//
else if (isa<ttnn::ToLayoutOp>(op)) {
BufferType bufferType =
utils::toTTNNBufferType(ttLayoutAttr.getMemorySpace());
TensorMemoryLayout tensorMemoryLayout =
utils::toTTNNTensorMemoryLayout(ttLayoutAttr.getMemLayout());
// Update the device op with the new tensor type.
//
ttnn::ToLayoutOp toLayoutOp = llvm::cast<ttnn::ToLayoutOp>(op);
toLayoutOp.setMemoryConfigAttr(ttnn::MemoryConfigAttr::get(
op->getContext(),
ttnn::TensorMemoryLayoutAttr::get(op->getContext(),
tensorMemoryLayout),
ttnn::BufferTypeAttr::get(op->getContext(), bufferType),
ttnn::ShardSpecAttr::get(
op->getContext(),
ttnn::ShapeAttr::get(
op->getContext(),
ttLayoutAttr.getMemref().getShape()))));
}
}
});

Expand Down
2 changes: 1 addition & 1 deletion test/ttmlir/Dialect/TTNN/input_layout_loc_override.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ module attributes {} {
func.func @forward(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x96xbf16>) -> tensor<64x96xbf16> {
%0 = tensor.empty() : tensor<64x96xbf16> loc(#loc2)
// CHECK-DAG: %{{.*}} = "ttnn.to_device"{{.*}} loc(#[[LOC_MATMUL_IN0]])
// CHECK-DAG: %{{.*}} = "ttnn.to_device"{{.*}} -> tensor<128x96xbf16, #[[IN_1_LAYOUT]]> loc(#[[LOC_MATMUL_IN1]])
// CHECK-DAG: %{{.*}} = "ttnn.to_device"{{.*}} <{memory_config = #ttnn.memory_config<<interleaved>, <l1>, <<4x3>>>}> : {{.*}} -> tensor<128x96xbf16, #[[IN_1_LAYOUT]]> loc(#[[LOC_MATMUL_IN1]])
// CHECK-DAG: %{{.*}} = "ttnn.matmul"{{.*}} loc(#[[LOC_MATMUL]])
%1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<128x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16> loc(#loc2)
return %1 : tensor<64x96xbf16>
Expand Down
2 changes: 1 addition & 1 deletion test/ttmlir/Silicon/TTNN/sharded/mnist_sharding_tiled.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path% enable-optimizer=true memory-layout-analysis-enabled=true override-output-layout=matmul_1=1x8:l1:width_sharded,add_2=1x8:l1:width_sharded,add_2_in_1_layout=1x8:l1:width_sharded,relu_3=1x8:l1:width_sharded,matmul_5=1x1:l1:width_sharded,add_6=1x1:l1:width_sharded,add_6_in_1_layout=1x1:l1:width_sharded,softmax_7=1x1:l1:width_sharded" %s > %t.mlir
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path% enable-optimizer=true memory-layout-analysis-enabled=true override-output-layout=matmul_1=1x8:l1:width_sharded,add_2=1x8:l1:width_sharded,relu_3=1x8:l1:width_sharded,matmul_5=1x1:l1:width_sharded,add_6=1x1:l1:width_sharded,softmax_7=1x1:l1:width_sharded" %s > %t.mlir
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
Expand Down

0 comments on commit d73456b

Please sign in to comment.