Skip to content

Commit

Permalink
emit C++ with metal kernel boilerplate
Browse files Browse the repository at this point in the history
  • Loading branch information
TT-billteng committed Dec 20, 2024
1 parent 528ca13 commit 3a50482
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 36 deletions.
16 changes: 15 additions & 1 deletion include/ttmlir/Conversion/TTKernelToEmitC/TTKernelToEmitC.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,27 @@ LogicalResult convertTTKernelRegionToEmitC(
// Converts given region to EmitC dialect and translates it to C++ code.
LogicalResult
emitOpRegionAsCpp(Region *region, std::string &regionCpp,
const ttkernel::KernelConfigInterface &kernelConfig);
const ttkernel::ThreadType &threadType);

LogicalResult
emitOpRegionAsCpp(Region *region, llvm::raw_ostream &os,
const ttkernel::ThreadType &threadType);

// Converts dispatch op's regions to C++ code.
LogicalResult
emitDispatchOpRegionsAsCpp(ttmetal::DispatchOp dispatchOp,
llvm::SmallVector<std::string> &cppStrings);


LogicalResult
emitNocKernelAsCpp( mlir::ModuleOp op, llvm::raw_ostream &os);

LogicalResult
emitTensixKernelAsCpp( mlir::ModuleOp op, llvm::raw_ostream &os);

LogicalResult
emitKernelAsCpp( mlir::ModuleOp op, llvm::raw_ostream &os, const ttkernel::ThreadType &threadType);

} // namespace mlir::tt

#endif
1 change: 1 addition & 0 deletions lib/Conversion/TTKernelToEmitC/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ add_mlir_conversion_library(TTMLIRTTKernelToEmitC
MLIRIR
MLIRPass
MLIRArithToEmitC
MLIRSCFToEmitC
MLIREmitCDialect
MLIRTargetCpp
MLIRTransformUtils
Expand Down
86 changes: 67 additions & 19 deletions lib/Conversion/TTKernelToEmitC/TTKernelToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -460,16 +460,16 @@ std::unique_ptr<::mlir::Pass> createConvertTTKernelToEmitC() {
class ThreadConfigHelper {
public:
ThreadConfigHelper(OpBuilder *builder, Location loc,
ttkernel::KernelConfigInterface kernelConfig)
: builder(builder), loc(loc), kernelConfig(kernelConfig) {
ttkernel::ThreadType threadType)
: builder(builder), loc(loc), threadType(threadType) {
builder->create<emitc::IncludeOp>(loc, "cstdint",
/*isStandard=*/true);
if (kernelConfig.getThreadType() == ttkernel::ThreadType::Noc) {
if (threadType == ttkernel::ThreadType::Noc) {

builder->create<emitc::IncludeOp>(loc, "dataflow_api.h",
/*isStandard=*/false);
}
if (kernelConfig.getThreadType() == ttkernel::ThreadType::Tensix) {
if (threadType == ttkernel::ThreadType::Tensix) {
builder->create<emitc::IncludeOp>(loc, "llk_defs.h",
/*isStandard=*/false);
builder->create<emitc::IncludeOp>(loc, "compute_kernel_api/common.h",
Expand Down Expand Up @@ -510,28 +510,30 @@ class ThreadConfigHelper {
builder->create<emitc::IncludeOp>(loc, "compute_kernel_api/reduce.h",
/*isStandard=*/false);
builder->create<emitc::VerbatimOp>(loc, "namespace NAMESPACE {");
}
}
}

~ThreadConfigHelper() {
if (kernelConfig.getThreadType() == ttkernel::ThreadType::Tensix) {
if (threadType == ttkernel::ThreadType::Tensix) {
builder->create<emitc::VerbatimOp>(loc, "void MAIN { kernel_main(); }");
builder->create<emitc::VerbatimOp>(loc,
"}"); // close namespace NAMESPACE
}
}

private:

OpBuilder *builder;
Location loc;
ttkernel::KernelConfigInterface kernelConfig;
ttkernel::ThreadType threadType;
};

LogicalResult convertTTKernelRegionToEmitC(
OpBuilder &builder, Region *region,
const ttkernel::KernelConfigInterface &kernelConfig) {
const ttkernel::ThreadType &threadType) {
ThreadConfigHelper threadConfigHelper(&builder, region->getLoc(),
kernelConfig);
threadType);

auto funcOp = builder.create<func::FuncOp>(
region->getLoc(), "kernel_main",
Expand All @@ -552,20 +554,33 @@ LogicalResult convertTTKernelRegionToEmitC(

LogicalResult
emitOpRegionAsCpp(Region *region, std::string &regionCpp,
const ttkernel::KernelConfigInterface &kernelConfig) {
OpBuilder builder(region->getContext());
const ttkernel::ThreadType &threadType) {

llvm::raw_string_ostream os(regionCpp);
return emitOpRegionAsCpp(region, os, threadType);
}

LogicalResult
emitOpRegionAsCpp(Region *region, llvm::raw_ostream &os,
const ttkernel::ThreadType &threadType) {

// We must load the EmitC dialect before we can emit any EmitC code. This
// dialect won't be loaded by MLIR until pass manager starts a pass that
// depends on it. Because we want to emit EmitC code before that, we need to
// load it here.
region->getContext()->getOrLoadDialect<emitc::EmitCDialect>();

OpBuilder builder(region->getContext());
// We will wrap everything in a module op so that we can run the
// translation.
auto moduleWrapper =
builder.create<mlir::ModuleOp>(region->getLoc(), "module_wrapper");
builder.setInsertionPointToStart(moduleWrapper.getBody());

if (convertTTKernelRegionToEmitC(builder, region, kernelConfig).failed()) {
if (convertTTKernelRegionToEmitC(builder, region, threadType).failed()) {
return failure();
}

llvm::raw_string_ostream os(regionCpp);
if (emitc::translateToCpp(moduleWrapper, os).failed()) {
return failure();
}
Expand All @@ -579,17 +594,13 @@ emitDispatchOpRegionsAsCpp(ttmetal::DispatchOp dispatchOp,
assert(cppStrings.size() == dispatchOp.getNumRegions() &&
"cppStrings size must match number of regions");

// We must load the EmitC dialect before we can emit any EmitC code. This
// dialect won't be loaded by MLIR until pass manager starts a pass that
// depends on it. Because we want to emit EmitC code before that, we need to
// load it here.
dispatchOp.getContext()->getOrLoadDialect<emitc::EmitCDialect>();


for (auto &reg : dispatchOp->getRegions()) {
auto kernelConfig = mlir::cast<ttkernel::KernelConfigInterface>(
dispatchOp.getKernelConfigs()[reg.getRegionNumber()]);
if (emitDispatchOpRegionAsCpp(&reg, cppStrings[reg.getRegionNumber()],
kernelConfig)
if (emitOpRegionAsCpp(&reg, cppStrings[reg.getRegionNumber()],
kernelConfig.getThreadType())
.failed()) {
return llvm::failure();
}
Expand All @@ -598,4 +609,41 @@ emitDispatchOpRegionsAsCpp(ttmetal::DispatchOp dispatchOp,
return success();
}

LogicalResult
emitNocKernelAsCpp(mlir::ModuleOp op, llvm::raw_ostream &os)
{
return emitKernelAsCpp(op, os, ttkernel::ThreadType::Noc);
}

LogicalResult
emitTensixKernelAsCpp(mlir::ModuleOp op, llvm::raw_ostream &os)
{
return emitKernelAsCpp(op, os, ttkernel::ThreadType::Tensix);
}

LogicalResult
emitKernelAsCpp(mlir::ModuleOp op, llvm::raw_ostream &os, const ttkernel::ThreadType &threadType )
{
std::vector<Operation*> ops;
op->walk([&](func::FuncOp entry) {
ops.push_back(entry);
});

// PassManager pm(ops[0]->getContext());
// pm.addPass(createConvertTTKernelToEmitC());

// if (pm.run(ops[0]).failed()) {
// return failure();
// }

for (auto &reg : ops[0]->getRegions()) {
if (emitOpRegionAsCpp(&reg, os,
threadType)
.failed()) {
return llvm::failure();
}
}
return llvm::success();
}

} // namespace mlir::tt
85 changes: 71 additions & 14 deletions lib/Target/TTKernel/TTKernelToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,87 @@
#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h"
#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
#include "mlir/Conversion/FuncToEmitC/FuncToEmitCPass.h"
#include <llvm/Support/raw_ostream.h>

namespace mlir::tt::ttkernel {

/// This pass illustrates the IR nesting through printing.
struct Printer{

/// The three methods below are mutually recursive and follow the nesting of
/// the IR: operation->region->block->operation->...
public:
void printOperation(Operation *op) {
// Print the operation itself and some of its properties
printIndent() << "visiting op: '" << op->getName() << "' with "
<< op->getNumOperands() << " operands and "
<< op->getNumResults() << " results\n";
// Print the operation attributes
if (!op->getAttrs().empty()) {
printIndent() << op->getAttrs().size() << " attributes:\n";
for (NamedAttribute attr : op->getAttrs())
printIndent() << " - '" << attr.getName().getValue() << "' : '"
<< attr.getValue() << "'\n";
}

// Recurse into each of the regions attached to the operation.
printIndent() << " " << op->getNumRegions() << " nested regions:\n";
auto indent = pushIndent();
for (Region &region : op->getRegions())
printRegion(region);
}

void printRegion(Region &region) {
// A region does not hold anything by itself other than a list of blocks.
printIndent() << "Region with " << region.getBlocks().size()
<< " blocks:\n";
auto indent = pushIndent();
for (Block &block : region.getBlocks())
printBlock(block);
}

void printBlock(Block &block) {
// Print the block intrinsics properties (basically: argument list)
printIndent()
<< "Block with " << block.getNumArguments() << " arguments, "
<< block.getNumSuccessors()
<< " successors, and "
// Note, this `.size()` is traversing a linked-list and is O(n).
<< block.getOperations().size() << " operations\n";

// Block main role is to hold a list of Operations: let's recurse.
auto indent = pushIndent();
for (Operation &op : block.getOperations())
printOperation(&op);
}

/// Manages the indentation as we traverse the IR nesting.
int indent;
struct IdentRAII {
int &indent;
IdentRAII(int &indent) : indent(indent) {}
~IdentRAII() { --indent; }
};
void resetIndent() { indent = 0; }
IdentRAII pushIndent() { return IdentRAII(++indent); }

llvm::raw_ostream &printIndent() {
for (int i = 0; i < indent; ++i)
llvm::outs() << " ";
return llvm::outs();
}
};

static llvm::LogicalResult translateModuleToCpp(
Operation *op, llvm::raw_ostream &os) {
ModuleOp module = dyn_cast<ModuleOp>(op);
assert(module && "Expected ModuleOp as top level operation");
mlir::PassManager pm(op->getContext());
// return mlir::tt::emitTensixKernelAsCpp(module, os);
return mlir::tt::emitNocKernelAsCpp(module, os);

pm.addPass(mlir::tt::createConvertTTKernelToEmitC());
pm.addPass(mlir::createConvertArithToEmitC());
pm.addPass(mlir::createSCFToEmitC());
pm.addPass(mlir::createConvertFuncToEmitC());

if (mlir::failed(pm.run(op))) {
return llvm::failure();
}

if ( mlir::failed( mlir::emitc::translateToCpp(op, os) ) ) {
return llvm::failure();
}
return success();
}


LogicalResult translateTTKernelToCpp(
Operation *op, llvm::raw_ostream &os) {
return translateModuleToCpp(op, os);
Expand Down
4 changes: 2 additions & 2 deletions lib/Target/TTKernel/TTKernelToCppRegistration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include "ttmlir/Dialect/TTKernel/IR/TTKernel.h"
#include "ttmlir/Target/TTKernel/TTKernelToCpp.h"
#include "mlir/Dialect/SCF/IR/SCF.h"

#include <mlir/Dialect/MemRef/IR/MemRef.h>
using namespace mlir;

namespace mlir::tt::ttkernel {
Expand All @@ -24,7 +24,7 @@ void registerTTKernelToCpp() {
[](DialectRegistry &registry) {
registry.insert<mlir::scf::SCFDialect,
mlir::tt::ttkernel::TTKernelDialect, mlir::arith::ArithDialect,
mlir::emitc::EmitCDialect, mlir::func::FuncDialect, mlir::tt::TTDialect>();
mlir::emitc::EmitCDialect, mlir::func::FuncDialect, mlir::tt::TTDialect, mlir::memref::MemRefDialect>();
});
}

Expand Down

0 comments on commit 3a50482

Please sign in to comment.