Skip to content

Commit

Permalink
Address comments and polish rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
pjanevskiTT committed Sep 25, 2024
1 parent 4dba9e5 commit 55afd59
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 166 deletions.
84 changes: 42 additions & 42 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -160,48 +160,6 @@ class TTIR_ElementwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
];
}

class TTIR_ElementwiseBinaryOp<string mnemonic, list<Trait> traits = []> :
TTIR_ElementwiseOp<mnemonic, traits> {
let summary = "Eltwise binary op.";
let description = [{
Eltwise binary op.
}];

let builders =
[
OpBuilder<(ins "Value": $lhs, "Value": $rhs, "Value": $out, "ArrayAttr": $operand_constraints),
[{
build($_builder, $_state, {out.getType()}, {lhs, rhs}, out, operand_constraints);
}]>
];
}

def TTIR_SubtractOp : TTIR_ElementwiseBinaryOp<"subtract"> {
let summary = "Eltwise subtract.";
let description = [{
Eltwise subtract operation.
}];
}

def TTIR_GreaterEqualOp : TTIR_ElementwiseBinaryOp<"ge"> {
let summary = "Eltwise greater than or equal to.";
let description = [{
Eltwise greater than or equal to operation.
}];
}

def TTIR_MaximumOp : TTIR_ElementwiseBinaryOp<"maximum"> {
let summary = "Eltwise maximum OP.";
let description = [{
Calculates maximum of input tensors' values element-wise and stores result in output tensor.

Example:
%lhs: [[3, 2, 7], [1, 4, 4]]
%rhs: [[1, 4, 2], [1, 2, 3]]
"ttir.maximum"(%lhs, %rhs, %out) -> %out: [[3, 4, 7], [1, 4, 4]]
}];
}

def TTIR_AbsOp: TTIR_ElementwiseUnaryOp<"abs"> {
let summary = "Eltwise absolute op.";
let description = [{
Expand Down Expand Up @@ -251,6 +209,48 @@ def TTIR_SigmoidOp: TTIR_ElementwiseUnaryOp<"sigmoid"> {
}];
}

class TTIR_ElementwiseBinaryOp<string mnemonic, list<Trait> traits = []> :
TTIR_ElementwiseOp<mnemonic, traits> {
let summary = "Eltwise binary op.";
let description = [{
Eltwise binary op.
}];

let builders =
[
OpBuilder<(ins "Value": $lhs, "Value": $rhs, "Value": $out, "ArrayAttr": $operand_constraints),
[{
build($_builder, $_state, {out.getType()}, {lhs, rhs}, out, operand_constraints);
}]>
];
}

def TTIR_SubtractOp : TTIR_ElementwiseBinaryOp<"subtract"> {
let summary = "Eltwise subtract.";
let description = [{
Eltwise subtract operation.
}];
}

def TTIR_GreaterEqualOp : TTIR_ElementwiseBinaryOp<"ge"> {
let summary = "Eltwise greater than or equal to.";
let description = [{
Eltwise greater than or equal to operation.
}];
}

def TTIR_MaximumOp : TTIR_ElementwiseBinaryOp<"maximum"> {
let summary = "Eltwise maximum OP.";
let description = [{
Calculates maximum of input tensors' values element-wise and stores result in output tensor.

Example:
%lhs: [[3, 2, 7], [1, 4, 4]]
%rhs: [[1, 4, 2], [1, 2, 3]]
"ttir.maximum"(%lhs, %rhs, %out) -> %out: [[3, 4, 7], [1, 4, 4]]
}];
}

class TTIR_ReductionOp<string mnemonic, list<Trait> traits = []> : TTIR_DPSOp<mnemonic, traits> {
let summary = "Reduction op.";
let description = [{
Expand Down
8 changes: 4 additions & 4 deletions include/ttmlir/Dialect/TTKernel/IR/TTKernelOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ def TTKernel_PackTileOp : TTKernel_Op<"pack_tile"> {
let arguments = (ins I32:$dst_index, TTKernel_CB:$out_cb, I32:$out_index);
}

def TTKernel_CopyTileToDstInitShort : TTKernel_Op<"copy_tile_to_dst_init_short"> {
let summary = "Perform the init short for copy tile. This does not reconfigure the unpacker data types.";
def TTKernel_CopyTileInitOp : TTKernel_Op<"copy_tile_init"> {
let summary = "Perform the init for copy tile. This does not reconfigure the unpacker data types.";
let description = [{
Must be called before copy_tile.
}];
Expand Down Expand Up @@ -276,14 +276,14 @@ def TTKernel_ExpTileOp : TTKernel_Op<"exp_tile"> {
let arguments = (ins I32:$tile_index);
}

def TTKernel_RecipTileInit : TTKernel_Op<"recip_tile_init"> {
def TTKernel_RecipTileInitOp : TTKernel_Op<"recip_tile_init"> {
let summary = "Init function for recip_tile operation. Refer to documentation for any init function.";
let description = [{
Must be called before recip_tile function.
}];
}

def TTKernel_RecipTile : TTKernel_Op<"recip_tile"> {
def TTKernel_RecipTileOp : TTKernel_Op<"recip_tile"> {
let summary = "Recip tile in the DST at specified index.";
let description = [{
Performs element-wise computation of the reciprocal on each element of a tile
Expand Down
169 changes: 61 additions & 108 deletions lib/Conversion/TTIRToTTMetal/TTIRToTTMetal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -694,9 +694,9 @@ class TTIRToTTMetalDispatchRewriter : public OpRewritePattern<ttir::GenericOp> {
} else if (mlir::isa<arith::MulFOp>(arithOrMathOp)) {
builder.create<ttkernel::MulTilesInitOp>(arithOrMathOp.getLoc(), inCB0,
inCB1);
} else if(mlir::isa<arith::DivFOp>(arithOrMathOp)) {
} else if (mlir::isa<arith::DivFOp>(arithOrMathOp)) {
builder.create<ttkernel::MulTilesInitFOp>(arithOrMathOp.getLoc());
}else {
} else {
llvm_unreachable("Unhandled binary op init conversion.");
}
}
Expand Down Expand Up @@ -733,13 +733,11 @@ class TTIRToTTMetalDispatchRewriter : public OpRewritePattern<ttir::GenericOp> {

// For all unary ops first copy tile from input CB at inCBTileIndex to DST
// register at dstTileIndex.

// We always operate on the first and only tile in DST register.
Value dstTileIndex = i32(0, builder);
auto outCBTileIndex = iterators[blockArgIteratorMapping.back()];
auto outCB = cbOperands.back();

// MATH acquires lock on DST register.
builder.create<ttkernel::TileRegsAcquireOp>(arithOrMathOp.getLoc());

builder.create<ttkernel::CopyTileOp>(arithOrMathOp.getLoc(), inCB,
Expand All @@ -753,17 +751,15 @@ class TTIRToTTMetalDispatchRewriter : public OpRewritePattern<ttir::GenericOp> {
llvm_unreachable("Unhandled unary op compute conversion.");
}

// MATH releases lock on DST.
builder.create<ttkernel::TileRegsCommitOp>(arithOrMathOp.getLoc());

// PACK acquires lock on DST register. Blocked until MATH releases it.
builder.create<ttkernel::TileRegsWaitOp>(arithOrMathOp.getLoc());

// Copy tile from DST at dstTileIndex to outCB at outCBTileIndex.
// outCBTileIndex increments as loops iterate, thus placing one result tile
// after another in outCB.
builder.create<ttkernel::PackTileOp>(arithOrMathOp.getLoc(), dstTileIndex, outCB,
outCBTileIndex);
builder.create<ttkernel::PackTileOp>(arithOrMathOp.getLoc(), dstTileIndex,
outCB, outCBTileIndex);

// PACK releases lock on DST.
builder.create<ttkernel::TileRegsReleaseOp>(arithOrMathOp.getLoc());
Expand All @@ -788,48 +784,37 @@ class TTIRToTTMetalDispatchRewriter : public OpRewritePattern<ttir::GenericOp> {

if (mlir::isa<arith::AddFOp>(arithOrMathOp)) {
Value dstIndex = i32(0, builder);
builder.create<ttkernel::TileRegsAcquireOp>(
arithOrMathOp.getLoc());
builder.create<ttkernel::AddTilesOp>(
arithOrMathOp.getLoc(), inCB0,
inCB1, inCB0TileIndex,
inCB1TileIndex, dstIndex);
builder.create<ttkernel::TileRegsCommitOp>(
arithOrMathOp.getLoc());
builder.create<ttkernel::TileRegsWaitOp>(
arithOrMathOp.getLoc());
builder.create<ttkernel::TileRegsAcquireOp>(arithOrMathOp.getLoc());
builder.create<ttkernel::AddTilesOp>(arithOrMathOp.getLoc(), inCB0, inCB1,
inCB0TileIndex, inCB1TileIndex,
dstIndex);
builder.create<ttkernel::TileRegsCommitOp>(arithOrMathOp.getLoc());
builder.create<ttkernel::TileRegsWaitOp>(arithOrMathOp.getLoc());
builder.create<ttkernel::PackTileOp>(
arithOrMathOp.getLoc(), dstIndex, outCB,
iterators[blockArgIteratorMapping[2]]);
builder.create<ttkernel::TileRegsReleaseOp>(
arithOrMathOp.getLoc());
builder.create<ttkernel::TileRegsReleaseOp>(arithOrMathOp.getLoc());
} else if (mlir::isa<arith::MulFOp>(arithOrMathOp)) {
commonComputeMulOp(arithOrMathOp, cbOperands,
iterators, blockArgIteratorMapping,
builder);
commonComputeMulOp(arithOrMathOp, cbOperands, iterators,
blockArgIteratorMapping, builder);
} else if (mlir::isa<arith::DivFOp>(arithOrMathOp)) {

SmallVector<std::int64_t> operandIndicesRecip;
// For DIV, input 1 is going through reciprocal.
operandIndicesRecip.push_back(1);
commonComputeRecipOp(arithOrMathOp, cbOperands,
iterators, blockArgIteratorMapping,
builder, operandIndicesRecip);

builder.create<ttkernel::MulTilesInitOp>(
arithOrMathOp.getLoc(),
inCB0,
inCB1);

commonComputeMulOp(arithOrMathOp, cbOperands,
iterators, blockArgIteratorMapping,
builder);

commonComputeRecipOp(arithOrMathOp, cbOperands, iterators,
blockArgIteratorMapping, builder,
operandIndicesRecip);

builder.create<ttkernel::MulTilesInitOp>(arithOrMathOp.getLoc(), inCB0,
inCB1);

commonComputeMulOp(arithOrMathOp, cbOperands, iterators,
blockArgIteratorMapping, builder);

Value one = i32(1, builder);
builder.create<ttkernel::CBPopFrontOp>(
arithOrMathOp.getLoc(),
inCB0,
one);
builder.create<ttkernel::CBPopFrontOp>(arithOrMathOp.getLoc(), inCB1,
one);
} else {
llvm_unreachable("Unhandled conversion for operation which is neither "
"unary nor binary.");
Expand All @@ -840,91 +825,60 @@ class TTIRToTTMetalDispatchRewriter : public OpRewritePattern<ttir::GenericOp> {
ArrayRef<BlockArgument> iterators,
SmallVector<unsigned> blockArgIteratorMapping,
OpBuilder &builder) const {

auto inCB0 = cbOperands[0];
auto inCB1 = cbOperands[1];
auto outCB = cbOperands[2];

Value dstIndex = i32(0, builder);

builder.create<ttkernel::TileRegsAcquireOp>(
op.getLoc());
builder.create<ttkernel::TileRegsAcquireOp>(op.getLoc());
if (mlir::isa<arith::MulFOp>(op)) {
builder.create<ttkernel::MulTilesOp>(
op.getLoc(),
inCB0,
inCB1,
iterators[blockArgIteratorMapping[0]],
iterators[blockArgIteratorMapping[0]],
dstIndex);
op.getLoc(), inCB0, inCB1, iterators[blockArgIteratorMapping[0]],
iterators[blockArgIteratorMapping[0]], dstIndex);
} else if (mlir::isa<arith::DivFOp>(op)) {
builder.create<ttkernel::MulTilesOp>(
op.getLoc(),
inCB0,
inCB1,
iterators[blockArgIteratorMapping[0]],
dstIndex,
dstIndex);
op.getLoc(), inCB0, inCB1, iterators[blockArgIteratorMapping[0]],
dstIndex, dstIndex);
} else {
llvm_unreachable("Common compute for multiplying tiles should be called only on MulFOp and DivFOp");
llvm_unreachable("Common compute for multiplying tiles should be called "
"only on MulFOp and DivFOp");
}

builder.create<ttkernel::TileRegsCommitOp>(
op.getLoc());
builder.create<ttkernel::TileRegsWaitOp>(
op.getLoc());
builder.create<ttkernel::PackTileOp>(
op.getLoc(), dstIndex, outCB,
iterators[blockArgIteratorMapping[2]]);
builder.create<ttkernel::TileRegsReleaseOp>(
op.getLoc());

builder.create<ttkernel::TileRegsCommitOp>(op.getLoc());
builder.create<ttkernel::TileRegsWaitOp>(op.getLoc());
builder.create<ttkernel::PackTileOp>(op.getLoc(), dstIndex, outCB,
iterators[blockArgIteratorMapping[2]]);
builder.create<ttkernel::TileRegsReleaseOp>(op.getLoc());
}

void commonComputeRecipOp(Operation &op, ArrayRef<BlockArgument> cbOperands,
ArrayRef<BlockArgument> iterators,
SmallVector<unsigned> blockArgIteratorMapping,
OpBuilder &builder,
SmallVector<std::int64_t>& operandIndices) const {
SmallVector<std::int64_t> &operandIndices) const {
Value dstIndex = i32(0, builder);
Value one = i32(1, builder);

builder.create<ttkernel::CopyTileToDstInitShort>(
op.getLoc());

builder.create<ttkernel::CopyTileInitOp>(op.getLoc());
builder.create<ttkernel::CBReserveBackOp>(
op.getLoc(),
cbOperands[operandIndices[0]],
one);
builder.create<ttkernel::TileRegsAcquireOp>(
op.getLoc());
builder.create<ttkernel::RecipTileInit>(
op.getLoc());
op.getLoc(), cbOperands[operandIndices[0]], one);
builder.create<ttkernel::TileRegsAcquireOp>(op.getLoc());
builder.create<ttkernel::RecipTileInitOp>(op.getLoc());
builder.create<ttkernel::CopyTileOp>(
op.getLoc(),
cbOperands[operandIndices[0]],
dstIndex,
dstIndex);
builder.create<ttkernel::RecipTile>(
op.getLoc(),
dstIndex);
builder.create<ttkernel::TileRegsCommitOp>(
op.getLoc());

builder.create<ttkernel::TileRegsWaitOp>(
op.getLoc());
op.getLoc(), cbOperands[operandIndices[0]], dstIndex, dstIndex);
builder.create<ttkernel::RecipTileOp>(op.getLoc(), dstIndex);
builder.create<ttkernel::TileRegsCommitOp>(op.getLoc());

builder.create<ttkernel::TileRegsWaitOp>(op.getLoc());
builder.create<ttkernel::PackTileOp>(
op.getLoc(), dstIndex,
cbOperands[operandIndices[0]],
dstIndex);
builder.create<ttkernel::TileRegsReleaseOp>(
op.getLoc());
builder.create<ttkernel::CBPushBackOp>(
op.getLoc(),
cbOperands[operandIndices[0]],
one);
builder.create<ttkernel::CBWaitFrontOp>(
op.getLoc(),
cbOperands[operandIndices[0]],
one);
op.getLoc(), dstIndex, cbOperands[operandIndices[0]], dstIndex);
builder.create<ttkernel::TileRegsReleaseOp>(op.getLoc());
builder.create<ttkernel::CBPushBackOp>(op.getLoc(),
cbOperands[operandIndices[0]], one);
builder.create<ttkernel::CBWaitFrontOp>(op.getLoc(),
cbOperands[operandIndices[0]], one);
}

// Convert arith and math dialect operations into ttkernel tile operations.
Expand All @@ -935,8 +889,7 @@ class TTIRToTTMetalDispatchRewriter : public OpRewritePattern<ttir::GenericOp> {
ArrayRef<BlockArgument> cbOperands,
ArrayRef<BlockArgument> iterators,
SmallVector<unsigned> blockArgIteratorMapping,
OpBuilder &builder,
std::int64_t numDpsInputs) const {
OpBuilder &builder, std::int64_t numDpsInputs) const {

if (numDpsInputs == 1) {
convertComputeUnaryOp(arithOrMathOp, cbOperands, iterators,
Expand Down Expand Up @@ -984,8 +937,7 @@ class TTIRToTTMetalDispatchRewriter : public OpRewritePattern<ttir::GenericOp> {
// Call compute function to execute on each tile. Result will be stored in
// DST.
convertComputeOp(arithOrMathOp, cbOperands, iterators,
blockArgIteratorMapping, innerLoopBuilder,
numDPSInputs);
blockArgIteratorMapping, innerLoopBuilder, numDPSInputs);
}

// Builds instructions to execute after loops are finished.
Expand Down Expand Up @@ -1068,8 +1020,9 @@ class TTIRToTTMetalDispatchRewriter : public OpRewritePattern<ttir::GenericOp> {
for (auto ty : rewrittenBlockArgumentTypes) {
tensixBlock->addArgument(ty, op.getLoc());
}

lowerBlock(&op->getRegion(0).front(), tensixBlock, op.getNumDpsInputs(), rewriter);

lowerBlock(&op->getRegion(0).front(), tensixBlock, op.getNumDpsInputs(),
rewriter);

rewriter.replaceOp(op, metalDispatch);

Expand Down
Loading

0 comments on commit 55afd59

Please sign in to comment.