Skip to content

Commit

Permalink
Implemented Transpose op end-to-end. (#311)
Browse files Browse the repository at this point in the history
  • Loading branch information
vladimirjovanovicTT authored Aug 8, 2024
1 parent ad7da1c commit a4ac702
Show file tree
Hide file tree
Showing 14 changed files with 230 additions and 1 deletion.
21 changes: 21 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,27 @@ def TTIR_SoftmaxOp : TTIR_DPSOp<"softmax"> {
let hasVerifier = 1;
}

def TTIR_TransposeOp : TTIR_DPSOp<"transpose"> {
let summary = "Transpose op.";
let description = [{
Transpose tensor along two given dimensions.
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
SI32Attr:$dimension1,
SI32Attr:$dimension2,
TT_OperandConstraintArrayAttr:$operand_constraints);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}

// ANCHOR: adding_an_op_matmul_ttir
def TTIR_MatmulOp : TTIR_DPSOp<"matmul"> {
let summary = "Matrix multiply operation.";
Expand Down
21 changes: 21 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,27 @@ def TTNN_SoftmaxOp : TTNN_NamedDPSOp<"softmax"> {

let hasVerifier = 1;
}

def TTNN_TransposeOp : TTNN_NamedDPSOp<"transpose"> {
let summary = "Transpose op.";
let description = [{
Transpose tensor along two given dimensions.
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
SI32Attr:$dimension1,
SI32Attr:$dimension2);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}

// ANCHOR: adding_an_op_matmul_ttnn
def TTNN_MatmulOp : TTNN_NamedDPSOp<"matmul"> {
let arguments = (ins AnyRankedTensor:$a,
Expand Down
10 changes: 9 additions & 1 deletion include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ table SoftmaxOp {
dimension: int32;
}

table TransposeOp {
in: tt.target.TensorRef;
out: tt.target.TensorRef;
dimension1: int32;
dimension2: int32;
}

// ANCHOR: adding_an_op_matmul_fbs
table MatmulOp {
in0: tt.target.TensorRef;
Expand All @@ -73,7 +80,8 @@ union OpType {
EltwiseOp,
MatmulOp,
ReductionOp,
SoftmaxOp
SoftmaxOp,
TransposeOp
}

table Operation {
Expand Down
17 changes: 17 additions & 0 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,22 @@ class SoftmaxOpConversionPattern : public OpConversionPattern<ttir::SoftmaxOp> {
}
};

class TransposeOpConversionPattern
: public OpConversionPattern<ttir::TransposeOp> {
public:
using OpConversionPattern<ttir::TransposeOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ttir::TransposeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<ttnn::TransposeOp>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), adaptor.getOutput(), adaptor.getDimension1(),
adaptor.getDimension2());
return success();
}
};

} // namespace

// ANCHOR: adding_an_op_matmul_op_rewriter
Expand Down Expand Up @@ -145,6 +161,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
ElementwiseBinaryOpConversionPattern<ttir::ReluOp, ttnn::ReluOp>,
ReductionOpConversionPattern<ttir::SumOp, ttnn::SumOp>,
ReductionOpConversionPattern<ttir::MeanOp, ttnn::MeanOp>,
TransposeOpConversionPattern,
SoftmaxOpConversionPattern,
MatmulOpConversionPattern
>(typeConverter, ctx);
Expand Down
5 changes: 5 additions & 0 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
patterns.add<DefaultOpConversionPattern<ttnn::GreaterEqualOp>>(typeConverter,
ctx);

// Tensor manipulation ops
//
patterns.add<DefaultOpConversionPattern<ttnn::TransposeOp>>(typeConverter,
ctx);

// Matmul ops
//
patterns.add<DefaultOpConversionPattern<ttnn::MatmulOp>>(typeConverter, ctx);
Expand Down
34 changes: 34 additions & 0 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,40 @@ ::mlir::LogicalResult mlir::tt::ttir::SoftmaxOp::verify() {
return success();
}

::mlir::LogicalResult mlir::tt::ttir::TransposeOp::verify() {
::mlir::RankedTensorType inputType = getInput().getType();
::mlir::RankedTensorType outputType = getOutput().getType();
auto inputShape = inputType.getShape();
auto outputShape = outputType.getShape();
int32_t dim1 = getDimension1();
int32_t dim2 = getDimension2();
if (inputType.getRank() < 2) {
return emitOpError("Input must be at least a 2D tensor");
}
if (inputType.getRank() != outputType.getRank()) {
return emitOpError("Input must have the same rank as output");
}
if (dim1 >= inputType.getRank() || dim1 < -inputType.getRank()) {
return emitOpError(
"Dimension 1 attribute must be within the bounds of the input tensor");
}
if (dim2 >= inputType.getRank() || dim2 < -inputType.getRank()) {
return emitOpError(
"Dimension 2 attribute must be within the bounds of the input tensor");
}
if (dim1 < 0) {
dim1 += inputType.getRank();
}
if (dim2 < 0) {
dim2 += inputType.getRank();
}
if (outputShape[dim1] != inputShape[dim2] ||
outputShape[dim2] != inputShape[dim1]) {
return emitOpError("Input-output transpose dimension mismatch.");
}
return success();
}

// ANCHOR: adding_an_op_matmul_ttir_verify
::mlir::LogicalResult mlir::tt::ttir::MatmulOp::verify() {
::mlir::RankedTensorType inputAType = getA().getType();
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TTIR/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,7 @@ class TTIRLayout : public impl::TTIRLayoutBase<TTIRLayout> {
TTIRLayoutOperandsRewriter<ReluOp>, TTIRLayoutOperandsRewriter<SumOp>,
TTIRLayoutOperandsRewriter<MeanOp>,
TTIRLayoutOperandsRewriter<SoftmaxOp>,
TTIRLayoutOperandsRewriter<TransposeOp>,
TTIRLayoutOperandsRewriter<MatmulOp>, TTIRLayoutFuncReturnRewriter>(
&getContext());
FrozenRewritePatternSet patternSet(std::move(patterns));
Expand Down
34 changes: 34 additions & 0 deletions lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,40 @@ ::mlir::LogicalResult mlir::tt::ttnn::SoftmaxOp::verify() {
return success();
}

::mlir::LogicalResult mlir::tt::ttnn::TransposeOp::verify() {
::mlir::RankedTensorType inputType = getInput().getType();
::mlir::RankedTensorType outputType = getOutput().getType();
auto inputShape = inputType.getShape();
auto outputShape = outputType.getShape();
int32_t dim1 = getDimension1();
int32_t dim2 = getDimension2();
if (inputType.getRank() < 2) {
return emitOpError("Input must be at least a 2D tensor");
}
if (inputType.getRank() != outputType.getRank()) {
return emitOpError("Input must have the same rank as output");
}
if (dim1 >= inputType.getRank() || dim1 < -inputType.getRank()) {
return emitOpError(
"Dimension 1 attribute must be within the bounds of the input tensor");
}
if (dim2 >= inputType.getRank() || dim2 < -inputType.getRank()) {
return emitOpError(
"Dimension 2 attribute must be within the bounds of the input tensor");
}
if (dim1 < 0) {
dim1 += inputType.getRank();
}
if (dim2 < 0) {
dim2 += inputType.getRank();
}
if (outputShape[dim1] != inputShape[dim2] ||
outputShape[dim2] != inputShape[dim1]) {
return emitOpError("Input-output transpose dimension mismatch.");
}
return success();
}

// ANCHOR: adding_an_op_matmul_ttnn_verify
::mlir::LogicalResult mlir::tt::ttnn::MatmulOp::verify() {
::mlir::RankedTensorType inputAType = getA().getType();
Expand Down
18 changes: 18 additions & 0 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,20 @@ createReductionOp(FlatbufferObjectCache &cache, ReductionOp op) {
dim_arg, op.getKeepDim());
}

template <typename TransposeOp>
::flatbuffers::Offset<::tt::target::ttnn::TransposeOp>
createTransposeOp(FlatbufferObjectCache &cache, TransposeOp op) {
auto in =
cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput()));
auto out = cache.at<::tt::target::TensorRef>(
getOperandThroughDPSOps(op.getResult()));
int32_t dimension1 = op.getDimension1();
int32_t dimension2 = op.getDimension2();

return ::tt::target::ttnn::CreateTransposeOp(*cache.fbb, in, out, dimension1,
dimension2);
}

template <typename SoftmaxOp>
::flatbuffers::Offset<::tt::target::ttnn::SoftmaxOp>
createSoftmaxOp(FlatbufferObjectCache &cache, SoftmaxOp op) {
Expand Down Expand Up @@ -219,6 +233,10 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,
return createOperation(cache, createSoftmaxOp(cache, softmaxOp),
debugString);
}
if (auto transposeOp = dyn_cast<TransposeOp>(op); transposeOp) {
return createOperation(cache, createTransposeOp(cache, transposeOp),
debugString);
}

llvm_unreachable("unhandled op in emitTTNNOperation");
}
Expand Down
30 changes: 30 additions & 0 deletions runtime/lib/ttnn/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
//
// SPDX-License-Identifier: Apache-2.0

#include <cstddef>
#include <cstdint>
#include <list>
#include <optional>
#include <unordered_map>
Expand All @@ -18,6 +20,9 @@
// Including this in ttnn.h causes multiple definition linker error
// due to non-inlined function definitions
#include "ttnn/operations/unary.hpp"
#pragma clang diagnostic ignored "-Wsign-compare"
#pragma clang diagnostic ignored "-Wunused-variable"
#include "ttnn/operations/data_movement.hpp"
#pragma clang diagnostic pop

// It seems like `ttnn::to_layout` cannot be called inside of the
Expand Down Expand Up @@ -275,6 +280,28 @@ run(::tt::target::ttnn::SoftmaxOp const *op, ::ttnn::device::Device &device,
liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back());
}

static void
run(::tt::target::ttnn::TransposeOp const *op, ::ttnn::device::Device &device,
std::unordered_map<std::uint32_t, ::ttnn::Tensor *> &liveTensors,
std::list<::ttnn::Tensor> &tensorPool) {
::ttnn::Tensor &in = *liveTensors.at(op->in()->global_id());
int32_t dimension1 = op->dimension1();
int32_t dimension2 = op->dimension2();
auto input_rank = in.get_shape().rank();
std::vector<int> dimensionOrder(input_rank);
std::iota(dimensionOrder.begin(), dimensionOrder.end(), 0);
if (dimension1 < 0) {
dimension1 += input_rank;
}
if (dimension2 < 0) {
dimension2 += input_rank;
}
std::swap(dimensionOrder[dimension1], dimensionOrder[dimension2]);
tensorPool.push_back(
::ttnn::operations::data_movement::permute(in, dimensionOrder));
liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back());
}

// ANCHOR: adding_an_op_matmul_runtime
static void
run(::tt::target::ttnn::MatmulOp const *op, ::ttnn::Device &device,
Expand Down Expand Up @@ -320,6 +347,9 @@ run(::tt::target::ttnn::Operation const *op, ::ttnn::Device &device,
case ::tt::target::ttnn::OpType::SoftmaxOp: {
return run(op->type_as_SoftmaxOp(), device, liveTensors, tensorPool);
}
case ::tt::target::ttnn::OpType::TransposeOp: {
return run(op->type_as_TransposeOp(), device, liveTensors, tensorPool);
}
default:
throw std::runtime_error("Unsupported operation type");
}
Expand Down
10 changes: 10 additions & 0 deletions test/ttmlir/Dialect/TTNN/transpose/simple_transpose.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: ttmlir-opt --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile>
module attributes {tt.system_desc = #tt.system_desc<[{arch = <wormhole_b0>, grid = 8x8, l1_size = 1048576, num_dram_channels = 12, dram_channel_size = 1048576, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32}], [0], [<pcie|host_mmio>], [<0, 0, 0, 0>]>} {
func.func @forward(%arg0: tensor<64x128xbf16>) -> tensor<128x64xbf16> {
%0 = tensor.empty() : tensor<128x64xbf16>
// CHECK: %[[C:.*]] = "ttnn.transpose"[[C:.*]]
%1 = "ttir.transpose"(%arg0, %0) <{dimension1 = 0 : si32, dimension2 = 1 : si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>) -> tensor<128x64xbf16>
return %1 : tensor<128x64xbf16>
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: ttmlir-opt --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile>
module attributes {tt.system_desc = #tt.system_desc<[{arch = <wormhole_b0>, grid = 8x8, l1_size = 1048576, num_dram_channels = 12, dram_channel_size = 1048576, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32}], [0], [<pcie|host_mmio>], [<0, 0, 0, 0>]>} {
func.func @forward(%arg0: tensor<8x16xbf16>) -> tensor<16x8xbf16> {
%0 = tensor.empty() : tensor<16x8xbf16>
// CHECK: %[[C:.*]] = "ttnn.transpose"[[C:.*]]
%1 = "ttir.transpose"(%arg0, %0) <{dimension1 = 1 : si32, dimension2 = 0 : si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<8x16xbf16>, tensor<16x8xbf16>) -> tensor<16x8xbf16>
return %1 : tensor<16x8xbf16>
}
}
10 changes: 10 additions & 0 deletions test/ttmlir/Dialect/TTNN/transpose/simple_transpose_8x8.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: ttmlir-opt --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile>
module attributes {tt.system_desc = #tt.system_desc<[{arch = <wormhole_b0>, grid = 8x8, l1_size = 1048576, num_dram_channels = 12, dram_channel_size = 1048576, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32}], [0], [<pcie|host_mmio>], [<0, 0, 0, 0>]>} {
func.func @forward(%arg0: tensor<8x8xbf16>) -> tensor<8x8xbf16> {
%0 = tensor.empty() : tensor<8x8xbf16>
// CHECK: %[[C:.*]] = "ttnn.transpose"[[C:.*]]
%1 = "ttir.transpose"(%arg0, %0) <{dimension1 = 0 : si32, dimension2 = 1 : si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<8x8xbf16>, tensor<8x8xbf16>) -> tensor<8x8xbf16>
return %1 : tensor<8x8xbf16>
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: ttmlir-opt --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile>
module attributes {tt.system_desc = #tt.system_desc<[{arch = <wormhole_b0>, grid = 8x8, l1_size = 1048576, num_dram_channels = 12, dram_channel_size = 1048576, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32}], [0], [<pcie|host_mmio>], [<0, 0, 0, 0>]>} {
func.func @forward(%arg0: tensor<8x8xbf16>) -> tensor<8x8xbf16> {
%0 = tensor.empty() : tensor<8x8xbf16>
// CHECK: %[[C:.*]] = "ttnn.transpose"[[C:.*]]
%1 = "ttir.transpose"(%arg0, %0) <{dimension1 = -1 : si32, dimension2 = -2 : si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<8x8xbf16>, tensor<8x8xbf16>) -> tensor<8x8xbf16>
return %1 : tensor<8x8xbf16>
}
}

0 comments on commit a4ac702

Please sign in to comment.