Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Metal Direct] Implement div op #780

Merged
merged 3 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -232,13 +232,6 @@ class TTIR_ElementwiseBinaryOp<string mnemonic, list<Trait> traits = []> :
];
}

def TTIR_DivOp : TTIR_ElementwiseBinaryOp<"div"> {
let summary = "Eltwise divide.";
let description = [{
Eltwise divide operation.
}];
}

def TTIR_SubtractOp : TTIR_ElementwiseBinaryOp<"subtract"> {
let summary = "Eltwise subtract.";
let description = [{
Expand Down Expand Up @@ -725,6 +718,13 @@ def TTIR_MultiplyOp : TTIR_GenericElementwiseBinaryOp<"multiply"> {
}];
}

def TTIR_DivOp : TTIR_GenericElementwiseBinaryOp<"div"> {
let summary = "Eltwise divide.";
let description = [{
Eltwise divide operation.
}];
}

//===----------------------------------------------------------------------===//
// TTIR region ops (ops that may appear inside of ttir.generic region)
//===----------------------------------------------------------------------===//
Expand Down
38 changes: 36 additions & 2 deletions include/ttmlir/Dialect/TTKernel/IR/TTKernelOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,15 @@ def TTKernel_PackTileOp : TTKernel_Op<"pack_tile"> {
let arguments = (ins I32:$dst_index, TTKernel_CB:$out_cb, I32:$out_index);
}

def TTKernel_CopyTileInitOp : TTKernel_Op<"copy_tile_init"> {
let summary = "Perform the init for copy tile. This does not reconfigure the unpacker data types.";
let description = [{
Must be called before copy_tile.
}];
}

def TTKernel_CopyTileOp : TTKernel_Op<"copy_tile"> {
let summary = "copy_tile";
let summary = "Copy tile from specified CB to DST.";
let description = [{
Copies a single tile from the specified input CB and writes the result to
DST at a specified index. The function will employ unpacker to first unpack into SRC
Expand All @@ -139,7 +146,7 @@ def TTKernel_CopyTileOp : TTKernel_Op<"copy_tile"> {
engine.
}];

let arguments = (ins TTKernel_CB:$cb0, I32:$tile_index_0, I32:$tile_index_1);
let arguments = (ins TTKernel_CB:$cb0, I32:$tile_index_cb, I32:$tile_index_dst);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -221,6 +228,13 @@ def TTKernel_MulTilesInitOp : TTKernel_Op<"mul_tiles_init"> {
let arguments = (ins TTKernel_CB:$in0_cb, TTKernel_CB:$in1_cb);
}

def TTKernel_MulTilesInitFOp : TTKernel_Op<"mul_tiles_init_f"> {
let summary = "Short init function. Init for math only.";
let description = [{
Must be run before mul_tiles.
}];
}

def TTKernel_MulTilesOp : TTKernel_Op<"mul_tiles"> {
let summary = "Mul operation";
let description = [{
Expand Down Expand Up @@ -262,6 +276,26 @@ def TTKernel_ExpTileOp : TTKernel_Op<"exp_tile"> {
let arguments = (ins I32:$tile_index);
}

def TTKernel_RecipTileInitOp : TTKernel_Op<"recip_tile_init"> {
let summary = "Init function for recip_tile operation. Refer to documentation for any init function.";
let description = [{
Must be called before recip_tile function.
}];
}

def TTKernel_RecipTileOp : TTKernel_Op<"recip_tile"> {
let summary = "Recip tile in the DST at specified index.";
let description = [{
Performs element-wise computation of the reciprocal on each element of a tile
in DST register at index tile_index. The DST register buffer must be in
acquired state via *tile_regs_acquire* call. This call is blocking and is only
available on the compute engine.
Only works for Float32, Float16_b, Bfp8_b data formats for full accuracy.
}];

let arguments = (ins I32:$tile_index);
}

//===----------------------------------------------------------------------===//
// TTKernel CB operations
//===----------------------------------------------------------------------===//
Expand Down
174 changes: 132 additions & 42 deletions lib/Conversion/TTIRToTTMetal/TTIRToTTMetal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,8 @@ class TTIRToTTMetalDispatchRewriter : public OpRewritePattern<ttir::GenericOp> {
} else if (mlir::isa<arith::MulFOp>(arithOrMathOp)) {
builder.create<ttkernel::MulTilesInitOp>(arithOrMathOp.getLoc(), inCB0,
inCB1);
} else if (mlir::isa<arith::DivFOp>(arithOrMathOp)) {
builder.create<ttkernel::MulTilesInitFOp>(arithOrMathOp.getLoc());
} else {
llvm_unreachable("Unhandled binary op init conversion.");
}
Expand Down Expand Up @@ -722,53 +724,167 @@ class TTIRToTTMetalDispatchRewriter : public OpRewritePattern<ttir::GenericOp> {
ArrayRef<BlockArgument> cbOperands,
ArrayRef<BlockArgument> iterators,
SmallVector<unsigned> blockArgIteratorMapping,
Value dstTileIndex, OpBuilder &builder) const {
OpBuilder &builder) const {
assert(cbOperands.size() == 2 &&
"Expected one input and one output CB for unary op.");

auto inCBTileIndex = iterators[blockArgIteratorMapping[0]];
auto inCB = cbOperands[0];
auto outCBTileIndex = iterators[blockArgIteratorMapping.back()];
auto outCB = cbOperands.back();

auto location = arithOrMathOp.getLoc();

// We always operate on the first and only tile in DST register.
Value dstTileIndex = i32(0, builder);

// MATH acquires lock on DST register.
builder.create<ttkernel::TileRegsAcquireOp>(location);

// For all unary ops first copy tile from input CB at inCBTileIndex to DST
// register at dstTileIndex.
builder.create<ttkernel::CopyTileOp>(arithOrMathOp.getLoc(), inCB,
inCBTileIndex, dstTileIndex);
builder.create<ttkernel::CopyTileOp>(location, inCB, inCBTileIndex,
dstTileIndex);

// Perform computation on tile in DST register on dstTileIndex (the only
// tile in DST).
if (mlir::isa<math::ExpOp>(arithOrMathOp)) {
builder.create<ttkernel::ExpTileOp>(arithOrMathOp.getLoc(), dstTileIndex);
builder.create<ttkernel::ExpTileOp>(location, dstTileIndex);
} else {
llvm_unreachable("Unhandled unary op compute conversion.");
}

// MATH releases lock on DST.
builder.create<ttkernel::TileRegsCommitOp>(location);

// PACK acquires lock on DST register. Blocked until MATH releases it.
builder.create<ttkernel::TileRegsWaitOp>(location);

// Copy tile from DST at dstTileIndex to outCB at outCBTileIndex.
// outCBTileIndex increments as loops iterate, thus placing one result tile
// after another in outCB.
builder.create<ttkernel::PackTileOp>(location, dstTileIndex, outCB,
outCBTileIndex);

// PACK releases lock on DST.
builder.create<ttkernel::TileRegsReleaseOp>(location);
}

void convertComputeBinaryOp(Operation &arithOrMathOp,
ArrayRef<BlockArgument> cbOperands,
ArrayRef<BlockArgument> iterators,
SmallVector<unsigned> blockArgIteratorMapping,
Value dstTileIndex, OpBuilder &builder) const {
OpBuilder &builder) const {
assert(cbOperands.size() == 3 &&
"Expected two input and one output CB for binary op.");

auto inCB0TileIndex = iterators[blockArgIteratorMapping[0]];
auto inCB0 = cbOperands[0];
auto inCB1TileIndex = iterators[blockArgIteratorMapping[1]];
auto inCB1 = cbOperands[1];
auto outCB = cbOperands[2];
auto outCBTileIndex = iterators[blockArgIteratorMapping[2]];

auto location = arithOrMathOp.getLoc();

// Perform computation C = A (*) B on tile A from inCB0 and tile B from
// inCB1 and store the result C in DST register on dstTileIndex.
if (mlir::isa<arith::AddFOp>(arithOrMathOp)) {
builder.create<ttkernel::AddTilesOp>(arithOrMathOp.getLoc(), inCB0, inCB1,
inCB0TileIndex, inCB1TileIndex,
dstTileIndex);
Value dstIndex = i32(0, builder);
builder.create<ttkernel::TileRegsAcquireOp>(location);
builder.create<ttkernel::AddTilesOp>(
location, inCB0, inCB1, inCB0TileIndex, inCB1TileIndex, dstIndex);
builder.create<ttkernel::TileRegsCommitOp>(location);
builder.create<ttkernel::TileRegsWaitOp>(location);
builder.create<ttkernel::PackTileOp>(location, dstIndex, outCB,
outCBTileIndex);
builder.create<ttkernel::TileRegsReleaseOp>(location);
} else if (mlir::isa<arith::MulFOp>(arithOrMathOp)) {
builder.create<ttkernel::MulTilesOp>(arithOrMathOp.getLoc(), inCB0, inCB1,
inCB0TileIndex, inCB1TileIndex,
dstTileIndex);
commonComputeMulOp(arithOrMathOp, cbOperands, iterators,
blockArgIteratorMapping, builder);
} else if (mlir::isa<arith::DivFOp>(arithOrMathOp)) {

SmallVector<std::int64_t> operandIndicesRecip;
// For DIV, input 1 is going through reciprocal.
operandIndicesRecip.push_back(1);
pjanevskiTT marked this conversation as resolved.
Show resolved Hide resolved
commonComputeRecipOp(arithOrMathOp, cbOperands, iterators,
blockArgIteratorMapping, builder,
operandIndicesRecip);

Value one = i32(1, builder);
builder.create<ttkernel::CBWaitFrontOp>(location, inCB1, one);

builder.create<ttkernel::MulTilesInitOp>(location, inCB0, inCB1);

commonComputeMulOp(arithOrMathOp, cbOperands, iterators,
blockArgIteratorMapping, builder);

builder.create<ttkernel::CBPopFrontOp>(location, inCB1, one);
} else {
llvm_unreachable("Unhandled conversion for operation which is neither "
"unary nor binary.");
}
}

void commonComputeMulOp(Operation &op, ArrayRef<BlockArgument> cbOperands,
ArrayRef<BlockArgument> iterators,
SmallVector<unsigned> blockArgIteratorMapping,
OpBuilder &builder) const {

auto inCB0 = cbOperands[0];
auto inCB1 = cbOperands[1];
auto outCB = cbOperands[2];
auto inCB0TileIndex = iterators[blockArgIteratorMapping[0]];
auto inCB1TileIndex = iterators[blockArgIteratorMapping[1]];

Value dstIndex = i32(0, builder);

builder.create<ttkernel::TileRegsAcquireOp>(op.getLoc());
if (mlir::isa<arith::MulFOp>(op)) {
builder.create<ttkernel::MulTilesOp>(
op.getLoc(), inCB0, inCB1, inCB0TileIndex, inCB1TileIndex, dstIndex);
} else if (mlir::isa<arith::DivFOp>(op)) {
// Source index for CB input 1 is 0(dstIndex), because of sync needed with
// recip.
builder.create<ttkernel::MulTilesOp>(op.getLoc(), inCB0, inCB1,
inCB0TileIndex, dstIndex, dstIndex);
} else {
llvm_unreachable("Unhandled binary op compute conversion.");
llvm_unreachable("Common compute for multiplying tiles should be called "
"only on MulFOp and DivFOp");
}

builder.create<ttkernel::TileRegsCommitOp>(op.getLoc());
builder.create<ttkernel::TileRegsWaitOp>(op.getLoc());
builder.create<ttkernel::PackTileOp>(op.getLoc(), dstIndex, outCB,
iterators[blockArgIteratorMapping[2]]);
builder.create<ttkernel::TileRegsReleaseOp>(op.getLoc());
}

void commonComputeRecipOp(Operation &op, ArrayRef<BlockArgument> cbOperands,
ArrayRef<BlockArgument> iterators,
SmallVector<unsigned> blockArgIteratorMapping,
OpBuilder &builder,
SmallVector<std::int64_t> &operandIndices) const {
Value dstIndex = i32(0, builder);
Value one = i32(1, builder);

auto inputCB = cbOperands[operandIndices[0]];
auto outputCB = inputCB;

builder.create<ttkernel::CopyTileInitOp>(op.getLoc());
builder.create<ttkernel::CBReserveBackOp>(op.getLoc(), inputCB, one);
builder.create<ttkernel::TileRegsAcquireOp>(op.getLoc());
builder.create<ttkernel::RecipTileInitOp>(op.getLoc());
builder.create<ttkernel::CopyTileOp>(op.getLoc(), inputCB, dstIndex,
dstIndex);
builder.create<ttkernel::RecipTileOp>(op.getLoc(), dstIndex);
builder.create<ttkernel::TileRegsCommitOp>(op.getLoc());

builder.create<ttkernel::TileRegsWaitOp>(op.getLoc());
builder.create<ttkernel::PackTileOp>(op.getLoc(), dstIndex, outputCB,
dstIndex);
builder.create<ttkernel::TileRegsReleaseOp>(op.getLoc());
builder.create<ttkernel::CBPushBackOp>(op.getLoc(), outputCB, one);
}

// Convert arith and math dialect operations into ttkernel tile operations.
Expand All @@ -779,15 +895,14 @@ class TTIRToTTMetalDispatchRewriter : public OpRewritePattern<ttir::GenericOp> {
ArrayRef<BlockArgument> cbOperands,
ArrayRef<BlockArgument> iterators,
SmallVector<unsigned> blockArgIteratorMapping,
Value dstTileIndex, OpBuilder &builder,
std::int64_t numDpsInputs) const {
OpBuilder &builder, std::int64_t numDpsInputs) const {

if (numDpsInputs == 1) {
convertComputeUnaryOp(arithOrMathOp, cbOperands, iterators,
blockArgIteratorMapping, dstTileIndex, builder);
blockArgIteratorMapping, builder);
} else if (numDpsInputs == 2) {
convertComputeBinaryOp(arithOrMathOp, cbOperands, iterators,
blockArgIteratorMapping, dstTileIndex, builder);
blockArgIteratorMapping, builder);
} else {
llvm_unreachable("Unhandled conversion for operation which is neither "
"unary nor binary.");
Expand Down Expand Up @@ -817,7 +932,6 @@ class TTIRToTTMetalDispatchRewriter : public OpRewritePattern<ttir::GenericOp> {
// The loop nest is created from outermost to innermost. Get the inner loop
// and place computation calls inside it.
Region *innerLoopRegion = loopNest.loopRegions.back();
const Location &location = innerLoopRegion->getLoc();
ArrayRef<BlockArgument> iterators =
loopNest.loops.back().getRegionIterArgs();
SmallVector<unsigned> blockArgIteratorMapping =
Expand All @@ -826,34 +940,10 @@ class TTIRToTTMetalDispatchRewriter : public OpRewritePattern<ttir::GenericOp> {
OpBuilder innerLoopBuilder(&innerLoopRegion->front(),
innerLoopRegion->front().begin());

// We always operate on the first and only tile in DST register.
Value dstTileIndex = i32(0, innerLoopBuilder);
auto outCBTileIndex = iterators[blockArgIteratorMapping.back()];
auto outCB = cbOperands.back();

// MATH acquires lock on DST register.
innerLoopBuilder.create<ttkernel::TileRegsAcquireOp>(location);

// Call compute function to execute on each tile. Result will be stored in
// DST.
convertComputeOp(arithOrMathOp, cbOperands, iterators,
blockArgIteratorMapping, dstTileIndex, innerLoopBuilder,
numDPSInputs);

// MATH releases lock on DST.
innerLoopBuilder.create<ttkernel::TileRegsCommitOp>(location);

// PACK acquires lock on DST register. Blocked until MATH releases it.
innerLoopBuilder.create<ttkernel::TileRegsWaitOp>(location);

// Copy tile from DST at dstTileIndex to outCB at outCBTileIndex.
// outCBTileIndex increments as loops iterate, thus placing one result tile
// after another in outCB.
innerLoopBuilder.create<ttkernel::PackTileOp>(location, dstTileIndex, outCB,
outCBTileIndex);

// PACK releases lock on DST.
innerLoopBuilder.create<ttkernel::TileRegsReleaseOp>(location);
blockArgIteratorMapping, innerLoopBuilder, numDPSInputs);
}

// Builds instructions to execute after loops are finished.
Expand Down
7 changes: 7 additions & 0 deletions lib/Conversion/TTKernelToEmitC/TTKernelToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,9 @@ class ConvertTTKernelToEmitCPass
patterns
.add<TTMetalToEmitCFuncArgsRewriter, TTMetalToEmitCReturnRewriter,
TTMetalToEmitCOpaqueRewriter<ttkernel::BuiltinOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::CopyTileInitOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::RecipTileInitOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::RecipTileOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::TileRegsAcquireOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::TileRegsCommitOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::TileRegsWaitOp>,
Expand All @@ -277,6 +280,7 @@ class ConvertTTKernelToEmitCPass
TTMetalToEmitCOpaqueRewriter<ttkernel::BinaryOpInitCommonOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::AddTilesInitOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::MulTilesInitOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::MulTilesInitFOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::AddTilesOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::MulTilesOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::GetNocAddrOp>,
Expand Down Expand Up @@ -343,6 +347,9 @@ class ThreadConfigHelper {
builder->create<emitc::IncludeOp>(
loc, "compute_kernel_api/eltwise_unary/sfpu_split_includes.h",
/*isStandard=*/false);
builder->create<emitc::IncludeOp>(
loc, "compute_kernel_api/eltwise_unary/recip.h",
/*isStandard=*/false);
builder->create<emitc::VerbatimOp>(loc, "namespace NAMESPACE {");
}
}
Expand Down
6 changes: 6 additions & 0 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@ void mlir::tt::ttir::ExpOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
buildGenericEltwiseUnaryRegion<math::ExpOp>(getLoc(), opBuilder, block);
}

void mlir::tt::ttir::DivOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
::mlir::Block *block) {
return buildGenericEltwiseBinaryRegion<arith::DivFOp>(getLoc(), opBuilder,
block);
}

::mlir::LogicalResult mlir::tt::ttir::EmbeddingOp::verify() {
::mlir::RankedTensorType inputType = getInput().getType();
::mlir::RankedTensorType weightType = getWeight().getType();
Expand Down
Loading
Loading