Skip to content

Commit

Permalink
Make CB indices strongly typed #362 (#457)
Browse files Browse the repository at this point in the history
  • Loading branch information
nsmithtt authored Aug 22, 2024
1 parent 56297cc commit 23c3cb2
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 23 deletions.
72 changes: 72 additions & 0 deletions include/ttmlir/Dialect/TTKernel/IR/TTKernelOpsEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,76 @@ def TTKernel_ThreadType : I32EnumAttr<"ThreadType", "TTKernel ThreadTypes",
let cppNamespace = "::mlir::tt::ttkernel";
}

def TTKernel_CBPortIn0 : I32EnumAttrCase<"In0", 0, "cb_in0">;
def TTKernel_CBPortIn1 : I32EnumAttrCase<"In1", 1, "cb_in1">;
def TTKernel_CBPortIn2 : I32EnumAttrCase<"In2", 2, "cb_in2">;
def TTKernel_CBPortIn3 : I32EnumAttrCase<"In3", 3, "cb_in3">;
def TTKernel_CBPortIn4 : I32EnumAttrCase<"In4", 4, "cb_in4">;
def TTKernel_CBPortIn5 : I32EnumAttrCase<"In5", 5, "cb_in5">;
def TTKernel_CBPortIn6 : I32EnumAttrCase<"In6", 6, "cb_in6">;
def TTKernel_CBPortIn7 : I32EnumAttrCase<"In7", 7, "cb_in7">;
def TTKernel_CBPortDataFlow0 : I32EnumAttrCase<"DataFlow0", 8, "cb_dfl0">;
def TTKernel_CBPortDataFlow1 : I32EnumAttrCase<"DataFlow1", 9, "cb_dfl1">;
def TTKernel_CBPortDataFlow2 : I32EnumAttrCase<"DataFlow2", 10, "cb_dfl2">;
def TTKernel_CBPortDataFlow3 : I32EnumAttrCase<"DataFlow3", 11, "cb_dfl3">;
def TTKernel_CBPortDataFlow4 : I32EnumAttrCase<"DataFlow4", 12, "cb_dfl4">;
def TTKernel_CBPortDataFlow5 : I32EnumAttrCase<"DataFlow5", 13, "cb_dfl5">;
def TTKernel_CBPortDataFlow6 : I32EnumAttrCase<"DataFlow6", 14, "cb_dfl6">;
def TTKernel_CBPortDataFlow7 : I32EnumAttrCase<"DataFlow7", 15, "cb_dfl7">;
def TTKernel_CBPortOut0 : I32EnumAttrCase<"Out0", 16, "cb_out0">;
def TTKernel_CBPortOut1 : I32EnumAttrCase<"Out1", 17, "cb_out1">;
def TTKernel_CBPortOut2 : I32EnumAttrCase<"Out2", 18, "cb_out2">;
def TTKernel_CBPortOut3 : I32EnumAttrCase<"Out3", 19, "cb_out3">;
def TTKernel_CBPortOut4 : I32EnumAttrCase<"Out4", 20, "cb_out4">;
def TTKernel_CBPortOut5 : I32EnumAttrCase<"Out5", 21, "cb_out5">;
def TTKernel_CBPortOut6 : I32EnumAttrCase<"Out6", 22, "cb_out6">;
def TTKernel_CBPortOut7 : I32EnumAttrCase<"Out7", 23, "cb_out7">;
def TTKernel_CBPortIntermed0 : I32EnumAttrCase<"Intermed0", 24, "cb_int0">;
def TTKernel_CBPortIntermed1 : I32EnumAttrCase<"Intermed1", 25, "cb_int1">;
def TTKernel_CBPortIntermed2 : I32EnumAttrCase<"Intermed2", 26, "cb_int2">;
def TTKernel_CBPortIntermed3 : I32EnumAttrCase<"Intermed3", 27, "cb_int3">;
def TTKernel_CBPortIntermed4 : I32EnumAttrCase<"Intermed4", 28, "cb_int4">;
def TTKernel_CBPortIntermed5 : I32EnumAttrCase<"Intermed5", 29, "cb_int5">;
def TTKernel_CBPortIntermed6 : I32EnumAttrCase<"Intermed6", 30, "cb_int6">;
def TTKernel_CBPortIntermed7 : I32EnumAttrCase<"Intermed7", 31, "cb_int7">;

def TTKernel_CBPort : I32EnumAttr<"CBPort", "TTKernel Circular Buffer Ports",
[
TTKernel_CBPortIn0,
TTKernel_CBPortIn1,
TTKernel_CBPortIn2,
TTKernel_CBPortIn3,
TTKernel_CBPortIn4,
TTKernel_CBPortIn5,
TTKernel_CBPortIn6,
TTKernel_CBPortIn7,
TTKernel_CBPortDataFlow0,
TTKernel_CBPortDataFlow1,
TTKernel_CBPortDataFlow2,
TTKernel_CBPortDataFlow3,
TTKernel_CBPortDataFlow4,
TTKernel_CBPortDataFlow5,
TTKernel_CBPortDataFlow6,
TTKernel_CBPortDataFlow7,
TTKernel_CBPortOut0,
TTKernel_CBPortOut1,
TTKernel_CBPortOut2,
TTKernel_CBPortOut3,
TTKernel_CBPortOut4,
TTKernel_CBPortOut5,
TTKernel_CBPortOut6,
TTKernel_CBPortOut7,
TTKernel_CBPortIntermed0,
TTKernel_CBPortIntermed1,
TTKernel_CBPortIntermed2,
TTKernel_CBPortIntermed3,
TTKernel_CBPortIntermed4,
TTKernel_CBPortIntermed5,
TTKernel_CBPortIntermed6,
TTKernel_CBPortIntermed7,
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::tt::ttkernel";
}

#endif
10 changes: 5 additions & 5 deletions include/ttmlir/Dialect/TTKernel/IR/TTKernelOpsTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@ class TTKernel_Type<string name, string typeMnemonic, list<Trait> traits = []>
def TTKernel_CB : TTKernel_Type<"CB", "cb"> {
let summary = "TTKernel cb";
let description = "Circular buffer type in TTKernel dialect";
let parameters = (ins "uint64_t":$address,
"uint64_t":$port,
let parameters = (ins "CBPort":$port,
"uint64_t":$address,
"MemRefType":$memref,
"uint64_t":$page_size,
"uint64_t":$num_buffers);
let assemblyFormat = "`<` $address`,` $port`,` $memref`,` $page_size`,` $num_buffers `>`";
let assemblyFormat = "`<` $port`,` $address`,` $memref`,` $page_size`,` $num_buffers `>`";

let extraClassDeclaration = [{
static CBType get(::mlir::MLIRContext *context,
CBPort port,
uint64_t address,
uint64_t port,
MemRefType memref) {
uint64_t numBuffers = 1;
uint64_t pageSize = 0;
Expand All @@ -39,7 +39,7 @@ def TTKernel_CB : TTKernel_Type<"CB", "cb"> {
} else {
pageSize = memref.getShape().back() * (memref.getElementType().getIntOrFloatBitWidth() / 8);
}
return CBType::get(context, address, port, memref, pageSize, numBuffers);
return CBType::get(context, port, address, memref, pageSize, numBuffers);
}

::llvm::ArrayRef<int64_t> getShape() const {
Expand Down
2 changes: 1 addition & 1 deletion lib/CAPI/TTKernelTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ using namespace mlir::tt::ttkernel;

MlirType ttmlirTTKernelCBTypeGet(MlirContext ctx, uint64_t address,
uint64_t port, MlirType memrefType) {
return wrap(CBType::get(unwrap(ctx), address, port,
return wrap(CBType::get(unwrap(ctx), symbolizeCBPort(port).value(), address,
mlir::cast<mlir::MemRefType>(unwrap(memrefType))));
}
92 changes: 85 additions & 7 deletions lib/Dialect/TTMetal/Transforms/KernelsToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,95 @@

namespace mlir::tt::ttmetal {

emitc::OpaqueAttr convertCBPort(Builder &builder, ttkernel::CBPort port) {
switch (port) {
case ttkernel::CBPort::In0:
return builder.getType<emitc::OpaqueAttr>("::tt::CB::c_in0");
case ttkernel::CBPort::In1:
return builder.getType<emitc::OpaqueAttr>("::tt::CB::c_in1");
case ttkernel::CBPort::In2:
return builder.getType<emitc::OpaqueAttr>("::tt::CB::c_in2");
case ttkernel::CBPort::In3:
return builder.getType<emitc::OpaqueAttr>("::tt::CB::c_in3");
case ttkernel::CBPort::In4:
return builder.getType<emitc::OpaqueAttr>("::tt::CB::c_in4");
case ttkernel::CBPort::In5:
return builder.getType<emitc::OpaqueAttr>("::tt::CB::c_in5");
case ttkernel::CBPort::In6:
return builder.getType<emitc::OpaqueAttr>("::tt::CB::c_in6");
case ttkernel::CBPort::In7:
return builder.getType<emitc::OpaqueAttr>("::tt::CB::c_in7");
case ttkernel::CBPort::DataFlow0:
return builder.getType<emitc::OpaqueAttr>("::tt::CB::dataflow0");
case ttkernel::CBPort::DataFlow1:
return builder.getType<emitc::OpaqueAttr>("::tt::CB::dataflow1");
case ttkernel::CBPort::DataFlow2:
return builder.getType<emitc::OpaqueAttr>("::tt::CB::dataflow2");
case ttkernel::CBPort::DataFlow3:
return builder.getType<emitc::OpaqueAttr>("::tt::CB::dataflow3");
case ttkernel::CBPort::DataFlow4:
return builder.getType<emitc::OpaqueAttr>("::tt::CB::dataflow4");
case ttkernel::CBPort::DataFlow5:
return builder.getType<emitc::OpaqueAttr>("::tt::CB::dataflow5");
case ttkernel::CBPort::DataFlow6:
return builder.getType<emitc::OpaqueAttr>("::tt::CB::dataflow6");
case ttkernel::CBPort::DataFlow7:
return builder.getType<emitc::OpaqueAttr>("::tt::CB::dataflow7");
case ttkernel::CBPort::Out0:
return builder.getType<emitc::OpaqueAttr>("::tt::CB::c_out0");
case ttkernel::CBPort::Out1:
return builder.getType<emitc::OpaqueAttr>("::tt::CB::c_out1");
case ttkernel::CBPort::Out2:
return builder.getType<emitc::OpaqueAttr>("::tt::CB::c_out2");
case ttkernel::CBPort::Out3:
return builder.getType<emitc::OpaqueAttr>("::tt::CB::c_out3");
case ttkernel::CBPort::Out4:
return builder.getType<emitc::OpaqueAttr>("::tt::CB::c_out4");
case ttkernel::CBPort::Out5:
return builder.getType<emitc::OpaqueAttr>("::tt::CB::c_out5");
case ttkernel::CBPort::Out6:
return builder.getType<emitc::OpaqueAttr>("::tt::CB::c_out6");
case ttkernel::CBPort::Out7:
return builder.getType<emitc::OpaqueAttr>("::tt::CB::c_out7");
case ttkernel::CBPort::Intermed0:
return builder.getType<emitc::OpaqueAttr>("::tt::CB::c_intermed0");
case ttkernel::CBPort::Intermed1:
return builder.getType<emitc::OpaqueAttr>("::tt::CB::c_intermed1");
case ttkernel::CBPort::Intermed2:
return builder.getType<emitc::OpaqueAttr>("::tt::CB::c_intermed2");
case ttkernel::CBPort::Intermed3:
return builder.getType<emitc::OpaqueAttr>("::tt::CB::c_intermed3");
case ttkernel::CBPort::Intermed4:
return builder.getType<emitc::OpaqueAttr>("::tt::CB::c_intermed4");
case ttkernel::CBPort::Intermed5:
return builder.getType<emitc::OpaqueAttr>("::tt::CB::c_intermed5");
case ttkernel::CBPort::Intermed6:
return builder.getType<emitc::OpaqueAttr>("::tt::CB::c_intermed6");
case ttkernel::CBPort::Intermed7:
return builder.getType<emitc::OpaqueAttr>("::tt::CB::c_intermed7");
}
llvm_unreachable("Unknown CBPort");
return nullptr;
}

class TTKernelToEmitCTypeConverter : public TypeConverter {
public:
TTKernelToEmitCTypeConverter(MLIRContext *ctx) {
addConversion([](Type type) { return type; });
addConversion([ctx](mlir::tt::ttkernel::NocAddrType type) -> Type {
return Builder(ctx).getI64Type();
});
addConversion([ctx](mlir::tt::ttkernel::CBType type) -> Type {
return Builder(ctx).getType<emitc::OpaqueType>("::tt::CB");
});
}
};

class TTMetalToEmitCFuncArgsRewriter : public OpRewritePattern<func::FuncOp> {
public:
using OpRewritePattern<func::FuncOp>::OpRewritePattern;
TTMetalToEmitCFuncArgsRewriter(TTKernelToEmitCTypeConverter &typeConverter,
MLIRContext *ctx)
: OpRewritePattern<func::FuncOp>(ctx), typeConverter(&typeConverter) {}

LogicalResult matchAndRewrite(func::FuncOp op,
PatternRewriter &rewriter) const final {
Expand All @@ -50,22 +126,25 @@ class TTMetalToEmitCFuncArgsRewriter : public OpRewritePattern<func::FuncOp> {
rewriter.setInsertionPointToStart(&op.getCallableRegion()->front());
for (auto arg : blockArgs) {
auto cb = cast<ttkernel::CBType>(arg.getType());
auto cbType = typeConverter->convertType(cb);
auto var = rewriter.create<emitc::VariableOp>(
op.getLoc(), rewriter.getI32Type(),
rewriter.getI32IntegerAttr(cb.getPort()));
op.getLoc(), cbType, convertCBPort(rewriter, cb.getPort()));
arg.replaceAllUsesWith(var);
}
op.getCallableRegion()->front().eraseArguments(0, blockArgs.size());
op.setType(rewriter.getType<FunctionType>(TypeRange(), TypeRange()));

return success();
}

TTKernelToEmitCTypeConverter *typeConverter;
};

class TTMetalToEmitCReturnRewriter
: public OpRewritePattern<ttkernel::ReturnOp> {
public:
using OpRewritePattern<ttkernel::ReturnOp>::OpRewritePattern;
TTMetalToEmitCReturnRewriter(TTKernelToEmitCTypeConverter &, MLIRContext *ctx)
: OpRewritePattern<ttkernel::ReturnOp>(ctx) {}

LogicalResult matchAndRewrite(ttkernel::ReturnOp op,
PatternRewriter &rewriter) const final {
Expand Down Expand Up @@ -188,9 +267,8 @@ LogicalResult emitDispatchOpRegionAsCpp(DispatchOp origOp,
TTKernelToEmitCTypeConverter typeConverter(module.getContext());
RewritePatternSet patterns(module.getContext());

patterns.add<TTMetalToEmitCFuncArgsRewriter, TTMetalToEmitCReturnRewriter>(
module.getContext());
patterns.add<TTMetalToEmitCOpaqueRewriter<ttkernel::BuiltinOp>,
patterns.add<TTMetalToEmitCFuncArgsRewriter, TTMetalToEmitCReturnRewriter,
TTMetalToEmitCOpaqueRewriter<ttkernel::BuiltinOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::CBPushBackOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::CBPopFrontOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::CBReserveBackOp>,
Expand Down
14 changes: 8 additions & 6 deletions lib/Dialect/TTMetal/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,11 +264,13 @@ class TTIRToTTMetalLayoutRewriter : public OpRewritePattern<ttir::ToLayoutOp> {
? inputLayout.getElementSizeBytes()
: outputLayout.getElementSizeBytes();
Type inputCBTy = rewriter.getType<ttkernel::CBType>(
inputBaseAddress, 0, mlir::cast<MemRefType>(inputLayout.getMemref()),
pageSize, /*num_buffers*/ 1);
ttkernel::CBPort::In0, inputBaseAddress,
mlir::cast<MemRefType>(inputLayout.getMemref()), pageSize,
/*num_buffers*/ 1);
Type outputCBTy = rewriter.getType<ttkernel::CBType>(
outputBaseAddress, 16, mlir::cast<MemRefType>(outputLayout.getMemref()),
pageSize, /*num_buffers*/ 1);
ttkernel::CBPort::Out0, outputBaseAddress,
mlir::cast<MemRefType>(outputLayout.getMemref()), pageSize,
/*num_buffers*/ 1);
tensixBlock->addArgument(inputCBTy, op.getLoc());
tensixBlock->addArgument(outputCBTy, op.getLoc());

Expand Down Expand Up @@ -407,8 +409,8 @@ class TTIRToTTMetalDispatchRewriter : public OpRewritePattern<ttir::GenericOp> {
auto tensor = mlir::cast<RankedTensorType>(arg.getType());
auto buffer = mlir::cast<BufferAttr>(tensor.getEncoding());
auto memref = buffer.getMemref();
rewrittenBlockArgumentTypes.push_back(
rewriter.getType<ttkernel::CBType>(address, port, memref));
rewrittenBlockArgumentTypes.push_back(rewriter.getType<ttkernel::CBType>(
ttkernel::symbolizeCBPort(port).value(), address, memref));
}
return rewrittenBlockArgumentTypes;
}
Expand Down
7 changes: 4 additions & 3 deletions lib/Dialect/TTMetal/Transforms/SerializeToBinary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,10 @@ ::tt::target::Dim2dRange toFlatbuffer(CoreRangeAttr coreRange) {
::flatbuffers::Offset<::tt::target::CBDesc>
cbTypeToFlatbuffer(FlatbufferObjectCache &cache, ttkernel::CBType cbType) {
auto memref = cache.getOrCreate(cbType.getMemref(), memrefAttrToFlatbuffer);
return ::tt::target::CreateCBDesc(*cache.fbb, cbType.getPort(), memref,
cbType.getPageSize(),
cbType.getNumBuffers());
return ::tt::target::CreateCBDesc(
*cache.fbb,
static_cast<std::underlying_type_t<ttkernel::CBPort>>(cbType.getPort()),
memref, cbType.getPageSize(), cbType.getNumBuffers());
}

class TTMetalSerializeToBinary
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: ttmlir-opt --ttir-to-ttmetal-backend-pipeline %s | FileCheck %s
// RUN: ttmlir-opt --ttir-load-system-desc="path=%system_desc_path%" --ttir-to-ttmetal-backend-pipeline --ttmetal-serialize-to-binary="output=%t.ttm" %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>

func.func @multiply(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> {
Expand Down

0 comments on commit 23c3cb2

Please sign in to comment.