Skip to content

Commit

Permalink
Support for memRef of L1 Interleaved tensors (#1292) (#1607)
Browse files Browse the repository at this point in the history
This PR fixes the issue of incorrect `memRef` for L1 Interleaved
layouts.

Closes issue: #1292
  • Loading branch information
fbajraktariTT authored Dec 19, 2024
1 parent 13cf48f commit 293d226
Show file tree
Hide file tree
Showing 20 changed files with 134 additions and 78 deletions.
1 change: 0 additions & 1 deletion include/ttmlir/Dialect/TTNN/Analysis/BFInterleavedPolicy.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysisPolicy.h"
#include <cstdint>

namespace mlir::tt::ttnn {

Expand Down
2 changes: 2 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ def TTNN_TTNNLayoutAttr: TTNN_Attr<"TTNNLayout", "ttnn_layout"> {
DataType getDataType() const;
uint64_t getElementSizeBytes() const;
int64_t getTensorSizeInBytes(ArrayRef<int64_t> tensorShape, ::mlir::tt::DeviceAttr device) const;
static llvm::SmallVector<int64_t> calculateLogicalShardShapeForSharding(ArrayRef<int64_t> tensorShape, mlir::AffineMap linear, GridAttr grid);
static llvm::SmallVector<int64_t> calculateLogicalShardShapeForL1Interleaved(ArrayRef<int64_t> tensorShape, Type elementType, mlir::AffineMap linear, GridAttr grid);
llvm::SmallVector<int64_t> getStride(ArrayRef<int64_t> logicalShape) const;
llvm::SmallVector<int64_t> getShardShape() const;
llvm::SmallVector<int64_t> getScalarShardShape() const;
Expand Down
3 changes: 1 addition & 2 deletions include/ttmlir/Dialect/TTNN/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ createRankedTensorTypeWithEncoding(RankedTensorType tensorType,
// Return the L1 memory usage of the output tensor of the given op.
// Used within L1 interleaved policies.
//
uint64_t getOpOutputL1Usage(Operation *op, TTNNLayoutAttr opLayout,
DeviceAttr &deviceAttr);
uint64_t getOpOutputL1Usage(TTNNLayoutAttr opLayout);

} // namespace mlir::tt::ttnn::utils

Expand Down
7 changes: 2 additions & 5 deletions lib/Dialect/TTNN/Analysis/BFInterleavedPolicy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ void BFInterleavedPolicy::run() {
for (Operation &funcOp : rootOp->getRegion(0).getOps()) {
func::FuncOp func = dyn_cast<func::FuncOp>(funcOp);
mlir::tt::scheduler::Scheduler scheduler(&func);
deviceAttr = getCurrentScopeDevice(func);

// Initialize the policy.
//
Expand Down Expand Up @@ -53,8 +52,7 @@ void BFInterleavedPolicy::run() {
//
if (hasL1BufferType(op)) {
TTNNLayoutAttr layout = getL1InterleavedLayout(op);
uint64_t opOutputL1Usage =
utils::getOpOutputL1Usage(op, layout, deviceAttr);
uint64_t opOutputL1Usage = utils::getOpOutputL1Usage(layout);

if (currentL1Usage + opOutputL1Usage <= getAvailableL1CacheSize()) {
allocOfL1Mem = opOutputL1Usage;
Expand Down Expand Up @@ -92,8 +90,7 @@ void BFInterleavedPolicy::run() {
uint64_t numOfUsers = std::distance(nextOpForScheduling->user_begin(),
nextOpForScheduling->user_end());
currentL1UsagePerOp[nextOpForScheduling].l1MemUsagePerUser =
utils::getOpOutputL1Usage(nextOpForScheduling, opL1MemSpec.layout,
deviceAttr);
utils::getOpOutputL1Usage(opL1MemSpec.layout);
currentL1UsagePerOp[nextOpForScheduling].numOfUnscheduledUsers =
numOfUsers;
currentL1Usage +=
Expand Down
16 changes: 7 additions & 9 deletions lib/Dialect/TTNN/Analysis/GreedyL1InterleavedPolicy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ GreedyL1InterleavedPolicy::OpConfig GreedyL1InterleavedPolicy::getGreedyConfig(
void GreedyL1InterleavedPolicy::run() {
for (Operation &funcOp : rootOp->getRegion(0).getOps()) {
func::FuncOp func = dyn_cast<func::FuncOp>(funcOp);
deviceAttr = getCurrentScopeDevice(func);

// Start the policy.
//
Expand Down Expand Up @@ -166,8 +165,8 @@ void GreedyL1InterleavedPolicy::run() {

if (op->hasOneUse() && hasL1BufferType(op)) {
L1Usage l1Usage;
l1Usage.outputL1Usage = utils::getOpOutputL1Usage(
op, getL1InterleavedLayout(op), deviceAttr);
l1Usage.outputL1Usage =
utils::getOpOutputL1Usage(getL1InterleavedLayout(op));
l1Usage.requiredL1Usage = 0;
opsL1Usage[op] = l1Usage;
}
Expand All @@ -192,8 +191,8 @@ void GreedyL1InterleavedPolicy::run() {
//
if (operandOpLayout.hasInterleavedL1TensorMemoryLayout()) {
L1Usage l1Usage;
l1Usage.outputL1Usage = utils::getOpOutputL1Usage(
operandOp, operandOpLayout, deviceAttr);
l1Usage.outputL1Usage =
utils::getOpOutputL1Usage(operandOpLayout);
l1Usage.requiredL1Usage = OpMemSpecMap[operandOp].requiredL1Usage;
opsL1Usage[operandOp] = l1Usage;
}
Expand Down Expand Up @@ -252,15 +251,14 @@ void GreedyL1InterleavedPolicy::run() {
std::max(intermediateRequiredL1Usage,
intermediateL1Usage +
OpMemSpecMap[operandOp].requiredL1Usage);
intermediateL1Usage += utils::getOpOutputL1Usage(
operandOp, OpMemSpecMap[operandOp].layout, deviceAttr);
intermediateL1Usage +=
utils::getOpOutputL1Usage(OpMemSpecMap[operandOp].layout);
}
}
OpMemSpecMap[op].requiredL1Usage =
std::max(intermediateRequiredL1Usage,
intermediateL1Usage +
utils::getOpOutputL1Usage(
op, OpMemSpecMap[op].layout, deviceAttr));
utils::getOpOutputL1Usage(OpMemSpecMap[op].layout));
}
}
}
Expand Down
17 changes: 11 additions & 6 deletions lib/Dialect/TTNN/Analysis/LegalLayoutAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,16 +228,21 @@ void LegalLayoutAnalysis::analysisImplementation() {
TensorMemoryLayoutAttr::get(op->getContext(),
TensorMemoryLayout::Interleaved)));

// L1 Interleaved (same as above).
analysisResult.push_back(TTNNLayoutAttr::get(
op->getContext(), tensorShape, elementType, BufferType::L1,
analysisInput.maxGrid,
TensorMemoryLayoutAttr::get(op->getContext(),
TensorMemoryLayout::Interleaved)));
// L1 Interleaved - It must be tiled.
// TODO(odjuricic): Check that this is always the case.
if (elementType == tileElementType) {
analysisResult.push_back(TTNNLayoutAttr::get(
op->getContext(), tensorShape, elementType, BufferType::L1,
analysisInput.maxGrid,
TensorMemoryLayoutAttr::get(op->getContext(),
TensorMemoryLayout::Interleaved)));
}

// L1 Sharded
TTNNLayoutAttr shardedBase =
layout.withBufferType(op->getContext(), BufferType::L1)
.withMemoryLayout(op->getContext(),
TensorMemoryLayout::BlockSharded)
.withElementType(op->getContext(), elementType);

assert(analysisInput.maxGrid.getShape().size() == 2 &&
Expand Down
97 changes: 82 additions & 15 deletions lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@

#include <numeric>

#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
#include "ttmlir/Dialect/TTNN/Utils/Utils.h"
#include "ttmlir/Utils.h"
#include "llvm/ADT/TypeSwitch.h"

using namespace mlir::tt::ttnn;

Expand Down Expand Up @@ -68,6 +65,67 @@ bool TTNNLayoutAttr::hasInterleavedDRAMTensorMemoryLayout() const {
(getMemLayout().getValue() == TensorMemoryLayout::Interleaved);
}

// Calculate the logical shape of the shard.
//
// Shard is defined as a piece of the tensor that is mapped to a single grid
// core. This function returns the shard shape for tensors with BLOCK SHARDED
// tensor memory layout.
//
// All examples assume that the tensor is mapped to a 8x8 grid.
// Example: tensor<32x32xbf16> -> {4, 4}
// Example: tensor<65x65xbf16> -> {9, 9}
//
// return The logical shard shape in case of block sharded tensor memory layout.
llvm::SmallVector<int64_t>
TTNNLayoutAttr::calculateLogicalShardShapeForSharding(
ArrayRef<int64_t> tensorShape, mlir::AffineMap linear, GridAttr grid) {
assert(linear.getNumResults() == grid.getShape().size());
mlir::SmallVector<std::int64_t> physicalShape =
ttmlir::utils::evalShape(linear, tensorShape);
mlir::SmallVector<std::int64_t> shardShape(linear.getNumResults());
for (size_t i = 0; i < linear.getNumResults(); ++i) {
shardShape[i] =
(physicalShape[i] + grid.getShape()[i] - 1) / grid.getShape()[i];
}
return shardShape;
}

// Calculate the logical shape of the shard.
//
// Shard is defined as a piece of the tensor that is mapped to a single grid
// core. This function returns the shard shape for tensors with INTERLEAVED
// tensor memory layout.
//
// All examples assume that the tensor is mapped to a 8x8 grid.
// Example: tensor<1x1024xbf16> ( -> 32 tiles ) -> {1, 1}
// Example: tensor<512x512xbf16> ( -> 256 tiles ) -> {1, 4}
// Example: tensor<32x2049xbf16> ( -> 65 tiles ) -> {1, 2}
//
// return The logical shard shape in case of interleaved tensor memory layout.
llvm::SmallVector<int64_t>
TTNNLayoutAttr::calculateLogicalShardShapeForL1Interleaved(
ArrayRef<int64_t> tensorShape, mlir::Type elementType,
mlir::AffineMap linear, mlir::tt::GridAttr grid) {
assert(linear.getNumResults() == grid.getShape().size());
assert(mlir::isa<mlir::tt::TileType>(elementType));

mlir::SmallVector<std::int64_t> physicalShape =
ttmlir::utils::evalShape(linear, tensorShape);
mlir::SmallVector<std::int64_t> physicalTiledShape =
mlir::cast<mlir::tt::TileType>(elementType).getTiledShape(physicalShape);
uint64_t numOfTiles =
std::accumulate(physicalTiledShape.begin(), physicalTiledShape.end(), 1,
std::multiplies<std::int64_t>());
uint64_t numOfGridUnits =
std::accumulate(grid.getShape().begin(), grid.getShape().end(), 1,
std::multiplies<std::int64_t>());

mlir::SmallVector<std::int64_t> shardShape;
shardShape.resize(grid.getShape().size() - 1, 1);
shardShape.push_back((numOfTiles + numOfGridUnits - 1) / numOfGridUnits);
return mlir::cast<mlir::tt::TileType>(elementType).getScalarShape(shardShape);
}

// Get stride given tensor logical shape
llvm::SmallVector<int64_t>
TTNNLayoutAttr::getStride(ArrayRef<int64_t> logicalShape) const {
Expand Down Expand Up @@ -157,12 +215,12 @@ mlir::tt::DataType TTNNLayoutAttr::getDataType() const {
return elementTypeToDataType(elementType);
}

// Gets the size of shard in bytes
// Get the size of the element in bytes
//
// This function returns the size of the shard in bytes.
// Size is calculated by multiplying shard shape with element size.
// This function returns the size of a single tensor element in bytes.
// Distinction is made between scalar types and TileType.
//
// return The size of the shard in bytes.
// return The size of the element in bytes.
uint64_t TTNNLayoutAttr::getElementSizeBytes() const {
mlir::Type elementType = getElementType();
if (isTiled()) {
Expand All @@ -177,7 +235,7 @@ uint64_t TTNNLayoutAttr::getElementSizeBytes() const {
// Return the shape of the shard.
// Example: memref<2x2x!tt.tile<32x32xf32>> -> { 2, 2 }
// Example: memref<128x128xf32> -> { 128, 128 }
// Example: memref<2x3!tt.tile<32x32xf32>> -> { 2, 3 }
// Example: memref<2x3x!tt.tile<32x32xf32>> -> { 2, 3 }
//
// return The shape of the shard.
llvm::SmallVector<int64_t> TTNNLayoutAttr::getShardShape() const {
Expand Down Expand Up @@ -283,13 +341,13 @@ mlir::AffineMap TTNNLayoutAttr::replaceMemoryMapSymbolsWithShardShape(
"shard rank");

SmallVector<AffineExpr> symReplacements;
for (unsigned i = 0; i < physicalMemoryMap.getNumSymbols(); ++i) {
for (size_t i = 0; i < physicalMemoryMap.getNumSymbols(); ++i) {
symReplacements.push_back(
getAffineConstantExpr(shardShape[i], getContext()));
}

SmallVector<AffineExpr> dimReplacements;
for (unsigned i = 0; i < physicalMemoryMap.getNumDims(); ++i) {
for (size_t i = 0; i < physicalMemoryMap.getNumDims(); ++i) {
dimReplacements.push_back(getAffineDimExpr(i, getContext()));
}

Expand Down Expand Up @@ -453,14 +511,23 @@ TTNNLayoutAttr TTNNLayoutAttr::get(
Type elementType, BufferType bufferType, GridAttr grid,
TensorMemoryLayoutAttr memLayoutAttr,
ArrayRef<std::pair<std::int64_t, std::int64_t>> collapseIntervals) {

// Construct a new affine map which will be used to map from logical
// space to physical space
// space to physical space.
AffineMap linear = collapsedLinearAffineMap(
context, tensorShape, grid.getShape(), collapseIntervals);
// Calculate shard shape by evaluating the linear map with last element
// of the tensor shape and dividing it by the grid shape
mlir::SmallVector<int64_t, 4> shardShape =
calculateLogicalShardShape(tensorShape, linear, grid);

// Calculate shard shape
mlir::SmallVector<int64_t> shardShape;
if (bufferType == BufferType::L1 &&
memLayoutAttr.getValue() == TensorMemoryLayout::Interleaved) {
shardShape = TTNNLayoutAttr::calculateLogicalShardShapeForL1Interleaved(
tensorShape, elementType, linear, grid);
} else {
shardShape = TTNNLayoutAttr::calculateLogicalShardShapeForSharding(
tensorShape, linear, grid);
}

// Build memref type with the given parameters
MemRefType memRefType = buildMemRef<BufferType, BufferTypeAttr>(
context, shardShape, elementType, bufferType);
Expand Down
14 changes: 2 additions & 12 deletions lib/Dialect/TTNN/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,24 +117,14 @@ createRankedTensorTypeWithEncoding(RankedTensorType tensorType,
tensorType.getElementType(), encoding);
}

uint64_t getOpOutputL1Usage(Operation *op, TTNNLayoutAttr opLayout,
DeviceAttr &deviceAttr) {
assert(mlir::isa<RankedTensorType>(op->getResult(0).getType()) &&
"L1 memory usage of the ops without output tensors cannot be "
"calculated.");

uint64_t getOpOutputL1Usage(TTNNLayoutAttr opLayout) {
// In case the opLayout is not in L1 memory space, L1 memory usage is 0.
//
if (opLayout.hasDRAMBufferType()) {
return 0;
}

llvm::ArrayRef<int64_t> opOutputTensorShape =
mlir::cast<RankedTensorType>(op->getResult(0).getType()).getShape();

uint64_t opL1OutputUsage =
opLayout.getTensorSizeInBytes(opOutputTensorShape, deviceAttr);
return opL1OutputUsage;
return opLayout.getShardSizeInBytes();
}

} // namespace mlir::tt::ttnn::utils
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module attributes {} {
// CHECK: #[[L1_:.*]] = #ttnn.buffer_type<l1>
// CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<32x20x!tt.tile<32x32, bf16>, #dram>, <interleaved>>
// CHECK-DAG: #[[LAYOUT_6:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<20x32x!tt.tile<32x32, bf16>, #dram>, <interleaved>>
// CHECK-DAG: #[[LAYOUT_7:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<20x20x!tt.tile<32x32, bf16>, #l1_>, <interleaved>>
// CHECK-DAG: #[[LAYOUT_7:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x400x!tt.tile<32x32, bf16>, #l1_>, <interleaved>>
%0 = tensor.empty() : tensor<5120x8192xbf16>
// CHECK-DAG: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<5120x8192xbf16, #[[LAYOUT_6]]>
%1 = "ttir.relu"(%arg0, %0) <{operandSegmentSizes = array<i32: 1, 1>}> : (tensor<5120x8192xbf16>, tensor<5120x8192xbf16>) -> tensor<5120x8192xbf16>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@
module attributes {} {
func.func @forward(%arg0: tensor<6144x1024xbf16>, %arg1: tensor<1024x6144xbf16>) -> tensor<6144x6144xbf16> {
// CHECK: #[[L1_:.*]] = #ttnn.buffer_type<l1>
// CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<24x4x!tt.tile<32x32, bf16>, #l1_>, <interleaved>>
// CHECK-DAG: #[[LAYOUT_6:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<4x24x!tt.tile<32x32, bf16>, #l1_>, <interleaved>>
// CHECK-DAG: #[[LAYOUT_7:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<24x24x!tt.tile<32x32, bf16>, #dram>, <interleaved>>
// CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x96x!tt.tile<32x32, bf16>, #l1_>, <interleaved>>
// CHECK-DAG: #[[LAYOUT_6:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<24x24x!tt.tile<32x32, bf16>, #dram>, <interleaved>>
%0 = tensor.empty() : tensor<6144x1024xbf16>
// CHECK-DAG: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<6144x1024xbf16, #[[LAYOUT_5]]>
%1 = "ttir.relu"(%arg0, %0) <{operandSegmentSizes = array<i32: 1, 1>}> : (tensor<6144x1024xbf16>, tensor<6144x1024xbf16>) -> tensor<6144x1024xbf16>
%2 = tensor.empty() : tensor<1024x6144xbf16>
// CHECK-DAG: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<1024x6144xbf16, #[[LAYOUT_6]]>
// CHECK-DAG: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<1024x6144xbf16, #[[LAYOUT_5]]>
%3 = "ttir.relu"(%arg1, %2) <{operandSegmentSizes = array<i32: 1, 1>}> : (tensor<1024x6144xbf16>, tensor<1024x6144xbf16>) -> tensor<1024x6144xbf16>
%4 = tensor.empty() : tensor<6144x6144xbf16>
// CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<6144x6144xbf16, #[[LAYOUT_7]]>
// CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<6144x6144xbf16, #[[LAYOUT_6]]>
%5 = "ttir.matmul"(%1, %3, %4) : (tensor<6144x1024xbf16>, tensor<1024x6144xbf16>, tensor<6144x6144xbf16>) -> tensor<6144x6144xbf16>
return %5 : tensor<6144x6144xbf16>
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
module attributes {} {
func.func @forward(%arg0: tensor<4096x5120xbf16>, %arg1: tensor<5120x1024xbf16>, %arg2: tensor<5120x1024xbf16>) -> tensor<4096x1024xbf16> {
// CHECK: #[[L1_:.*]] = #ttnn.buffer_type<l1>
// CHECK: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<16x20x!tt.tile<32x32, bf16>, #l1_>, <interleaved>>
// CHECK: #[[LAYOUT_6:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<16x4x!tt.tile<32x32, bf16>, #l1_>, <interleaved>>
// CHECK: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x320x!tt.tile<32x32, bf16>, #l1_>, <interleaved>>
// CHECK: #[[LAYOUT_6:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x64x!tt.tile<32x32, bf16>, #l1_>, <interleaved>>
%0 = tensor.empty() : tensor<4096x5120xbf16>
// CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<4096x5120xbf16, #[[LAYOUT_5]]>
%1 = "ttir.relu"(%arg0, %0) <{operandSegmentSizes = array<i32: 1, 1>}> : (tensor<4096x5120xbf16>, tensor<4096x5120xbf16>) -> tensor<4096x5120xbf16>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
module attributes {} {
func.func @forward(%arg0: tensor<4096x5120xbf16>, %arg1: tensor<5120x9216xbf16>, %arg2: tensor<9216x1024xbf16>, %arg3: tensor<5120x1024xbf16>) -> tensor<4096x1024xbf16> {
// CHECK: #[[L1_:.*]] = #ttnn.buffer_type<l1>
// CHECK: #[[LAYOUT_9:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<16x20x!tt.tile<32x32, bf16>, #l1_>, <interleaved>>
// CHECK: #[[LAYOUT_9:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x320x!tt.tile<32x32, bf16>, #l1_>, <interleaved>>
// CHECK: #[[LAYOUT_10:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<16x36x!tt.tile<32x32, bf16>, #dram>, <interleaved>>
// CHECK: #[[LAYOUT_11:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<16x4x!tt.tile<32x32, bf16>, #l1_>, <interleaved>>
// CHECK: #[[LAYOUT_11:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x64x!tt.tile<32x32, bf16>, #l1_>, <interleaved>>
%0 = tensor.empty() : tensor<4096x5120xbf16>
// CHECK-DAG: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<4096x5120xbf16, #[[LAYOUT_9]]>
%1 = "ttir.relu"(%arg0, %0) <{operandSegmentSizes = array<i32: 1, 1>}> : (tensor<4096x5120xbf16>, tensor<4096x5120xbf16>) -> tensor<4096x5120xbf16>
Expand Down
Loading

0 comments on commit 293d226

Please sign in to comment.