-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for conv_transpose2d operation #1540
base: main
Are you sure you want to change the base?
Conversation
2b4f3e6
to
b000588
Compare
Adding @LPanosTT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clang-Tidy
found issue(s) with the introduced code (1/1)
Hey thanks for adding this. I have something to say about this op though. It seems as though some frontends reverse the order of the data in the kernel window for this op, and some do not. I.e PyTorch does (and thus TTNN does) and JAX does not. You will see that There is an issue to add |
Also if you could add a pattern to lower |
b000588
to
7b36217
Compare
683fb3b
to
4ddde58
Compare
I will merge this pr as is but I have opened an issue to track the things you mentioned: #1662 |
62b9199
to
2837812
Compare
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great change Joco, thanks, couple of comments inline.
// Using a tensor::EmptyOp so that the rewriter for EmptyOp can handle the | ||
// attribute determination | ||
auto convDPSOutput = rewriter.replaceOpWithNewOp<tensor::EmptyOp>( | ||
adaptor.getOutput().getDefiningOp(), flattenedOutputShape, | ||
outputTy.getElementType()); | ||
|
||
// Must set the type to the output type to maintain the layout attributes | ||
convDPSOutput.getResult().setType(outputTy); | ||
|
||
ttnn::ConvTranspose2dOp new_conv = rewriter.create<ttnn::ConvTranspose2dOp>( | ||
op.getLoc(), outputTy, adaptor.getInput(), adaptor.getWeight(), | ||
adaptor.getBias(), convDPSOutput, device, inChannels, outChannels, | ||
batchSize, inputHeight, inputWidth, kernelSize, stride, padding, | ||
outputPadding, dilation, groups); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please sync with @azecevicTT, he had in mind an API for creating a DPS op, not sure if applicable here. :)
2837812
to
b219b1f
Compare
b219b1f
to
1a7caeb
Compare
closes (#1084)