-
Notifications
You must be signed in to change notification settings - Fork 511
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support max_unpool2d lowering #3733
base: main
Are you sure you want to change the base?
Conversation
09326ef
to
13abfb6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @jinchen62, the PR says that it support MaxUnpool2d
lowering, but I see that it has been removed from some of the files. Also, to verify this PR, please add some e2e tests so that we can test the support added for the op. Please correct me, if I have quoted something wrong.
@vivekkhandelwal1 It got lowered to torch.aten.max_unpool3d, so it could be lowered to linalg. I added the lit tests for both onnx->torch and torch->linalg. So basically there is no big difference between 2d and 3d, it could be generalized, but I couldn't rename the op because of the pytorch upstream, I attached the related links in commit msg. |
Do you mean to say that you're lowering the 4-d input case of |
@vivekkhandelwal1 Yes, 2D and 3D max_unpool can be generalized as one op. |
That's fine but what you've done in this PR is not correct. You have added the support to handle 2d pooling case in the MaxUnpool3d op which is wrong. Ideally, you should've added the lowering for MaxUnpool2d op, and if there exists an issue related to PyTorch with that, then you can define a new op in https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/IR/TorchOps.td (before this, we have to be sure what's the exact issue with using the TorchMaxUnpool2d op, and can that be fixed in upstream PyTorch), say |
@vivekkhandelwal1 For sure I could add a separate lowering for 2D, but that would be most of duplicate codes. |
No, you should not do it in a way that the code is duplicated. Instead take all the common code in a utility or templatize the lowering so that the code can be re-used. |
@vivekkhandelwal1 Using 3D lowering is also because torch.aten.max_unpool2d misses |
80c2426
to
0eeff42
Compare
I think this is actually an issue with the PyTorch definition of the |
6e8caa0
to
828f1b6
Compare
need to merge pytorch/pytorch#138805 |
Support torch->linalg lowering for max_unpool2d.
Fixes nod-ai/SHARK-ModelDev#718