Skip to content

Commit

Permalink
Limit number of sharded layouts
Browse files Browse the repository at this point in the history
  • Loading branch information
odjuricicTT committed Aug 29, 2024
1 parent 5e196c4 commit eb0c300
Showing 1 changed file with 41 additions and 30 deletions.
71 changes: 41 additions & 30 deletions lib/Dialect/TTIR/Analysis/LegalGridAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
// SPDX-License-Identifier: Apache-2.0

#include "ttmlir/Dialect/TTIR/Analysis/LegalGridAnalysis.h"
#include <mlir/IR/Operation.h>

namespace mlir::tt::ttir {

Expand All @@ -15,24 +14,23 @@ bool mock_is_output_tensor_legal_for_op(Operation *op, LayoutAttr layout) {
bool tensor_shape_compatible_with_shard(Operation *op, LayoutAttr layout) {
// These constraints are implemented seteratelly in every TTNN op.
// Almost nothing seems to be shared between EVERY op, so is hard to have any
// logic here without the risk of discarding a valid configuraiton.
// This logic may be offloaded to the TTNN op interface.

// For now we will check if the tilesed tensor dims are divisible by the grid dims.
// This will definitly discard possible valid configurations, but is a start.
// logic here without the risk of discarding a valid configuraiton or modeling
// the constraint for each op. This logic may be offloaded to the TTNN op
// interface.

// For now we will check if the tilised tensor dims are divisible by the grid
// dims. This will definitly discard possible valid configurations, but is a
// start.
RankedTensorType tensorType =
mlir::cast<RankedTensorType>(op->getResult(0).getType());
llvm::ArrayRef<int64_t> tensorShape = tensorType.getShape();

long MTiles = 1;
if (tensorType.getRank() == 2) {
MTiles = (tensorShape[0] + 31) / 32;
} else if (tensorType.getRank() > 2) {
MTiles = (tensorShape[-2] * tensorShape[-3] + 31) / 32;
if (tensorType.getRank() >= 2) {
MTiles = (tensorShape.rbegin()[1] + 31) / 32;
}

auto KTIles = (tensorShape[-1] + 31) / 32;
auto KTIles = (tensorShape.back() + 31) / 32;

auto gridR = layout.getGrid().getShape()[0];
auto gridC = layout.getGrid().getShape()[1];
Expand All @@ -45,6 +43,7 @@ bool LegalGridAnalysis::applyOverrides() {
// operation.
//

// TODO(odjuricic): We may need to infer shard type.
RankedTensorType tensorType =
mlir::cast<RankedTensorType>(op->getResult(0).getType());
LayoutAttr layout = mlir::cast<LayoutAttr>(tensorType.getEncoding());
Expand Down Expand Up @@ -75,32 +74,35 @@ void LegalGridAnalysis::analysisImplementation() {
// other components.

// Get output tensor type.
// TODO: This ignores multiple outputs...?
RankedTensorType tensorType =
mlir::cast<RankedTensorType>(op->getResult(0).getType());
LayoutAttr layout = mlir::cast<LayoutAttr>(tensorType.getEncoding());

// DRAM
// No grid is set since the tensor is not sharded.
// TODO: Is this a viable solution or should we have a grid?
LayoutAttr dram =
layout.withMemorySpace(op->getContext(), MemorySpace::DeviceDRAM);
analysisResult.push_back(dram);
if (mock_is_output_tensor_legal_for_op(op, dram)) {
analysisResult.push_back(dram);
}

// L1 Interleaved (same as above)
// L1 Interleaved (same as above).
LayoutAttr l1Interleaved =
layout.withMemorySpace(op->getContext(), MemorySpace::DeviceL1);
analysisResult.push_back(l1Interleaved);
if (mock_is_output_tensor_legal_for_op(op, l1Interleaved)) {
analysisResult.push_back(l1Interleaved);
}

// L1 Sharded
LayoutAttr shardedBase =
layout.withMemorySpace(op->getContext(), MemorySpace::DeviceL1);
std::vector<LayoutAttr> shardedResults;

// Block Sharded
for (auto width = 2; width <= analysisInput.maxGrid.getShape()[0]; ++width) {
for (auto height = 2; height <= analysisInput.maxGrid.getShape()[1];
++height) {
analysisResult.push_back(shardedBase.withGrid(
shardedResults.push_back(shardedBase.withGrid(
op->getContext(), tensorType,
GridAttr::get(op->getContext(), {width, height})));
}
Expand All @@ -109,32 +111,41 @@ void LegalGridAnalysis::analysisImplementation() {
auto numCores =
analysisInput.maxGrid.getShape()[0] * analysisInput.maxGrid.getShape()[1];
// Height Sharded
// TODO: Missing affine mapping to actual grid.
// TODO: Can we have every shape of 1d grid? Probably not, need to check what
// is divisible by grid sides.
// TODO: Limit the number of options to some reasonable number.
// TODO: Put all of this into the same loop.
// TODO(odjuricic): Missing affine mapping to actual grid. Need to check with
// runtime implementation on what to produce here.
for (auto height = 2; height <= numCores; ++height) {
analysisResult.push_back(
shardedResults.push_back(
shardedBase.withGrid(op->getContext(), tensorType,
GridAttr::get(op->getContext(), {height, 1})));
}

// Width Sharded
for (auto width = 2; width <= numCores; ++width) {
analysisResult.push_back(
shardedResults.push_back(
shardedBase.withGrid(op->getContext(), tensorType,
GridAttr::get(op->getContext(), {1, width})));
}

// Filter layouts based on output tensor legality for current op.
analysisResult.erase(
std::remove_if(analysisResult.begin(), analysisResult.end(),
shardedResults.erase(
std::remove_if(shardedResults.begin(), shardedResults.end(),
[this](LayoutAttr layout) {
return !mock_is_output_tensor_legal_for_op(op, layout);
return !tensor_shape_compatible_with_shard(op, layout) ||
!mock_is_output_tensor_legal_for_op(op, layout);
}),
analysisResult.end());

// TODO: Potetialy filter out tensors that dont fit into L1 at all.
shardedResults.end());

// Pick top largest sharded grids.
int MAX_SHARDED_GRIDS = 50;
std::sort(shardedResults.begin(), shardedResults.end(),
[](LayoutAttr a, LayoutAttr b) {
return a.getGrid().getShape()[0] * a.getGrid().getShape()[1] >
b.getGrid().getShape()[0] * b.getGrid().getShape()[1];
});

analysisResult.insert(
analysisResult.end(), shardedResults.begin(),
shardedResults.begin() +
std::min(MAX_SHARDED_GRIDS, static_cast<int>(shardedResults.size())));
}
} // namespace mlir::tt::ttir

0 comments on commit eb0c300

Please sign in to comment.