Skip to content

Commit

Permalink
[Metal Direct] Implement div op (#780)
Browse files Browse the repository at this point in the history
* Implement metal direct div op

* Address comments

* Recursive inside dispatch region function
  • Loading branch information
pjanevskiTT authored Oct 1, 2024
1 parent ddee2bd commit f983235
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 52 deletions.
14 changes: 7 additions & 7 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -232,13 +232,6 @@ class TTIR_ElementwiseBinaryOp<string mnemonic, list<Trait> traits = []> :
];
}

def TTIR_DivOp : TTIR_ElementwiseBinaryOp<"div"> {
let summary = "Eltwise divide.";
let description = [{
Eltwise divide operation.
}];
}

def TTIR_SubtractOp : TTIR_ElementwiseBinaryOp<"subtract"> {
let summary = "Eltwise subtract.";
let description = [{
Expand Down Expand Up @@ -725,6 +718,13 @@ def TTIR_MultiplyOp : TTIR_GenericElementwiseBinaryOp<"multiply"> {
}];
}

def TTIR_DivOp : TTIR_GenericElementwiseBinaryOp<"div"> {
let summary = "Eltwise divide.";
let description = [{
Eltwise divide operation.
}];
}

//===----------------------------------------------------------------------===//
// TTIR region ops (ops that may appear inside of ttir.generic region)
//===----------------------------------------------------------------------===//
Expand Down
38 changes: 36 additions & 2 deletions include/ttmlir/Dialect/TTKernel/IR/TTKernelOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,15 @@ def TTKernel_PackTileOp : TTKernel_Op<"pack_tile"> {
let arguments = (ins I32:$dst_index, TTKernel_CB:$out_cb, I32:$out_index);
}

def TTKernel_CopyTileInitOp : TTKernel_Op<"copy_tile_init"> {
let summary = "Perform the init for copy tile. This does not reconfigure the unpacker data types.";
let description = [{
Must be called before copy_tile.
}];
}

def TTKernel_CopyTileOp : TTKernel_Op<"copy_tile"> {
let summary = "copy_tile";
let summary = "Copy tile from specified CB to DST.";
let description = [{
Copies a single tile from the specified input CB and writes the result to
DST at a specified index. The function will employ unpacker to first unpack into SRC
Expand All @@ -139,7 +146,7 @@ def TTKernel_CopyTileOp : TTKernel_Op<"copy_tile"> {
engine.
}];

let arguments = (ins TTKernel_CB:$cb0, I32:$tile_index_0, I32:$tile_index_1);
let arguments = (ins TTKernel_CB:$cb0, I32:$tile_index_cb, I32:$tile_index_dst);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -221,6 +228,13 @@ def TTKernel_MulTilesInitOp : TTKernel_Op<"mul_tiles_init"> {
let arguments = (ins TTKernel_CB:$in0_cb, TTKernel_CB:$in1_cb);
}

def TTKernel_MulTilesInitFOp : TTKernel_Op<"mul_tiles_init_f"> {
let summary = "Short init function. Init for math only.";
let description = [{
Must be run before mul_tiles.
}];
}

def TTKernel_MulTilesOp : TTKernel_Op<"mul_tiles"> {
let summary = "Mul operation";
let description = [{
Expand Down Expand Up @@ -262,6 +276,26 @@ def TTKernel_ExpTileOp : TTKernel_Op<"exp_tile"> {
let arguments = (ins I32:$tile_index);
}

def TTKernel_RecipTileInitOp : TTKernel_Op<"recip_tile_init"> {
let summary = "Init function for recip_tile operation. Refer to documentation for any init function.";
let description = [{
Must be called before recip_tile function.
}];
}

def TTKernel_RecipTileOp : TTKernel_Op<"recip_tile"> {
let summary = "Recip tile in the DST at specified index.";
let description = [{
Performs element-wise computation of the reciprocal on each element of a tile
in DST register at index tile_index. The DST register buffer must be in
acquired state via *tile_regs_acquire* call. This call is blocking and is only
available on the compute engine.
Only works for Float32, Float16_b, Bfp8_b data formats for full accuracy.
}];

let arguments = (ins I32:$tile_index);
}

//===----------------------------------------------------------------------===//
// TTKernel CB operations
//===----------------------------------------------------------------------===//
Expand Down
174 changes: 132 additions & 42 deletions lib/Conversion/TTIRToTTMetal/TTIRToTTMetal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,8 @@ class TTIRToTTMetalDispatchRewriter : public OpRewritePattern<ttir::GenericOp> {
} else if (mlir::isa<arith::MulFOp>(arithOrMathOp)) {
builder.create<ttkernel::MulTilesInitOp>(arithOrMathOp.getLoc(), inCB0,
inCB1);
} else if (mlir::isa<arith::DivFOp>(arithOrMathOp)) {
builder.create<ttkernel::MulTilesInitFOp>(arithOrMathOp.getLoc());
} else {
llvm_unreachable("Unhandled binary op init conversion.");
}
Expand Down Expand Up @@ -722,53 +724,167 @@ class TTIRToTTMetalDispatchRewriter : public OpRewritePattern<ttir::GenericOp> {
ArrayRef<BlockArgument> cbOperands,
ArrayRef<BlockArgument> iterators,
SmallVector<unsigned> blockArgIteratorMapping,
Value dstTileIndex, OpBuilder &builder) const {
OpBuilder &builder) const {
assert(cbOperands.size() == 2 &&
"Expected one input and one output CB for unary op.");

auto inCBTileIndex = iterators[blockArgIteratorMapping[0]];
auto inCB = cbOperands[0];
auto outCBTileIndex = iterators[blockArgIteratorMapping.back()];
auto outCB = cbOperands.back();

auto location = arithOrMathOp.getLoc();

// We always operate on the first and only tile in DST register.
Value dstTileIndex = i32(0, builder);

// MATH acquires lock on DST register.
builder.create<ttkernel::TileRegsAcquireOp>(location);

// For all unary ops first copy tile from input CB at inCBTileIndex to DST
// register at dstTileIndex.
builder.create<ttkernel::CopyTileOp>(arithOrMathOp.getLoc(), inCB,
inCBTileIndex, dstTileIndex);
builder.create<ttkernel::CopyTileOp>(location, inCB, inCBTileIndex,
dstTileIndex);

// Perform computation on tile in DST register on dstTileIndex (the only
// tile in DST).
if (mlir::isa<math::ExpOp>(arithOrMathOp)) {
builder.create<ttkernel::ExpTileOp>(arithOrMathOp.getLoc(), dstTileIndex);
builder.create<ttkernel::ExpTileOp>(location, dstTileIndex);
} else {
llvm_unreachable("Unhandled unary op compute conversion.");
}

// MATH releases lock on DST.
builder.create<ttkernel::TileRegsCommitOp>(location);

// PACK acquires lock on DST register. Blocked until MATH releases it.
builder.create<ttkernel::TileRegsWaitOp>(location);

// Copy tile from DST at dstTileIndex to outCB at outCBTileIndex.
// outCBTileIndex increments as loops iterate, thus placing one result tile
// after another in outCB.
builder.create<ttkernel::PackTileOp>(location, dstTileIndex, outCB,
outCBTileIndex);

// PACK releases lock on DST.
builder.create<ttkernel::TileRegsReleaseOp>(location);
}

void convertComputeBinaryOp(Operation &arithOrMathOp,
ArrayRef<BlockArgument> cbOperands,
ArrayRef<BlockArgument> iterators,
SmallVector<unsigned> blockArgIteratorMapping,
Value dstTileIndex, OpBuilder &builder) const {
OpBuilder &builder) const {
assert(cbOperands.size() == 3 &&
"Expected two input and one output CB for binary op.");

auto inCB0TileIndex = iterators[blockArgIteratorMapping[0]];
auto inCB0 = cbOperands[0];
auto inCB1TileIndex = iterators[blockArgIteratorMapping[1]];
auto inCB1 = cbOperands[1];
auto outCB = cbOperands[2];
auto outCBTileIndex = iterators[blockArgIteratorMapping[2]];

auto location = arithOrMathOp.getLoc();

// Perform computation C = A (*) B on tile A from inCB0 and tile B from
// inCB1 and store the result C in DST register on dstTileIndex.
if (mlir::isa<arith::AddFOp>(arithOrMathOp)) {
builder.create<ttkernel::AddTilesOp>(arithOrMathOp.getLoc(), inCB0, inCB1,
inCB0TileIndex, inCB1TileIndex,
dstTileIndex);
Value dstIndex = i32(0, builder);
builder.create<ttkernel::TileRegsAcquireOp>(location);
builder.create<ttkernel::AddTilesOp>(
location, inCB0, inCB1, inCB0TileIndex, inCB1TileIndex, dstIndex);
builder.create<ttkernel::TileRegsCommitOp>(location);
builder.create<ttkernel::TileRegsWaitOp>(location);
builder.create<ttkernel::PackTileOp>(location, dstIndex, outCB,
outCBTileIndex);
builder.create<ttkernel::TileRegsReleaseOp>(location);
} else if (mlir::isa<arith::MulFOp>(arithOrMathOp)) {
builder.create<ttkernel::MulTilesOp>(arithOrMathOp.getLoc(), inCB0, inCB1,
inCB0TileIndex, inCB1TileIndex,
dstTileIndex);
commonComputeMulOp(arithOrMathOp, cbOperands, iterators,
blockArgIteratorMapping, builder);
} else if (mlir::isa<arith::DivFOp>(arithOrMathOp)) {

SmallVector<std::int64_t> operandIndicesRecip;
// For DIV, input 1 is going through reciprocal.
operandIndicesRecip.push_back(1);
commonComputeRecipOp(arithOrMathOp, cbOperands, iterators,
blockArgIteratorMapping, builder,
operandIndicesRecip);

Value one = i32(1, builder);
builder.create<ttkernel::CBWaitFrontOp>(location, inCB1, one);

builder.create<ttkernel::MulTilesInitOp>(location, inCB0, inCB1);

commonComputeMulOp(arithOrMathOp, cbOperands, iterators,
blockArgIteratorMapping, builder);

builder.create<ttkernel::CBPopFrontOp>(location, inCB1, one);
} else {
llvm_unreachable("Unhandled conversion for operation which is neither "
"unary nor binary.");
}
}

void commonComputeMulOp(Operation &op, ArrayRef<BlockArgument> cbOperands,
ArrayRef<BlockArgument> iterators,
SmallVector<unsigned> blockArgIteratorMapping,
OpBuilder &builder) const {

auto inCB0 = cbOperands[0];
auto inCB1 = cbOperands[1];
auto outCB = cbOperands[2];
auto inCB0TileIndex = iterators[blockArgIteratorMapping[0]];
auto inCB1TileIndex = iterators[blockArgIteratorMapping[1]];

Value dstIndex = i32(0, builder);

builder.create<ttkernel::TileRegsAcquireOp>(op.getLoc());
if (mlir::isa<arith::MulFOp>(op)) {
builder.create<ttkernel::MulTilesOp>(
op.getLoc(), inCB0, inCB1, inCB0TileIndex, inCB1TileIndex, dstIndex);
} else if (mlir::isa<arith::DivFOp>(op)) {
// Source index for CB input 1 is 0(dstIndex), because of sync needed with
// recip.
builder.create<ttkernel::MulTilesOp>(op.getLoc(), inCB0, inCB1,
inCB0TileIndex, dstIndex, dstIndex);
} else {
llvm_unreachable("Unhandled binary op compute conversion.");
llvm_unreachable("Common compute for multiplying tiles should be called "
"only on MulFOp and DivFOp");
}

builder.create<ttkernel::TileRegsCommitOp>(op.getLoc());
builder.create<ttkernel::TileRegsWaitOp>(op.getLoc());
builder.create<ttkernel::PackTileOp>(op.getLoc(), dstIndex, outCB,
iterators[blockArgIteratorMapping[2]]);
builder.create<ttkernel::TileRegsReleaseOp>(op.getLoc());
}

void commonComputeRecipOp(Operation &op, ArrayRef<BlockArgument> cbOperands,
ArrayRef<BlockArgument> iterators,
SmallVector<unsigned> blockArgIteratorMapping,
OpBuilder &builder,
SmallVector<std::int64_t> &operandIndices) const {
Value dstIndex = i32(0, builder);
Value one = i32(1, builder);

auto inputCB = cbOperands[operandIndices[0]];
auto outputCB = inputCB;

builder.create<ttkernel::CopyTileInitOp>(op.getLoc());
builder.create<ttkernel::CBReserveBackOp>(op.getLoc(), inputCB, one);
builder.create<ttkernel::TileRegsAcquireOp>(op.getLoc());
builder.create<ttkernel::RecipTileInitOp>(op.getLoc());
builder.create<ttkernel::CopyTileOp>(op.getLoc(), inputCB, dstIndex,
dstIndex);
builder.create<ttkernel::RecipTileOp>(op.getLoc(), dstIndex);
builder.create<ttkernel::TileRegsCommitOp>(op.getLoc());

builder.create<ttkernel::TileRegsWaitOp>(op.getLoc());
builder.create<ttkernel::PackTileOp>(op.getLoc(), dstIndex, outputCB,
dstIndex);
builder.create<ttkernel::TileRegsReleaseOp>(op.getLoc());
builder.create<ttkernel::CBPushBackOp>(op.getLoc(), outputCB, one);
}

// Convert arith and math dialect operations into ttkernel tile operations.
Expand All @@ -779,15 +895,14 @@ class TTIRToTTMetalDispatchRewriter : public OpRewritePattern<ttir::GenericOp> {
ArrayRef<BlockArgument> cbOperands,
ArrayRef<BlockArgument> iterators,
SmallVector<unsigned> blockArgIteratorMapping,
Value dstTileIndex, OpBuilder &builder,
std::int64_t numDpsInputs) const {
OpBuilder &builder, std::int64_t numDpsInputs) const {

if (numDpsInputs == 1) {
convertComputeUnaryOp(arithOrMathOp, cbOperands, iterators,
blockArgIteratorMapping, dstTileIndex, builder);
blockArgIteratorMapping, builder);
} else if (numDpsInputs == 2) {
convertComputeBinaryOp(arithOrMathOp, cbOperands, iterators,
blockArgIteratorMapping, dstTileIndex, builder);
blockArgIteratorMapping, builder);
} else {
llvm_unreachable("Unhandled conversion for operation which is neither "
"unary nor binary.");
Expand Down Expand Up @@ -817,7 +932,6 @@ class TTIRToTTMetalDispatchRewriter : public OpRewritePattern<ttir::GenericOp> {
// The loop nest is created from outermost to innermost. Get the inner loop
// and place computation calls inside it.
Region *innerLoopRegion = loopNest.loopRegions.back();
const Location &location = innerLoopRegion->getLoc();
ArrayRef<BlockArgument> iterators =
loopNest.loops.back().getRegionIterArgs();
SmallVector<unsigned> blockArgIteratorMapping =
Expand All @@ -826,34 +940,10 @@ class TTIRToTTMetalDispatchRewriter : public OpRewritePattern<ttir::GenericOp> {
OpBuilder innerLoopBuilder(&innerLoopRegion->front(),
innerLoopRegion->front().begin());

// We always operate on the first and only tile in DST register.
Value dstTileIndex = i32(0, innerLoopBuilder);
auto outCBTileIndex = iterators[blockArgIteratorMapping.back()];
auto outCB = cbOperands.back();

// MATH acquires lock on DST register.
innerLoopBuilder.create<ttkernel::TileRegsAcquireOp>(location);

// Call compute function to execute on each tile. Result will be stored in
// DST.
convertComputeOp(arithOrMathOp, cbOperands, iterators,
blockArgIteratorMapping, dstTileIndex, innerLoopBuilder,
numDPSInputs);

// MATH releases lock on DST.
innerLoopBuilder.create<ttkernel::TileRegsCommitOp>(location);

// PACK acquires lock on DST register. Blocked until MATH releases it.
innerLoopBuilder.create<ttkernel::TileRegsWaitOp>(location);

// Copy tile from DST at dstTileIndex to outCB at outCBTileIndex.
// outCBTileIndex increments as loops iterate, thus placing one result tile
// after another in outCB.
innerLoopBuilder.create<ttkernel::PackTileOp>(location, dstTileIndex, outCB,
outCBTileIndex);

// PACK releases lock on DST.
innerLoopBuilder.create<ttkernel::TileRegsReleaseOp>(location);
blockArgIteratorMapping, innerLoopBuilder, numDPSInputs);
}

// Builds instructions to execute after loops are finished.
Expand Down
7 changes: 7 additions & 0 deletions lib/Conversion/TTKernelToEmitC/TTKernelToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,9 @@ class ConvertTTKernelToEmitCPass
patterns
.add<TTMetalToEmitCFuncArgsRewriter, TTMetalToEmitCReturnRewriter,
TTMetalToEmitCOpaqueRewriter<ttkernel::BuiltinOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::CopyTileInitOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::RecipTileInitOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::RecipTileOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::TileRegsAcquireOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::TileRegsCommitOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::TileRegsWaitOp>,
Expand All @@ -277,6 +280,7 @@ class ConvertTTKernelToEmitCPass
TTMetalToEmitCOpaqueRewriter<ttkernel::BinaryOpInitCommonOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::AddTilesInitOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::MulTilesInitOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::MulTilesInitFOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::AddTilesOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::MulTilesOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::GetNocAddrOp>,
Expand Down Expand Up @@ -343,6 +347,9 @@ class ThreadConfigHelper {
builder->create<emitc::IncludeOp>(
loc, "compute_kernel_api/eltwise_unary/sfpu_split_includes.h",
/*isStandard=*/false);
builder->create<emitc::IncludeOp>(
loc, "compute_kernel_api/eltwise_unary/recip.h",
/*isStandard=*/false);
builder->create<emitc::VerbatimOp>(loc, "namespace NAMESPACE {");
}
}
Expand Down
6 changes: 6 additions & 0 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@ void mlir::tt::ttir::ExpOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
buildGenericEltwiseUnaryRegion<math::ExpOp>(getLoc(), opBuilder, block);
}

void mlir::tt::ttir::DivOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
::mlir::Block *block) {
return buildGenericEltwiseBinaryRegion<arith::DivFOp>(getLoc(), opBuilder,
block);
}

::mlir::LogicalResult mlir::tt::ttir::EmbeddingOp::verify() {
::mlir::RankedTensorType inputType = getInput().getType();
::mlir::RankedTensorType weightType = getWeight().getType();
Expand Down
Loading

0 comments on commit f983235

Please sign in to comment.