Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue #798 - Constant OP conversion doesn't convert scalar values #802

Merged
merged 1 commit into from
Oct 23, 2024

Conversation

mrakitaTT
Copy link
Contributor

@mrakitaTT mrakitaTT commented Sep 23, 2024

auto outputType = mlir::cast<RankedTensorType>(
getTypeConverter()->convertType(srcOp.getResult().getType()));

mlir::ElementsAttr valueAttr = srcOp.getValue();
if (valueAttr.getShapedType().getShape().empty()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be an assertion? What happens if it's not a splat?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's not splat then we will assert in TTIR->TTNN conversion for constant op (see here). I didn't want to assert here because it is not limitation of StableHLO->TTIR conversion and will be common for all third party dialects, it is TTNN limitation.

@mrakitaTT
Copy link
Contributor Author

Reviving this PR. It turned out that couple of days after I've sent it @LPanosTT had also sent a PR with a similar fix #836. I've now merged the best of both fixes and included tests.

@mrakitaTT mrakitaTT merged commit 80b295a into main Oct 23, 2024
13 checks passed
@ddilbazTT
Copy link
Contributor

Hey Marko! I am writing tt-xla tests for gather and I think I am hitting a constant op related failure. I would appreciate if you could take a look: tenstorrent/tt-xla#40

@AleksKnezevic
Copy link
Contributor

@ddilbazTT, did you update tt-xla to include this PR?

@ddilbazTT
Copy link
Contributor

@AleksKnezevic I rebased tt-mlir today and updated tt-mlir commit hash in CMakeLists.txt

@mmanzoorTT
Copy link
Contributor

@AleksKnezevic @mrakitaTT @ddilbazTT I had a similar issue with the constant for my experimentation. Currently, we are not handling boolean types for stablehlo.constant conversion. I have implemented a fix for handling boolean tensors. However Defne's graph contains scalar constant; I'll handle this case as well and will submit PR.

@mrakitaTT mrakitaTT mentioned this pull request Dec 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[StableHLO] constant OP conversion doesn't convert scalar values Constant Op fails for MINIST tests
7 participants