Skip to content

Commit

Permalink
Adding concat op (#497)
Browse files Browse the repository at this point in the history
Adding `concat` op
  • Loading branch information
mtopalovicTT authored Aug 27, 2024
1 parent 287a2ef commit c61a554
Show file tree
Hide file tree
Showing 11 changed files with 186 additions and 2 deletions.
20 changes: 20 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,26 @@ def TTIR_TransposeOp : TTIR_DPSOp<"transpose"> {
let hasVerifier = 1;
}

def TTIR_ConcatOp : TTIR_DPSOp<"concat"> {
let summary = "Concat op.";
let description = [{
Concat tensors along a given dimension.
}];

let arguments = (ins Variadic<AnyRankedTensor>:$inputs,
AnyRankedTensor:$output,
SI32Attr:$dim,
TT_OperandConstraintArrayAttr:$operand_constraints);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}

// ANCHOR: adding_an_op_matmul_ttir
def TTIR_MatmulOp : TTIR_DPSOp<"matmul"> {
let summary = "Matrix multiply operation.";
Expand Down
20 changes: 20 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,26 @@ def TTNN_TransposeOp : TTNN_NamedDPSOp<"transpose"> {
let hasVerifier = 1;
}


def TTNN_ConcatOp : TTNN_NamedDPSOp<"concat"> {
let summary = "Concat op.";
let description = [{
Concat tensors along a given dimension.
}];

let arguments = (ins Variadic<AnyRankedTensor>:$inputs,
AnyRankedTensor:$output,
SI32Attr:$dim);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}

// ANCHOR: adding_an_op_matmul_ttnn
def TTNN_MatmulOp : TTNN_NamedDPSOp<"matmul"> {
let arguments = (ins AnyRankedTensor:$a,
Expand Down
9 changes: 8 additions & 1 deletion include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ table TransposeOp {
dim1: int32;
}

table ConcatOp {
inputs: [tt.target.TensorRef];
out: tt.target.TensorRef;
dim: int32;
}

// ANCHOR: adding_an_op_matmul_fbs
table MatmulOp {
in0: tt.target.TensorRef;
Expand All @@ -98,7 +104,8 @@ union OpType {
ReductionOp,
EmbeddingOp,
SoftmaxOp,
TransposeOp
TransposeOp,
ConcatOp
}

table Operation {
Expand Down
15 changes: 15 additions & 0 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,20 @@ class TransposeOpConversionPattern
}
};

class ConcatOpConversionPattern : public OpConversionPattern<ttir::ConcatOp> {
public:
using OpConversionPattern<ttir::ConcatOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ttir::ConcatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<ttnn::ConcatOp>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInputs(), adaptor.getOutput(), adaptor.getDim());
return success();
}
};

} // namespace

// ANCHOR: adding_an_op_matmul_op_rewriter
Expand Down Expand Up @@ -199,6 +213,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
EmbeddingOpConversionPattern,
SoftmaxOpConversionPattern,
TransposeOpConversionPattern,
ConcatOpConversionPattern,
MatmulOpConversionPattern
>(typeConverter, ctx);
// ANCHOR_END: op_rewriter_pattern_set
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
//
patterns.add<DefaultOpConversionPattern<ttnn::TransposeOp>>(typeConverter,
ctx);
patterns.add<DefaultOpConversionPattern<ttnn::ConcatOp>>(typeConverter, ctx);

// Matmul ops
//
Expand Down
37 changes: 37 additions & 0 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,43 @@ ::mlir::LogicalResult mlir::tt::ttir::TransposeOp::verify() {
return success();
}

::mlir::LogicalResult mlir::tt::ttir::ConcatOp::verify() {
mlir::OperandRange inputs = getInputs();
int32_t dim = getDim();
mlir::RankedTensorType firstTensor =
mlir::cast<mlir::RankedTensorType>(inputs.front().getType());
int64_t firstTensorRank = firstTensor.getRank();

// Check that the dimension `dim` is valid.
if (dim < 0 || dim >= firstTensor.getRank()) {
return emitOpError() << "Invalid dimension " << dim
<< " for concatenation.";
}

// Get the rank of the first input tensor
// and check that all input tensors have the same rank
// and that all dimensions except `dim` are the same.
for (auto input : inputs.drop_front()) {
auto inputType = mlir::cast<mlir::RankedTensorType>(input.getType());

// Check if all inputs have the same rank.
if (inputType.getRank() != firstTensorRank) {
return emitOpError("All input tensors must have the same rank.");
}

// Check that dimensions (except `dim`) are the same.
for (int64_t i = 0; i < firstTensorRank; ++i) {
if (i != dim && inputType.getDimSize(i) != firstTensor.getDimSize(i)) {
return emitOpError() << "All input tensors must have the same "
"dimensions, except for dimension "
<< dim << ".";
}
}
}

return mlir::success();
}

// ANCHOR: adding_an_op_matmul_ttir_verify
::mlir::LogicalResult mlir::tt::ttir::MatmulOp::verify() {
::mlir::RankedTensorType inputAType = getA().getType();
Expand Down
37 changes: 37 additions & 0 deletions lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,43 @@ ::mlir::LogicalResult mlir::tt::ttnn::TransposeOp::verify() {
return success();
}

::mlir::LogicalResult mlir::tt::ttnn::ConcatOp::verify() {
mlir::OperandRange inputs = getInputs();
int32_t dim = getDim();
mlir::RankedTensorType firstTensor =
mlir::cast<mlir::RankedTensorType>(inputs.front().getType());
int64_t firstTensorRank = firstTensor.getRank();

// Check that the dimension `dim` is valid.
if (dim < 0 || dim >= firstTensor.getRank()) {
return emitOpError() << "Invalid dimension " << dim
<< " for concatenation.";
}

// Get the rank of the first input tensor
// and check that all input tensors have the same rank
// and that all dimensions except `dim` are the same.
for (auto input : inputs.drop_front()) {
auto inputType = mlir::cast<mlir::RankedTensorType>(input.getType());

// Check if all inputs have the same rank.
if (inputType.getRank() != firstTensorRank) {
return emitOpError("All input tensors must have the same rank.");
}

// Check that dimensions (except `dim`) are the same.
for (int64_t i = 0; i < firstTensorRank; ++i) {
if (i != dim && inputType.getDimSize(i) != firstTensor.getDimSize(i)) {
return emitOpError() << "All input tensors must have the same "
"dimensions, except for dimension "
<< dim << ".";
}
}
}

return mlir::success();
}

// ANCHOR: adding_an_op_matmul_ttnn_verify
::mlir::LogicalResult mlir::tt::ttnn::MatmulOp::verify() {
::mlir::RankedTensorType inputAType = getA().getType();
Expand Down
19 changes: 18 additions & 1 deletion lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,21 @@ createTransposeOp(FlatbufferObjectCache &cache, TransposeOp op) {
return ::tt::target::ttnn::CreateTransposeOp(*cache.fbb, in, out, dim0, dim1);
}

template <typename ConcatOp>
::flatbuffers::Offset<::tt::target::ttnn::ConcatOp>
createConcatOp(FlatbufferObjectCache &cache, ConcatOp op) {
std::vector<::flatbuffers::Offset<::tt::target::TensorRef>> ins;
for (auto input : op.getInputs()) {
ins.push_back(
cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(input)));
}
auto out = cache.at<::tt::target::TensorRef>(
getOperandThroughDPSOps(op.getResult()));
int32_t dim = op.getDim();

return ::tt::target::ttnn::CreateConcatOpDirect(*cache.fbb, &ins, out, dim);
}

template <typename EmbeddingOp>
::flatbuffers::Offset<::tt::target::ttnn::EmbeddingOp>
createEmbeddingOp(FlatbufferObjectCache &cache, EmbeddingOp op) {
Expand Down Expand Up @@ -291,7 +306,9 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,
return createOperation(cache, createTransposeOp(cache, transposeOp),
debugString);
}

if (auto concatOp = dyn_cast<ConcatOp>(op); concatOp) {
return createOperation(cache, createConcatOp(cache, concatOp), debugString);
}
llvm_unreachable("unhandled op in emitTTNNOperation");
}

Expand Down
1 change: 1 addition & 0 deletions runtime/include/tt/runtime/detail/ttnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include "ttnn/operations/copy.hpp"
#include "ttnn/operations/core/core.hpp"
#include "ttnn/operations/creation.hpp"
#include "ttnn/operations/data_movement/concat/concat.hpp"
#include "ttnn/operations/data_movement/permute/permute.hpp"
#include "ttnn/operations/eltwise/binary/binary.hpp"
#include "ttnn/operations/eltwise/unary/unary.hpp"
Expand Down
16 changes: 16 additions & 0 deletions runtime/lib/ttnn/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,19 @@ run(::tt::target::ttnn::TransposeOp const *op, ::ttnn::Device &device,
liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back());
}

static void
run(::tt::target::ttnn::ConcatOp const *op, ::ttnn::Device &device,
std::unordered_map<std::uint32_t, ::ttnn::Tensor *> &liveTensors,
std::list<::ttnn::Tensor> &tensorPool) {
std::vector<::ttnn::Tensor> inputs;
for (const auto &input : *op->inputs()) {
inputs.push_back(*liveTensors.at(input->global_id()));
}
int32_t dim = op->dim();
tensorPool.push_back(::ttnn::concat(inputs, dim));
liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back());
}

// ANCHOR: adding_an_op_matmul_runtime
static void
run(::tt::target::ttnn::MatmulOp const *op, ::ttnn::Device &device,
Expand Down Expand Up @@ -441,6 +454,9 @@ run(::tt::target::ttnn::Operation const *op, ::ttnn::Device &device,
case ::tt::target::ttnn::OpType::TransposeOp: {
return run(op->type_as_TransposeOp(), device, liveTensors, tensorPool);
}
case ::tt::target::ttnn::OpType::ConcatOp: {
return run(op->type_as_ConcatOp(), device, liveTensors, tensorPool);
}
default:
throw std::runtime_error("Unsupported operation type");
}
Expand Down
13 changes: 13 additions & 0 deletions test/ttmlir/Dialect/TTNN/simple_concat.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<32x32xf32>, %arg1: tensor<32x64xf32>) -> tensor<32x96xf32> {
// CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]]
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%0 = tensor.empty() : tensor<32x96xf32>
// CHECK: %[[C:.*]] = "ttnn.concat"[[C:.*]]
%1 = "ttir.concat"(%arg0, %arg1, %0) <{dim = 1 : si32, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x64xf32>, tensor<32x96xf32>) -> tensor<32x96xf32>
// CHECK: "ttnn.close_device"[[C:.*]]
return %1 : tensor<32x96xf32>
}
}

0 comments on commit c61a554

Please sign in to comment.