Skip to content

Commit

Permalink
V1 implementation for L1Interleaved policy (#1132)
Browse files Browse the repository at this point in the history
  • Loading branch information
fbajraktariTT authored Nov 4, 2024
1 parent e303453 commit 0e45561
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 21 deletions.
120 changes: 101 additions & 19 deletions lib/Dialect/TTNN/Analysis/L1InterleavedPolicy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,47 +4,131 @@

#include "ttmlir/Dialect/TTNN/Analysis/L1InterleavedPolicy.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"
#include "ttmlir/Scheduler/Scheduler.h"

namespace mlir::tt::ttnn {

uint64_t getOpOutputLayoutUsage(
Operation *op,
llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>> &legalLayouts,
DeviceAttr &deviceAttr) {
tt::LayoutAttr opLayout = legalLayouts.lookup(op).front();
assert(opLayout.hasInterleavedL1TensorMemoryLayout());

llvm::ArrayRef<int64_t> opOutputTensorShape =
mlir::cast<RankedTensorType>(op->getResult(0).getType()).getShape();

uint64_t opL1OutputUsage = deviceAttr.getLayoutSizeBytes(
opOutputTensorShape, opLayout, opLayout.getMemorySpace());
return opL1OutputUsage;
}

void L1InterleavedPolicy::run() {
rootOp->walk([&](func::FuncOp func) {
DeviceAttr deviceAttr = getCurrentScopeDevice(func);
mlir::tt::scheduler::Scheduler scheduler(&func);
llvm::SmallVector<mlir::Operation *> scheduleableOps;
Operation *currentOp = nullptr;
llvm::DenseMap<Operation *, tt::LayoutAttr> selectedOpLayout;
Operation *currentOp = nullptr;

// TODO(fbajraktari):
// This is V0 implementation of L1 interleaved policy. In the current
// implementation we have a single L1ChainCofig per FuncOp. This implies
// that in case of DRAM spill we will have a disconnected chain of L1 ops.
// This will be fixed in V1.
//
l1ChainConfigs->push_back(L1ChainConfig());
while (scheduler.hasUnscheduledOps()) {
scheduleableOps = scheduler.getScheduleableOps();
currentOp = scheduleableOps[0];

// Before starting a l1 chain, schedule layout/memory management ops
// first until they are exhausted from schedulable ops.
//
if (l1ChainConfigs->back().isEmpty()) {
for (auto *op : scheduleableOps) {
if (isa<ToLayoutOp>(op)) {
currentOp = op;
break;
}
}
}

if (currentOp == nullptr) {
currentOp = scheduleableOps[0];
}

// Schedule currentOp.
//
scheduler.scheduleOp(currentOp);

// Check if currentOp is valid l1 interleaved op.
// Skip starting sharding chain if currentOp is a memory management op.
//
if (legalLayouts.lookup(currentOp).size() > 0) {
selectedOpLayout[currentOp] = legalLayouts.lookup(currentOp).front();
if (l1ChainConfigs->back().isEmpty() && isa<ToLayoutOp>(currentOp)) {
currentOp = nullptr;
continue;
}

// Add currentOp to l1 chain config.
//
OpL1MemSpec shardSpec;
shardSpec.op = currentOp;
if (scheduler.hasUnscheduledOps()) {
scheduleableOps = scheduler.getScheduleableOps();

// Hardcoded tensor split factor for now, until pipeline OP
// support is added.
// Check if currentOp has a valid successor.
//
shardSpec.tensorSplitFactor = 1;
l1ChainConfigs->back().addOpL1MemSpec(std::move(shardSpec));
Operation *nextOp = nullptr;
for (auto *op : scheduleableOps) {
for (auto operand : op->getOperands()) {
if (operand.getDefiningOp() == currentOp) {
nextOp = op;
break;
}
}
}

if (nextOp) {

// V1: Check that currentOp is not fork/join op.
//
bool validForL1Interleaved =
currentOp->hasOneUse() &&
legalLayouts.lookup(currentOp).size() > 0 &&
legalLayouts.lookup(nextOp).size() > 0;

if (validForL1Interleaved) {
// Figure out this const based on exec data, but will be replaced
// with API.
//
constexpr float tensorL1UsageCap = 0.8;
uint64_t currentOpL1OutputUsage =
getOpOutputLayoutUsage(currentOp, legalLayouts, deviceAttr);
uint64_t nextOpL1OutputUsage =
getOpOutputLayoutUsage(nextOp, legalLayouts, deviceAttr);
bool l1UsageValid = (currentOpL1OutputUsage + nextOpL1OutputUsage) <
tensorL1UsageCap * usableL1CacheSize;

if (l1UsageValid) {
selectedOpLayout[currentOp] =
legalLayouts.lookup(currentOp).front();

// Add currentOp to l1 chain config.
//
OpL1MemSpec shardSpec;
shardSpec.op = currentOp;

// Hardcoded tensor split factor for now, until pipeline OP
// support is added.
//
shardSpec.tensorSplitFactor = 1;
l1ChainConfigs->back().addOpL1MemSpec(std::move(shardSpec));

// Update currentOp pointer.
//
currentOp = nextOp;
continue;
}
}
}

currentOp = nullptr;
if (!l1ChainConfigs->back().isEmpty()) {
l1ChainConfigs->back().build();
l1ChainConfigs->push_back(L1ChainConfig());
}
}
}

Expand All @@ -59,14 +143,12 @@ void L1InterleavedPolicy::run() {
// Resolve l1 chain configs.
//
for (auto &l1ChainConfig : *l1ChainConfigs) {
l1ChainConfig.build();
l1ChainConfig.resolve();

std::unordered_set<Edge> memReconfigEdges;
l1ChainConfig.complete(selectedOpLayout, memReconfigEdges);
}
});
llvm::errs() << "usableL1CacheSize: " << usableL1CacheSize << "\n";
}

} // namespace mlir::tt::ttnn
3 changes: 2 additions & 1 deletion test/ttmlir/Silicon/TTNN/all_l1_interleaved_policy.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ module attributes {} {
// CHECK: #[[L1_:.*]] = #tt.memory_space<l1>
// CHECK: #[[LAYOUT_6:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<8x12xbf16, #l1_>, interleaved>
// CHECK: #[[LAYOUT_7:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<8x4xbf16, #l1_>, interleaved>
// CHECK: #[[LAYOUT_8:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<8x4xbf16, #dram>, interleaved>
%0 = tensor.empty() : tensor<64x96xbf16>
// CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<64x96xbf16, #[[LAYOUT_6]]>
%1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<128x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16>
Expand All @@ -23,7 +24,7 @@ module attributes {} {
// CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_7]]>
%9 = "ttir.add"(%7, %arg4, %8) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x32xbf16>, tensor<64x32xbf16>, tensor<64x32xbf16>) -> tensor<64x32xbf16>
%10 = tensor.empty() : tensor<64x32xbf16>
// CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_7]]>
// CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_8]]>
%11 = "ttir.relu"(%9, %10) <{operandSegmentSizes = array<i32: 1, 1>, operand_constraints = [#any_device, #any_device]}> : (tensor<64x32xbf16>, tensor<64x32xbf16>) -> tensor<64x32xbf16>
return %11 : tensor<64x32xbf16>
}
Expand Down
3 changes: 2 additions & 1 deletion test/ttmlir/Silicon/TTNN/mnist_l1_interleaved.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ module @"tt-forge-graph" attributes {} {
func.func @main(%arg0: tensor<1x784xf32> loc("MNISTLinear":4294967295:0), %arg1: tensor<1x10xf32> loc("MNISTLinear":4294967295:0), %arg2: tensor<256x10xf32> loc("MNISTLinear":4294967295:0), %arg3: tensor<1x256xf32> loc("MNISTLinear":4294967295:0), %arg4: tensor<784x256xf32> loc("MNISTLinear":4294967295:0)) -> tensor<1x10xf32> {
// CHECK: #[[LAYOUT_6:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<1x32xf32, #l1_>, interleaved>
// CHECK: #[[LAYOUT_7:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<1x2xf32, #l1_>, interleaved>
// CHECK: #[[LAYOUT_8:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<1x2xf32, #dram>, interleaved>
%0 = tensor.empty() : tensor<1x256xf32> loc(#loc8)
// CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] -> tensor<1x256xf32, #[[LAYOUT_6]]>
%1 = "ttir.matmul"(%arg0, %arg4, %0) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x784xf32>, tensor<784x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc8)
Expand All @@ -23,7 +24,7 @@ module @"tt-forge-graph" attributes {} {
// CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x10xf32, #[[LAYOUT_7]]>
%9 = "ttir.add"(%7, %arg1, %8) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc12)
%10 = tensor.empty() : tensor<1x10xf32> loc(#loc13)
// CHECK: %{{.*}} = "ttnn.softmax"{{.*}} -> tensor<1x10xf32, #[[LAYOUT_7]]>
// CHECK: %{{.*}} = "ttnn.softmax"{{.*}} -> tensor<1x10xf32, #[[LAYOUT_8]]>
%11 = "ttir.softmax"(%9, %10) <{dimension = 1 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc13)
return %11 : tensor<1x10xf32> loc(#loc7)
} loc(#loc)
Expand Down

0 comments on commit 0e45561

Please sign in to comment.