Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Optimizer] Fix mem reconfig/reshard. #1116

Merged
merged 1 commit into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/ttmlir/Dialect/TT/Utils/OverrideParams.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ struct InputLayoutOverrideParser

static void print(llvm::raw_ostream &os,
const llvm::StringMap<InputLayoutOverrideParams> &value) {
os << "insert-reshard=";
os << "insert-memreconfig=";
size_t count = 0;
for (const auto &entry : value) {
os << entry.getKey() << "=";
Expand Down
10 changes: 5 additions & 5 deletions include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,17 @@ struct TTIRToTTNNBackendPipelineOptions
//
// Full Example: "op1=0,op2=0:1"
//
// This will insert one TTIR_ToLayoutOps responsible for resharding the op1's
// first operand and two TTIR_ToLayoutOps responsible for resharding the op2's
// first and second operand.
// This will insert one memory reconfig op responsible for resharding the
// op1's first operand and two memory reconfig ops responsible for resharding
// the op2's first and second operand.
//
// Note: This option is only valid if optimizerPassEnabled is true.
//
Option<llvm::StringMap<InputLayoutOverrideParams>, InputLayoutOverrideParser>
overrideInputLayout{
*this, "insert-reshard",
*this, "insert-memreconfig",
llvm::cl::desc(
"Manually insert TTIR_ToLayoutOp for specific op's operand."),
"Manually insert memory reconfig op for specific op's operand."),
llvm::cl::init(llvm::StringMap<InputLayoutOverrideParams>())};

// Option to override output layout for specific ops.
Expand Down
4 changes: 2 additions & 2 deletions include/ttmlir/Dialect/TTNN/Transforms/Optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ class TTNNOptimizerBase : public ::mlir::OperationPass<::mlir::ModuleOp> {
::mlir::Pass::Option<llvm::StringMap<InputLayoutOverrideParams>,
mlir::tt::InputLayoutOverrideParser>
overrideInputLayout{
*this, "insert-reshard",
*this, "insert-memreconfig",
::llvm::cl::desc(
"Manually insert reshard for specific op's operand."),
"Manually insert memory reconfig op for specific op's operand."),
::llvm::cl::init(llvm::StringMap<InputLayoutOverrideParams>())};
::mlir::Pass::Option<llvm::StringMap<OutputLayoutOverrideParams>,
mlir::tt::OutputLayoutOverrideParser>
Expand Down
5 changes: 2 additions & 3 deletions lib/Dialect/TTNN/Analysis/DFShardingPolicy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void DFShardingPolicy::run(
//
if (l1ChainConfigs->back().isEmpty()) {
for (auto *op : scheduleableOps) {
if (isa<ttnn::ToLayoutOp>(op)) {
if (isa<ToLayoutOp>(op)) {
currentOp = op;
break;
}
Expand All @@ -52,8 +52,7 @@ void DFShardingPolicy::run(

// Skip starting sharding chain if currentOp is a memory management op.
//
if (l1ChainConfigs->back().isEmpty() &&
isa<ttnn::ToLayoutOp>(currentOp)) {
if (l1ChainConfigs->back().isEmpty() && isa<ToLayoutOp>(currentOp)) {
currentOp = nullptr;
continue;
}
Expand Down
179 changes: 82 additions & 97 deletions lib/Dialect/TTNN/Transforms/Optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase<TTNNOptimizer> {
return;
}

// Skip empty ops. Handled via DPS op output operand update.
//
if (isa<EmptyOp>(op)) {
return;
}

if (!isa<RankedTensorType>(op->getResult(0).getType())) {
return;
}
Expand Down Expand Up @@ -149,7 +155,7 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase<TTNNOptimizer> {
EmptyOp emptyOp =
mlir::cast<EmptyOp>(op->getOperands().back().getDefiningOp());

emptyOp.setMemoryConfigAttr(ttnn::MemoryConfigAttr::get(
emptyOp.setMemoryConfigAttr(MemoryConfigAttr::get(
op->getContext(),
TensorMemoryLayoutAttr::get(op->getContext(),
tensorMemoryLayout),
Expand All @@ -159,29 +165,6 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase<TTNNOptimizer> {
ShapeAttr::get(op->getContext(),
ttLayoutAttr.getMemref().getShape()))));
}
// TODO (nobradovic): Other memory management ops after lowering to
// TTNN will need to be special handled as well. Depends on ttnn
// layout attr refactor and lowering.
//
else if (isa<ttnn::ToLayoutOp>(op)) {
BufferType bufferType =
utils::toTTNNBufferType(ttLayoutAttr.getMemorySpace());
TensorMemoryLayout tensorMemoryLayout =
utils::toTTNNTensorMemoryLayout(ttLayoutAttr.getMemLayout());
// Update the device op with the new tensor type.
//
ttnn::ToLayoutOp toLayoutOp = llvm::cast<ttnn::ToLayoutOp>(op);
toLayoutOp.setMemoryConfigAttr(ttnn::MemoryConfigAttr::get(
op->getContext(),
ttnn::TensorMemoryLayoutAttr::get(op->getContext(),
tensorMemoryLayout),
ttnn::BufferTypeAttr::get(op->getContext(), bufferType),
ttnn::ShardSpecAttr::get(
op->getContext(),
ttnn::ShapeAttr::get(
op->getContext(),
ttLayoutAttr.getMemref().getShape()))));
}
}
});

Expand Down Expand Up @@ -233,6 +216,19 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase<TTNNOptimizer> {
assert(overrideInputLayout.size() == overrideReshardEdges.size());
}

mlir::TypedValue<mlir::tt::DeviceType>
getDeviceOpValue(Operation *contextOp) {
Block *block = contextOp->getBlock();
mlir::TypedValue<mlir::tt::DeviceType> deviceOpResult;
for (auto &op : block->getOperations()) {
if (GetDeviceOp deviceOp = dyn_cast<GetDeviceOp>(op)) {
deviceOpResult = deviceOp.getResult();
nobradovictt marked this conversation as resolved.
Show resolved Hide resolved
break;
}
}
return deviceOpResult;
}

void
processMemReconfigEdges(const std::unordered_set<Edge> &memReconfigEdges) {
// Insert memory reconfig ops here based on results of memory layout
Expand All @@ -242,86 +238,75 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase<TTNNOptimizer> {
Operation *producerOp = edge.producerOp;
Operation *consumerOp = edge.consumerOp;

tt::LayoutAttr consumerOpOutputLayout = mlir::cast<tt::LayoutAttr>(
mlir::cast<RankedTensorType>(consumerOp->getResult(0).getType())
.getEncoding());

RankedTensorType producerOpTensorType =
mlir::cast<RankedTensorType>(producerOp->getResult(0).getType());
llvm::ArrayRef<int64_t> producerOpTensorShape =
producerOpTensorType.getShape();
tt::LayoutAttr producerOpLayout =
mlir::cast<tt::LayoutAttr>(producerOpTensorType.getEncoding());

// TODO(nobradovic): Match memory space and layout of consumer op.
// This actually needs to be properly resolved based on op type, output
// layout and other inputs.
//
RankedTensorType newTensorType = RankedTensorType::get(
producerOpTensorShape, producerOpTensorType.getElementType(),
producerOpLayout
.withElementType(consumerOp->getContext(),
consumerOpOutputLayout.getElementType())
.withMemorySpace(consumerOp->getContext(),
consumerOpOutputLayout.getMemorySpace())
.withMemoryLayout(consumerOp->getContext(),
consumerOpOutputLayout.getMemLayout())
.withGrid(consumerOp->getContext(), producerOpTensorType,
consumerOpOutputLayout.getGrid()));

BufferType outputBufferType =
utils::toTTNNBufferType(consumerOpOutputLayout.getMemorySpace());
TensorMemoryLayout outputTensorMemoryLayout =
utils::toTTNNTensorMemoryLayout(
consumerOpOutputLayout.getMemLayout());
MemRefType outputMemref = consumerOpOutputLayout.getMemref();

MemoryConfigAttr outputMemConfigAttr = MemoryConfigAttr::get(
consumerOp->getContext(),
TensorMemoryLayoutAttr::get(consumerOp->getContext(),
outputTensorMemoryLayout),
BufferTypeAttr::get(consumerOp->getContext(), outputBufferType),
ShardSpecAttr::get(consumerOp->getContext(),
ShapeAttr::get(consumerOp->getContext(),
outputMemref.getShape())));

// If producerOp is a toLayoutOp, adjust its output layout(update
// inplace) to reflect consumerOp's output layout. If producerOp is not a
// toLayoutOp, insert a toLayoutOp in between producerOp
// and consumerOp.
//
if (isa<ttnn::ToLayoutOp>(producerOp)) {
ttnn::ToLayoutOp toLayoutOp = llvm::cast<ttnn::ToLayoutOp>(producerOp);
tt::LayoutAttr consumerOpOutputLayout = mlir::cast<tt::LayoutAttr>(
mlir::cast<RankedTensorType>(consumerOp->getResult(0).getType())
.getEncoding());

RankedTensorType toLayoutOpTensorType =
mlir::cast<RankedTensorType>(toLayoutOp.getResult().getType());
llvm::ArrayRef<int64_t> toLayoutOpTensorShape =
toLayoutOpTensorType.getShape();
tt::LayoutAttr toLayoutOpLayout =
mlir::cast<tt::LayoutAttr>(toLayoutOpTensorType.getEncoding());

// TODO(nobradovic): Match memory space and layout of consumer op. This
// actually needs to be properly resolved based on op type, output
// layout and other inputs.
//
RankedTensorType newTensorType = RankedTensorType::get(
toLayoutOpTensorShape, toLayoutOpTensorType.getElementType(),
toLayoutOpLayout
.withElementType(toLayoutOp->getContext(),
consumerOpOutputLayout.getElementType())
.withMemorySpace(toLayoutOp.getContext(),
consumerOpOutputLayout.getMemorySpace())
.withMemoryLayout(toLayoutOp.getContext(),
consumerOpOutputLayout.getMemLayout())
.withGrid(toLayoutOp.getContext(), toLayoutOpTensorType,
consumerOpOutputLayout.getGrid()));

if (isa<ToLayoutOp>(producerOp)) {
ToLayoutOp toLayoutOp = llvm::cast<ToLayoutOp>(producerOp);
toLayoutOp.setMemoryConfigAttr(outputMemConfigAttr);
toLayoutOp.getResult().setType(newTensorType);
} else {
OpBuilder builder(consumerOp);

DataTypeAttr outputDataType =
DataTypeAttr::get(consumerOp->getContext(),
utils::getDataTypeFromMemRef(outputMemref));
Layout outputLayoutEnum = utils::getLayoutFromMemRef(outputMemref);
LayoutAttr outputLayout =
LayoutAttr::get(consumerOp->getContext(), outputLayoutEnum);
Operation *memoryReconfigOp = builder.create<ToLayoutOp>(
consumerOp->getLoc(), newTensorType, producerOp->getResult(0),
outputLayout, outputDataType, outputMemConfigAttr,
getDeviceOpValue(consumerOp));

consumerOp->setOperand(edge.operandIndex,
memoryReconfigOp->getResult(0));
}
// TODO (nobradovic): Memory layout reconfig needs to be reimplemented for
// TTNN dialect.
// else {
// tt::LayoutAttr consumerOpOutputLayout = mlir::cast<tt::LayoutAttr>(
// mlir::cast<RankedTensorType>(consumerOp->getResult(0).getType())
// .getEncoding());

// RankedTensorType producerOpTensorType =
// mlir::cast<RankedTensorType>(producerOp->getResult(0).getType());
// llvm::ArrayRef<int64_t> producerOpTensorShape =
// producerOpTensorType.getShape();
// tt::LayoutAttr producerOpLayout =
// mlir::cast<tt::LayoutAttr>(producerOpTensorType.getEncoding());

// // TODO(nobradovic): Match memory space and layout of consumer op.
// This
// // actually needs to be properly resolved based on op type, output
// // layout and other inputs.
// //
// RankedTensorType newTensorType = RankedTensorType::get(
// producerOpTensorShape, producerOpTensorType.getElementType(),
// producerOpLayout
// .withElementType(consumerOp->getContext(),
// consumerOpOutputLayout.getElementType())
// .withMemorySpace(consumerOp->getContext(),
// consumerOpOutputLayout.getMemorySpace())
// .withMemoryLayout(consumerOp->getContext(),
// consumerOpOutputLayout.getMemLayout())
// .withGrid(consumerOp->getContext(), producerOpTensorType,
// consumerOpOutputLayout.getGrid()));

// OpBuilder builder(consumerOp);

// mlir::tensor::EmptyOp emptyOp = builder.create<tensor::EmptyOp>(
// consumerOp->getLoc(), producerOpTensorShape,
// producerOpTensorType.getElementType(),
// mlir::cast<LayoutAttr>(newTensorType.getEncoding()));

// Operation *toLayoutOp = builder.create<ttir::ToLayoutOp>(
// consumerOp->getLoc(), newTensorType, producerOp->getResult(0),
// emptyOp);

// consumerOp->setOperand(edge.operandIndex, toLayoutOp->getResult(0));
// }
}
}
};
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-optimizer=true memory-layout-analysis-enabled=true memreconfig-enabled=true insert-reshard=add_0_1_2=0" %s | FileCheck %s
// UNSUPPORTED: true
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-optimizer=true memory-layout-analysis-enabled=true memreconfig-enabled=true insert-memreconfig=add_0_1_2=0 override-output-layout=add_1_2=1x1:dram:interleaved" %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
#loc = loc("test_ops.py:17_0_0":0:0)
module attributes {} {
func.func @main(%arg0: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg1: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg2: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0)) -> tensor<1x32x32xf32> {
// CHECK: #[[L1_:.*]] = #tt.memory_space<l1>
// CHECK: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <1x1>, memref<32x32xf32, #l1_>, block_sharded>
// CHECK-DAG: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <1x1>, memref<32x32xf32, #l1_>, block_sharded>
// CHECK-DAG: #[[LAYOUT_2:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <1x1>, memref<32x32xf32, #dram>, interleaved>
%0 = tensor.empty() : tensor<1x32x32xf32> loc(#loc5)
// CHECK: %[[C:.*]] = "ttnn.add"{{.*}} -> tensor<1x32x32xf32, #[[LAYOUT_1]]>
// CHECK: %[[C:.*]] = "ttnn.add"{{.*}} -> tensor<1x32x32xf32, #[[LAYOUT_2]]>
%1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc5)
%2 = tensor.empty() : tensor<1x32x32xf32> loc(#loc6)
// CHECK: %{{.*}} = "ttnn.to_layout"(%[[C]], %0) {{.*}} -> tensor<1x32x32xf32, #[[LAYOUT_1]]>
// CHECK: %{{.*}} = "ttnn.to_memory_config"(%[[C]]) {{.*}} -> tensor<1x32x32xf32, #[[LAYOUT_1]]>
%3 = "ttir.add"(%1, %arg0, %2) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc6)
%4 = tensor.empty() : tensor<1x32x32xf32> loc(#loc7)
%5 = "ttir.relu"(%3, %4) <{operandSegmentSizes = array<i32: 1, 1>, operand_constraints = [#any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc7)
Expand Down
Loading
Loading