Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for conv_transpose2d operation #1540

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

jserbedzijaTT
Copy link
Contributor

closes (#1084)

@nsmithtt
Copy link
Contributor

nsmithtt commented Dec 9, 2024

Adding @LPanosTT

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Clang-Tidy found issue(s) with the introduced code (1/1)

lib/Dialect/TTNN/Transforms/TTNNLayout.cpp Outdated Show resolved Hide resolved
@LPanosTT
Copy link
Contributor

LPanosTT commented Dec 9, 2024

Hey thanks for adding this. I have something to say about this op though. It seems as though some frontends reverse the order of the data in the kernel window for this op, and some do not. I.e PyTorch does (and thus TTNN does) and JAX does not. You will see that ttir.convolution has a window_reversal boolean attr as well. In order to model the cases in all frontends we need this attribute for conv_transpose2d in ttnn. Or for us to add ttir.reverse so we can consteval the window reversal away.

There is an issue to add window_reversal to ttnn: tenstorrent/tt-metal#15342

@LPanosTT
Copy link
Contributor

LPanosTT commented Dec 9, 2024

Also if you could add a pattern to lower ttir.convolution to ttir.conv_transpose2d that would be great. Check out the stablehlo spec for convolution, which ttir.convolution is meant to mimic to see how you can tell if a given convolution is a transposed convolution or not.

@jserbedzijaTT jserbedzijaTT force-pushed the jserbedzija/add_conv_transpose2d_operation branch from b000588 to 7b36217 Compare December 20, 2024 12:58
@jserbedzijaTT jserbedzijaTT force-pushed the jserbedzija/add_conv_transpose2d_operation branch 2 times, most recently from 683fb3b to 4ddde58 Compare December 23, 2024 11:06
@jserbedzijaTT
Copy link
Contributor Author

jserbedzijaTT commented Dec 24, 2024

Also if you could add a pattern to lower ttir.convolution to ttir.conv_transpose2d that would be great. Check out the stablehlo spec for convolution, which ttir.convolution is meant to mimic to see how you can tell if a given convolution is a transposed convolution or not.

I will merge this pr as is but I have opened an issue to track the things you mentioned: #1662

@jserbedzijaTT jserbedzijaTT force-pushed the jserbedzija/add_conv_transpose2d_operation branch 2 times, most recently from 62b9199 to 2837812 Compare December 24, 2024 10:32
@mtopalovicTT
Copy link
Contributor

Copy link
Contributor

@sdjordjevicTT sdjordjevicTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great change Joco, thanks, couple of comments inline.

lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp Outdated Show resolved Hide resolved
lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp Outdated Show resolved Hide resolved
lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp Outdated Show resolved Hide resolved
Comment on lines 925 to 931
// Using a tensor::EmptyOp so that the rewriter for EmptyOp can handle the
// attribute determination
auto convDPSOutput = rewriter.replaceOpWithNewOp<tensor::EmptyOp>(
adaptor.getOutput().getDefiningOp(), flattenedOutputShape,
outputTy.getElementType());

// Must set the type to the output type to maintain the layout attributes
convDPSOutput.getResult().setType(outputTy);

ttnn::ConvTranspose2dOp new_conv = rewriter.create<ttnn::ConvTranspose2dOp>(
op.getLoc(), outputTy, adaptor.getInput(), adaptor.getWeight(),
adaptor.getBias(), convDPSOutput, device, inChannels, outChannels,
batchSize, inputHeight, inputWidth, kernelSize, stride, padding,
outputPadding, dilation, groups);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sync with @azecevicTT, he had in mind an API for creating a DPS op, not sure if applicable here. :)

lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp Show resolved Hide resolved
lib/Dialect/TTIR/IR/TTIROps.cpp Outdated Show resolved Hide resolved
lib/Dialect/TTNN/IR/TTNNOps.cpp Outdated Show resolved Hide resolved
lib/Dialect/TTNN/Transforms/TTNNLayout.cpp Outdated Show resolved Hide resolved
lib/Target/TTNN/TTNNToFlatbuffer.cpp Outdated Show resolved Hide resolved
@jserbedzijaTT jserbedzijaTT force-pushed the jserbedzija/add_conv_transpose2d_operation branch from 2837812 to b219b1f Compare December 27, 2024 10:48
@jserbedzijaTT jserbedzijaTT force-pushed the jserbedzija/add_conv_transpose2d_operation branch from b219b1f to 1a7caeb Compare December 27, 2024 13:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants