diff --git a/lib/Dialect/TTIR/Analysis/LegalGridAnalysis.cpp b/lib/Dialect/TTIR/Analysis/LegalGridAnalysis.cpp index 833dcac81c..5bd2851e31 100644 --- a/lib/Dialect/TTIR/Analysis/LegalGridAnalysis.cpp +++ b/lib/Dialect/TTIR/Analysis/LegalGridAnalysis.cpp @@ -3,7 +3,6 @@ // SPDX-License-Identifier: Apache-2.0 #include "ttmlir/Dialect/TTIR/Analysis/LegalGridAnalysis.h" -#include namespace mlir::tt::ttir { @@ -15,24 +14,23 @@ bool mock_is_output_tensor_legal_for_op(Operation *op, LayoutAttr layout) { bool tensor_shape_compatible_with_shard(Operation *op, LayoutAttr layout) { // These constraints are implemented seteratelly in every TTNN op. // Almost nothing seems to be shared between EVERY op, so is hard to have any - // logic here without the risk of discarding a valid configuraiton. - // This logic may be offloaded to the TTNN op interface. - - // For now we will check if the tilesed tensor dims are divisible by the grid dims. - // This will definitly discard possible valid configurations, but is a start. + // logic here without the risk of discarding a valid configuraiton or modeling + // the constraint for each op. This logic may be offloaded to the TTNN op + // interface. + // For now we will check if the tilised tensor dims are divisible by the grid + // dims. This will definitly discard possible valid configurations, but is a + // start. RankedTensorType tensorType = mlir::cast(op->getResult(0).getType()); llvm::ArrayRef tensorShape = tensorType.getShape(); long MTiles = 1; - if (tensorType.getRank() == 2) { - MTiles = (tensorShape[0] + 31) / 32; - } else if (tensorType.getRank() > 2) { - MTiles = (tensorShape[-2] * tensorShape[-3] + 31) / 32; + if (tensorType.getRank() >= 2) { + MTiles = (tensorShape.rbegin()[1] + 31) / 32; } - auto KTIles = (tensorShape[-1] + 31) / 32; + auto KTIles = (tensorShape.back() + 31) / 32; auto gridR = layout.getGrid().getShape()[0]; auto gridC = layout.getGrid().getShape()[1]; @@ -45,6 +43,7 @@ bool LegalGridAnalysis::applyOverrides() { // operation. // + // TODO(odjuricic): We may need to infer shard type. RankedTensorType tensorType = mlir::cast(op->getResult(0).getType()); LayoutAttr layout = mlir::cast(tensorType.getEncoding()); @@ -75,32 +74,35 @@ void LegalGridAnalysis::analysisImplementation() { // other components. // Get output tensor type. - // TODO: This ignores multiple outputs...? RankedTensorType tensorType = mlir::cast(op->getResult(0).getType()); LayoutAttr layout = mlir::cast(tensorType.getEncoding()); // DRAM // No grid is set since the tensor is not sharded. - // TODO: Is this a viable solution or should we have a grid? LayoutAttr dram = layout.withMemorySpace(op->getContext(), MemorySpace::DeviceDRAM); - analysisResult.push_back(dram); + if (mock_is_output_tensor_legal_for_op(op, dram)) { + analysisResult.push_back(dram); + } - // L1 Interleaved (same as above) + // L1 Interleaved (same as above). LayoutAttr l1Interleaved = layout.withMemorySpace(op->getContext(), MemorySpace::DeviceL1); - analysisResult.push_back(l1Interleaved); + if (mock_is_output_tensor_legal_for_op(op, l1Interleaved)) { + analysisResult.push_back(l1Interleaved); + } // L1 Sharded LayoutAttr shardedBase = layout.withMemorySpace(op->getContext(), MemorySpace::DeviceL1); + std::vector shardedResults; // Block Sharded for (auto width = 2; width <= analysisInput.maxGrid.getShape()[0]; ++width) { for (auto height = 2; height <= analysisInput.maxGrid.getShape()[1]; ++height) { - analysisResult.push_back(shardedBase.withGrid( + shardedResults.push_back(shardedBase.withGrid( op->getContext(), tensorType, GridAttr::get(op->getContext(), {width, height}))); } @@ -109,32 +111,41 @@ void LegalGridAnalysis::analysisImplementation() { auto numCores = analysisInput.maxGrid.getShape()[0] * analysisInput.maxGrid.getShape()[1]; // Height Sharded - // TODO: Missing affine mapping to actual grid. - // TODO: Can we have every shape of 1d grid? Probably not, need to check what - // is divisible by grid sides. - // TODO: Limit the number of options to some reasonable number. - // TODO: Put all of this into the same loop. + // TODO(odjuricic): Missing affine mapping to actual grid. Need to check with + // runtime implementation on what to produce here. for (auto height = 2; height <= numCores; ++height) { - analysisResult.push_back( + shardedResults.push_back( shardedBase.withGrid(op->getContext(), tensorType, GridAttr::get(op->getContext(), {height, 1}))); } // Width Sharded for (auto width = 2; width <= numCores; ++width) { - analysisResult.push_back( + shardedResults.push_back( shardedBase.withGrid(op->getContext(), tensorType, GridAttr::get(op->getContext(), {1, width}))); } // Filter layouts based on output tensor legality for current op. - analysisResult.erase( - std::remove_if(analysisResult.begin(), analysisResult.end(), + shardedResults.erase( + std::remove_if(shardedResults.begin(), shardedResults.end(), [this](LayoutAttr layout) { - return !mock_is_output_tensor_legal_for_op(op, layout); + return !tensor_shape_compatible_with_shard(op, layout) || + !mock_is_output_tensor_legal_for_op(op, layout); }), - analysisResult.end()); - - // TODO: Potetialy filter out tensors that dont fit into L1 at all. + shardedResults.end()); + + // Pick top largest sharded grids. + int MAX_SHARDED_GRIDS = 50; + std::sort(shardedResults.begin(), shardedResults.end(), + [](LayoutAttr a, LayoutAttr b) { + return a.getGrid().getShape()[0] * a.getGrid().getShape()[1] > + b.getGrid().getShape()[0] * b.getGrid().getShape()[1]; + }); + + analysisResult.insert( + analysisResult.end(), shardedResults.begin(), + shardedResults.begin() + + std::min(MAX_SHARDED_GRIDS, static_cast(shardedResults.size()))); } } // namespace mlir::tt::ttir