Skip to content

Commit

Permalink
Fix tilizing in tt-metal backend
Browse files Browse the repository at this point in the history
Until now compiler would create single tilize block call with total
amount of tiles needed to be generated. LLK actually needs to work on a
single row of tiles at the time in order to know how much to stride.

This fix changes lowering to TTMetal dialect and creates correct number
of tilize/untilize calls. In between them it pops/pushes from/to given CBs.

With SCF dialect we'll be able to insert loops, but for now these calls
will be unrolled.
  • Loading branch information
rpavlovicTT committed Aug 22, 2024
1 parent ef3bfe8 commit 160334c
Showing 1 changed file with 31 additions and 17 deletions.
48 changes: 31 additions & 17 deletions lib/Dialect/TTMetal/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,34 +274,48 @@ class TTIRToTTMetalLayoutRewriter : public OpRewritePattern<ttir::ToLayoutOp> {
tensixBlock->addArgument(inputCBTy, op.getLoc());
tensixBlock->addArgument(outputCBTy, op.getLoc());

int shardTileVolume = 1;
for (auto dim : (shouldTilize ? outputLayout.getMemref().getShape()
: inputLayout.getMemref().getShape())) {
shardTileVolume *= dim;
}
llvm::ArrayRef<int64_t> shardTileShape =
(shouldTilize ? outputLayout.getMemref().getShape()
: inputLayout.getMemref().getShape());

assert(shardTileShape.size() >= 2 && "Tile shape rank must be at least 2");

auto numTiles = tensixBuilder.create<arith::ConstantOp>(
op.getLoc(), tensixBuilder.getI32Type(),
tensixBuilder.getI32IntegerAttr(shardTileVolume));
// How many tiles should kernel tilize in one block.
arith::ConstantOp numTilesPerBlock =
tensixBuilder.create<arith::ConstantOp>(
op.getLoc(), tensixBuilder.getI32Type(),
tensixBuilder.getI32IntegerAttr(shardTileShape.back()));

if (shouldTilize) {
tensixBuilder.create<ttkernel::TilizeInitOp>(
op.getLoc(), tensixBlock->getArgument(0), numTiles,
op.getLoc(), tensixBlock->getArgument(0), numTilesPerBlock,
tensixBlock->getArgument(1));
} else {
tensixBuilder.create<ttkernel::UntilizeInitOp>(
op.getLoc(), tensixBlock->getArgument(0),
tensixBlock->getArgument(1));
}

if (shouldTilize) {
tensixBuilder.create<ttkernel::TilizeBlockOp>(
op.getLoc(), tensixBlock->getArgument(0), numTiles,
tensixBlock->getArgument(1));
} else {
tensixBuilder.create<ttkernel::UntilizeBlockOp>(
op.getLoc(), tensixBlock->getArgument(0), numTiles,
tensixBlock->getArgument(1));
uint64_t shardTileVolume = 1;
for (int64_t dim : shardTileShape) {
shardTileVolume *= dim;
}
const uint64_t numBlocks = shardTileVolume / shardTileShape.back();

for (uint iblock = 0; iblock < numBlocks; ++iblock) {
if (shouldTilize) {
tensixBuilder.create<ttkernel::TilizeBlockOp>(
op.getLoc(), tensixBlock->getArgument(0), numTilesPerBlock,
tensixBlock->getArgument(1));
} else {
tensixBuilder.create<ttkernel::UntilizeBlockOp>(
op.getLoc(), tensixBlock->getArgument(0), numTilesPerBlock,
tensixBlock->getArgument(1));
}
tensixBuilder.create<ttkernel::CBPopFrontOp>(
op.getLoc(), tensixBlock->getArgument(0), numTilesPerBlock);
tensixBuilder.create<ttkernel::CBPushBackOp>(
op.getLoc(), tensixBlock->getArgument(1), numTilesPerBlock);
}

tensixBuilder.create<ttkernel::ReturnOp>(op.getLoc());
Expand Down

0 comments on commit 160334c

Please sign in to comment.