Skip to content

Commit

Permalink
V1 implementation for L1Interleaved policy & full pipeline silicon tests
Browse files Browse the repository at this point in the history
  • Loading branch information
fbajraktariTT committed Nov 1, 2024
1 parent 34a2f4e commit c1276f9
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 98 deletions.
120 changes: 101 additions & 19 deletions lib/Dialect/TTNN/Analysis/L1InterleavedPolicy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,48 +4,132 @@

#include "ttmlir/Dialect/TTNN/Analysis/L1InterleavedPolicy.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"
#include "ttmlir/Scheduler/Scheduler.h"

namespace mlir::tt::ttnn {

uint64_t getOpOutputLayoutUsage(
Operation *op,
llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>> &legalLayouts,
DeviceAttr &deviceAttr) {
tt::LayoutAttr opLayout = legalLayouts.lookup(op).front();
assert(opLayout.hasInterleavedL1TensorMemoryLayout());

llvm::ArrayRef<int64_t> opOutputTensorShape =
mlir::cast<RankedTensorType>(op->getResult(0).getType()).getShape();

uint64_t opL1OutputUsage = deviceAttr.getLayoutSizeBytes(
opOutputTensorShape, opLayout, opLayout.getMemorySpace());
return opL1OutputUsage;
}

void L1InterleavedPolicy::run(
const std::unordered_set<Edge> &overrideReshardEdges) {
rootOp->walk([&](func::FuncOp func) {
DeviceAttr deviceAttr = getCurrentScopeDevice(func);
mlir::tt::scheduler::Scheduler scheduler(&func);
llvm::SmallVector<mlir::Operation *> scheduleableOps;
Operation *currentOp = nullptr;
llvm::DenseMap<Operation *, tt::LayoutAttr> selectedOpLayout;
Operation *currentOp = nullptr;

// TODO(fbajraktari):
// This is V0 implementation of L1 interleaved policy. In the current
// implementation we have a single L1ChainCofig per FuncOp. This implies
// that in case of DRAM spil we will have a disconnected chain of L1 ops.
// This will be fixed in V1.
//
l1ChainConfigs->push_back(L1ChainConfig());
while (scheduler.hasUnscheduledOps()) {
scheduleableOps = scheduler.getScheduleableOps();
currentOp = scheduleableOps[0];

// Before starting a l1 chain, schedule layout/memory management ops
// first until they are exhausted from schedulable ops.
//
if (l1ChainConfigs->back().isEmpty()) {
for (auto *op : scheduleableOps) {
if (isa<ToLayoutOp>(op)) {
currentOp = op;
break;
}
}
}

if (currentOp == nullptr) {
currentOp = scheduleableOps[0];
}

// Schedule currentOp.
//
scheduler.scheduleOp(currentOp);

// Check if currentOp is valid l1 interleaved op.
// Skip starting sharding chain if currentOp is a memory management op.
//
if (legalLayouts.lookup(currentOp).size() > 0) {
selectedOpLayout[currentOp] = legalLayouts.lookup(currentOp).front();
if (l1ChainConfigs->back().isEmpty() && isa<ToLayoutOp>(currentOp)) {
currentOp = nullptr;
continue;
}

// Add currentOp to l1 chain config.
//
OpL1MemSpec shardSpec;
shardSpec.op = currentOp;
if (scheduler.hasUnscheduledOps()) {
scheduleableOps = scheduler.getScheduleableOps();

// Hardcoded tensor split factor for now, until pipeline OP
// support is added.
// Check if currentOp has a valid successor.
//
shardSpec.tensorSplitFactor = 1;
l1ChainConfigs->back().addOpL1MemSpec(std::move(shardSpec));
Operation *nextOp = nullptr;
for (auto *op : scheduleableOps) {
for (auto operand : op->getOperands()) {
if (operand.getDefiningOp() == currentOp) {
nextOp = op;
break;
}
}
}

if (nextOp) {

// V1: Check that currentOp is not fork/join op.
//
bool validForL1Interleaved =
currentOp->hasOneUse() &&
legalLayouts.lookup(currentOp).size() > 0 &&
legalLayouts.lookup(nextOp).size() > 0;

if (validForL1Interleaved) {
// Figure out this const based on exec data, but will be replaced
// with API.
//
constexpr float tensorL1UsageCap = 0.8;
uint64_t currentOpL1OutputUsage =
getOpOutputLayoutUsage(currentOp, legalLayouts, deviceAttr);
uint64_t nextOpL1OutputUsage =
getOpOutputLayoutUsage(nextOp, legalLayouts, deviceAttr);
bool l1UsageValid = (currentOpL1OutputUsage + nextOpL1OutputUsage) <
tensorL1UsageCap * usableL1CacheSize;

if (l1UsageValid) {
selectedOpLayout[currentOp] =
legalLayouts.lookup(currentOp).front();

// Add currentOp to l1 chain config.
//
OpL1MemSpec shardSpec;
shardSpec.op = currentOp;

// Hardcoded tensor split factor for now, until pipeline OP
// support is added.
//
shardSpec.tensorSplitFactor = 1;
l1ChainConfigs->back().addOpL1MemSpec(std::move(shardSpec));

// Update currentOp pointer.
//
currentOp = nextOp;
continue;
}
}
}

currentOp = nullptr;
if (!l1ChainConfigs->back().isEmpty()) {
l1ChainConfigs->back().build();
l1ChainConfigs->push_back(L1ChainConfig());
}
}
}

Expand All @@ -60,14 +144,12 @@ void L1InterleavedPolicy::run(
// Resolve l1 chain configs.
//
for (auto &l1ChainConfig : *l1ChainConfigs) {
l1ChainConfig.build();
l1ChainConfig.resolve();

std::unordered_set<Edge> memReconfigEdges;
l1ChainConfig.complete(selectedOpLayout, memReconfigEdges);
}
});
llvm::errs() << "usableL1CacheSize: " << usableL1CacheSize << "\n";
}

} // namespace mlir::tt::ttnn
55 changes: 19 additions & 36 deletions test/ttmlir/Silicon/TTNN/all_l1_interleaved_policy.mlir
Original file line number Diff line number Diff line change
@@ -1,48 +1,31 @@
// RUN: ttmlir-opt --ttir-load-system-desc="path=%system_desc_path%" --ttnn-optimizer="memory-layout-analysis-enabled=true memory-layout-analysis-policy=L1Interleaved" --ttnn-decompose-layouts --ttnn-deallocate %s > %t.mlir
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path% enable-optimizer=true memory-layout-analysis-enabled=true memory-layout-analysis-policy=L1Interleaved" %s > %t.mlir
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
#device = #tt.device<workerGrid = #tt.grid<8x8, (d0, d1) -> (0, d0, d1)>, l1Map = (d0, d1)[s0, s1] -> (0, d0 floordiv s0, d1 floordiv s1, (d0 mod s0) * s1 + d1 mod s1), dramMap = (d0, d1)[s0, s1] -> (0, 0, ((((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 8192) mod 12, (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 98304 + (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) mod 8192), meshShape = , chipIds = [0]>
#dram = #tt.memory_space<dram>
#system = #tt.memory_space<system>
#system_desc = #tt.system_desc<[{arch = <wormhole_b0>, grid = 8x8, l1_size = 1499136, num_dram_channels = 12, dram_channel_size = 1073741824, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32, l1_unreserved_base = 1024, erisc_l1_unreserved_base = 1024, dram_unreserved_base = 1024, dram_unreserved_end = 1073741824, physical_cores = {worker = [ 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 1x0, 1x1, 1x2, 1x3, 1x4, 1x5, 1x6, 1x7, 2x0, 2x1, 2x2, 2x3, 2x4, 2x5, 2x6, 2x7, 3x0, 3x1, 3x2, 3x3, 3x4, 3x5, 3x6, 3x7, 4x0, 4x1, 4x2, 4x3, 4x4, 4x5, 4x6, 4x7, 5x0, 5x1, 5x2, 5x3, 5x4, 5x5, 5x6, 5x7, 6x0, 6x1, 6x2, 6x3, 6x4, 6x5, 6x6, 6x7, 7x0, 7x1, 7x2, 7x3, 7x4, 7x5, 7x6, 7x7] dram = [ 8x0, 9x0, 10x0, 8x1, 9x1, 10x1, 8x2, 9x2, 10x2, 8x3, 9x3, 10x3]}, supported_data_types = [<f32>, <f16>, <bf16>, <bfp_f8>, <bfp_bf8>, <bfp_f4>, <bfp_bf4>, <bfp_f2>, <bfp_bf2>, <u32>, <u16>, <u8>], supported_tile_sizes = [ 4x16, 16x16, 32x16, 4x32, 16x32, 32x32], num_cbs = 32}], [0], [3 : i32], [ 0x0x0x0]>
#layout = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xbf16, #system>>
#layout1 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<128x96xbf16, #system>>
#layout2 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x96xbf16, #system>>
#layout3 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<96x32xbf16, #system>>
#layout4 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x32xbf16, #system>>
#layout5 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<1x1x!tt.tile<32x32, bf16>, #dram>, interleaved>
#layout6 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x96xbf16, #dram>, interleaved>
#layout7 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x32xbf16, #dram>, interleaved>
module attributes {tt.device = #device, tt.system_desc = #system_desc} {
func.func @forward(%arg0: tensor<64x128xbf16, #layout>, %arg1: tensor<128x96xbf16, #layout1>, %arg2: tensor<64x96xbf16, #layout2>, %arg3: tensor<96x32xbf16, #layout3>, %arg4: tensor<64x32xbf16, #layout4>) -> tensor<64x32xbf16, #layout4> {
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x96xbf16>, %arg2: tensor<64x96xbf16>, %arg3: tensor<96x32xbf16>, %arg4: tensor<64x32xbf16>) -> tensor<64x32xbf16> {
// CHECK: #[[L1_:.*]] = #tt.memory_space<l1>
// CHECK: #[[LAYOUT_6:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<8x12xbf16, #l1_>, interleaved>
// CHECK: #[[LAYOUT_7:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<8x4xbf16, #l1_>, interleaved>
%0 = "ttnn.get_device"() <{mesh_shape = #ttnn<mesh_shape 1x1>}> : () -> !tt.device<#device>
%1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<<interleaved>, <dram>, <<1x1>>>}> : (tensor<64x128xbf16, #layout>, !tt.device<#device>) -> tensor<64x128xbf16, #layout5>
%2 = "ttnn.to_layout"(%arg1, %0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<<interleaved>, <dram>, <<1x1>>>}> : (tensor<128x96xbf16, #layout1>, !tt.device<#device>) -> tensor<128x96xbf16, #layout5>
%3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<<interleaved>, <dram>, <<64x96>>>, shape = #ttnn.shape<64x96>}> : (!tt.device<#device>) -> tensor<64x96xbf16, #layout6>
// CHECK: #[[LAYOUT_8:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<8x4xbf16, #dram>, interleaved>
%0 = tensor.empty() : tensor<64x96xbf16>
// CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<64x96xbf16, #[[LAYOUT_6]]>
%4 = "ttnn.matmul"(%1, %2, %3) : (tensor<64x128xbf16, #layout5>, tensor<128x96xbf16, #layout5>, tensor<64x96xbf16, #layout6>) -> tensor<64x96xbf16, #layout6>
%5 = "ttnn.to_layout"(%arg2, %0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<<interleaved>, <dram>, <<1x1>>>}> : (tensor<64x96xbf16, #layout2>, !tt.device<#device>) -> tensor<64x96xbf16, #layout5>
%6 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<<interleaved>, <dram>, <<64x96>>>, shape = #ttnn.shape<64x96>}> : (!tt.device<#device>) -> tensor<64x96xbf16, #layout6>
%1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<128x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16>
%2 = tensor.empty() : tensor<64x96xbf16>
// CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<64x96xbf16, #[[LAYOUT_6]]>
%7 = "ttnn.add"(%4, %5, %6) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<64x96xbf16, #layout6>, tensor<64x96xbf16, #layout5>, tensor<64x96xbf16, #layout6>) -> tensor<64x96xbf16, #layout6>
%8 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<<interleaved>, <dram>, <<64x96>>>, shape = #ttnn.shape<64x96>}> : (!tt.device<#device>) -> tensor<64x96xbf16, #layout6>
%3 = "ttir.add"(%1, %arg2, %2) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x96xbf16>, tensor<64x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16>
%4 = tensor.empty() : tensor<64x96xbf16>
// CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<64x96xbf16, #[[LAYOUT_6]]>
%9 = "ttnn.relu"(%7, %8) <{operandSegmentSizes = array<i32: 1, 1>}> : (tensor<64x96xbf16, #layout6>, tensor<64x96xbf16, #layout6>) -> tensor<64x96xbf16, #layout6>
%10 = "ttnn.to_layout"(%arg3, %0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<<interleaved>, <dram>, <<1x1>>>}> : (tensor<96x32xbf16, #layout3>, !tt.device<#device>) -> tensor<96x32xbf16, #layout5>
%11 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<<interleaved>, <dram>, <<64x32>>>, shape = #ttnn.shape<64x32>}> : (!tt.device<#device>) -> tensor<64x32xbf16, #layout7>
%5 = "ttir.relu"(%3, %4) <{operandSegmentSizes = array<i32: 1, 1>, operand_constraints = [#any_device, #any_device]}> : (tensor<64x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16>
%6 = tensor.empty() : tensor<64x32xbf16>
// CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_7]]>
%12 = "ttnn.matmul"(%9, %10, %11) : (tensor<64x96xbf16, #layout6>, tensor<96x32xbf16, #layout5>, tensor<64x32xbf16, #layout7>) -> tensor<64x32xbf16, #layout7>
%13 = "ttnn.to_layout"(%arg4, %0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<<interleaved>, <dram>, <<1x1>>>}> : (tensor<64x32xbf16, #layout4>, !tt.device<#device>) -> tensor<64x32xbf16, #layout5>
%14 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<<interleaved>, <dram>, <<64x32>>>, shape = #ttnn.shape<64x32>}> : (!tt.device<#device>) -> tensor<64x32xbf16, #layout7>
%7 = "ttir.matmul"(%5, %arg3, %6) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x96xbf16>, tensor<96x32xbf16>, tensor<64x32xbf16>) -> tensor<64x32xbf16>
%8 = tensor.empty() : tensor<64x32xbf16>
// CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_7]]>
%15 = "ttnn.add"(%12, %13, %14) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<64x32xbf16, #layout7>, tensor<64x32xbf16, #layout5>, tensor<64x32xbf16, #layout7>) -> tensor<64x32xbf16, #layout7>
%16 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<<interleaved>, <dram>, <<64x32>>>, shape = #ttnn.shape<64x32>}> : (!tt.device<#device>) -> tensor<64x32xbf16, #layout7>
// CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_7]]>
%17 = "ttnn.relu"(%15, %16) <{operandSegmentSizes = array<i32: 1, 1>}> : (tensor<64x32xbf16, #layout7>, tensor<64x32xbf16, #layout7>) -> tensor<64x32xbf16, #layout7>
%18 = "ttnn.to_layout"(%17) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<<none>, <system_memory>, <<64x32>>>}> : (tensor<64x32xbf16, #layout7>) -> tensor<64x32xbf16, #layout4>
return %18 : tensor<64x32xbf16, #layout4>
%9 = "ttir.add"(%7, %arg4, %8) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x32xbf16>, tensor<64x32xbf16>, tensor<64x32xbf16>) -> tensor<64x32xbf16>
%10 = tensor.empty() : tensor<64x32xbf16>
// CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_8]]>
%11 = "ttir.relu"(%9, %10) <{operandSegmentSizes = array<i32: 1, 1>, operand_constraints = [#any_device, #any_device]}> : (tensor<64x32xbf16>, tensor<64x32xbf16>) -> tensor<64x32xbf16>
return %11 : tensor<64x32xbf16>
}
}
Loading

0 comments on commit c1276f9

Please sign in to comment.