diff --git a/lib/Dialect/TTNN/Analysis/L1InterleavedPolicy.cpp b/lib/Dialect/TTNN/Analysis/L1InterleavedPolicy.cpp index 3b0885741..fa48ea356 100644 --- a/lib/Dialect/TTNN/Analysis/L1InterleavedPolicy.cpp +++ b/lib/Dialect/TTNN/Analysis/L1InterleavedPolicy.cpp @@ -4,47 +4,131 @@ #include "ttmlir/Dialect/TTNN/Analysis/L1InterleavedPolicy.h" #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" #include "ttmlir/Scheduler/Scheduler.h" namespace mlir::tt::ttnn { +uint64_t getOpOutputLayoutUsage( + Operation *op, + llvm::DenseMap> &legalLayouts, + DeviceAttr &deviceAttr) { + tt::LayoutAttr opLayout = legalLayouts.lookup(op).front(); + assert(opLayout.hasInterleavedL1TensorMemoryLayout()); + + llvm::ArrayRef opOutputTensorShape = + mlir::cast(op->getResult(0).getType()).getShape(); + + uint64_t opL1OutputUsage = deviceAttr.getLayoutSizeBytes( + opOutputTensorShape, opLayout, opLayout.getMemorySpace()); + return opL1OutputUsage; +} + void L1InterleavedPolicy::run() { rootOp->walk([&](func::FuncOp func) { + DeviceAttr deviceAttr = getCurrentScopeDevice(func); mlir::tt::scheduler::Scheduler scheduler(&func); llvm::SmallVector scheduleableOps; - Operation *currentOp = nullptr; llvm::DenseMap selectedOpLayout; + Operation *currentOp = nullptr; // TODO(fbajraktari): - // This is V0 implementation of L1 interleaved policy. In the current - // implementation we have a single L1ChainCofig per FuncOp. This implies - // that in case of DRAM spill we will have a disconnected chain of L1 ops. - // This will be fixed in V1. // l1ChainConfigs->push_back(L1ChainConfig()); while (scheduler.hasUnscheduledOps()) { scheduleableOps = scheduler.getScheduleableOps(); - currentOp = scheduleableOps[0]; + + // Before starting a l1 chain, schedule layout/memory management ops + // first until they are exhausted from schedulable ops. + // + if (l1ChainConfigs->back().isEmpty()) { + for (auto *op : scheduleableOps) { + if (isa(op)) { + currentOp = op; + break; + } + } + } + + if (currentOp == nullptr) { + currentOp = scheduleableOps[0]; + } // Schedule currentOp. // scheduler.scheduleOp(currentOp); - // Check if currentOp is valid l1 interleaved op. + // Skip starting sharding chain if currentOp is a memory management op. // - if (legalLayouts.lookup(currentOp).size() > 0) { - selectedOpLayout[currentOp] = legalLayouts.lookup(currentOp).front(); + if (l1ChainConfigs->back().isEmpty() && isa(currentOp)) { + currentOp = nullptr; + continue; + } - // Add currentOp to l1 chain config. - // - OpL1MemSpec shardSpec; - shardSpec.op = currentOp; + if (scheduler.hasUnscheduledOps()) { + scheduleableOps = scheduler.getScheduleableOps(); - // Hardcoded tensor split factor for now, until pipeline OP - // support is added. + // Check if currentOp has a valid successor. // - shardSpec.tensorSplitFactor = 1; - l1ChainConfigs->back().addOpL1MemSpec(std::move(shardSpec)); + Operation *nextOp = nullptr; + for (auto *op : scheduleableOps) { + for (auto operand : op->getOperands()) { + if (operand.getDefiningOp() == currentOp) { + nextOp = op; + break; + } + } + } + + if (nextOp) { + + // V1: Check that currentOp is not fork/join op. + // + bool validForL1Interleaved = + currentOp->hasOneUse() && + legalLayouts.lookup(currentOp).size() > 0 && + legalLayouts.lookup(nextOp).size() > 0; + + if (validForL1Interleaved) { + // Figure out this const based on exec data, but will be replaced + // with API. + // + constexpr float tensorL1UsageCap = 0.8; + uint64_t currentOpL1OutputUsage = + getOpOutputLayoutUsage(currentOp, legalLayouts, deviceAttr); + uint64_t nextOpL1OutputUsage = + getOpOutputLayoutUsage(nextOp, legalLayouts, deviceAttr); + bool l1UsageValid = (currentOpL1OutputUsage + nextOpL1OutputUsage) < + tensorL1UsageCap * usableL1CacheSize; + + if (l1UsageValid) { + selectedOpLayout[currentOp] = + legalLayouts.lookup(currentOp).front(); + + // Add currentOp to l1 chain config. + // + OpL1MemSpec shardSpec; + shardSpec.op = currentOp; + + // Hardcoded tensor split factor for now, until pipeline OP + // support is added. + // + shardSpec.tensorSplitFactor = 1; + l1ChainConfigs->back().addOpL1MemSpec(std::move(shardSpec)); + + // Update currentOp pointer. + // + currentOp = nextOp; + continue; + } + } + } + + currentOp = nullptr; + if (!l1ChainConfigs->back().isEmpty()) { + l1ChainConfigs->back().build(); + l1ChainConfigs->push_back(L1ChainConfig()); + } } } @@ -59,14 +143,12 @@ void L1InterleavedPolicy::run() { // Resolve l1 chain configs. // for (auto &l1ChainConfig : *l1ChainConfigs) { - l1ChainConfig.build(); l1ChainConfig.resolve(); std::unordered_set memReconfigEdges; l1ChainConfig.complete(selectedOpLayout, memReconfigEdges); } }); - llvm::errs() << "usableL1CacheSize: " << usableL1CacheSize << "\n"; } } // namespace mlir::tt::ttnn diff --git a/test/ttmlir/Silicon/TTNN/all_l1_interleaved_policy.mlir b/test/ttmlir/Silicon/TTNN/all_l1_interleaved_policy.mlir index d90d93e64..f33831c29 100644 --- a/test/ttmlir/Silicon/TTNN/all_l1_interleaved_policy.mlir +++ b/test/ttmlir/Silicon/TTNN/all_l1_interleaved_policy.mlir @@ -7,6 +7,7 @@ module attributes {} { // CHECK: #[[L1_:.*]] = #tt.memory_space // CHECK: #[[LAYOUT_6:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<8x12xbf16, #l1_>, interleaved> // CHECK: #[[LAYOUT_7:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<8x4xbf16, #l1_>, interleaved> + // CHECK: #[[LAYOUT_8:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<8x4xbf16, #dram>, interleaved> %0 = tensor.empty() : tensor<64x96xbf16> // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<64x96xbf16, #[[LAYOUT_6]]> %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<128x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16> @@ -23,7 +24,7 @@ module attributes {} { // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_7]]> %9 = "ttir.add"(%7, %arg4, %8) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x32xbf16>, tensor<64x32xbf16>, tensor<64x32xbf16>) -> tensor<64x32xbf16> %10 = tensor.empty() : tensor<64x32xbf16> - // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_7]]> + // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_8]]> %11 = "ttir.relu"(%9, %10) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x32xbf16>, tensor<64x32xbf16>) -> tensor<64x32xbf16> return %11 : tensor<64x32xbf16> } diff --git a/test/ttmlir/Silicon/TTNN/mnist_l1_interleaved.mlir b/test/ttmlir/Silicon/TTNN/mnist_l1_interleaved.mlir index 87b96ab35..923a4f209 100644 --- a/test/ttmlir/Silicon/TTNN/mnist_l1_interleaved.mlir +++ b/test/ttmlir/Silicon/TTNN/mnist_l1_interleaved.mlir @@ -7,6 +7,7 @@ module @"tt-forge-graph" attributes {} { func.func @main(%arg0: tensor<1x784xf32> loc("MNISTLinear":4294967295:0), %arg1: tensor<1x10xf32> loc("MNISTLinear":4294967295:0), %arg2: tensor<256x10xf32> loc("MNISTLinear":4294967295:0), %arg3: tensor<1x256xf32> loc("MNISTLinear":4294967295:0), %arg4: tensor<784x256xf32> loc("MNISTLinear":4294967295:0)) -> tensor<1x10xf32> { // CHECK: #[[LAYOUT_6:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<1x32xf32, #l1_>, interleaved> // CHECK: #[[LAYOUT_7:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<1x2xf32, #l1_>, interleaved> + // CHECK: #[[LAYOUT_8:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<1x2xf32, #dram>, interleaved> %0 = tensor.empty() : tensor<1x256xf32> loc(#loc8) // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] -> tensor<1x256xf32, #[[LAYOUT_6]]> %1 = "ttir.matmul"(%arg0, %arg4, %0) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x784xf32>, tensor<784x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc8) @@ -23,7 +24,7 @@ module @"tt-forge-graph" attributes {} { // CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x10xf32, #[[LAYOUT_7]]> %9 = "ttir.add"(%7, %arg1, %8) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc12) %10 = tensor.empty() : tensor<1x10xf32> loc(#loc13) - // CHECK: %{{.*}} = "ttnn.softmax"{{.*}} -> tensor<1x10xf32, #[[LAYOUT_7]]> + // CHECK: %{{.*}} = "ttnn.softmax"{{.*}} -> tensor<1x10xf32, #[[LAYOUT_8]]> %11 = "ttir.softmax"(%9, %10) <{dimension = 1 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc13) return %11 : tensor<1x10xf32> loc(#loc7) } loc(#loc)