Skip to content

Commit

Permalink
Add embedding op (#473)
Browse files Browse the repository at this point in the history
  • Loading branch information
svuckovicTT authored Aug 23, 2024
1 parent 48b3340 commit 3819089
Show file tree
Hide file tree
Showing 13 changed files with 187 additions and 1 deletion.
20 changes: 20 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,26 @@ def TTIR_MeanOp : TTIR_ReductionOp<"mean"> {
}];
}

def TTIR_EmbeddingOp : TTIR_DPSOp<"embedding"> {
let summary = "Embedding op.";
let description = [{
Embedding operation.
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$weight,
AnyRankedTensor:$output,
TT_OperandConstraintArrayAttr:$operand_constraints);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}

def TTIR_SoftmaxOp : TTIR_DPSOp<"softmax"> {
let summary = "Softmax operation.";
let description = [{
Expand Down
19 changes: 19 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,25 @@ def TTNN_MeanOp : TTNN_ReductionOp<"mean"> {
}];
}

def TTNN_EmbeddingOp : TTNN_NamedDPSOp<"embedding"> {
let summary = "Embedding op.";
let description = [{
Embedding operation.
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$weight,
AnyRankedTensor:$output);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}

def TTNN_SoftmaxOp : TTNN_NamedDPSOp<"softmax"> {
let summary = "Softmax op.";
let description = [{
Expand Down
7 changes: 7 additions & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ table ReductionOp {
keep_dim: bool;
}

table EmbeddingOp {
input: tt.target.TensorRef;
weight: tt.target.TensorRef;
output: tt.target.TensorRef;
}

table SoftmaxOp {
in: tt.target.TensorRef;
out: tt.target.TensorRef;
Expand Down Expand Up @@ -90,6 +96,7 @@ union OpType {
EltwiseOp,
MatmulOp,
ReductionOp,
EmbeddingOp,
SoftmaxOp,
TransposeOp
}
Expand Down
20 changes: 19 additions & 1 deletion lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/Casting.h"
#include <llvm/Support/LogicalResult.h>

using namespace mlir;
using namespace mlir::tt;
Expand Down Expand Up @@ -111,6 +112,22 @@ class ReductionOpConversionPattern : public OpConversionPattern<TTIROpTy> {
}
};

class EmbeddingOpConversionPattern
: public OpConversionPattern<ttir::EmbeddingOp> {
public:
using OpConversionPattern<ttir::EmbeddingOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ttir::EmbeddingOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<ttnn::EmbeddingOp>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), adaptor.getWeight(), adaptor.getOutput());

return success();
}
};

class SoftmaxOpConversionPattern : public OpConversionPattern<ttir::SoftmaxOp> {
public:
using OpConversionPattern<ttir::SoftmaxOp>::OpConversionPattern;
Expand Down Expand Up @@ -179,8 +196,9 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
ElementwiseBinaryOpConversionPattern<ttir::DivOp, ttnn::DivOp>,
ReductionOpConversionPattern<ttir::SumOp, ttnn::SumOp>,
ReductionOpConversionPattern<ttir::MeanOp, ttnn::MeanOp>,
TransposeOpConversionPattern,
EmbeddingOpConversionPattern,
SoftmaxOpConversionPattern,
TransposeOpConversionPattern,
MatmulOpConversionPattern
>(typeConverter, ctx);
// ANCHOR_END: op_rewriter_pattern_set
Expand Down
3 changes: 3 additions & 0 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,10 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
patterns.add<DefaultOpConversionPattern<ttnn::MeanOp>>(typeConverter, ctx);

// Other ops
//
patterns.add<DefaultOpConversionPattern<ttnn::SoftmaxOp>>(typeConverter, ctx);
patterns.add<DefaultOpConversionPattern<ttnn::EmbeddingOp>>(typeConverter,
ctx);
}

} // namespace mlir::tt
24 changes: 24 additions & 0 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "ttmlir/Dialect/TTIR/IR/TTIROps.h"

#include "ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.cpp.inc"
#include <mlir/IR/BuiltinTypes.h>

#define GET_OP_CLASSES
#include "ttmlir/Dialect/TTIR/IR/TTIROps.cpp.inc"
Expand Down Expand Up @@ -72,6 +73,29 @@ void mlir::tt::ttir::MultiplyOp::buildGenericRegion(
block);
}

::mlir::LogicalResult mlir::tt::ttir::EmbeddingOp::verify() {
::mlir::RankedTensorType inputType = getInput().getType();
::mlir::RankedTensorType weightType = getWeight().getType();
::mlir::RankedTensorType outputType = getOutput().getType();

// inputType can have any rank

// weightType must have rank of 2: (dictionary_size, embedding_size)
//
if (weightType.getRank() != 2) {
return emitOpError("Weight must be a 2D tensor");
}

// outputType must have rank of inputType + and additional dimension of
// embedding_size
//
if (outputType.getRank() - inputType.getRank() != 1) {
return emitOpError("Output must have one dimension more than input");
}

return success();
}

::mlir::LogicalResult mlir::tt::ttir::SoftmaxOp::verify() {
::mlir::RankedTensorType inputType = getInput().getType();
::mlir::RankedTensorType outputType = getOutput().getType();
Expand Down
23 changes: 23 additions & 0 deletions lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,29 @@ ::mlir::LogicalResult mlir::tt::ttnn::ToMemoryConfigOp::verify() {
return success();
}

::mlir::LogicalResult mlir::tt::ttnn::EmbeddingOp::verify() {
::mlir::RankedTensorType inputType = getInput().getType();
::mlir::RankedTensorType weightType = getWeight().getType();
::mlir::RankedTensorType outputType = getOutput().getType();

// inputType can have any rank

// weightType must have rank of 2: (dictionary_size, embedding_size)
//
if (weightType.getRank() != 2) {
return emitOpError("Weight must be a 2D tensor");
}

// outputType must have rank of inputType + and additional dimension of
// embedding_size
//
if (outputType.getRank() - inputType.getRank() != 1) {
return emitOpError("Output must have one dimension more than input");
}

return success();
}

::mlir::LogicalResult mlir::tt::ttnn::SoftmaxOp::verify() {
::mlir::RankedTensorType inputType = getInput().getType();
::mlir::RankedTensorType outputType = getOutput().getType();
Expand Down
16 changes: 16 additions & 0 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,18 @@ createTransposeOp(FlatbufferObjectCache &cache, TransposeOp op) {
return ::tt::target::ttnn::CreateTransposeOp(*cache.fbb, in, out, dim0, dim1);
}

template <typename EmbeddingOp>
::flatbuffers::Offset<::tt::target::ttnn::EmbeddingOp>
createEmbeddingOp(FlatbufferObjectCache &cache, EmbeddingOp op) {
auto in0 =
cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput()));
auto in1 = cache.at<::tt::target::TensorRef>(
getOperandThroughDPSOps(op.getWeight()));
auto output = cache.at<::tt::target::TensorRef>(
getOperandThroughDPSOps(op.getResult()));
return ::tt::target::ttnn::CreateEmbeddingOp(*cache.fbb, in0, in1, output);
}

template <typename SoftmaxOp>
::flatbuffers::Offset<::tt::target::ttnn::SoftmaxOp>
createSoftmaxOp(FlatbufferObjectCache &cache, SoftmaxOp op) {
Expand Down Expand Up @@ -267,6 +279,10 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,
return createOperation(cache, createReductionOp(cache, meanOp),
debugString);
}
if (auto embeddingOp = dyn_cast<EmbeddingOp>(op); embeddingOp) {
return createOperation(cache, createEmbeddingOp(cache, embeddingOp),
debugString);
}
if (auto softmaxOp = dyn_cast<SoftmaxOp>(op); softmaxOp) {
return createOperation(cache, createSoftmaxOp(cache, softmaxOp),
debugString);
Expand Down
1 change: 1 addition & 0 deletions runtime/include/tt/runtime/detail/ttnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#include "ttnn/operations/data_movement/permute/permute.hpp"
#include "ttnn/operations/eltwise/binary/binary.hpp"
#include "ttnn/operations/eltwise/unary/unary.hpp"
#include "ttnn/operations/embedding/embedding.hpp"
#include "ttnn/operations/matmul/matmul.hpp"
#include "ttnn/operations/normalization/softmax/softmax.hpp"
#include "ttnn/operations/reduction/generic/generic_reductions.hpp"
Expand Down
14 changes: 14 additions & 0 deletions runtime/lib/ttnn/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,17 @@ run(::tt::target::ttnn::ReductionOp const *op, ::ttnn::Device &device,
}
}

static void
run(::tt::target::ttnn::EmbeddingOp const *op, ::ttnn::Device &device,
std::unordered_map<std::uint32_t, ::ttnn::Tensor *> &liveTensors,
std::list<::ttnn::Tensor> &tensorPool) {
::ttnn::Tensor &input = *liveTensors.at(op->input()->global_id());
::ttnn::Tensor &weight = *liveTensors.at(op->weight()->global_id());

tensorPool.push_back(::ttnn::embedding(input, weight));
liveTensors.insert_or_assign(op->output()->global_id(), &tensorPool.back());
}

static void
run(::tt::target::ttnn::SoftmaxOp const *op, ::ttnn::Device &device,
std::unordered_map<std::uint32_t, ::ttnn::Tensor *> &liveTensors,
Expand Down Expand Up @@ -399,6 +410,9 @@ run(::tt::target::ttnn::Operation const *op, ::ttnn::Device &device,
case ::tt::target::ttnn::OpType::ReductionOp: {
return run(op->type_as_ReductionOp(), device, liveTensors, tensorPool);
}
case ::tt::target::ttnn::OpType::EmbeddingOp: {
return run(op->type_as_EmbeddingOp(), device, liveTensors, tensorPool);
}
case ::tt::target::ttnn::OpType::SoftmaxOp: {
return run(op->type_as_SoftmaxOp(), device, liveTensors, tensorPool);
}
Expand Down
14 changes: 14 additions & 0 deletions test/ttmlir/Dialect/TTNN/embedding/embedding_1d_tensor.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s
// XFAIL: true
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<32xf32>, %arg1: tensor<512x128xf32>) -> tensor<32x128xf32> {
// CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]]
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%0 = tensor.empty() : tensor<32x128xf32>
// CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]]
%1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32xf32>, tensor<512x128xf32>, tensor<32x128xf32>) -> tensor<32x128xf32>
// CHECK: "ttnn.close_device"[[C:.*]]
return %1 : tensor<32x128xf32>
}
}
14 changes: 14 additions & 0 deletions test/ttmlir/Dialect/TTNN/embedding/embedding_non_tile.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s
// UNSUPPORTED: true
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<1x32xf32>, %arg1: tensor<512x128xf32>) -> tensor<1x32x128xf32> {
// CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]]
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%0 = tensor.empty() : tensor<1x32x128xf32>
// CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]]
%1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32xf32>, tensor<512x128xf32>, tensor<1x32x128xf32>) -> tensor<1x32x128xf32>
// CHECK: "ttnn.close_device"[[C:.*]]
return %1 : tensor<1x32x128xf32>
}
}
13 changes: 13 additions & 0 deletions test/ttmlir/Dialect/TTNN/embedding/simple_embedding.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<32x32xf32>, %arg1: tensor<512x128xf32>) -> tensor<32x32x128xf32> {
// CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]]
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%0 = tensor.empty() : tensor<32x32x128xf32>
// CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]]
%1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xf32>, tensor<512x128xf32>, tensor<32x32x128xf32>) -> tensor<32x32x128xf32>
// CHECK: "ttnn.close_device"[[C:.*]]
return %1 : tensor<32x32x128xf32>
}
}

0 comments on commit 3819089

Please sign in to comment.