Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for conv_transpose2d operation #1540

Merged
merged 1 commit into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,58 @@ def TTIR_Conv2dOp : TTIR_DPSOp<"conv2d"> {
let hasVerifier = 1;
}

def TTIR_ConvTranspose2dOp : TTIR_DPSOp<"conv_transpose2d"> {
let summary = "ConvTranspose2d operation.";
let description = [{
Applies a 2D transposed convolution operator over an input image composed of several input planes.

Inputs:
- `input` AnyRankedTensor: NHWC format (batch_size x height x width x channels)
- `weight` AnyRankedTensor: OIHW format (output_channels x input_channels x height x width)
- `bias` Optional<AnyRankedTensor>: (1 x 1 x 1 x output_channels)
- `output` AnyRankedTensor: NHWC format (batch_size x height x width x channels)

Attributes:
- `stride` (i32 | array<i32>): Controls the stride for the cross-correlation.
- `padding` (i32 | array<i32>): Controls the amount of implicit zero padding on both sides for dilation * (kernel_size - 1) - padding number of points.
- `output_padding` (i32 | array<i32>): Controls the additional size added to one side of the output shape.
- `dilation` (i32 | array<i32>): Controls the spacing between the kernel points
- `groups` i32: Controls the connections between inputs and outputs. Must be divisible by input and output channels.

Example:
%input = tensor.empty() : () -> tensor<256x256x3x3xbf16>
%weight = tensor.empty() : () -> tensor<256x256x3x3xbf16>
%bias = tensor.empty() : () -> tensor<1x1x1x256xbf16>
%output = tensor.empty() : () -> tensor<1x10x10x256xbf16>
%0 = "ttir.conv_transpose2d"(%input, %weight, %bias, %output)
<{
stride = = array<i32: 1, 1>,
padding = 0: i32,
output_padding = 0: i32,
dilation = 1: i32,
groups = 1: i32
> : (tensor<1x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<1x10x10x256xbf16>) -> tensor<1x10x10x256xbf16>
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$weight,
Optional<AnyRankedTensor>:$bias,
AnyRankedTensor:$output,
AnyAttrOf<[I32Attr, DenseI32ArrayAttr]>:$stride,
AnyAttrOf<[I32Attr, DenseI32ArrayAttr]>:$padding,
AnyAttrOf<[I32Attr, DenseI32ArrayAttr]>:$output_padding,
AnyAttrOf<[I32Attr, DenseI32ArrayAttr]>:$dilation,
I32Attr:$groups);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}

def TTIR_ConvolutionOp : TTIR_DPSOp<"convolution"> {
let summary = "Generalized convolution op.";
let description = [{
Expand Down
51 changes: 51 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -861,6 +861,57 @@ def TTNN_Conv2dOp : TTNN_NamedDPSOp<"conv2d"> {
let hasVerifier = 1;
}

def TTNN_ConvTranspose2dOp : TTNN_NamedDPSOp<"conv_transpose2d"> {
let summary = "ConvTranspose2d operation.";
let description = [{
Applies a 2D transposed convolution operator over an input image composed of several input planes.

Inputs:
- `input` AnyRankedTensor: NHWC format (batch_size x height x width x channels)
- `weight` AnyRankedTensor: OIHW format (output_channels x input_channels x height x width)
- `bias` Optional<AnyRankedTensor>: (1 x 1 x 1 x output_channels)
- `output` AnyRankedTensor: (1 x 1 x (batch_size * height * width) x channels)

Attributes:
- `in_channels` i32: The number of input channels.
- `out_channels` i32: The number of output channels.
- `batch_size` i32: The batch size.
- `input_height` i32: The input height.
- `input_width` i32: The input width.
- `kernel_size` array<i32>: The kernel size.
- `stride` array<i32>: Controls the stride for the cross-correlation.
- `padding` array<i32>: Controls the amount of implicit zero padding on both sides for dilation * (kernel_size - 1) - padding number of points.
- `output_padding` array<i32>: Controls the additional size added to one side of the output shape.
- `dilation` array<i32>: Controls the spacing between the kernel points
- `groups` i32: Controls the connections between inputs and outputs. Must be divisible by input and output channels.
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$weight,
Optional<AnyRankedTensor>:$bias,
AnyRankedTensor:$output,
TT_Device:$device,
I32Attr:$in_channels,
I32Attr:$out_channels,
I32Attr:$batch_size,
I32Attr:$input_height,
I32Attr:$input_width,
DenseI32ArrayAttr:$kernel_size,
DenseI32ArrayAttr:$stride,
DenseI32ArrayAttr:$padding,
DenseI32ArrayAttr:$output_padding,
DenseI32ArrayAttr:$dilation,
I32Attr:$groups);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}

def TTNN_MaxPool2dOp : TTNN_NamedDPSOp<"max_pool2d"> {
let summary = "Applies a 2D max pooling over an input signal composed of several input planes.";
let description = [{
Expand Down
20 changes: 20 additions & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,25 @@ table Conv2dOp {
groups: uint32;
}

table ConvTranspose2dOp {
input: tt.target.TensorRef;
weight: tt.target.TensorRef;
bias: tt.target.TensorRef;
out: tt.target.TensorRef;
device: tt.target.DeviceRef;
in_channels: uint32;
out_channels: uint32;
batch_size: uint32;
input_height: uint32;
input_width: uint32;
kernel_size: [int32];
stride: [int32];
padding: [int32];
output_padding: [int32];
dilation: [int32];
groups: uint32;
}

table MaxPool2dOp {
in: tt.target.TensorRef;
out: tt.target.TensorRef;
Expand Down Expand Up @@ -346,6 +365,7 @@ union OpType {
SoftmaxOp,
TransposeOp,
Conv2dOp,
ConvTranspose2dOp,
ConcatOp,
ReshapeOp,
SliceOp,
Expand Down
102 changes: 102 additions & 0 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "ttmlir/Dialect/TTNN/Types/Types.h"
#include "ttmlir/Dialect/TTNN/Utils/TransformUtils.h"
#include "ttmlir/Dialect/TTNN/Utils/Utils.h"
#include "ttmlir/Utils.h"

#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Attributes.h"
Expand All @@ -26,6 +27,7 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/ErrorHandling.h"
#include <llvm/Support/LogicalResult.h>

#include <cstdint>

Expand Down Expand Up @@ -883,6 +885,105 @@ class Conv2dOpConversionPattern : public OpConversionPattern<ttir::Conv2dOp> {
}
};

class ConvTranspose2dOpConversionPattern
jserbedzijaTT marked this conversation as resolved.
Show resolved Hide resolved
: public OpConversionPattern<ttir::ConvTranspose2dOp> {
public:
using OpConversionPattern<ttir::ConvTranspose2dOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ttir::ConvTranspose2dOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto device = ::ttnn::utils::getOrInsertDevice(rewriter, op);

auto inputTy = mlir::cast<RankedTensorType>(adaptor.getInput().getType());
auto kernelTy = mlir::cast<RankedTensorType>(adaptor.getWeight().getType());
auto outputTy = mlir::cast<RankedTensorType>(adaptor.getOutput().getType());

std::function<int64_t(const RankedTensorType &, int)> getLastDim =
[](const RankedTensorType &ty, int offset = 1) {
return ty.getShape()[ty.getRank() - offset];
};

auto inChannelsAttr = rewriter.getI32IntegerAttr(getLastDim(inputTy, 1));
auto outChannelsAttr = rewriter.getI32IntegerAttr(getLastDim(outputTy, 1));
auto batchSizeAttr = rewriter.getI32IntegerAttr(getLastDim(inputTy, 4));
auto inputHeightAttr = rewriter.getI32IntegerAttr(getLastDim(inputTy, 3));
auto inputWidthAttr = rewriter.getI32IntegerAttr(getLastDim(inputTy, 2));

auto kernelSizeAttr = rewriter.getDenseI32ArrayAttr(
{static_cast<int32_t>(getLastDim(kernelTy, 2)),
static_cast<int32_t>(getLastDim(kernelTy, 1))});

auto strideAttr = attrToDenseI32ArrayAttr(adaptor.getStride(), rewriter);
if (auto error = strideAttr.takeError()) {
return LogicalResult::failure();
}

auto paddingAttr = attrToDenseI32ArrayAttr(adaptor.getPadding(), rewriter);
if (auto error = paddingAttr.takeError()) {
return LogicalResult::failure();
}

auto outputPaddingAttr =
attrToDenseI32ArrayAttr(adaptor.getOutputPadding(), rewriter);
if (auto error = outputPaddingAttr.takeError()) {
return LogicalResult::failure();
}

auto dilationAttr =
attrToDenseI32ArrayAttr(adaptor.getDilation(), rewriter);
if (auto error = dilationAttr.takeError()) {
return LogicalResult::failure();
}

auto groupsAttr = rewriter.getI32IntegerAttr(adaptor.getGroups());

// Transposed convolution in ttnn returns a tensor in a flattened shape
// (1 x 1 x N * H * W x C)
llvm::ArrayRef<std::int64_t> output_shape = outputTy.getShape();
llvm::SmallVector<std::int64_t, 4> flattenedOutputShape = {
1, 1, output_shape[0] * output_shape[1] * output_shape[2],
output_shape[3]};
outputTy = mlir::cast<RankedTensorType>(getTypeConverter()->convertType(
outputTy.cloneWith(flattenedOutputShape, outputTy.getElementType())));

// Using a tensor::EmptyOp so that the rewriter for EmptyOp can handle the
// attribute determination
auto convDPSOutput = rewriter.replaceOpWithNewOp<tensor::EmptyOp>(
adaptor.getOutput().getDefiningOp(), flattenedOutputShape,
outputTy.getElementType());

// Must set the type to the output type to maintain the layout attributes
convDPSOutput.getResult().setType(outputTy);

ttnn::ConvTranspose2dOp new_conv = rewriter.create<ttnn::ConvTranspose2dOp>(
op.getLoc(), outputTy, adaptor.getInput(), adaptor.getWeight(),
adaptor.getBias(), convDPSOutput, device, inChannelsAttr,
outChannelsAttr, batchSizeAttr, inputHeightAttr, inputWidthAttr,
kernelSizeAttr, *strideAttr, *paddingAttr, *outputPaddingAttr,
*dilationAttr, groupsAttr);

// Restore the normal shape (N x H x W x C)
Value output =
ttir_to_ttnn::utils::generateReshape(new_conv, output_shape, rewriter);

rewriter.replaceOp(op, output);
return success();
}

private:
llvm::Expected<DenseI32ArrayAttr>
attrToDenseI32ArrayAttr(mlir::Attribute attr,
ConversionPatternRewriter &rewriter) const {
auto pair = ttmlir::utils::getPairOfInteger<int32_t>(attr);
if (auto error = pair.takeError()) {
return error;
}

return rewriter.getDenseI32ArrayAttr({pair->first, pair->second});
}
};

class MaxPool2dOpConversionPattern
: public OpConversionPattern<ttir::MaxPool2dOp> {
public:
Expand Down Expand Up @@ -1223,6 +1324,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
LinearOpConversionPattern,
MatmulOpConversionPattern,
Conv2dOpConversionPattern,
ConvTranspose2dOpConversionPattern,
MaxPool2dOpConversionPattern,
SubtractOpConversionPattern,
MeshShardOpConversionPattern,
Expand Down
2 changes: 2 additions & 0 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,8 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
// Conv ops
//
patterns.add<DefaultOpConversionPattern<ttnn::Conv2dOp>>(typeConverter, ctx);
patterns.add<DefaultOpConversionPattern<ttnn::ConvTranspose2dOp>>(
typeConverter, ctx);
patterns.add<DefaultOpConversionPattern<ttnn::MaxPool2dOp>>(typeConverter,
ctx);

Expand Down
Loading
Loading