Skip to content

Commit

Permalink
[EmitC] Support MNIST (#1663)
Browse files Browse the repository at this point in the history
  • Loading branch information
svuckovicTT authored Dec 25, 2024
1 parent b4405d0 commit 436a9b8
Show file tree
Hide file tree
Showing 11 changed files with 193 additions and 143 deletions.
11 changes: 6 additions & 5 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -903,11 +903,12 @@ def TTNN_EmptyOp : TTNN_Op<"empty"> {
Tensor empty operation
}];

let arguments = (ins Optional<TT_Device>:$device,
TTNN_ShapeAttr:$shape,
OptionalAttr<TT_DataTypeAttr>:$dtype,
OptionalAttr<TTNN_LayoutAttr>:$layout,
OptionalAttr<TTNN_MemoryConfigAttr>:$memory_config);
let arguments = (ins TTNN_ShapeAttr:$shape,
TT_DataTypeAttr:$dtype,
TTNN_LayoutAttr:$layout,
TT_Device:$device,
TTNN_MemoryConfigAttr:$memory_config);

let results = (outs AnyRankedTensor:$result);

let hasVerifier = 1;
Expand Down
35 changes: 14 additions & 21 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,33 +64,26 @@ class TensorEmptyConversionPattern
ttnn::LayoutAttr tensorLayoutAttr =
ttnn::LayoutAttr::get(op.getContext(), ttnnLayoutEnum);

// If the tensor is not going to device, we can create the op without
// device-specific attributes
// Device
//
ttnn::TensorMemoryLayoutAttr memLayout = layoutAttr.getMemLayout();
if (!memLayout) {
rewriter.replaceOpWithNewOp<ttnn::EmptyOp>(
op, this->getTypeConverter()->convertType(op.getType()), nullptr,
shapeAttr, dTypeAttr, tensorLayoutAttr, nullptr);

return success();
}

ttnn::BufferType bufferType = layoutAttr.getBufferType();
auto device = ::ttnn::utils::getOrInsertDevice(rewriter, op);

// Create MemoryConfigAttr
//
auto device = ::ttnn::utils::getOrInsertDevice(rewriter, op);
llvm::SmallVector<int64_t> shardShape = layoutAttr.getShardShape();
ttnn::MemoryConfigAttr memoryConfigAttr = ttnn::MemoryConfigAttr::get(
op.getContext(), ttnn::BufferTypeAttr::get(op.getContext(), bufferType),
ttnn::ShardSpecAttr::get(
op.getContext(), ttnn::ShapeAttr::get(op.getContext(), shardShape)),
memLayout);
ttnn::BufferTypeAttr bufferTypeAttr =
ttnn::BufferTypeAttr::get(op.getContext(), layoutAttr.getBufferType());
ttnn::ShardSpecAttr shardSpecAttr = ttnn::ShardSpecAttr::get(
op.getContext(),
ttnn::ShapeAttr::get(op.getContext(), layoutAttr.getShardShape()));
ttnn::MemoryConfigAttr memoryConfigAttr =
ttnn::MemoryConfigAttr::get(op.getContext(), bufferTypeAttr,
shardSpecAttr, layoutAttr.getMemLayout());

// Replace op
//
rewriter.replaceOpWithNewOp<ttnn::EmptyOp>(
op, this->getTypeConverter()->convertType(op.getType()), device,
shapeAttr, dTypeAttr, tensorLayoutAttr, memoryConfigAttr);
op, this->getTypeConverter()->convertType(op.getType()), shapeAttr,
dTypeAttr, tensorLayoutAttr, device, memoryConfigAttr);

return success();
}
Expand Down
174 changes: 126 additions & 48 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,43 @@ class DefaultOpConversionPattern
}
};

// Eltwise Unary op conversion pattern
//
// Currently, it has to insert nullopts for some parameters that are not
// modelled in the dialect (memcfg)
//
template <typename SourceOp, typename Adaptor = typename SourceOp::Adaptor>
class EltwiseUnaryOpConversionPattern
: public TTNNToEmitCBaseOpConversionPattern<SourceOp> {

public:
EltwiseUnaryOpConversionPattern(const TypeConverter &typeConverter,
MLIRContext *context,
PatternBenefit benefit = 1)
: TTNNToEmitCBaseOpConversionPattern<SourceOp>(typeConverter, context,
benefit) {}

LogicalResult
matchAndRewrite(SourceOp srcOp, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// emitc::CallOpaqueOp needs to know positions of operands vs attributes, so
// an ArrayAttr object holding IndexTypes is created to denote this
//
llvm::SmallVector<Attribute, 5> attrs;
attrs.push_back(mlir::IntegerAttr::get(rewriter.getIndexType(), 0));
attrs.push_back(ttnn_to_emitc::utils::createStdNullopt(rewriter));
attrs.push_back(mlir::IntegerAttr::get(rewriter.getIndexType(), 1));

ArrayAttr arrayAttrs = ArrayAttr::get(srcOp->getContext(), attrs);

rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
srcOp, this->getTypeConverter()->convertType(srcOp.getType(0)),
this->convertOpName(srcOp), arrayAttrs, nullptr, adaptor.getOperands());

return success();
}
};

// Eltwise Binary op conversion pattern
//
// Currently, it has to insert nullopts for some parameters that are not
Expand All @@ -132,6 +169,7 @@ class EltwiseBinaryOpConversionPattern
LogicalResult
matchAndRewrite(SourceOp srcOp, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

// emitc::CallOpaqueOp needs to know positions of operands vs attributes, so
// an ArrayAttr object holding IndexTypes is created to denote this
//
Expand All @@ -152,6 +190,50 @@ class EltwiseBinaryOpConversionPattern
}
};

// Matmul op conversion pattern
//
class MatmulOpConversionPattern
: public TTNNToEmitCBaseOpConversionPattern<ttnn::MatmulOp> {

public:
MatmulOpConversionPattern(const TypeConverter &typeConverter,
MLIRContext *context, PatternBenefit benefit = 1)
: TTNNToEmitCBaseOpConversionPattern<ttnn::MatmulOp>(typeConverter,
context, benefit) {}

LogicalResult
matchAndRewrite(ttnn::MatmulOp matmulOp, ttnn::MatmulOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

// emitc::CallOpaqueOp needs to know positions of operands vs attributes, so
// an ArrayAttr object holding IndexTypes is created to denote this
//
ArrayAttr arrayAttrs = rewriter.getArrayAttr({
mlir::IntegerAttr::get(rewriter.getIndexType(), 0),
mlir::IntegerAttr::get(rewriter.getIndexType(), 1),
ttnn_to_emitc::utils::convertBoolAttr(
rewriter, BoolAttr::get(rewriter.getContext(), false)),
ttnn_to_emitc::utils::convertBoolAttr(
rewriter, BoolAttr::get(rewriter.getContext(), false)),
ttnn_to_emitc::utils::createStdNullopt(rewriter),
ttnn_to_emitc::utils::createStdNullopt(rewriter),
ttnn_to_emitc::utils::createStdNullopt(rewriter),
ttnn_to_emitc::utils::createStdNullopt(rewriter),
ttnn_to_emitc::utils::createStdNullopt(rewriter),
ttnn_to_emitc::utils::createStdNullopt(rewriter),
ttnn_to_emitc::utils::createStdNullopt(rewriter),
mlir::IntegerAttr::get(rewriter.getIndexType(), 2),
});

rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
matmulOp, this->getTypeConverter()->convertType(matmulOp.getType()),
this->convertOpName(matmulOp), arrayAttrs, nullptr,
adaptor.getOperands());

return success();
}
};

// GetDeviceOp conversion pattern
//
class GetDeviceOpConversionPattern
Expand Down Expand Up @@ -390,46 +472,42 @@ class EmptyOpConversionPattern
tt::DataTypeAttr dataTypeAttr = srcOp.getDtypeAttr();
ttnn::LayoutAttr layoutAttr = srcOp.getLayoutAttr();

// Find the GetDeviceOp
//
ttnn::GetDeviceOp getDeviceOp;
srcOp->getParentOp()->walk(
[&getDeviceOp](ttnn::GetDeviceOp currGetDeviceOp) {
getDeviceOp = currGetDeviceOp;
});

// Create ttnn::Shape() call
//
emitc::ExpressionOp shapeExpressionOp = ttnn_to_emitc::utils::createShapeOp(
rewriter, shapeAttr, srcOp->getBlock(), srcOp.getLoc());

llvm::SmallVector<Value, 3> operands{
shapeExpressionOp->getResult(0),
};

// If there is a device operand, create tensor on device
// Create operands vector
//
ArrayAttr arrayAttr;
if (adaptor.getDevice()) {
operands.append(1, adaptor.getDevice());
llvm::SmallVector<Value, 3> operands{shapeExpressionOp->getResult(0),
adaptor.getDevice()};

// Create MemoryConfig object first, then pass it to the op
//
emitc::CallOpaqueOp memCfgOp = ttnn_to_emitc::utils::createMemoryConfigOp(
rewriter, srcOp.getMemoryConfig().value(), srcOp.getLoc());
// Create MemoryConfig object first, then pass it to the op
//
emitc::CallOpaqueOp memCfgOp = ttnn_to_emitc::utils::createMemoryConfigOp(
rewriter, srcOp.getMemoryConfig(), srcOp.getLoc());

// Concat operands and MemoryConfig object
//
operands.append(1, memCfgOp.getResult(0));
// Concat operands and MemoryConfig object
//
operands.append(1, memCfgOp.getResult(0));

// Create ArrayAttr object holding attributes and pointers to operands
//
arrayAttr = rewriter.getArrayAttr({
rewriter.getIndexAttr(0), // ttnn::Shape
ttnn_to_emitc::utils::convertDType(rewriter, dataTypeAttr),
ttnn_to_emitc::utils::convertLayoutAttr(rewriter, layoutAttr),
rewriter.getIndexAttr(1), // ttnn::Device
rewriter.getIndexAttr(2), // ttnn::MemoryConfig
});
} else {
arrayAttr = rewriter.getArrayAttr({
rewriter.getIndexAttr(0), // ttnn::Shape
ttnn_to_emitc::utils::convertDType(rewriter, dataTypeAttr),
ttnn_to_emitc::utils::convertLayoutAttr(rewriter, layoutAttr),
});
}
// Create ArrayAttr object holding attributes and pointers to operands
//
ArrayAttr arrayAttr = rewriter.getArrayAttr({
rewriter.getIndexAttr(0), // ttnn::Shape
ttnn_to_emitc::utils::convertDType(rewriter, dataTypeAttr),
ttnn_to_emitc::utils::convertLayoutAttr(rewriter, layoutAttr),
rewriter.getIndexAttr(1), // ttnn::Device
rewriter.getIndexAttr(2), // ttnn::MemoryConfig
});

// Finally, convert ttir::EmptyOp to ttnn::EmptyOp
//
Expand Down Expand Up @@ -469,14 +547,14 @@ class OnesOpConversionPattern
// Attrs (like shape) need to be instantiated into objects before being
// passed to the op. Therefore:
//
// We first create a ttnn::Shape object (SSA) by calling createShapeOp() and
// add it to the operands vector, but also add an IndexAttr in ArrayAttr to
// reference it (this is an EmitC mechanism that allows for combining Attrs
// and Values when calling an OpaqueOp).
// All the other input params are optional, so we create them on-the-fly
// into the ArrayAttr, whether they are an actual Attr, or a Value pointed
// to by IndexAttr. If they are present, we create the object and pass it to
// the op. If not, we pass std::nullopt.
// We first create a ttnn::Shape object (SSA) by calling createShapeOp()
// and add it to the operands vector, but also add an IndexAttr in
// ArrayAttr to reference it (this is an EmitC mechanism that allows for
// combining Attrs and Values when calling an OpaqueOp). All the other
// input params are optional, so we create them on-the-fly into the
// ArrayAttr, whether they are an actual Attr, or a Value pointed to by
// IndexAttr. If they are present, we create the object and pass it to the
// op. If not, we pass std::nullopt.

// Create ttnn::Shape() call
//
Expand All @@ -489,8 +567,8 @@ class OnesOpConversionPattern

// Create ArrayAttr object holding attributes and pointers to operands
//
// Params that are Values are added to the operands vector on-the-fly, and a
// corresponding IndexAttr is added to the ArrayAttr to reference them.
// Params that are Values are added to the operands vector on-the-fly, and
// a corresponding IndexAttr is added to the ArrayAttr to reference them.
//
size_t operandIndex = 0;
ArrayAttr arrayAttr = rewriter.getArrayAttr({
Expand Down Expand Up @@ -594,8 +672,8 @@ class GetTupleElementOpConversionPattern
getTupleElementOp->getLoc(), rewriter.getIndexType(),
std::to_string(adaptor.getIndex()));

// SubscriptOp also returns an emitc::LValueType, so we wrap the OpaqueType
// with LValueType
// SubscriptOp also returns an emitc::LValueType, so we wrap the
// OpaqueType with LValueType
//
emitc::LValueType lvalueReturnType = emitc::LValueType::get(
emitc::OpaqueType::get(rewriter.getContext(), "ttnn::Tensor"));
Expand All @@ -621,9 +699,9 @@ class TupleOpConversionPattern : public OpConversionPattern<tt::TupleOp> {
LogicalResult
matchAndRewrite(tt::TupleOp tupleOp, tt::TupleOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// EmitC doesn't offer a way to create a vector from a list of values, so we
// need to create a utility function that does this. This is achieved by
// using EmitC's VerbatimOp.
// EmitC doesn't offer a way to create a vector from a list of values, so
// we need to create a utility function that does this. This is achieved
// by using EmitC's VerbatimOp.

// Try to find if utility vec creation function is already defined in the
// module. If not, insert it.
Expand Down Expand Up @@ -708,7 +786,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
DefaultOpConversionPattern<ttnn::LogicalNotOp>,
DefaultOpConversionPattern<ttnn::BitwiseNotOp>,
DefaultOpConversionPattern<ttnn::NegOp>,
DefaultOpConversionPattern<ttnn::ReluOp>,
EltwiseUnaryOpConversionPattern<ttnn::ReluOp>,
DefaultOpConversionPattern<ttnn::LeakyReluOp>,
DefaultOpConversionPattern<ttnn::GeluOp>,
DefaultOpConversionPattern<ttnn::SqrtOp>,
Expand Down Expand Up @@ -761,7 +839,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
// Matmul ops
//
patterns.add<DefaultOpConversionPattern<ttnn::LinearOp>,
DefaultOpConversionPattern<ttnn::MatmulOp>>(typeConverter, ctx);
MatmulOpConversionPattern>(typeConverter, ctx);

// Reduction ops
//
Expand Down
52 changes: 8 additions & 44 deletions lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,6 @@ ::mlir::LogicalResult mlir::tt::ttnn::ArangeOp::verify() {

// EmptyOp verification
::mlir::LogicalResult mlir::tt::ttnn::EmptyOp::verify() {
// ==============================
// === CHECK ATTRIBUTES START ===
// ==============================
// Check that the attributes of the op match the attributes of the output
// tensor type.
//
Expand All @@ -192,50 +189,17 @@ ::mlir::LogicalResult mlir::tt::ttnn::EmptyOp::verify() {

// DataType and Layout
//
if (getLayout().has_value()) {
ttnn::Layout ttnnLayoutEnum = layoutAttr.getLayout();
assert(ttnnLayoutEnum == getLayoutAttr().getValue());
}
if (getDtype().has_value()) {
tt::DataType dtype = layoutAttr.getDataType();
assert(dtype == getDtype());
}
assert(getLayout() == layoutAttr.getLayout());
assert(getDtype() == layoutAttr.getDataType());

// MemoryConfig
// Check that op has MemoryConfigAttr set on itself, then compare internal
// attrs with output tensor attrs.
//
if (getMemoryConfig().has_value()) {
ttnn::BufferType bufferType = layoutAttr.getBufferType();
ttnn::TensorMemoryLayoutAttr tensorMemoryLayoutAttr =
layoutAttr.getMemLayout();
assert(bufferType == getMemoryConfig()->getBufferType().getValue());
assert(tensorMemoryLayoutAttr ==
getMemoryConfig()->getTensorMemoryLayout());
}
// Compare internal attrs with output tensor attrs.
//
// ==============================
// ==== CHECK ATTRIBUTES END ====
// ==============================

// ==============================
// === CHECK SIGNATURES START ===
// ==============================
// Check that call-site uses the correct signature. We only allow 2 for now:
// 1. none, Shape, DataType, Layout, none
// 2. Device, Shape, DataType, Layout, MemoryConfig
//
assert(
// 1.
(!getDevice() && getDtype().has_value() && getLayout().has_value() &&
!getMemoryConfig().has_value()) ||
// 2.
(getDevice() && getDtype().has_value() && getLayout().has_value() &&
getMemoryConfig().has_value()));
//
// ==============================
// ==== CHECK SIGNATURES END ====
// ==============================
assert(getMemoryConfig().getBufferType().getValue() ==
layoutAttr.getBufferType());
assert(getMemoryConfig().getTensorMemoryLayout() ==
layoutAttr.getMemLayout());

return success();
}

Expand Down
Loading

0 comments on commit 436a9b8

Please sign in to comment.