Skip to content

Commit

Permalink
Set TensorMemoryLayout in legal layout generation
Browse files Browse the repository at this point in the history
  • Loading branch information
odjuricicTT committed Sep 4, 2024
1 parent cc2d180 commit f2efe2d
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 18 deletions.
45 changes: 28 additions & 17 deletions lib/Dialect/TTIR/Analysis/LegalGridAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,35 +91,40 @@ void LegalGridAnalysis::analysisImplementation() {
mlir::cast<RankedTensorType>(op->getResult(0).getType());
LayoutAttr layout = mlir::cast<LayoutAttr>(tensorType.getEncoding());

// L1 Interleaved (same as above).
LayoutAttr l1Interleaved =
layout.withMemorySpace(op->getContext(), MemorySpace::DeviceL1);
if (mock_is_output_tensor_legal_for_op(op, l1Interleaved)) {
analysisResult.push_back(l1Interleaved);
}

// DRAM
// No grid is set since the tensor is not sharded.
// TODO(odjuricic): We need to set grid here since it will be used as the
// compute gird. (not implemented in runtime atm)
LayoutAttr dram =
layout.withMemorySpace(op->getContext(), MemorySpace::DeviceDRAM);
layout.withMemorySpace(op->getContext(), MemorySpace::DeviceDRAM)
.withMemoryLayout(op->getContext(), TensorMemoryLayout::Interleaved);
if (mock_is_output_tensor_legal_for_op(op, dram)) {
analysisResult.push_back(dram);
}

// L1 Interleaved (same as above).
LayoutAttr l1Interleaved =
layout.withMemorySpace(op->getContext(), MemorySpace::DeviceL1)
.withMemoryLayout(op->getContext(), TensorMemoryLayout::Interleaved);
if (mock_is_output_tensor_legal_for_op(op, l1Interleaved)) {
analysisResult.push_back(l1Interleaved);
}

// L1 Sharded
LayoutAttr shardedBase =
layout.withMemorySpace(op->getContext(), MemorySpace::DeviceL1);
std::vector<LayoutAttr> shardedResults;

// Block Sharded
for (auto width = 2; width <= analysisInput.maxGrid.getShape()[0]; ++width) {
for (auto height = 2; height <= analysisInput.maxGrid.getShape()[1];
for (auto width = 1; width <= analysisInput.maxGrid.getShape()[0]; ++width) {
for (auto height = 1; height <= analysisInput.maxGrid.getShape()[1];
++height) {
shardedResults.push_back(shardedBase.withGrid(
op->getContext(), tensorType,
GridAttr::get(op->getContext(), {width, height})));
shardedResults.push_back(
shardedBase
.withGrid(op->getContext(), tensorType,
GridAttr::get(op->getContext(), {width, height}))
.withMemoryLayout(op->getContext(),
TensorMemoryLayout::BlockSharded));
}
}

Expand All @@ -130,15 +135,21 @@ void LegalGridAnalysis::analysisImplementation() {
// runtime implementation on what to produce here.
for (auto height = 2; height <= numCores; ++height) {
shardedResults.push_back(
shardedBase.withGrid(op->getContext(), tensorType,
GridAttr::get(op->getContext(), {height, 1})));
shardedBase
.withGrid(op->getContext(), tensorType,
GridAttr::get(op->getContext(), {height, 1}))
.withMemoryLayout(op->getContext(),
TensorMemoryLayout::HeightSharded));
}

// Width Sharded
for (auto width = 2; width <= numCores; ++width) {
shardedResults.push_back(
shardedBase.withGrid(op->getContext(), tensorType,
GridAttr::get(op->getContext(), {1, width})));
shardedBase
.withGrid(op->getContext(), tensorType,
GridAttr::get(op->getContext(), {1, width}))
.withMemoryLayout(op->getContext(),
TensorMemoryLayout::WidthSharded));
}

// Filter layouts based on output tensor legality for current op.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path% enable-grid-set=false" %s > %t.mlir
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
#l1_block_sharded = #tt.operand_constraint<l1_block_sharded>
Expand Down

0 comments on commit f2efe2d

Please sign in to comment.